In [1]:
!pip install -q miditok symusic tqdm

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
import os
import random
import math
import glob
import json
import numpy as np
from pathlib import Path
from miditok import REMI, TokenizerConfig
from symusic import Score
from tqdm.notebook import tqdm

In [3]:
class Config:
    DATA_DIR = "BACH_MIDI"

    STRIDE = 128           
    SEQ_LENGTH = 1024      
    
    EMBED_DIM = 256
    N_HEADS = 8
    N_LAYERS = 6
    DROPOUT = 0.3          
    
    BATCH_SIZE = 16
    ACCUMULATION_STEPS = 1 
    LEARNING_RATE = 5e-4
    EPOCHS = 50
    PATIENCE = 5
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config()
print(f"Device: {config.DEVICE}")

Device: cuda


In [4]:
def load_bach_remi(path):
    path = Path(path)
    train_files = list((path / "train").glob("*.mid"))
    val_files = list((path / "valid").glob("*.mid"))
    
    print(f"Found {len(train_files)} Train | {len(val_files)} Valid files")

    tokenizer_config = TokenizerConfig(
        num_velocities=1,      
        use_chords=True,      
        use_tempos=False,    
        use_velocities=False, 
        use_durations=False,  
        use_programs=False,
        beat_res={(0, 4): 4}   
    )
    tokenizer = REMI(tokenizer_config)
    
    def tokenize_batch(file_list, desc):
        tokens_list = []
        for f in tqdm(file_list, desc=desc):
            try:
                midi = Score(f)
                toks = tokenizer(midi)
                if len(toks) > 0: tokens_list.extend(toks[0].ids)
            except: pass
        return torch.tensor(tokens_list, dtype=torch.long)

    train_tensor = tokenize_batch(train_files, "Tokenizing Train")
    val_tensor = tokenize_batch(val_files, "Tokenizing Valid")
    
    return train_tensor, val_tensor, tokenizer

train_tensor, val_tensor, tokenizer = load_bach_remi(config.DATA_DIR)
vocab_size = tokenizer.vocab_size
print(f"Vocab Size: {vocab_size}")

Found 229 Train | 76 Valid files


Tokenizing Train:   0%|          | 0/229 [00:00<?, ?it/s]

Tokenizing Valid:   0%|          | 0/76 [00:00<?, ?it/s]

Vocab Size: 202


In [5]:
class MidiDataset(Dataset):
    def __init__(self, data_tensor, seq_len, stride):
        self.seq_len = seq_len
        self.data = data_tensor
        self.indices = range(0, len(data_tensor) - seq_len - 1, stride)
    def __len__(self): return len(self.indices)
    def __getitem__(self, idx):
        start_idx = self.indices[idx]
        return (self.data[start_idx : start_idx + self.seq_len], 
                self.data[start_idx+1 : start_idx + self.seq_len + 1])

train_ds = MidiDataset(train_tensor, config.SEQ_LENGTH, config.STRIDE)
val_ds = MidiDataset(val_tensor, config.SEQ_LENGTH, config.STRIDE)

train_loader = DataLoader(train_ds, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=2)

In [6]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div)
        pe[:, 1::2] = torch.cos(position * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x): return x + self.pe[:, :x.size(1), :]

class StandardBlock(nn.Module):
    def __init__(self, d_model, n_heads, dropout):
        super().__init__()
        # 1. Self Attention
        self.attn_norm = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        
        # 2. Feed-Forward
        self.ffn_norm = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask):
        res = x
        x_norm = self.attn_norm(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm, attn_mask=mask, is_causal=True)
        x = res + attn_out
        
        res = x
        ffn_out = self.ffn(self.ffn_norm(x))
        x = res + ffn_out
        return x

class MusicTransformer(nn.Module):
    def __init__(self, vocab_size, cfg):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, cfg.EMBED_DIM)
        self.pos = PositionalEncoding(cfg.EMBED_DIM)
        
        self.layers = nn.ModuleList([
            StandardBlock(cfg.EMBED_DIM, cfg.N_HEADS, cfg.DROPOUT) 
            for _ in range(cfg.N_LAYERS)
        ])
        
        self.final_norm = nn.LayerNorm(cfg.EMBED_DIM)
        self.head = nn.Linear(cfg.EMBED_DIM, vocab_size)

    def forward(self, x):
        # Causal Mask
        mask = torch.triu(torch.ones(x.size(1), x.size(1)) * float('-inf'), diagonal=1).to(x.device)
        
        x = self.pos(self.embed(x))
        for layer in self.layers: 
            x = layer(x, mask)
        x = self.final_norm(x)
        return self.head(x)

In [7]:
model = MusicTransformer(vocab_size, config).to(config.DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scaler = torch.amp.GradScaler('cuda')

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=config.LEARNING_RATE,
    steps_per_epoch=len(train_loader), epochs=config.EPOCHS, pct_start=0.1
)

def train_epoch(model, loader):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    loop = tqdm(loader, desc="Train", leave=False)
    
    for i, (x, y) in enumerate(loop):
        x, y = x.to(config.DEVICE), y.to(config.DEVICE)
        
        with autocast():
            logits = model(x)
            loss = criterion(logits.view(-1, vocab_size), y.view(-1))
        
        scaler.scale(loss).backward()
        
        # Accuracy Metrics
        predictions = logits.argmax(dim=-1) 
        total_correct += (predictions == y).sum().item()
        total_samples += y.numel()
        
        if (i + 1) % config.ACCUMULATION_STEPS == 0:
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)
            
        total_loss += loss.item()
        loop.set_postfix(loss=f"{loss.item():.4f}")
        
    return total_loss / len(loader), total_correct / total_samples

