<a href="https://colab.research.google.com/github/ruskstoic/Chopin_Music_Gen/blob/main/Chopin_Music_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Import

In [None]:
# --- Install dependencies ---
!apt-get -y install fluidsynth
!pip install midi2audio mido
!pip install pretty_midi miditok torch torchaudio tqdm
!pip install pyfluidsynth

# --- Imports ---
import mido
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from collections import deque
from midi2audio import FluidSynth
from IPython.display import Audio

## Mount Google Drive

In [None]:
## Mount Google Drive
from google.colab import drive

drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


### Play and Visualize MIDI File

In [None]:
# Load the MIDI file
midi_path = "/content/gdrive/MyDrive/Chopin Music Transformer/nocturne_op_9_1_2979_r_(nc)smythe.mid"
mid = mido.MidiFile(midi_path)

# Extract Note and Time Data
note_data = []
time_data = []

current_time = 0
for msg in mid:
    if not msg.is_meta and msg.type in ['note_on', 'note_off']:
        note_data.append(msg.note)
        time_data.append(current_time)
    current_time += msg.time

# Static Plot of MIDI Events
plt.figure(figsize=(12, 5))
plt.plot(time_data, note_data, 'bo-', lw=1)
plt.title('MIDI Note Events Over Time')
plt.xlabel('Time (seconds)')
plt.ylabel('Note Number')
plt.grid(True)
plt.show()

# Convert MIDI to Louder WAV
fs = FluidSynth()
fs.midi_to_audio(midi_path, "output.wav")

# Play Audio
Audio("output.wav")

## Parse and Load MIDI File Function

In [None]:
### Parse only 1 MIDI File

import pretty_midi
import os
from glob import glob

# Load MIDI files
midi_folder = '/content/gdrive/MyDrive/Chopin Music Transformer'  # Put your MIDIs here
midi_files = glob(os.path.join(midi_folder, '*.mid'))

def extract_notes_from_1_file(midi_path):
    midi = pretty_midi.PrettyMIDI(midi_path)
    notes = []
    for instrument in midi.instruments:
        if instrument.is_drum:
            continue
        for note in instrument.notes:
            notes.append((note.start, note.end, note.pitch, note.velocity))
    return sorted(notes, key=lambda x: x[0])  # sort by start time



## Create Token Vocabulary for All

In [None]:
# Initialize Hyperparameters
max_shift = 501 # Set the highest possible timeshift duration

# Create Token Vocabulary for All Possible Notes and Timesteps
token_vocab = []
for i in range(21, 109):
    token_vocab.append(f"note_on_{i}")
    token_vocab.append(f"note_off_{i}")
for i in range(0, max_shift):
    token_vocab.append(f"time_shift_{i}")
token_vocab.append("song_end")
token2idx = {token: idx for idx, token in enumerate(token_vocab)}
idx2token = {idx: token for token, idx in token2idx.items()}

print('Size of Token Vocabulary:', len(token_vocab))

# Encode Notes from MIDI
def notes_to_encoded(notes, sequence_length=128, time_division=0.05, max_shift_val=max_shift-1, seed=False):
    events = []
    for start, end, pitch, _ in notes:
        events.append(('note_on', start, pitch))
        events.append(('note_off', end, pitch))

    # Sort all events by time
    events.sort(key=lambda x: x[1])

    encoded_sequence = []
    prev_time = 0.0
    active_notes_tracker = {} # To keep track of notes that are 'on'

    for event_type, event_time, pitch in events:
      if seed and len(encoded_sequence) >= sequence_length: # only for generating seed; for limiting seed sequence length
            break

      shift = round((event_time - prev_time) / time_division)
      shift = max(0, shift)
      shift = min(shift, max_shift_val) # Cap the shift to prevent out-of-vocab issues

      encoded_sequence.append(f"time_shift_{shift}")
      encoded_sequence.append(f"{event_type}_{pitch}")

      # Update active notes tracker (optional, but good for robustness if you want to ensure notes are off)
      if event_type == 'note_on':
          active_notes_tracker[pitch] = True
      elif event_type == 'note_off':
          if pitch in active_notes_tracker:
              del active_notes_tracker[pitch]

      prev_time = event_time

    # After processing all notes from the file, append cleanup and song_end tokens
    # Ensure any remaining active notes are turned off
    for pitch in active_notes_tracker:
        encoded_sequence.append(f"time_shift_0") # or 1, a tiny pause
        encoded_sequence.append(f"note_off_{pitch}")

    # Add a final time_shift (silence) and the song_end token
    encoded_sequence.append(f"time_shift_5") # Represents a 0.25 second pause
    encoded_sequence.append("song_end")

    return encoded_sequence

