In [1]:
!pip install -q miditok symusic tqdm

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
import os
import random
import math
import glob
import json
import numpy as np
from pathlib import Path
from miditok import REMI, TokenizerConfig
from symusic import Score
from tqdm.notebook import tqdm

In [3]:
class Config:
    DATA_DIR = "/workspace/MAESTRO"
    
    FILE_LIMIT = None     
    STRIDE = 256         
    SEQ_LENGTH = 1024     
    
    EMBED_DIM = 512     
    N_HEADS = 8
    N_LAYERS = 8
    CNN_KERNEL = 65       
    DROPOUT = 0.3     
    
    BATCH_SIZE = 16
    ACCUMULATION_STEPS = 2
    LEARNING_RATE = 1e-3  
    EPOCHS = 50           
    PATIENCE = 5          
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config()
print(config.DEVICE)

cuda


In [4]:
def load_split_maestro(path, split_ratio=0.9):
    print(f"Scanning {path}...")
    # 1. Get all filenames first
    files = list(Path(path).glob("**/*.midi")) + list(Path(path).glob("**/*.mid"))
    random.shuffle(files)
    
    # 2. Split the FILES, not the tokens
    split_idx = int(len(files) * split_ratio)
    train_files = files[:split_idx]
    val_files = files[split_idx:]
    
    print(f"Train Files: {len(train_files)} | Val Files: {len(val_files)}")
    
    # 3. Helper to tokenize a specific list of files
    def tokenize_files(file_list, desc):
        tokenizer_config = TokenizerConfig(num_velocities=16, use_chords=True, use_tempos=True)
        tokenizer = REMI(tokenizer_config)
        tokens_list = []
        for f in tqdm(file_list, desc=desc):
            try:
                midi = Score(f)
                toks = tokenizer(midi)
                if len(toks) > 0: tokens_list.extend(toks[0].ids)
            except: pass
        return torch.tensor(tokens_list, dtype=torch.long), tokenizer

    print("Processing Training Set...")
    train_data, tokenizer = tokenize_files(train_files, "Train Tokenization")
    
    print("Processing Validation Set...")
    val_data, _ = tokenize_files(val_files, "Val Tokenization")
    
    return train_data, val_data, tokenizer

train_tensor, val_tensor, tokenizer = load_split_maestro(config.DATA_DIR)
vocab_size = tokenizer.vocab_size

Scanning /workspace/MAESTRO...
Train Files: 1148 | Val Files: 128
Processing Training Set...


Train Tokenization:   0%|          | 0/1148 [00:00<?, ?it/s]

Processing Validation Set...


Val Tokenization:   0%|          | 0/128 [00:00<?, ?it/s]

In [5]:
class MidiDataset(Dataset):
    def __init__(self, data_tensor, seq_len, stride):
        self.seq_len = seq_len
        self.data = data_tensor
        self.indices = range(0, len(data_tensor) - seq_len - 1, stride)

    def __len__(self): return len(self.indices)

    def __getitem__(self, idx):
        start_idx = self.indices[idx]
        # X: Context, Y: Target
        return (self.data[start_idx : start_idx + self.seq_len], 
                self.data[start_idx+1 : start_idx + self.seq_len + 1])

# Split 90/10
train_ds = MidiDataset(train_tensor, config.SEQ_LENGTH, config.STRIDE)
val_ds = MidiDataset(val_tensor, config.SEQ_LENGTH, config.STRIDE)

# Loaders
train_loader = DataLoader(train_ds, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Training Batches per Epoch: {len(train_loader)}")

Training Batches per Epoch: 5681


In [6]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div)
        pe[:, 1::2] = torch.cos(position * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x): return x + self.pe[:, :x.size(1), :]

class StandardBlock(nn.Module):
    def __init__(self, d_model, n_heads, dropout):
        super().__init__()
        
        # 1. Multi-Head Self Attention
        self.attn_norm = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        
        # 2. Feed-Forward Network (Standard 4x expansion)
        self.ffn_norm = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(), # Standard Activation
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask):
        # 1. Attention (Pre-Norm)
        res = x
        x_norm = self.attn_norm(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm, attn_mask=mask, is_causal=True)
        x = res + attn_out
        
        # 2. Feed-Forward (Pre-Norm)
        res = x
        ffn_out = self.ffn(self.ffn_norm(x))
        x = res + ffn_out
        
        return x

