In [8]:
import pickle

# Load data
with open("processed_sequences.pkl", "rb") as f:
    data = pickle.load(f)

with open("chord_mappings.pkl", "rb") as f:
    mappings = pickle.load(f)

# input/target/genre sequences
input_sequences = data["input_sequences"]
target_sequences = data["target_sequences"]
genres_for_sequences = data["genres_for_sequences"]

# chord_mappings
chord_to_id = mappings["chord_to_id"]
id_to_chord = mappings["id_to_chord"]

# genre_mappings
genres = sorted(set(genres_for_sequences))
genre_to_id = {genre: i for i, genre in enumerate(genres)}
id_to_genre = {i: genre for genre, i in genre_to_id.items()}

# Encode genre labels
encoded_genres = [genre_to_id[g] for g in genres_for_sequences]

# tokenizer
def tokenize_chords(chord_seq):
    return [chord_to_id.get(chord, chord_to_id["UNK"]) for chord in chord_seq]

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

MAX_SEQ_LEN = 200  # Set the maz sequence length to 200

class ChordPredictionDataset(Dataset):
    def __init__(self, inputs, targets, genres, chord_to_id, genre_to_id):
        self.inputs = inputs
        self.targets = targets
        self.genres = [genre_to_id[g] for g in genres]
        self.pad_id = chord_to_id["PAD"]

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return {
            "input_ids": self.inputs[idx],
            "target_ids": self.targets[idx],
            "genre_id": self.genres[idx]
        }


In [6]:
def collate_fn(batch):
    input_seqs = [item["input_ids"] for item in batch]
    target_seqs = [item["target_ids"] for item in batch]
    genre_ids = [item["genre_id"] for item in batch]

    # Pad or truncate sequences to MAX_SEQ_LEN
    pad = lambda seqs: [
        (seq[:MAX_SEQ_LEN] + [chord_to_id["PAD"]] * (MAX_SEQ_LEN - len(seq))) if len(seq) < MAX_SEQ_LEN else seq[:MAX_SEQ_LEN]
        for seq in seqs
    ]
    mask = lambda seqs: [
        [1]*len(seq) + [0]*(MAX_SEQ_LEN - len(seq)) if len(seq) < MAX_SEQ_LEN else [1]*MAX_SEQ_LEN
        for seq in seqs
    ]

    return {
        "input_ids": torch.tensor(pad(input_seqs), dtype=torch.long),
        "target_ids": torch.tensor(pad(target_seqs), dtype=torch.long),
        "attention_mask": torch.tensor(mask(input_seqs), dtype=torch.long),
        "genre_id": torch.tensor(genre_ids, dtype=torch.long)
    }

In [7]:
from torch.utils.data import Subset, DataLoader

dataset = ChordPredictionDataset(input_sequences, target_sequences, genres_for_sequences, chord_to_id, genre_to_id)

# train with 10000 samples 
small_dataset = Subset(dataset, list(range(min(10000, len(dataset)))))

dataloader = DataLoader(small_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [None]:
import torch.nn.functional as F
from torch.optim import Adam
from genre_aware_model import build_transformer
import torch.nn as nn
from tqdm import tqdm  # Import tqdm for the progress bar


device = torch.device("cuda" if torch.cuda.is_available() else "mps")
print(device)
model = build_transformer(
    src_vocab_size=len(chord_to_id),
    tgt_vocab_size=len(chord_to_id),
    src_seq_len=200,
    tgt_seq_len=200,
    genre_len=len(genre_to_id),
    d_model=256,
    ffn_size=256, 
    dropout=0.2
).to(device)

optimizer = Adam(model.parameters(), lr=1e-4)
loss_fn = nn.NLLLoss(ignore_index=chord_to_id["PAD"])

# Training loop
for epoch in range(10):
    model.train()
    total_loss = 0

    # each batch: input_ids, target_ids, genre_id
    # input_ids: [batch_size, seq_len]
    # target_ids: [batch_size, seq_len]
    # genre_id: [batch_size]
    # attention_mask: [batch_size, seq_len]
    # tgt_input: [batch_size, seq_len]
    # tgt_output: [batch_size, seq_len]
    for i, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}", ncols=100, leave=False)):

        src = batch["input_ids"].to(device)
        tgt_input = batch["input_ids"].to(device)
        tgt_output = batch["target_ids"].to(device)
        genre = batch["genre_id"].to(device)
        src_mask = batch["attention_mask"].unsqueeze(1).unsqueeze(2).to(device)
        tgt_mask = torch.tril(torch.ones((src.size(1), src.size(1)))).unsqueeze(0).unsqueeze(0).to(device)

        enc_out = model.encode(src, genre, src_mask=src_mask)
        dec_out = model.decode(tgt_input, enc_out, genre, src_mask=src_mask, tgt_mask=tgt_mask)
        logits = model.project(dec_out)

        loss = loss_fn(logits.view(-1, logits.size(-1)), tgt_output.view(-1))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect() 

        total_loss += loss.item()

        tqdm.write(f"Batch {i+1}, Loss: {loss.item():.4f}")

    print(f"Epoch {epoch+1}, Total Loss: {total_loss:.4f}")
    
    torch.save(model.state_dict(), f"transformer_chord_predictor_epoch{epoch+1}.pt")

torch.save(model.state_dict(), "transformer_chord_predictor_final.pt")