def validate(model, loader):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(config.DEVICE), y.to(config.DEVICE)
            with autocast():
                logits = model(x)
                loss = criterion(logits.view(-1, vocab_size), y.view(-1))
            
            predictions = logits.argmax(dim=-1)
            total_correct += (predictions == y).sum().item()
            total_samples += y.numel()
            total_loss += loss.item()
            
    return total_loss / len(loader), total_correct / total_samples

print(f"Model Params: {sum(p.numel() for p in model.parameters()):,}")

Model Params: 4,842,698


In [8]:
best_loss = float('inf')
patience = 0

for epoch in range(config.EPOCHS):
    t_loss, t_acc = train_epoch(model, train_loader)
    v_loss, v_acc = validate(model, val_loader)
    
    try: ppl = math.exp(v_loss)
    except: ppl = float('inf')
    
    print(
        f"Epoch {epoch+1}/{config.EPOCHS} | "
        f"Train: {t_loss:.4f} ({t_acc:.1%}) | "
        f"Val: {v_loss:.4f} ({v_acc:.1%}) | "
        f"PPL: {ppl:.2f}"
    )
    
    if v_loss < best_loss:
        best_loss = v_loss
        patience = 0
        torch.save(model.state_dict(), "best_bach_transformer.pth")
        print("Saved New Best Model")
    else:
        patience += 1
        if patience >= config.PATIENCE:
            print("Early Stopping")
            break

Train:   0%|          | 0/244 [00:00<?, ?it/s]

  with autocast():
  with autocast():


Epoch 1/50 | Train: 2.6500 (46.0%) | Val: 1.9901 (50.0%) | PPL: 7.32
Saved New Best Model


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 2/50 | Train: 1.8660 (51.6%) | Val: 1.6767 (55.4%) | PPL: 5.35
Saved New Best Model


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 3/50 | Train: 1.4803 (59.0%) | Val: 1.1930 (62.9%) | PPL: 3.30
Saved New Best Model


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 4/50 | Train: 0.8020 (75.7%) | Val: 0.3494 (90.1%) | PPL: 1.42
Saved New Best Model


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 5/50 | Train: 0.3322 (90.2%) | Val: 0.2610 (92.2%) | PPL: 1.30
Saved New Best Model


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 6/50 | Train: 0.2578 (92.1%) | Val: 0.2406 (92.8%) | PPL: 1.27
Saved New Best Model


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 7/50 | Train: 0.2267 (93.0%) | Val: 0.2303 (93.1%) | PPL: 1.26
Saved New Best Model


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 8/50 | Train: 0.2064 (93.5%) | Val: 0.2247 (93.3%) | PPL: 1.25
Saved New Best Model


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 9/50 | Train: 0.1892 (93.9%) | Val: 0.2231 (93.3%) | PPL: 1.25
Saved New Best Model


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 10/50 | Train: 0.1728 (94.4%) | Val: 0.2248 (93.3%) | PPL: 1.25


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 11/50 | Train: 0.1579 (94.8%) | Val: 0.2314 (93.3%) | PPL: 1.26


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 12/50 | Train: 0.1424 (95.3%) | Val: 0.2408 (93.2%) | PPL: 1.27


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 13/50 | Train: 0.1266 (95.8%) | Val: 0.2534 (93.1%) | PPL: 1.29


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 14/50 | Train: 0.1115 (96.2%) | Val: 0.2692 (92.8%) | PPL: 1.31
Early Stopping


In [9]:
from miditok import TokSequence
from datetime import datetime

GEN_CONF = {
    'num_examples': 2,
    'max_len': 1024,
    'temperature': 0.9,
    'top_k': 10
}

def generate_sequence(model, tokenizer, length, temp=1.0, top_k=20):
    model.eval()
    start_token = torch.randint(0, tokenizer.vocab_size, (1, 1)).to(config.DEVICE)
    generated = start_token
    
    print(f"Generating... (Target: {length})")
    with torch.no_grad():
        for _ in tqdm(range(length), desc="Gen"):
            logits = model(generated)
            last_logits = logits[:, -1, :] / temp
            v, _ = torch.topk(last_logits, top_k)
            last_logits[last_logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(last_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat((generated, next_token), dim=1)
    return generated[0].cpu().numpy()

for i in range(GEN_CONF['num_examples']):
    try:
        raw_tokens = generate_sequence(model, tokenizer, GEN_CONF['max_len'], GEN_CONF['temperature'], GEN_CONF['top_k'])
        seq = TokSequence(ids=raw_tokens.tolist())
        score = tokenizer.decode([seq])
        filename = f"bach_transformer_{i+1}.mid"
        score.dump_midi(filename)
        print(f"Saved: {filename}")
    except Exception as e:
        print(f"Error: {e}")

Generating... (Target: 1024)


Gen:   0%|          | 0/1024 [00:00<?, ?it/s]

Saved: bach_transformer_1.mid
Generating... (Target: 1024)


Gen:   0%|          | 0/1024 [00:00<?, ?it/s]

Saved: bach_transformer_2.mid