class MusicTransformer(nn.Module):
    def __init__(self, vocab_size, cfg):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, cfg.EMBED_DIM)
        self.pos = PositionalEncoding(cfg.EMBED_DIM)
        
        # Standard Blocks only
        self.layers = nn.ModuleList([
            StandardBlock(cfg.EMBED_DIM, cfg.N_HEADS, cfg.DROPOUT) 
            for _ in range(cfg.N_LAYERS)
        ])
        
        self.final_norm = nn.LayerNorm(cfg.EMBED_DIM)
        self.head = nn.Linear(cfg.EMBED_DIM, vocab_size)

    def forward(self, x):
        # Causal Mask
        mask = torch.triu(torch.ones(x.size(1), x.size(1)) * float('-inf'), diagonal=1).to(x.device)
        
        x = self.pos(self.embed(x))
        for layer in self.layers: 
            x = layer(x, mask)
        x = self.final_norm(x)
        return self.head(x)

In [7]:
print("Initializing...")
model = MusicTransformer(vocab_size, config).to(config.DEVICE)

optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=config.LEARNING_RATE,
    steps_per_epoch=len(train_loader) // config.ACCUMULATION_STEPS,
    epochs=config.EPOCHS,
    pct_start=0.1
)

print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")

Initializing...
Model Parameters: 25,541,946


  scaler = GradScaler()


In [8]:
def train_epoch(model, loader, opt, scaler, sched):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    loop = tqdm(loader, desc="Train", leave=False)
    opt.zero_grad(set_to_none=True)
    
    for i, (x, y) in enumerate(loop):
        x, y = x.to(config.DEVICE), y.to(config.DEVICE)
        
        with autocast():
            logits = model(x)
            loss = criterion(logits.view(-1, vocab_size), y.view(-1))
            loss = loss / config.ACCUMULATION_STEPS 
        
        scaler.scale(loss).backward()
        
        predictions = logits.argmax(dim=-1) 
        total_correct += (predictions == y).sum().item()
        total_samples += y.numel()
        
        if (i + 1) % config.ACCUMULATION_STEPS == 0:
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(opt)
            scaler.update()
            sched.step()
            opt.zero_grad(set_to_none=True)
        
        current_loss = loss.item() * config.ACCUMULATION_STEPS
        total_loss += current_loss
        loop.set_postfix(loss=f"{current_loss:.4f}", acc=f"{total_correct/total_samples:.2%}")
        
    return total_loss / len(loader), total_correct / total_samples

def validate(model, loader):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(config.DEVICE), y.to(config.DEVICE)
            with autocast():
                logits = model(x)
                loss = criterion(logits.view(-1, vocab_size), y.view(-1))
            
            predictions = logits.argmax(dim=-1)
            total_correct += (predictions == y).sum().item()
            total_samples += y.numel()
            total_loss += loss.item()
    return total_loss / len(loader), total_correct / total_samples

In [9]:
best_loss = float('inf')
patience_counter = 0

print("Training...")

for epoch in range(config.EPOCHS):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, scaler, scheduler)
    val_loss, val_acc = validate(model, val_loader)
    
    try: ppl = math.exp(val_loss)
    except: ppl = float('inf')
    
    print(
        f"Epoch {epoch+1}/{config.EPOCHS} | "
        f"Train: {train_loss:.4f} ({train_acc:.1%}) | "
        f"Val: {val_loss:.4f} ({val_acc:.1%}) | "
        f"PPL: {ppl:.2f}"
    )
    
    if val_loss < best_loss:
        best_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), "best_transformer_model.pth")
        print("Saved New Best Model")
    else:
        patience_counter += 1
        print(f"No improvement ({patience_counter}/{config.PATIENCE})")
        if patience_counter >= config.PATIENCE:
            print("Early stopping triggered.")
            break

Training...


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

  with autocast():
  with autocast():


