In [1]:
!pip install -q miditok symusic tqdm

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
import os
import random
import math
import glob
import json
import numpy as np
from pathlib import Path
from miditok import REMI, TokenizerConfig
from symusic import Score
from tqdm.notebook import tqdm

In [3]:
class Config:
    DATA_DIR = "/workspace/POP909"
    
    FILE_LIMIT = None     
    STRIDE = 256         
    SEQ_LENGTH = 1024     
    
    EMBED_DIM = 256    
    N_HEADS = 8
    N_LAYERS = 6
    CNN_KERNEL = 31       
    DROPOUT = 0.3      
    
    BATCH_SIZE = 16
    ACCUMULATION_STEPS = 2
    LEARNING_RATE = 5e-4 
    EPOCHS = 50           
    PATIENCE = 5          
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config()
print(config.DEVICE)

cuda


In [5]:
def load_pop909(path, split_ratio=0.9):
    print(f"Scanning {path}...")
    
    # 1. File Finding
    files = list(Path(path).glob("**/*.mid")) + list(Path(path).glob("**/*.midi"))
    random.shuffle(files)
    
    # 2. THE FIX: Quantization Config
    tokenizer_config = TokenizerConfig(
        num_velocities=16, 
        use_chords=True, 
        use_tempos=True,
        use_programs=False,
        beat_res={(0, 4): 8, (4, 12): 4}
    )
    tokenizer = REMI(tokenizer_config)
    
    # 3. Helper to Tokenize
    def tokenize_batch(file_list, desc):
        tokens_list = []
        for f in tqdm(file_list, desc=desc):
            try:
                midi = Score(f)
                midi = midi.resample(tpq=960) 
                
                toks = tokenizer(midi)
                if len(toks) > 0: tokens_list.extend(toks[0].ids)
            except: 
                pass 
        return torch.tensor(tokens_list, dtype=torch.long)

    # 4. Split & Process
    split_idx = int(len(files) * split_ratio)
    train_files = files[:split_idx]
    val_files = files[split_idx:]
    
    print(f"Found {len(files)} files.")
    print("Processing Training Data...")
    train_data = tokenize_batch(train_files, "Tokenizing Train")
    
    print("Processing Validation Data...")
    val_data = tokenize_batch(val_files, "Tokenizing Val")
    
    return train_data, val_data, tokenizer

train_tensor, val_tensor, tokenizer = load_pop909(config.DATA_DIR)
vocab_size = tokenizer.vocab_size
print(f"New Vocab Size: {vocab_size}")

Scanning /workspace/POP909...
Found 909 files.
Processing Training Data...


Tokenizing Train:   0%|          | 0/818 [00:00<?, ?it/s]

Processing Validation Data...


Tokenizing Val:   0%|          | 0/91 [00:00<?, ?it/s]

New Vocab Size: 314


In [6]:
class MidiDataset(Dataset):
    def __init__(self, data_tensor, seq_len, stride):
        self.seq_len = seq_len
        self.data = data_tensor
        self.indices = range(0, len(data_tensor) - seq_len - 1, stride)
    def __len__(self): return len(self.indices)
    def __getitem__(self, idx):
        start_idx = self.indices[idx]
        return (self.data[start_idx : start_idx + self.seq_len], 
                self.data[start_idx+1 : start_idx + self.seq_len + 1])

train_ds = MidiDataset(train_tensor, config.SEQ_LENGTH, config.STRIDE)
val_ds = MidiDataset(val_tensor, config.SEQ_LENGTH, config.STRIDE)

train_loader = DataLoader(train_ds, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=2)

In [7]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div)
        pe[:, 1::2] = torch.cos(position * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x): return x + self.pe[:, :x.size(1), :]

class StandardBlock(nn.Module):
    def __init__(self, d_model, n_heads, dropout):
        super().__init__()
        
        # 1. Multi-Head Self Attention
        self.attn_norm = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        
        # 2. Feed-Forward Network (Standard 4x expansion)
        self.ffn_norm = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(), # Standard Activation
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask):
        # 1. Attention (Pre-Norm)
        res = x
        x_norm = self.attn_norm(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm, attn_mask=mask, is_causal=True)
        x = res + attn_out
        
        # 2. Feed-Forward (Pre-Norm)
        res = x
        ffn_out = self.ffn(self.ffn_norm(x))
        x = res + ffn_out
        
        return x