# Extract Notes and Convert to Tokens by MIDI File
all_encoded_tokens = []
for i, midi_file_path in enumerate(midi_files):
  notes_for_file = extract_notes_from_1_file(midi_file_path)
  if notes_for_file:
    encoded_sequence = notes_to_encoded(notes_for_file, time_division=0.05, max_shift_val=max_shift-1)
    tokened_sequence = [token2idx[tok] for tok in file_tokens]
    all_encoded_tokens.extend(encoded_file_tokens)
  print('Total MIDI Files Extracted:', i)

print('Length of Tokened Sequence', len(tokened_sequence))

Length of Encoded 128388


## Dataset and DataLoader

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

class MIDIDataset(Dataset):
    def __init__(self, token_list, seq_len=128):
        self.sequences = []
        for i in range(len(token_list) - seq_len):
            self.sequences.append(token_list[i:i+seq_len+1])

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        return torch.tensor(seq[:-1]), torch.tensor(seq[1:]) # x, y

dataset = MIDIDataset(tokened_sequence)
loader = DataLoader(dataset, batch_size=128, shuffle=True)


## Lightweight Transformer Model

In [None]:
import torch
import torch.nn as nn

class MusicTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = nn.Parameter(torch.randn(1, 1000, d_model))  # max 1000 tokens
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.output = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x) + self.pos_encoder[:, :x.size(1)]
        x = self.transformer(x, is_causal=True)
        return self.output(x)

model = MusicTransformer(vocab_size=len(token2idx)).to("cuda" if torch.cuda.is_available() else "cpu")




## Training Loop

In [None]:
### Training Loop that Saves Best Model

import torch.optim as optim

device = "cuda" if torch.cuda.is_available() else "cpu"
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

best_loss = float('inf')  # Initialize with very high loss

for epoch in range(5):  # Increase later
    total_loss = 0
    model.train()

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = criterion(out.view(-1, out.size(-1)), y.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch}: Loss {avg_loss:.4f}")

    # Save model if this is the best so far
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.state_dict(), "/content/gdrive/MyDrive/Chopin Music Transformer/best_model.pth")
        print(f"✅ Saved new best model at epoch {epoch} with loss {avg_loss:.4f}")


Epoch 0: Loss 2.7113
✅ Saved new best model at epoch 0 with loss 2.7113
Epoch 1: Loss 2.4327
✅ Saved new best model at epoch 1 with loss 2.4327
Epoch 2: Loss 2.4171
✅ Saved new best model at epoch 2 with loss 2.4171
Epoch 3: Loss 2.4112
✅ Saved new best model at epoch 3 with loss 2.4112
Epoch 4: Loss 2.4083
✅ Saved new best model at epoch 4 with loss 2.4083


## Resume Training

In [None]:
# Resume Training Loop

best_loss = float('inf')  # Initialize with very high loss

for epoch in range(5, 10):  # Increase later
    total_loss = 0
    model.train()

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = criterion(out.view(-1, out.size(-1)), y.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch}: Loss {avg_loss:.4f}")

    # Save model if this is the best so far
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.state_dict(), "/content/gdrive/MyDrive/Chopin Music Transformer/best_model1.pth")
        print(f"✅ Saved new best model at epoch {epoch} with loss {avg_loss:.4f}")

Epoch 5: Loss 2.4069
✅ Saved new best model at epoch 5 with loss 2.4069


## Generate Seed

In [None]:
import pretty_midi

# Generate seed sequence from specific MIDI file
def generate_seed_from_specific_midi(midi_file, sequence_length, tokenizer):
    notes = extract_notes_from_1_file(midi_file)
    encoded = notes_to_encoded(notes, sequence_length=sequence_length, seed=True)
    tokened = [tokenizer[tok] for tok in encoded if tok in token2idx]
    return tokened

# Convert tokens to MIDI
def tokens_to_midi(tokens, time_division=0.05):
    midi = pretty_midi.PrettyMIDI()
    piano = pretty_midi.Instrument(program=0)
    current_time = 0.0
    note_on_dict = {}

    for token in tokens:
        if token.startswith("time_shift"):
            shift = int(token.split("_")[-1])
            current_time += shift * time_division

        elif token.startswith("note_on"):
            pitch = int(token.split("_")[-1])
            # Store the time when the note starts
            note_on_dict[pitch] = current_time

        elif token.startswith("note_off"):
            pitch = int(token.split("_")[-1])
            if pitch in note_on_dict:
                start_time = note_on_dict.pop(pitch)
                end_time = current_time
                if end_time > start_time:  # Sanity check
                    note = pretty_midi.Note(
                        velocity=80, pitch=pitch, start=start_time, end=end_time
                    )
                    piano.notes.append(note)

    midi.instruments.append(piano)
    return midi

specific_midi = '/content/gdrive/MyDrive/Chopin Music Transformer/test/2_Eb_op9no2_inoue.mid'
seed = generate_seed_from_specific_midi(specific_midi, 128, token2idx)
seed_tokens = [idx2token[i] for i in seed]

