In [1]:
!pip install -q miditok symusic tqdm

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
import os
import random
import math
import glob
import json
import numpy as np
from pathlib import Path
from miditok import REMI, TokenizerConfig
from symusic import Score
from tqdm.notebook import tqdm

In [3]:
class Config:
    DATA_DIR = "/workspace/POP909"
    
    FILE_LIMIT = None     
    STRIDE = 256         
    SEQ_LENGTH = 1024     
    
    EMBED_DIM = 256    
    N_HEADS = 8
    N_LAYERS = 6
    CNN_KERNEL = 31       
    DROPOUT = 0.3      
    
    BATCH_SIZE = 16
    ACCUMULATION_STEPS = 2
    LEARNING_RATE = 5e-4 
    EPOCHS = 50           
    PATIENCE = 5          
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config()
print(config.DEVICE)

cuda


In [4]:
def load_pop909(path, split_ratio=0.9):
    print(f"Scanning {path}...")
    
    # 1. File Finding
    files = list(Path(path).glob("**/*.mid")) + list(Path(path).glob("**/*.midi"))
    random.shuffle(files)
    
    # 2. THE FIX: Quantization Config
    tokenizer_config = TokenizerConfig(
        num_velocities=16, 
        use_chords=True, 
        use_tempos=True,
        use_programs=False,
        beat_res={(0, 4): 8, (4, 12): 4}
    )
    tokenizer = REMI(tokenizer_config)
    
    # 3. Helper to Tokenize
    def tokenize_batch(file_list, desc):
        tokens_list = []
        for f in tqdm(file_list, desc=desc):
            try:
                midi = Score(f)
                midi = midi.resample(tpq=960) 
                
                toks = tokenizer(midi)
                if len(toks) > 0: tokens_list.extend(toks[0].ids)
            except: 
                pass 
        return torch.tensor(tokens_list, dtype=torch.long)

    # 4. Split & Process
    split_idx = int(len(files) * split_ratio)
    train_files = files[:split_idx]
    val_files = files[split_idx:]
    
    print(f"Found {len(files)} files.")
    print("Processing Training Data...")
    train_data = tokenize_batch(train_files, "Tokenizing Train")
    
    print("Processing Validation Data...")
    val_data = tokenize_batch(val_files, "Tokenizing Val")
    
    return train_data, val_data, tokenizer

train_tensor, val_tensor, tokenizer = load_pop909(config.DATA_DIR)
vocab_size = tokenizer.vocab_size
print(f"New Vocab Size: {vocab_size}")

Scanning /workspace/POP909...
Found 909 files.
Processing Training Data...


Tokenizing Train:   0%|          | 0/818 [00:00<?, ?it/s]

Processing Validation Data...


Tokenizing Val:   0%|          | 0/91 [00:00<?, ?it/s]

New Vocab Size: 314


In [5]:
class AugmentedMidiDataset(Dataset):
    def __init__(self, data_tensor, seq_len, stride):
        self.seq_len = seq_len
        self.data = data_tensor
        self.indices = range(0, len(data_tensor) - seq_len - 1, stride)

    def __len__(self): return len(self.indices)

    def __getitem__(self, idx):
        start_idx = self.indices[idx]
        seq = self.data[start_idx : start_idx + self.seq_len + 1].clone()

        # PITCH AUGMENTATION
        shift = random.randint(-5, 6) 
        if shift != 0:
            mask = (seq >= 21) & (seq <= 108) 
            seq[mask] = torch.clamp(seq[mask] + shift, 21, 108)

        return seq[:-1], seq[1:]

train_ds = AugmentedMidiDataset(train_tensor, config.SEQ_LENGTH, config.STRIDE)
val_ds = AugmentedMidiDataset(val_tensor, config.SEQ_LENGTH, config.STRIDE)

train_loader = DataLoader(train_ds, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Batches per Epoch: {len(train_loader)}")

Batches per Epoch: 290


In [6]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div)
        pe[:, 1::2] = torch.cos(position * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x): return x + self.pe[:, :x.size(1), :]