Epoch 1/50 | Train: 2.4227 (30.8%) | Val: 2.0529 (36.4%) | PPL: 7.79
Saved New Best Model


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 2/50 | Train: 1.8998 (40.2%) | Val: 1.6703 (45.9%) | PPL: 5.31
Saved New Best Model


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 3/50 | Train: 1.6481 (46.3%) | Val: 1.5318 (49.3%) | PPL: 4.63
Saved New Best Model


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 4/50 | Train: 1.5225 (49.3%) | Val: 1.4624 (50.9%) | PPL: 4.32
Saved New Best Model


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 5/50 | Train: 1.4395 (51.3%) | Val: 1.4122 (52.2%) | PPL: 4.10
Saved New Best Model


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 6/50 | Train: 1.3753 (52.9%) | Val: 1.3814 (53.0%) | PPL: 3.98
Saved New Best Model


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 7/50 | Train: 1.3271 (54.1%) | Val: 1.3660 (53.4%) | PPL: 3.92
Saved New Best Model


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 8/50 | Train: 1.2901 (55.0%) | Val: 1.3541 (53.8%) | PPL: 3.87
Saved New Best Model


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 9/50 | Train: 1.2602 (55.8%) | Val: 1.3495 (54.0%) | PPL: 3.86
Saved New Best Model


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 10/50 | Train: 1.2358 (56.4%) | Val: 1.3433 (54.1%) | PPL: 3.83
Saved New Best Model


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 11/50 | Train: 1.2151 (57.0%) | Val: 1.3395 (54.3%) | PPL: 3.82
Saved New Best Model


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 12/50 | Train: 1.1969 (57.4%) | Val: 1.3386 (54.4%) | PPL: 3.81
Saved New Best Model


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 13/50 | Train: 1.1809 (57.9%) | Val: 1.3371 (54.5%) | PPL: 3.81
Saved New Best Model


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 14/50 | Train: 1.1667 (58.2%) | Val: 1.3371 (54.5%) | PPL: 3.81
Saved New Best Model


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 15/50 | Train: 1.1536 (58.6%) | Val: 1.3397 (54.5%) | PPL: 3.82
No improvement (1/5)


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 16/50 | Train: 1.1417 (58.9%) | Val: 1.3421 (54.5%) | PPL: 3.83
No improvement (2/5)


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 17/50 | Train: 1.1304 (59.2%) | Val: 1.3364 (54.6%) | PPL: 3.81
Saved New Best Model


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 18/50 | Train: 1.1200 (59.5%) | Val: 1.3502 (54.5%) | PPL: 3.86
No improvement (1/5)


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 19/50 | Train: 1.1099 (59.8%) | Val: 1.3501 (54.5%) | PPL: 3.86
No improvement (2/5)


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 20/50 | Train: 1.1004 (60.1%) | Val: 1.3494 (54.6%) | PPL: 3.86
No improvement (3/5)


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 21/50 | Train: 1.0912 (60.3%) | Val: 1.3519 (54.6%) | PPL: 3.86
No improvement (4/5)


Train:   0%|          | 0/5681 [00:00<?, ?it/s]

Epoch 22/50 | Train: 1.0823 (60.6%) | Val: 1.3639 (54.5%) | PPL: 3.91
No improvement (5/5)
Early stopping triggered.


In [11]:
from miditok import TokSequence
import torch.nn.functional as F
from datetime import datetime
import numpy as np

# 1. Configuration
GEN_CONF = {
    'num_examples': 5,
    'max_len': 1024,       
    'temperature': 0.8,
    'top_k': 20
}

# 2. Generation Function
def generate_sequence(model, tokenizer, length, temp=1.0, top_k=20):
    model.eval()
    # Seed with a random token from the vocab
    start_token = torch.randint(0, tokenizer.vocab_size, (1, 1)).to(config.DEVICE)
    generated = start_token
    
    print(f"Generating... (Target: {length} tokens)")
    
    with torch.no_grad():
        for _ in tqdm(range(length), desc="Gen"):
            logits = model(generated)
            # Focus on the last token
            last_logits = logits[:, -1, :] / temp
            
            # Top-K Sampling (Filters out bad notes)
            v, _ = torch.topk(last_logits, top_k)
            last_logits[last_logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(last_logits, dim=-1)
            
            # Sample
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat((generated, next_token), dim=1)
            
    return generated[0].cpu().numpy()

# 3. Main Loop
print("Starting Generation Loop...")

for i in range(GEN_CONF['num_examples']):
    try:
        # A. Generate
        raw_tokens = generate_sequence(
            model, 
            tokenizer, 
            GEN_CONF['max_len'], 
            GEN_CONF['temperature'], 
            GEN_CONF['top_k']
        )
        
        # B. Convert to Python List
        token_list = raw_tokens.tolist()
        
        # C. Wrap in TokSequence
        seq = TokSequence(ids=token_list)
        
        # D. Decode and Save
        generated_score = tokenizer.decode([seq])
        
        timestamp = datetime.now().strftime("%H%M%S")
        filename = f"transformer_0.3_{i+1}_{timestamp}.mid"
        generated_score.dump_midi(filename)
        print(f"Saved: {filename}")
        
    except Exception as e:
        print(f"Error on file {i+1}: {e}")
        if 'token_list' in locals():
            print(f"First 10 tokens: {token_list[:10]}")

Starting Generation Loop...
Generating... (Target: 1024 tokens)


Gen:   0%|          | 0/1024 [00:00<?, ?it/s]

Saved: transformer_0.3_1_021639.mid
Generating... (Target: 1024 tokens)


Gen:   0%|          | 0/1024 [00:00<?, ?it/s]

Saved: transformer_0.3_2_021642.mid
Generating... (Target: 1024 tokens)


Gen:   0%|          | 0/1024 [00:00<?, ?it/s]

Saved: transformer_0.3_3_021646.mid
Generating... (Target: 1024 tokens)


Gen:   0%|          | 0/1024 [00:00<?, ?it/s]

Saved: transformer_0.3_4_021649.mid
Generating... (Target: 1024 tokens)


Gen:   0%|          | 0/1024 [00:00<?, ?it/s]

Saved: transformer_0.3_5_021653.mid