class MusicTransformer(nn.Module):
    def __init__(self, vocab_size, cfg):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, cfg.EMBED_DIM)
        self.pos = PositionalEncoding(cfg.EMBED_DIM)
        
        # Standard Blocks only
        self.layers = nn.ModuleList([
            StandardBlock(cfg.EMBED_DIM, cfg.N_HEADS, cfg.DROPOUT) 
            for _ in range(cfg.N_LAYERS)
        ])
        
        self.final_norm = nn.LayerNorm(cfg.EMBED_DIM)
        self.head = nn.Linear(cfg.EMBED_DIM, vocab_size)

    def forward(self, x):
        # Causal Mask
        mask = torch.triu(torch.ones(x.size(1), x.size(1)) * float('-inf'), diagonal=1).to(x.device)
        
        x = self.pos(self.embed(x))
        for layer in self.layers: 
            x = layer(x, mask)
        x = self.final_norm(x)
        return self.head(x)

In [8]:
print("Initializing...")
model = MusicTransformer(vocab_size, config).to(config.DEVICE)

optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=config.LEARNING_RATE,
    steps_per_epoch=len(train_loader) // config.ACCUMULATION_STEPS,
    epochs=config.EPOCHS,
    pct_start=0.1
)

print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")

Initializing...
Model Parameters: 4,900,154


  scaler = GradScaler()


In [9]:
def train_epoch(model, loader, opt, scaler, sched):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    loop = tqdm(loader, desc="Train", leave=False)
    opt.zero_grad(set_to_none=True)
    
    for i, (x, y) in enumerate(loop):
        x, y = x.to(config.DEVICE), y.to(config.DEVICE)
        
        with autocast():
            logits = model(x)
            loss = criterion(logits.view(-1, vocab_size), y.view(-1))
            loss = loss / config.ACCUMULATION_STEPS 
        
        scaler.scale(loss).backward()
        
        predictions = logits.argmax(dim=-1) 
        total_correct += (predictions == y).sum().item()
        total_samples += y.numel()
        
        if (i + 1) % config.ACCUMULATION_STEPS == 0:
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(opt)
            scaler.update()
            sched.step()
            opt.zero_grad(set_to_none=True)
        
        current_loss = loss.item() * config.ACCUMULATION_STEPS
        total_loss += current_loss
        loop.set_postfix(loss=f"{current_loss:.4f}", acc=f"{total_correct/total_samples:.2%}")
        
    return total_loss / len(loader), total_correct / total_samples

def validate(model, loader):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(config.DEVICE), y.to(config.DEVICE)
            with autocast():
                logits = model(x)
                loss = criterion(logits.view(-1, vocab_size), y.view(-1))
            
            predictions = logits.argmax(dim=-1)
            total_correct += (predictions == y).sum().item()
            total_samples += y.numel()
            total_loss += loss.item()
    return total_loss / len(loader), total_correct / total_samples

In [11]:
best_loss = float('inf')
patience_counter = 0

print("Training...")

for epoch in range(config.EPOCHS):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, scaler, scheduler)
    val_loss, val_acc = validate(model, val_loader)
    
    try: ppl = math.exp(val_loss)
    except: ppl = float('inf')
    
    print(
        f"Epoch {epoch+1}/{config.EPOCHS} | "
        f"Train: {train_loss:.4f} ({train_acc:.1%}) | "
        f"Val: {val_loss:.4f} ({val_acc:.1%}) | "
        f"PPL: {ppl:.2f}"
    )
    
    if val_loss < best_loss:
        best_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), "best_transformer_model.pth")
        print("Saved New Best Model")
    else:
        patience_counter += 1
        print(f"No improvement ({patience_counter}/{config.PATIENCE})")
        if patience_counter >= config.PATIENCE:
            print("Early stopping triggered.")
            break

Training...


Train:   0%|          | 0/291 [00:00<?, ?it/s]

  with autocast():
  with autocast():