class CausalMacaronBlock(nn.Module):
    def __init__(self, d_model, n_heads, cnn_k, dropout):
        super().__init__()
        self.cnn_k = cnn_k
        
        # 1. First Feed-Forward
        self.ff1_norm = nn.LayerNorm(d_model)
        self.ff1 = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.SiLU(), 
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
        
        # 2. Multi-Head Self Attention
        self.attn_norm = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        
        # 3. Causal Convolution Module
        self.conv_norm = nn.LayerNorm(d_model)
        
        # Pointwise 1
        self.conv_pointwise1 = nn.Conv1d(d_model, d_model * 2, 1)
        self.glu = nn.GLU(dim=1)
        
        # Depthwise Conv (NO PADDING defined here; we do it manually)
        self.conv_depthwise = nn.Conv1d(d_model, d_model, cnn_k, padding=0, groups=d_model)
        
        # Norm & Pointwise 2
        self.conv_batchnorm = nn.BatchNorm1d(d_model)
        self.conv_act = nn.SiLU()
        self.conv_pointwise2 = nn.Conv1d(d_model, d_model, 1)
        self.conv_dropout = nn.Dropout(dropout)
        
        # 4. Second Feed-Forward
        self.ff2_norm = nn.LayerNorm(d_model)
        self.ff2 = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.SiLU(),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
        
        self.final_norm = nn.LayerNorm(d_model)

    def forward(self, x, mask):
        # 1. First FFN
        x = x + 0.5 * self.ff1(self.ff1_norm(x))
        
        # 2. Attention
        res = x
        x_norm = self.attn_norm(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm, attn_mask=mask, is_causal=True)
        x = res + attn_out
        
        # 3. Causal Convolution
        res = x
        x_cnn = self.conv_norm(x).permute(0, 2, 1) # [B, Dim, Seq]
        
        # Expansion
        x_cnn = self.conv_pointwise1(x_cnn)
        x_cnn = self.glu(x_cnn)
        
        x_cnn = F.pad(x_cnn, (self.cnn_k - 1, 0))
        
        x_cnn = self.conv_depthwise(x_cnn)
        x_cnn = self.conv_batchnorm(x_cnn)
        x_cnn = self.conv_act(x_cnn)
        x_cnn = self.conv_pointwise2(x_cnn)
        x_cnn = self.conv_dropout(x_cnn)
        
        x = res + x_cnn.permute(0, 2, 1)
        
        # 4. Second FFN
        x = x + 0.5 * self.ff2(self.ff2_norm(x))
        
        return self.final_norm(x)

class MusicConformer(nn.Module):
    def __init__(self, vocab_size, cfg):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, cfg.EMBED_DIM)
        self.pos = PositionalEncoding(cfg.EMBED_DIM)
        self.layers = nn.ModuleList([
            CausalMacaronBlock(cfg.EMBED_DIM, cfg.N_HEADS, cfg.CNN_KERNEL, cfg.DROPOUT) 
            for _ in range(cfg.N_LAYERS)
        ])
        self.head = nn.Linear(cfg.EMBED_DIM, vocab_size)

    def forward(self, x):
        mask = torch.triu(torch.ones(x.size(1), x.size(1)) * float('-inf'), diagonal=1).to(x.device)
        x = self.pos(self.embed(x))
        for layer in self.layers: 
            x = layer(x, mask)
        return self.head(x)

model = MusicConformer(vocab_size, config).to(config.DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=config.LEARNING_RATE,
    steps_per_epoch=len(train_loader) // config.ACCUMULATION_STEPS,
    epochs=config.EPOCHS,
    pct_start=0.1
)

print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")

Model Parameters: 9,298,746


  scaler = GradScaler()


In [7]:
def train_epoch(model, loader, opt, scaler, sched):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    loop = tqdm(loader, desc="Train", leave=False)
    opt.zero_grad(set_to_none=True)
    
    for i, (x, y) in enumerate(loop):
        x, y = x.to(config.DEVICE), y.to(config.DEVICE)
        
        with autocast():
            logits = model(x)
            loss = criterion(logits.view(-1, vocab_size), y.view(-1))
            loss = loss / config.ACCUMULATION_STEPS 
        
        scaler.scale(loss).backward()
        
        # Accuracy
        predictions = logits.argmax(dim=-1) 
        total_correct += (predictions == y).sum().item()
        total_samples += y.numel()
        
        # Step only every N batches
        if (i + 1) % config.ACCUMULATION_STEPS == 0:
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(opt)
            scaler.update()
            sched.step()
            opt.zero_grad(set_to_none=True)
        
        # Metrics for display
        current_loss = loss.item() * config.ACCUMULATION_STEPS
        total_loss += current_loss
        loop.set_postfix(loss=f"{current_loss:.4f}", acc=f"{total_correct/total_samples:.2%}")
        
    return total_loss / len(loader), total_correct / total_samples

