### Data extraction and feature extraction 

In [1]:
import pretty_midi
import pandas as pd
from pathlib import Path
from collections import defaultdict

In [19]:
def extract_note_features_with_instrument(midi_file):
    try:
        pm = pretty_midi.PrettyMIDI(midi_file)
        notes = defaultdict(list)
        
        if not pm.instruments:
            raise ValueError("No instruments found in MIDI file.")

        for instrument in pm.instruments:
            program_num = instrument.program if not instrument.is_drum else 127  # Assign drum as 128，now change to 127
            #program_num = instrument.program  # <- Extract instrument ID
            for note in instrument.notes:
                notes["pitch"].append(note.pitch)
                notes["velocity"].append(note.velocity)
                notes["note_name"].append(pretty_midi.note_number_to_name(note.pitch))  # e.g., 'C#4'
                notes["octave"].append(note.pitch // 12 - 1)  # Convert MIDI pitch to octave number
                notes["start"].append(note.start)
                notes["end"].append(note.end)
                notes["duration"].append(note.end - note.start)
                notes["instrument"].append(program_num)  # <- Add this line

        return pd.DataFrame(notes)
    
    except Exception as e:
        print(f"Failed to parse {midi_file} due to error: {e}")
        return pd.DataFrame()

In [20]:
def extract_advanced_note_features(midi_file):
    try:
        pm = pretty_midi.PrettyMIDI(midi_file)
        instrument = pm.instruments[0]  # Assuming single instrument for now

        notes = defaultdict(list)
        for note in instrument.notes:
            notes["pitch"].append(note.pitch)
            notes["velocity"].append(note.velocity)  # Extract actual velocity (1-127)
            notes["note_name"].append(pretty_midi.note_number_to_name(note.pitch))  # e.g., 'C#4'
            notes["octave"].append(note.pitch // 12 - 1)  # Convert MIDI pitch to octave number
            notes["start"].append(note.start)
            notes["end"].append(note.end)
            notes["duration"].append(note.end - note.start)            

        return pd.DataFrame(notes)

    except Exception as e:
        print(f"Failed to process {midi_file}: {e}")
        return pd.DataFrame()  # Return empty if failed
    
    
def extract_all_midi_files(folder, max_files=100):
    paths = list(Path(folder).rglob("*.mid"))[:max_files]
    print(f"Found {len(paths)} MIDI files")

    all_dfs = []
    for path in paths:
        df = extract_advanced_note_features(str(path))
        if not df.empty:
            df["filename"] = path.name
            all_dfs.append(df)

    if all_dfs:
        return pd.concat(all_dfs, ignore_index=True)
    else:
        return pd.DataFrame()

In [21]:
# read data and extract pitch, velocity, note_name, octave, start_time, end, and duration 
data_dir = "/Users/yang/Desktop/Yale Spring 2025/CPSC 552 Deep learning theory and applications /DeepL project - music generation /Data set/lmd_matched"
all_notes_df = extract_all_midi_files(data_dir)
all_notes_df.head()

#all_notes_df.to_csv("/Users/yang/Documents/processed_notes_lakh.csv", index=False)



Found 100 MIDI files




Unnamed: 0,pitch,velocity,note_name,octave,start,end,duration,filename
0,46,90,A#2,2,6.315776,6.776301,0.460525,2740bc2a1cd9bae5dfb5ddc40f2aefb9.mid
1,41,92,F2,2,7.103192,7.487648,0.384456,2740bc2a1cd9bae5dfb5ddc40f2aefb9.mid
2,44,91,G#2,2,7.493816,7.845378,0.351562,2740bc2a1cd9bae5dfb5ddc40f2aefb9.mid
3,46,83,A#2,2,7.880329,8.402531,0.522203,2740bc2a1cd9bae5dfb5ddc40f2aefb9.mid
4,41,74,F2,2,8.673912,9.08304,0.409127,2740bc2a1cd9bae5dfb5ddc40f2aefb9.mid


In [5]:
all_notes_df["filename"].nunique() # check number of unique files in our code 

100

## Version 2: try to extract instrument information correctly

In [47]:
import os

def extract_all_midi_files_2(root_dir, max_files=None):
    from tqdm import tqdm
    import pandas as pd

    all_notes = []
    count = 0

    for dirpath, _, filenames in os.walk(root_dir):
        for file in filenames:
            if not file.lower().endswith(".mid"):
                continue

            full_path = os.path.join(dirpath, file)

            df = extract_note_features_with_instrument(full_path)
            if not df.empty:
                df["filename"] = file
                all_notes.append(df)
                count += 1

            if max_files and count >= max_files:
                break

    return pd.concat(all_notes, ignore_index=True) if all_notes else pd.DataFrame()

def extract_all_midi_files_3(folder, max_files=500): # this works 
    paths = list(Path(folder).rglob("*.mid"))[:max_files]
    print(f"Found {len(paths)} MIDI files")

    all_dfs = []
    for path in paths:
        df = extract_note_features_with_instrument(str(path))
        if not df.empty:
            df["filename"] = path.name
            all_dfs.append(df)

    if all_dfs:
        return pd.concat(all_dfs, ignore_index=True)
    else:
        return pd.DataFrame()

**Successfully extract the instrument information now**

In [48]:
# second version that also extract information related to instrument used 

data_dir = "/Users/yang/Desktop/Yale Spring 2025/CPSC 552 Deep learning theory and applications /DeepL project - music generation /Data set/lmd_matched"
extracted_lakh_df = extract_all_midi_files_3(data_dir)
extracted_lakh_df.head()

Found 500 MIDI files




Failed to parse /Users/yang/Desktop/Yale Spring 2025/CPSC 552 Deep learning theory and applications /DeepL project - music generation /Data set/lmd_matched/R/U/I/TRRUIKN128E078218F/c03a94962031205a7656296e967b92f7.mid due to error: data byte must be in range 0..127




Unnamed: 0,pitch,velocity,note_name,octave,start,end,duration,instrument,filename
0,46,90,A#2,2,6.315776,6.776301,0.460525,0,2740bc2a1cd9bae5dfb5ddc40f2aefb9.mid
1,41,92,F2,2,7.103192,7.487648,0.384456,0,2740bc2a1cd9bae5dfb5ddc40f2aefb9.mid
2,44,91,G#2,2,7.493816,7.845378,0.351562,0,2740bc2a1cd9bae5dfb5ddc40f2aefb9.mid
3,46,83,A#2,2,7.880329,8.402531,0.522203,0,2740bc2a1cd9bae5dfb5ddc40f2aefb9.mid
4,41,74,F2,2,8.673912,9.08304,0.409127,0,2740bc2a1cd9bae5dfb5ddc40f2aefb9.mid


In [None]:
extracted_lakh_df.to_csv("/Users/yang/Documents/processed_notes_lakh_instru.csv", index=False) 

### Data preparation & Preprocessing 


In [49]:
# Full pipeline for symbolic MusicGen-style training using extracted note features (with chord conditioning)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import numpy as np
from collections import Counter

# Configuration
D_MODEL = 256
NUM_PITCHES = 128
NUM_VELOCITIES = 32
NUM_DURATIONS = 32
NUM_CHORDS = 12  # 12 pitch classes (C, C#, D, ..., B)
NUM_INSTRUMENTS = 128
SEQ_LEN = 64

In [50]:
# Chord Estimation from Notes (chromagram-inspired)
def estimate_chords(df):
    df = df.copy()
    df["chord"] = -1
    filenames = df["filename"].unique()
    for fname in filenames:
        song = df[df["filename"] == fname].copy()
        song = song.sort_values("start")
        chords = []
        for i in range(0, len(song), SEQ_LEN):
            segment = song.iloc[i:i+SEQ_LEN]
            pitch_classes = [p % 12 for p in segment["pitch"]]
            if len(pitch_classes) == 0:
                chord_id = 0
            else:
                chord_id = Counter(pitch_classes).most_common(1)[0][0]
            chords += [chord_id] * len(segment)
        df.loc[df["filename"] == fname, "chord"] = chords
    return df

In [51]:
# data preparation 
def discretize_velocity(velocity):
    return min(int(velocity // 4), NUM_VELOCITIES - 1)

def discretize_duration(duration):
    
    idx = np.floor(duration / 0.1).astype(int)
    # return min(int(duration / 0.1), NUM_DURATIONS - 1) # original version, change to avoid overflow
    return np.clip(idx, 0, NUM_DURATIONS - 1)

def build_sequence_tensor(df, max_seq_len=SEQ_LEN):
    sequences = []
    grouped = df.groupby("filename")
    
    # For debug: record all values
    all_durations = []
    all_velocities = []

    for _, group in grouped:
        group = group.sort_values("start")
        for i in range(0, len(group) - max_seq_len, max_seq_len):
            chunk = group.iloc[i:i+max_seq_len]
            pitch = torch.tensor(chunk["pitch"].values, dtype=torch.long)
            velocity = torch.tensor(chunk["velocity"].apply(discretize_velocity).values, dtype=torch.long)
            duration = torch.tensor(chunk["duration"].apply(discretize_duration).values, dtype=torch.long)
            chord = torch.tensor(chunk["chord"].values, dtype=torch.long)
            instrument = torch.tensor(chunk["instrument"].values, dtype=torch.long)
            sequences.append((pitch, velocity, duration, chord, instrument)) # cannot successfully obtain instru
            
            # save for debug
            all_durations.extend(duration.tolist())
            all_velocities.extend(velocity.tolist())
            
            
    # After discretization inside build_sequence_tensor
    print("Duration max:", max(all_durations))
    print("Duration min:", min(all_durations))
    print("Velocity max:", max(all_velocities))
    print("Velocity min:", min(all_velocities))

    
    return sequences

In [52]:

# Dataset Wrapper
class SequenceDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx]

In [53]:
#df = estimate_chords(all_notes_df)
df = estimate_chords(extracted_lakh_df)
sequences = build_sequence_tensor(df)
dataset = SequenceDataset(sequences)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

Duration max: 31
Duration min: 0
Velocity max: 31
Velocity min: 0


In [54]:
df.head()

Unnamed: 0,pitch,velocity,note_name,octave,start,end,duration,instrument,filename,chord
0,46,90,A#2,2,6.315776,6.776301,0.460525,0,2740bc2a1cd9bae5dfb5ddc40f2aefb9.mid,10
1,41,92,F2,2,7.103192,7.487648,0.384456,0,2740bc2a1cd9bae5dfb5ddc40f2aefb9.mid,10
2,44,91,G#2,2,7.493816,7.845378,0.351562,0,2740bc2a1cd9bae5dfb5ddc40f2aefb9.mid,10
3,46,83,A#2,2,7.880329,8.402531,0.522203,0,2740bc2a1cd9bae5dfb5ddc40f2aefb9.mid,10
4,41,74,F2,2,8.673912,9.08304,0.409127,0,2740bc2a1cd9bae5dfb5ddc40f2aefb9.mid,10


### Model construction 

In [55]:
# Model Definition (Chord Conditioning)
class SymbolicMusicTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.pitch_embed = nn.Embedding(NUM_PITCHES, D_MODEL)
        self.velocity_embed = nn.Embedding(NUM_VELOCITIES, D_MODEL)
        self.duration_embed = nn.Embedding(NUM_DURATIONS, D_MODEL)
        self.chord_embed = nn.Embedding(NUM_CHORDS, D_MODEL)
        self.instrument_embed = nn.Embedding(NUM_INSTRUMENTS, D_MODEL)
        self.pos_embed = nn.Embedding(SEQ_LEN, D_MODEL)

        encoder_layer = nn.TransformerEncoderLayer(d_model=D_MODEL, nhead=8, dim_feedforward=512)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=4)

        self.pitch_out = nn.Linear(D_MODEL, NUM_PITCHES)
        self.velocity_out = nn.Linear(D_MODEL, NUM_VELOCITIES)
        self.duration_out = nn.Linear(D_MODEL, NUM_DURATIONS)
        self.instrument_out = nn.Linear(D_MODEL, NUM_INSTRUMENTS)

    def forward(self, pitch, velocity, duration, chord, instrument):
        B, T = pitch.shape
        pos = torch.arange(T, device=pitch.device).unsqueeze(0).expand(B, T)

        x = self.pitch_embed(pitch) + \
            self.velocity_embed(velocity) + \
            self.duration_embed(duration) + \
            self.chord_embed(chord) + \
            self.instrument_embed(instrument) + \
            self.pos_embed(pos)

        x = x.transpose(0, 1)
        x = self.transformer(x)
        x = x.transpose(0, 1)

        return self.pitch_out(x), self.velocity_out(x), self.duration_out(x), self.instrument_out(x)

In [56]:
# Training Loop
def train_model(model, dataloader, epochs=7):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        for pitch, velocity, duration, chord, instrument in dataloader:
            optimizer.zero_grad()
            pitch_logits, vel_logits, dur_logits, instr_logits = model(pitch, velocity, duration, chord, instrument)
            
            loss = F.cross_entropy(pitch_logits.view(-1, NUM_PITCHES), pitch.view(-1)) + \
                   F.cross_entropy(vel_logits.view(-1, NUM_VELOCITIES), velocity.view(-1)) + \
                   F.cross_entropy(dur_logits.view(-1, NUM_DURATIONS), duration.view(-1)) + \
                   F.cross_entropy(instr_logits.view(-1, NUM_INSTRUMENTS), instrument.view(-1))
        
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss / len(dataloader):.4f}")


In [57]:
print("Pitch range:", df["pitch"].max(), df["pitch"].min())
print("Velocity range:", df["velocity"].max(), df["velocity"].min())
print("Duration range:", df["duration"].max(), df["duration"].min())
print("Chord range:", df["chord"].max(), df["chord"].min())
print("Instrument range:", df["instrument"].max(), df["instrument"].min())

Pitch range: 123 0
Velocity range: 127 1
Duration range: 219.42774187500004 4.844958333194427e-05
Chord range: 11 0
Instrument range: 127 0


In [58]:
model = SymbolicMusicTransformer()
train_model(model, dataloader)



Epoch 1, Loss: 0.4242
Epoch 2, Loss: 0.0031
Epoch 3, Loss: 0.0010
Epoch 4, Loss: 0.0004
Epoch 5, Loss: 0.0002
Epoch 6, Loss: 0.0002
Epoch 7, Loss: 7.4614


## Test model generation

In [35]:
import torch

# Randomly pick a chord condition between 0 and NUM_CHORDS-1
random_chord_condition = torch.randint(0, NUM_CHORDS, (1,)).item()
print("Randomly selected chord condition:", random_chord_condition)


Randomly selected chord condition: 0


Built initial input to feed to the model 

In [None]:
SEQ_LEN = 64  # (or whatever your model uses)

pitch_input = torch.zeros((1, SEQ_LEN), dtype=torch.long)
velocity_input = torch.zeros((1, SEQ_LEN), dtype=torch.long)
duration_input = torch.zeros((1, SEQ_LEN), dtype=torch.long)
instrument_input = torch.zeros((1, SEQ_LEN), dtype=torch.long)
chord_input = torch.full((1, SEQ_LEN), random_chord_condition, dtype=torch.long)  # filled with the chosen chord


Run the model to generate outputs 

In [37]:
model.eval()
with torch.no_grad():
    pitch_logits, velocity_logits, duration_logits, instrument_logits = model(
        pitch_input, velocity_input, duration_input, chord_input, instrument_input
    )

# Take the argmax (highest probability) for each time step
generated_pitch = pitch_logits.argmax(-1).squeeze(0).cpu().numpy()
generated_velocity = velocity_logits.argmax(-1).squeeze(0).cpu().numpy()
generated_duration = duration_logits.argmax(-1).squeeze(0).cpu().numpy()
generated_instrument = instrument_logits.argmax(-1).squeeze(0).cpu().numpy()


Convert the generated sequence into MIDI

In [39]:
import pretty_midi

def save_generated_midi(pitches, velocities, durations, instruments, output_path="./generated_sample.mid"):
    midi = pretty_midi.PrettyMIDI()
    instrument_map = {}

    time = 0.0
    for p, v, d, i in zip(pitches, velocities, durations, instruments):
        start_time = time
        dur_sec = (d + 1) * 0.1  # Each discrete duration bucket is 0.1 second
        end_time = start_time + dur_sec

        velocity = int(min(max(v * 4 + 30, 1), 127))  # Map velocity bucket to actual value

        if i not in instrument_map:
            instrument_map[i] = pretty_midi.Instrument(program=int(i))

        note = pretty_midi.Note(
            velocity=velocity,
            pitch=int(p),
            start=start_time,
            end=end_time
        )
        instrument_map[i].notes.append(note)

        time += dur_sec  # move forward in time

    # Add instruments into MIDI
    for inst in instrument_map.values():
        midi.instruments.append(inst)

    midi.write(output_path)
    print(f"MIDI file saved to: {output_path}")

# Save generated music
#output_path = "./generated_test_sample_v1.mid"
output_path = "/Users/yang/Documents/generated_sample.mid"
save_generated_midi(generated_pitch, generated_velocity, generated_duration, generated_instrument, output_path)


MIDI file saved to: /Users/yang/Documents/generated_sample.mid


Multiple-chord generation

In [59]:
# Helper function to generate symbolic music conditioned on multiple chords
import torch
import pretty_midi

def generate_music_with_chord_sequence_mul(model, chord_list, output_path="generated_multi_chord_sample.mid"):
    model.eval()
    with torch.no_grad():
        T = SEQ_LEN

        # Build the chord sequence (repeat each chord equally)
        frames_per_chord = T // len(chord_list)
        chord_sequence = []
        for chord in chord_list:
            chord_sequence += [chord] * frames_per_chord

        # If not perfectly divisible, pad with the last chord
        if len(chord_sequence) < T:
            chord_sequence += [chord_list[-1]] * (T - len(chord_sequence))

        chord_input = torch.tensor(chord_sequence, dtype=torch.long).unsqueeze(0)

        # Initialize zero inputs for pitch, velocity, duration, instrument
        pitch_input = torch.zeros((1, T), dtype=torch.long)
        velocity_input = torch.zeros((1, T), dtype=torch.long)
        duration_input = torch.zeros((1, T), dtype=torch.long)
        instrument_input = torch.zeros((1, T), dtype=torch.long)

        # Model forward
        pitch_logits, velocity_logits, duration_logits, instrument_logits = model(
            pitch_input, velocity_input, duration_input, chord_input, instrument_input
        )

        # Take argmax to get predicted notes
        generated_pitch = pitch_logits.argmax(-1).squeeze(0).tolist()
        generated_velocity = velocity_logits.argmax(-1).squeeze(0).tolist()
        generated_duration = duration_logits.argmax(-1).squeeze(0).tolist()
        generated_instrument = instrument_logits.argmax(-1).squeeze(0).tolist()

        # Save to MIDI
        save_generated_midi(
            generated_pitch,
            generated_velocity,
            generated_duration,
            generated_instrument,
            output_path=output_path
        )



In [60]:
def save_generated_midi(pitch_seq, velocity_seq, duration_seq, instrument_seq, output_path="generated_output.mid"):
    midi = pretty_midi.PrettyMIDI()
    instrument_map = {}

    time = 0.0
    for p, v, d, i in zip(pitch_seq, velocity_seq, duration_seq, instrument_seq):
        dur_sec = (d + 1) * 0.1  # duration bucket to seconds
        velocity_clipped = min(max(int(v * 4 + 30), 1), 127)

        start_time = time
        end_time = time + dur_sec

        if i not in instrument_map:
            program_num = int(i) if i < 127 else 127
            instrument_map[i] = pretty_midi.Instrument(program=program_num)

        note = pretty_midi.Note(velocity=velocity_clipped, pitch=p, start=start_time, end=end_time)
        instrument_map[i].notes.append(note)
        time += dur_sec

    for inst in instrument_map.values():
        midi.instruments.append(inst)

    midi.write(output_path)
    print(f"MIDI file saved to: {output_path}")

In [61]:
# Example Usage
# Assuming model is already trained or initialized
chord_list = [0, 5, 7, 2]  # C major, F major, G major, D minor
output_path = "/Users/yang/Documents/generated_sample3.mid"
generate_music_with_chord_sequence_mul(model, chord_list, output_path)

MIDI file saved to: /Users/yang/Documents/generated_sample3.mid


In [40]:
# Suppose you want to switch chord every 16 frames
chords = [0, 5, 7, 9]  # C, F, G, A (example)

# Repeat each chord for 16 frames
chord_sequence = []
for c in chords:
    chord_sequence += [c] * 16

# Ensure length matches SEQ_LEN
chord_sequence = chord_sequence[:SEQ_LEN]  

# Convert to tensor
chord_input = torch.tensor(chord_sequence, dtype=torch.long).unsqueeze(0)  # shape (1, SEQ_LEN)


In [46]:
print(set(generated_instrument))

{0}


# Miscellaneous

In [None]:
# Symbolic Music Transformer (MIDI-like model)
# Pitch, Velocity, Duration, Instrument embeddings + Chord conditioning

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random
import numpy as np
import pandas as pd
import pretty_midi
import os
from pathlib import Path

# -----------------------------
# Configuration
# -----------------------------
D_MODEL = 256  # Embedding dimension
NUM_PITCHES = 128
NUM_VELOCITIES = 32
NUM_DURATIONS = 64
NUM_CHORDS = 50
NUM_INSTRUMENTS = 128  # Standard MIDI program count
SEQ_LEN = 64

# -----------------------------
# Dataset Class with Instrument Awareness
# -----------------------------
class SymbolicMusicDataset(Dataset):
    def __init__(self, midi_folder, max_files=100):
        self.samples = []
        self._collect_samples(midi_folder, max_files)

    def _collect_samples(self, midi_folder, max_files):
        midi_paths = list(Path(midi_folder).rglob("*.mid"))[:max_files]
        for midi_path in midi_paths:
            try:
                midi = pretty_midi.PrettyMIDI(str(midi_path))
                for inst in midi.instruments:
                    for note in inst.notes:
                        pitch = note.pitch
                        velocity = min(note.velocity // 4, NUM_VELOCITIES - 1)
                        duration = min(int((note.end - note.start) / 0.1), NUM_DURATIONS - 1)
                        chord = random.randint(0, NUM_CHORDS - 1)  # placeholder
                        instrument = inst.program if not inst.is_drum else 0
                        self.samples.append((pitch, velocity, duration, chord, instrument))
            except Exception as e:
                print(f"Error reading {midi_path}: {e}")

    def __len__(self):
        return len(self.samples) // SEQ_LEN

    def __getitem__(self, idx):
        idx = idx * SEQ_LEN
        s = self.samples[idx:idx+SEQ_LEN]
        pitch = torch.tensor([x[0] for x in s])
        velocity = torch.tensor([x[1] for x in s])
        duration = torch.tensor([x[2] for x in s])
        chord = torch.tensor([x[3] for x in s])
        instrument = torch.tensor([x[4] for x in s])
        return pitch, velocity, duration, chord, instrument

# -----------------------------
# Model Definition
# -----------------------------
class SymbolicTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        # Embeddings
        self.pitch_embed = nn.Embedding(NUM_PITCHES, D_MODEL)
        self.velocity_embed = nn.Embedding(NUM_VELOCITIES, D_MODEL)
        self.duration_embed = nn.Embedding(NUM_DURATIONS, D_MODEL)
        self.chord_embed = nn.Embedding(NUM_CHORDS, D_MODEL)
        self.instrument_embed = nn.Embedding(NUM_INSTRUMENTS, D_MODEL)
        self.pos_embed = nn.Embedding(SEQ_LEN, D_MODEL)

        encoder_layer = nn.TransformerEncoderLayer(d_model=D_MODEL, nhead=8, dim_feedforward=512)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=4)

        self.pitch_out = nn.Linear(D_MODEL, NUM_PITCHES)
        self.velocity_out = nn.Linear(D_MODEL, NUM_VELOCITIES)
        self.duration_out = nn.Linear(D_MODEL, NUM_DURATIONS)
        self.instrument_out = nn.Linear(D_MODEL, NUM_INSTRUMENTS)

    def forward(self, pitch, velocity, duration, chord, instrument):
        B, T = pitch.size()
        pos = torch.arange(T, device=pitch.device).unsqueeze(0).expand(B, T)

        x = self.pitch_embed(pitch) + \
            self.velocity_embed(velocity) + \
            self.duration_embed(duration) + \
            self.chord_embed(chord) + \
            self.instrument_embed(instrument) + \
            self.pos_embed(pos)

        x = x.transpose(0, 1)
        x = self.transformer(x)
        x = x.transpose(0, 1)

        return self.pitch_out(x), self.velocity_out(x), self.duration_out(x), self.instrument_out(x)

# -----------------------------
# Training Loop
# -----------------------------
def train_model(data_dir):
    dataset = SymbolicMusicDataset(midi_folder=data_dir)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
    model = SymbolicTransformer()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(5):
        total_loss = 0
        for pitch, vel, dur, chord, instr in dataloader:
            optimizer.zero_grad()
            pitch_logits, vel_logits, dur_logits, instr_logits = model(pitch, vel, dur, chord, instr)

            loss = F.cross_entropy(pitch_logits.view(-1, NUM_PITCHES), pitch.view(-1)) + \
                   F.cross_entropy(vel_logits.view(-1, NUM_VELOCITIES), vel.view(-1)) + \
                   F.cross_entropy(dur_logits.view(-1, NUM_DURATIONS), dur.view(-1)) + \
                   F.cross_entropy(instr_logits.view(-1, NUM_INSTRUMENTS), instr.view(-1))

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")

if __name__ == "__main__":
    train_model("/path/to/your/lakh/dataset")