seed_midi = tokens_to_midi(seed_tokens)
seed_midi.write("/content/gdrive/MyDrive/Chopin Music Transformer/output/seed.mid")
print('Seed Tokens', seed_tokens)

# Convert Seed MIDI to Louder WAV and Play
fs = FluidSynth()
fs.midi_to_audio('/content/gdrive/MyDrive/Chopin Music Transformer/output/seed.mid', "/content/gdrive/MyDrive/Chopin Music Transformer/output/seed_output.wav")
Audio("/content/gdrive/MyDrive/Chopin Music Transformer/output/seed_output.wav")

Seed Tokens ['time_shift_55', 'note_on_70', 'time_shift_15', 'note_off_70', 'time_shift_1', 'note_on_79', 'time_shift_0', 'note_on_39', 'time_shift_7', 'note_off_39', 'time_shift_7', 'note_on_63', 'time_shift_0', 'note_on_55', 'time_shift_10', 'note_off_63', 'time_shift_0', 'note_off_55', 'time_shift_0', 'note_on_67', 'time_shift_0', 'note_on_63', 'time_shift_0', 'note_on_58', 'time_shift_8', 'note_off_67', 'time_shift_0', 'note_off_63', 'time_shift_0', 'note_off_58', 'time_shift_1', 'note_on_51', 'time_shift_4', 'note_off_51', 'time_shift_4', 'note_off_79', 'time_shift_0', 'note_on_77', 'time_shift_0', 'note_on_62', 'time_shift_0', 'note_on_56', 'time_shift_10', 'note_off_77', 'time_shift_0', 'note_off_62', 'time_shift_0', 'note_off_56', 'time_shift_0', 'note_on_79', 'time_shift_0', 'note_on_68', 'time_shift_0', 'note_on_62', 'time_shift_0', 'note_on_59', 'time_shift_11', 'note_off_68', 'time_shift_0', 'note_off_62', 'time_shift_0', 'note_off_59', 'time_shift_1', 'note_off_79', 'time_

## Generate Output and Play

In [None]:
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"
model.load_state_dict(torch.load("/content/gdrive/MyDrive/Chopin Music Transformer/best_model1.pth"))
model.to(device)
model.eval()

def generate(model, seed_seq, length=500, temperature=1.0, top_k=10):
    model.eval()
    generated = seed_seq[:]
    input_seq = torch.tensor(seed_seq[-128:], dtype=torch.long).unsqueeze(0).to(device)

    for _ in range(length):
        with torch.no_grad():
            out = model(input_seq)
            logits = out[0, -1] / temperature  # Add temperature scaling
            probs = F.softmax(logits, dim=-1)

            # Penalize time_shift_0
            penalty_mask = torch.zeros_like(logits)
            penalty_mask[token2idx['time_shift_0']] = -1.0  # reduce its logit
            logits += penalty_mask

            # Top-k sampling
            top_k_probs, top_k_indices = torch.topk(probs, top_k)
            next_token = top_k_indices[torch.multinomial(top_k_probs, 1)].item()

            generated.append(next_token)
            input_seq = torch.tensor(generated[-128:], dtype=torch.long).unsqueeze(0).to(device)

    return generated


gen_ids = generate(model, seed)
gen_tokens = [idx2token[i] for i in gen_ids]
print(gen_tokens[:30])

# Convert Generated MIDI to Louder WAV and Play
gen_midi = tokens_to_midi(gen_tokens)
gen_midi.write("/content/gdrive/MyDrive/Chopin Music Transformer/output/generated.mid")

fs = FluidSynth()
fs.midi_to_audio("/content/gdrive/MyDrive/Chopin Music Transformer/output/generated.mid", "/content/gdrive/MyDrive/Chopin Music Transformer/output/gen_output.wav")
Audio("/content/gdrive/MyDrive/Chopin Music Transformer/output/gen_output.wav")

## Plot Piano Roll

In [None]:
import librosa.display
midi_path = '/content/gdrive/MyDrive/Chopin Music Transformer/output/generated.mid'
mid = pretty_midi.PrettyMIDI(midi_path)

piano_roll = mid.get_piano_roll(fs=100)
librosa.display.specshow(piano_roll, x_axis="time", y_axis="cqt_note")

## Plot Chromagram

In [None]:
import matplotlib.pyplot as plt
import librosa

file_path = '/content/gdrive/MyDrive/Chopin Music Transformer/output/gen_output.wav'
y, sr = librosa.load(file_path, sr=22050)

# Extract Chromagram
chroma = librosa.feature.chroma_stft(y=y, sr=sr)
plt.figure(figsize=(10, 4))
librosa.display.specshow(chroma, sr=sr, x_axis='time', y_axis='chroma')
plt.colorbar()
plt.title("Chromagram")
plt.show()