def validate(model, loader):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(config.DEVICE), y.to(config.DEVICE)
            with autocast():
                logits = model(x)
                loss = criterion(logits.view(-1, vocab_size), y.view(-1))
            
            predictions = logits.argmax(dim=-1)
            total_correct += (predictions == y).sum().item()
            total_samples += y.numel()
            total_loss += loss.item()
    return total_loss / len(loader), total_correct / total_samples

In [8]:
best_loss = float('inf')
patience_counter = 0

print("Training...")

for epoch in range(config.EPOCHS):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, scaler, scheduler)
    val_loss, val_acc = validate(model, val_loader)
    
    try: ppl = math.exp(val_loss)
    except: ppl = float('inf')
    
    print(
        f"Epoch {epoch+1}/{config.EPOCHS} | "
        f"Train: {train_loss:.4f} ({train_acc:.1%}) | "
        f"Val: {val_loss:.4f} ({val_acc:.1%}) | "
        f"PPL: {ppl:.2f}"
    )
    
    if val_loss < best_loss:
        best_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), "best_conformer_model.pth")
        print("Saved New Best Model")
    else:
        patience_counter += 1
        print(f"No improvement ({patience_counter}/{config.PATIENCE})")
        if patience_counter >= config.PATIENCE:
            print("Early stopping triggered.")
            break

Training...


Train:   0%|          | 0/290 [00:00<?, ?it/s]

  with autocast():
  with autocast():


Epoch 1/50 | Train: 4.4029 (15.7%) | Val: 3.1508 (27.0%) | PPL: 23.35
Saved New Best Model


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 2/50 | Train: 2.5157 (34.1%) | Val: 1.9574 (44.4%) | PPL: 7.08
Saved New Best Model


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 3/50 | Train: 1.6775 (48.3%) | Val: 1.5008 (52.3%) | PPL: 4.49
Saved New Best Model


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 4/50 | Train: 1.4288 (53.2%) | Val: 1.3982 (54.5%) | PPL: 4.05
Saved New Best Model


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 5/50 | Train: 1.3366 (55.3%) | Val: 1.3428 (55.4%) | PPL: 3.83
Saved New Best Model


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 6/50 | Train: 1.2831 (56.6%) | Val: 1.3029 (56.5%) | PPL: 3.68
Saved New Best Model


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 7/50 | Train: 1.2365 (57.9%) | Val: 1.2645 (57.6%) | PPL: 3.54
Saved New Best Model


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 8/50 | Train: 1.1869 (59.6%) | Val: 1.1949 (60.5%) | PPL: 3.30
Saved New Best Model


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 9/50 | Train: 1.1081 (63.2%) | Val: 1.0671 (66.1%) | PPL: 2.91
Saved New Best Model


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 10/50 | Train: 1.0067 (67.5%) | Val: 0.9735 (70.0%) | PPL: 2.65
Saved New Best Model


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 11/50 | Train: 0.9331 (70.3%) | Val: 0.9329 (71.4%) | PPL: 2.54
Saved New Best Model


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 12/50 | Train: 0.8864 (71.7%) | Val: 0.8863 (72.5%) | PPL: 2.43
Saved New Best Model


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 13/50 | Train: 0.8491 (72.7%) | Val: 0.8769 (72.7%) | PPL: 2.40
Saved New Best Model


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 14/50 | Train: 0.8251 (73.3%) | Val: 0.8619 (73.1%) | PPL: 2.37
Saved New Best Model


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 15/50 | Train: 0.8024 (74.0%) | Val: 0.8615 (73.2%) | PPL: 2.37
Saved New Best Model


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 16/50 | Train: 0.7872 (74.3%) | Val: 0.8494 (73.5%) | PPL: 2.34
Saved New Best Model


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 17/50 | Train: 0.7708 (74.8%) | Val: 0.8608 (73.5%) | PPL: 2.36
No improvement (1/5)


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 18/50 | Train: 0.7588 (75.1%) | Val: 0.8600 (73.7%) | PPL: 2.36
No improvement (2/5)


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 19/50 | Train: 0.7470 (75.4%) | Val: 0.8575 (73.5%) | PPL: 2.36
No improvement (3/5)


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 20/50 | Train: 0.7328 (75.8%) | Val: 0.8641 (73.4%) | PPL: 2.37
No improvement (4/5)


Train:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 21/50 | Train: 0.7231 (76.0%) | Val: 0.8699 (73.6%) | PPL: 2.39
No improvement (5/5)
Early stopping triggered.