Epoch 1/50 | Train: 4.4755 (12.5%) | Val: 3.2751 (21.2%) | PPL: 26.44
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 2/50 | Train: 2.7188 (26.0%) | Val: 2.3046 (30.0%) | PPL: 10.02
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 3/50 | Train: 2.0842 (33.1%) | Val: 1.9286 (36.6%) | PPL: 6.88
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 4/50 | Train: 1.7456 (41.9%) | Val: 1.5288 (49.0%) | PPL: 4.61
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 5/50 | Train: 1.4798 (49.9%) | Val: 1.4334 (51.5%) | PPL: 4.19
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 6/50 | Train: 1.3939 (52.3%) | Val: 1.3972 (52.5%) | PPL: 4.04
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 7/50 | Train: 1.3485 (53.6%) | Val: 1.3668 (53.3%) | PPL: 3.92
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 8/50 | Train: 1.2978 (55.1%) | Val: 1.3223 (54.7%) | PPL: 3.75
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 9/50 | Train: 1.2573 (56.4%) | Val: 1.3032 (55.2%) | PPL: 3.68
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 10/50 | Train: 1.2243 (57.5%) | Val: 1.2836 (56.0%) | PPL: 3.61
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 11/50 | Train: 1.1876 (58.9%) | Val: 1.2456 (58.0%) | PPL: 3.47
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 12/50 | Train: 1.1426 (60.9%) | Val: 1.1790 (61.1%) | PPL: 3.25
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 13/50 | Train: 1.0873 (63.3%) | Val: 1.1003 (64.9%) | PPL: 3.01
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 14/50 | Train: 1.0296 (65.8%) | Val: 1.0446 (67.1%) | PPL: 2.84
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 15/50 | Train: 0.9794 (67.7%) | Val: 1.0081 (68.5%) | PPL: 2.74
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 16/50 | Train: 0.9370 (69.2%) | Val: 0.9869 (69.3%) | PPL: 2.68
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 17/50 | Train: 0.9030 (70.3%) | Val: 0.9801 (69.6%) | PPL: 2.66
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 18/50 | Train: 0.8733 (71.2%) | Val: 0.9751 (69.9%) | PPL: 2.65
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 19/50 | Train: 0.8496 (71.9%) | Val: 0.9704 (70.0%) | PPL: 2.64
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 20/50 | Train: 0.8264 (72.6%) | Val: 0.9687 (70.1%) | PPL: 2.63
Saved New Best Model


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 21/50 | Train: 0.8065 (73.2%) | Val: 0.9756 (70.1%) | PPL: 2.65
No improvement (1/5)


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 22/50 | Train: 0.7879 (73.7%) | Val: 0.9795 (70.1%) | PPL: 2.66
No improvement (2/5)


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 23/50 | Train: 0.7697 (74.3%) | Val: 0.9857 (70.0%) | PPL: 2.68
No improvement (3/5)


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 24/50 | Train: 0.7526 (74.8%) | Val: 0.9951 (70.0%) | PPL: 2.70
No improvement (4/5)


Train:   0%|          | 0/291 [00:00<?, ?it/s]

Epoch 25/50 | Train: 0.7364 (75.3%) | Val: 1.0126 (69.9%) | PPL: 2.75
No improvement (5/5)
Early stopping triggered.


In [12]:
from miditok import TokSequence
import torch.nn.functional as F
from datetime import datetime
import numpy as np

GEN_CONF = {
    'num_examples': 5,
    'max_len': 64,     
    'temperature': 1.1,    
    'top_k': 40            
}


def generate_sequence(model, tokenizer, length, temp=1.0, top_k=20):
    model.eval()
    
    # 1. Start with a random valid token to "seed" the song
    start_token = torch.randint(0, tokenizer.vocab_size, (1, 1)).to(config.DEVICE)
    generated = start_token
    
    print(f"Generating... (Target: {length} tokens)")
    
    with torch.no_grad():
        for _ in tqdm(range(length), desc="Gen"):
            # Forward pass
            logits = model(generated)
            
            # Focus on the last token's prediction
            last_logits = logits[:, -1, :] / temp
            
            # Top-K Filtering (Keep only the likely notes)
            v, _ = torch.topk(last_logits, top_k)
            last_logits[last_logits < v[:, [-1]]] = -float('Inf')
            
            # Probability Distribution
            probs = F.softmax(last_logits, dim=-1)
            
            # Sample next token
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append to sequence
            generated = torch.cat((generated, next_token), dim=1)
            
    return generated[0].cpu().numpy()

for i in range(GEN_CONF['num_examples']):
    try:
        # A. Generate Tokens
        raw_tokens = generate_sequence(
            model, 
            tokenizer, 
            GEN_CONF['max_len'], 
            GEN_CONF['temperature'], 
            GEN_CONF['top_k']
        )
        
        # B. Convert to Standard List
        token_list = raw_tokens.tolist()
        
        # C. Wrap in TokSequence (Required for MidiTok)
        seq = TokSequence(ids=token_list)
        
        # D. Decode & Save
        generated_score = tokenizer.decode([seq])
        
        timestamp = datetime.now().strftime("%H%M%S")
        filename = f"pop909_gen_{i+1}_{timestamp}.mid"
        generated_score.dump_midi(filename)
        print(f"Saved: {filename}")
        
    except Exception as e:
        print(f"Error generating file {i+1}: {e}")

Generating... (Target: 64 tokens)


Gen:   0%|          | 0/64 [00:00<?, ?it/s]

Saved: pop909_gen_1_154214.mid
Generating... (Target: 64 tokens)


Gen:   0%|          | 0/64 [00:00<?, ?it/s]

Saved: pop909_gen_2_154215.mid
Generating... (Target: 64 tokens)


Gen:   0%|          | 0/64 [00:00<?, ?it/s]

Saved: pop909_gen_3_154217.mid
Generating... (Target: 64 tokens)


Gen:   0%|          | 0/64 [00:00<?, ?it/s]

Saved: pop909_gen_4_154220.mid
Generating... (Target: 64 tokens)


Gen:   0%|          | 0/64 [00:00<?, ?it/s]

Saved: pop909_gen_5_154222.mid
