In [1]:
!pip install -q miditok symusic tqdm

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
import os
import random
import math
import glob
import json
import numpy as np
from pathlib import Path
from miditok import REMI, TokenizerConfig
from symusic import Score
from tqdm.notebook import tqdm

In [3]:
print(f"GPUs available: {torch.cuda.device_count()}")

GPUs available: 2


In [4]:
class Config:
    DATA_DIR = "/workspace/MAESTRO"
    
    FILE_LIMIT = None     
    STRIDE = 256         
    SEQ_LENGTH = 1024     
    
    EMBED_DIM = 512     
    N_HEADS = 8
    N_LAYERS = 8
    CNN_KERNEL = 65       
    DROPOUT = 0.3      
    
    BATCH_SIZE = 64
    ACCUMULATION_STEPS = 2
    LEARNING_RATE = 1e-3  
    EPOCHS = 50           
    PATIENCE = 5          
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config()
print(config.DEVICE)

cuda


In [5]:
def load_split_maestro(path, split_ratio=0.9):
    print(f"Scanning {path}...")
    # 1. Get all filenames first
    files = list(Path(path).glob("**/*.midi")) + list(Path(path).glob("**/*.mid"))
    random.shuffle(files)
    
    # 2. Split the FILES, not the tokens
    split_idx = int(len(files) * split_ratio)
    train_files = files[:split_idx]
    val_files = files[split_idx:]
    
    print(f"Train Files: {len(train_files)} | Val Files: {len(val_files)}")
    
    # 3. Helper to tokenize a specific list of files
    def tokenize_files(file_list, desc):
        tokenizer_config = TokenizerConfig(num_velocities=16, use_chords=True, use_tempos=True)
        tokenizer = REMI(tokenizer_config)
        tokens_list = []
        for f in tqdm(file_list, desc=desc):
            try:
                midi = Score(f)
                toks = tokenizer(midi)
                if len(toks) > 0: tokens_list.extend(toks[0].ids)
            except: pass
        return torch.tensor(tokens_list, dtype=torch.long), tokenizer

    print("Processing Training Set...")
    train_data, tokenizer = tokenize_files(train_files, "Train Tokenization")
    
    print("Processing Validation Set...")
    val_data, _ = tokenize_files(val_files, "Val Tokenization")
    
    return train_data, val_data, tokenizer

train_tensor, val_tensor, tokenizer = load_split_maestro(config.DATA_DIR)
vocab_size = tokenizer.vocab_size

Scanning /workspace/MAESTRO...
Train Files: 1148 | Val Files: 128
Processing Training Set...


Train Tokenization:   0%|          | 0/1148 [00:00<?, ?it/s]

Processing Validation Set...


Val Tokenization:   0%|          | 0/128 [00:00<?, ?it/s]

In [6]:
class MidiDataset(Dataset):
    def __init__(self, data_tensor, seq_len, stride):
        self.seq_len = seq_len
        self.data = data_tensor
        self.indices = range(0, len(data_tensor) - seq_len - 1, stride)

    def __len__(self): return len(self.indices)

    def __getitem__(self, idx):
        start_idx = self.indices[idx]
        # X: Context, Y: Target
        return (self.data[start_idx : start_idx + self.seq_len], 
                self.data[start_idx+1 : start_idx + self.seq_len + 1])

train_ds = MidiDataset(train_tensor, config.SEQ_LENGTH, config.STRIDE)
val_ds = MidiDataset(val_tensor, config.SEQ_LENGTH, config.STRIDE)

# Loaders

import multiprocessing
num_cpu = multiprocessing.cpu_count()
workers = min(24, num_cpu)
print(f"Using {workers} workers for data loading.")

train_loader = DataLoader(train_ds, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=workers,pin_memory=True,persistent_workers=True)
val_loader = DataLoader(val_ds, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=workers,pin_memory=True,persistent_workers=True)

print(f"Training Batches per Epoch: {len(train_loader)}")

Using 24 workers for data loading.
Training Batches per Epoch: 1413


In [7]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div)
        pe[:, 1::2] = torch.cos(position * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x): return x + self.pe[:, :x.size(1), :]

In [8]:
class CausalMacaronBlock(nn.Module):
    def __init__(self, d_model, n_heads, cnn_k, dropout):
        super().__init__()
        self.cnn_k = cnn_k
        
        # 1. First Feed-Forward
        self.ff1_norm = nn.LayerNorm(d_model)
        self.ff1 = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.SiLU(), 
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
        
        # 2. Multi-Head Self Attention
        self.attn_norm = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        
        # 3. Causal Convolution Module
        self.conv_norm = nn.LayerNorm(d_model)
        
        # Pointwise 1
        self.conv_pointwise1 = nn.Conv1d(d_model, d_model * 2, 1)
        self.glu = nn.GLU(dim=1)
        
        # Depthwise Conv (NO PADDING defined here; we do it manually)
        self.conv_depthwise = nn.Conv1d(d_model, d_model, cnn_k, padding=0, groups=d_model)
        
        # Norm & Pointwise 2
        self.conv_batchnorm = nn.BatchNorm1d(d_model)
        self.conv_act = nn.SiLU()
        self.conv_pointwise2 = nn.Conv1d(d_model, d_model, 1)
        self.conv_dropout = nn.Dropout(dropout)
        
        # 4. Second Feed-Forward
        self.ff2_norm = nn.LayerNorm(d_model)
        self.ff2 = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.SiLU(),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
        
        self.final_norm = nn.LayerNorm(d_model)

    def forward(self, x, mask):
        # 1. First FFN
        x = x + 0.5 * self.ff1(self.ff1_norm(x))
        
        # 2. Attention
        res = x
        x_norm = self.attn_norm(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm, attn_mask=mask, is_causal=True)
        x = res + attn_out
        
        # 3. Causal Convolution
        res = x
        x_cnn = self.conv_norm(x).permute(0, 2, 1) # [B, Dim, Seq]
        
        # Expansion
        x_cnn = self.conv_pointwise1(x_cnn)
        x_cnn = self.glu(x_cnn)
        
        x_cnn = F.pad(x_cnn, (self.cnn_k - 1, 0))
        
        x_cnn = self.conv_depthwise(x_cnn)
        x_cnn = self.conv_batchnorm(x_cnn)
        x_cnn = self.conv_act(x_cnn)
        x_cnn = self.conv_pointwise2(x_cnn)
        x_cnn = self.conv_dropout(x_cnn)
        
        x = res + x_cnn.permute(0, 2, 1)
        
        # 4. Second FFN
        x = x + 0.5 * self.ff2(self.ff2_norm(x))
        
        return self.final_norm(x)

In [9]:
class MusicConformer(nn.Module):
    def __init__(self, vocab_size, cfg):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, cfg.EMBED_DIM)
        self.pos = PositionalEncoding(cfg.EMBED_DIM)
        self.layers = nn.ModuleList([
            CausalMacaronBlock(cfg.EMBED_DIM, cfg.N_HEADS, cfg.CNN_KERNEL, cfg.DROPOUT) 
            for _ in range(cfg.N_LAYERS)
        ])
        self.head = nn.Linear(cfg.EMBED_DIM, vocab_size)

    def forward(self, x):
        mask = torch.triu(torch.ones(x.size(1), x.size(1)) * float('-inf'), diagonal=1).to(x.device)
        x = self.pos(self.embed(x))
        for layer in self.layers: 
            x = layer(x, mask)
        return self.head(x)

model = MusicConformer(vocab_size, config)

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)

model = model.to(config.DEVICE)

optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=config.LEARNING_RATE,
    steps_per_epoch=len(train_loader) // config.ACCUMULATION_STEPS,
    epochs=config.EPOCHS,
    pct_start=0.1
)

print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")


Using 2 GPUs!
Model Parameters: 48,945,466


  scaler = GradScaler()


In [10]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001, path='best_conformer.pth'):
        self.patience = patience
        self.min_delta = min_delta
        self.path = path
        self.counter = 0
        self.best_loss = np.inf
        self.early_stop = False

    def __call__(self, val_loss, model):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            torch.save(model.state_dict(), self.path)
            print(f"Saved (Loss: {val_loss:.4f})")
        else:
            self.counter += 1
            print(f"No improvement ({self.counter}/{self.patience})")
            if self.counter >= self.patience:
                self.early_stop = True

def train_epoch(model, loader, opt, scaler, sched):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    loop = tqdm(loader, desc="Train", leave=False)
    opt.zero_grad(set_to_none=True)
    
    for i, (x, y) in enumerate(loop):
        x, y = x.to(config.DEVICE), y.to(config.DEVICE)
        
        with autocast():
            logits = model(x)
            loss = criterion(logits.view(-1, vocab_size), y.view(-1))
            loss = loss / config.ACCUMULATION_STEPS 
        
        scaler.scale(loss).backward()
        
        # Accuracy
        predictions = logits.argmax(dim=-1) 
        total_correct += (predictions == y).sum().item()
        total_samples += y.numel()
        
        # Step only every N batches
        if (i + 1) % config.ACCUMULATION_STEPS == 0:
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(opt)
            scaler.update()
            sched.step()
            opt.zero_grad(set_to_none=True)
        
        # Metrics for display
        current_loss = loss.item() * config.ACCUMULATION_STEPS
        total_loss += current_loss
        loop.set_postfix(loss=f"{current_loss:.4f}", acc=f"{total_correct/total_samples:.2%}")
        
    return total_loss / len(loader), total_correct / total_samples

def validate(model, loader):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(config.DEVICE), y.to(config.DEVICE)
            with autocast():
                logits = model(x)
                loss = criterion(logits.view(-1, vocab_size), y.view(-1))
            
            predictions = logits.argmax(dim=-1)
            total_correct += (predictions == y).sum().item()
            total_samples += y.numel()
            total_loss += loss.item()
    return total_loss / len(loader), total_correct / total_samples

In [11]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001, path='best_conformer.pth'):
        self.patience = patience
        self.min_delta = min_delta
        self.path = path
        self.counter = 0
        self.best_loss = np.inf
        self.early_stop = False

    def __call__(self, val_loss, model):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            torch.save(model.state_dict(), self.path)
            print(f"Saved (Loss: {val_loss:.4f})")
        else:
            self.counter += 1
            print(f"No improvement ({self.counter}/{self.patience})")
            if self.counter >= self.patience:
                self.early_stop = True

def train_epoch(model, loader, opt, scaler, sched):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    loop = tqdm(loader, desc="Train", leave=False)
    opt.zero_grad(set_to_none=True)
    
    for i, (x, y) in enumerate(loop):
        x, y = x.to(config.DEVICE), y.to(config.DEVICE)
        
        with autocast():
            logits = model(x)
            loss = criterion(logits.view(-1, vocab_size), y.view(-1))
            loss = loss / config.ACCUMULATION_STEPS 
        
        scaler.scale(loss).backward()
        
        # Accuracy
        predictions = logits.argmax(dim=-1) 
        total_correct += (predictions == y).sum().item()
        total_samples += y.numel()
        
        # Step only every N batches
        if (i + 1) % config.ACCUMULATION_STEPS == 0:
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(opt)
            scaler.update()
            sched.step()
            opt.zero_grad(set_to_none=True)
        
        # Metrics for display
        current_loss = loss.item() * config.ACCUMULATION_STEPS
        total_loss += current_loss
        loop.set_postfix(loss=f"{current_loss:.4f}", acc=f"{total_correct/total_samples:.2%}")
        
    return total_loss / len(loader), total_correct / total_samples

def validate(model, loader):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(config.DEVICE), y.to(config.DEVICE)
            with autocast():
                logits = model(x)
                loss = criterion(logits.view(-1, vocab_size), y.view(-1))
            
            predictions = logits.argmax(dim=-1)
            total_correct += (predictions == y).sum().item()
            total_samples += y.numel()
            total_loss += loss.item()
    return total_loss / len(loader), total_correct / total_samples

In [12]:
best_loss = float('inf')
patience_counter = 0

print("Starting Training...")

for epoch in range(config.EPOCHS):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, scaler, scheduler)
    val_loss, val_acc = validate(model, val_loader)
    
    try: ppl = math.exp(val_loss)
    except: ppl = float('inf')
    
    print(
        f"Epoch {epoch+1}/{config.EPOCHS} | "
        f"Train Loss: {train_loss:.4f} (Acc: {train_acc:.2%}) | "
        f"Val Loss: {val_loss:.4f} (Acc: {val_acc:.2%}) | "
        f"PPL: {ppl:.2f}"
    )
    
    if val_loss < best_loss:
        best_loss = val_loss
        patience_counter = 0
        torch.save(model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict(), "conformer_0.3.pth")
        print("Saved New Best Model")
    else:
        patience_counter += 1
        print(f"No improvement ({patience_counter}/{config.PATIENCE})")
        if patience_counter >= config.PATIENCE:
            print("Early stopping triggered.")
            break

# Load best weights
model.load_state_dict(torch.load("best_macaron_model.pth"))

Starting Training...


Train:   0%|          | 0/1413 [00:00<?, ?it/s]

  with autocast():
  with autocast():


Epoch 1/50 | Train Loss: 2.6813 (Acc: 28.20%) | Val Loss: 2.2577 (Acc: 32.69%) | PPL: 9.56
Saved New Best Model


Train:   0%|          | 0/1413 [00:00<?, ?it/s]

Epoch 2/50 | Train Loss: 2.0862 (Acc: 36.17%) | Val Loss: 1.9361 (Acc: 38.98%) | PPL: 6.93
Saved New Best Model


Train:   0%|          | 0/1413 [00:00<?, ?it/s]

Epoch 3/50 | Train Loss: 1.7998 (Acc: 42.64%) | Val Loss: 1.6914 (Acc: 45.02%) | PPL: 5.43
Saved New Best Model


Train:   0%|          | 0/1413 [00:00<?, ?it/s]

Epoch 4/50 | Train Loss: 1.6111 (Acc: 47.24%) | Val Loss: 1.5629 (Acc: 48.37%) | PPL: 4.77
Saved New Best Model


Train:   0%|          | 0/1413 [00:00<?, ?it/s]

Epoch 5/50 | Train Loss: 1.4751 (Acc: 50.59%) | Val Loss: 1.4688 (Acc: 50.50%) | PPL: 4.34
Saved New Best Model


Train:   0%|          | 0/1413 [00:00<?, ?it/s]

Epoch 6/50 | Train Loss: 1.3711 (Acc: 53.12%) | Val Loss: 1.4205 (Acc: 51.68%) | PPL: 4.14
Saved New Best Model


Train:   0%|          | 0/1413 [00:00<?, ?it/s]

Epoch 7/50 | Train Loss: 1.2947 (Acc: 54.99%) | Val Loss: 1.3962 (Acc: 52.46%) | PPL: 4.04
Saved New Best Model


Train:   0%|          | 0/1413 [00:00<?, ?it/s]

Epoch 8/50 | Train Loss: 1.2367 (Acc: 56.44%) | Val Loss: 1.3737 (Acc: 53.11%) | PPL: 3.95
Saved New Best Model


Train:   0%|          | 0/1413 [00:00<?, ?it/s]

Epoch 9/50 | Train Loss: 1.1906 (Acc: 57.62%) | Val Loss: 1.3655 (Acc: 53.37%) | PPL: 3.92
Saved New Best Model


Train:   0%|          | 0/1413 [00:00<?, ?it/s]

Epoch 10/50 | Train Loss: 1.1521 (Acc: 58.63%) | Val Loss: 1.3641 (Acc: 53.51%) | PPL: 3.91
Saved New Best Model


Train:   0%|          | 0/1413 [00:00<?, ?it/s]

Epoch 11/50 | Train Loss: 1.1191 (Acc: 59.52%) | Val Loss: 1.3739 (Acc: 53.66%) | PPL: 3.95
No improvement (1/5)


Train:   0%|          | 0/1413 [00:00<?, ?it/s]

Epoch 12/50 | Train Loss: 1.0901 (Acc: 60.31%) | Val Loss: 1.3685 (Acc: 53.71%) | PPL: 3.93
No improvement (2/5)


Train:   0%|          | 0/1413 [00:00<?, ?it/s]

Epoch 13/50 | Train Loss: 1.0638 (Acc: 61.06%) | Val Loss: 1.3754 (Acc: 53.80%) | PPL: 3.96
No improvement (3/5)


Train:   0%|          | 0/1413 [00:00<?, ?it/s]

Epoch 14/50 | Train Loss: 1.0400 (Acc: 61.75%) | Val Loss: 1.3797 (Acc: 53.82%) | PPL: 3.97
No improvement (4/5)


Train:   0%|          | 0/1413 [00:00<?, ?it/s]

Epoch 15/50 | Train Loss: 1.0184 (Acc: 62.39%) | Val Loss: 1.3805 (Acc: 53.83%) | PPL: 3.98
No improvement (5/5)
Early stopping triggered.


In [13]:
from miditok import TokSequence
import torch.nn.functional as F
from datetime import datetime
import numpy as np

# 1. Configuration
GEN_CONF = {
    'num_examples': 5,
    'max_len': 1024,       
    'temperature': 0.8,
    'top_k': 20
}

# 2. Generation Function
def generate_sequence(model, tokenizer, length, temp=1.0, top_k=20):
    model.eval()
    # Seed with a random token from the vocab
    start_token = torch.randint(0, tokenizer.vocab_size, (1, 1)).to(config.DEVICE)
    generated = start_token
    
    print(f"Generating... (Target: {length} tokens)")
    
    with torch.no_grad():
        for _ in tqdm(range(length), desc="Gen"):
            logits = model(generated)
            # Focus on the last token
            last_logits = logits[:, -1, :] / temp
            
            # Top-K Sampling (Filters out bad notes)
            v, _ = torch.topk(last_logits, top_k)
            last_logits[last_logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(last_logits, dim=-1)
            
            # Sample
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat((generated, next_token), dim=1)
            
    return generated[0].cpu().numpy()

# 3. Main Loop
print("Starting Generation Loop...")

for i in range(GEN_CONF['num_examples']):
    try:
        # A. Generate
        raw_tokens = generate_sequence(
            model, 
            tokenizer, 
            GEN_CONF['max_len'], 
            GEN_CONF['temperature'], 
            GEN_CONF['top_k']
        )
        
        # B. Convert to Python List
        token_list = raw_tokens.tolist()
        
        # C. Wrap in TokSequence
        seq = TokSequence(ids=token_list)
        
        # D. Decode and Save
        generated_score = tokenizer.decode([seq])
        
        timestamp = datetime.now().strftime("%H%M%S")
        filename = f"conformer_0.3_{i+1}_{timestamp}.mid"
        generated_score.dump_midi(filename)
        print(f"Saved: {filename}")
        
    except Exception as e:
        print(f"Error on file {i+1}: {e}")
        if 'token_list' in locals():
            print(f"First 10 tokens: {token_list[:10]}")

Starting Generation Loop...
Generating... (Target: 1024 tokens)


Gen:   0%|          | 0/1024 [00:00<?, ?it/s]

Saved: conformer_0.3_1_065600.mid
Generating... (Target: 1024 tokens)


Gen:   0%|          | 0/1024 [00:00<?, ?it/s]

Saved: conformer_0.3_2_065608.mid
Generating... (Target: 1024 tokens)


Gen:   0%|          | 0/1024 [00:00<?, ?it/s]

Saved: conformer_0.3_3_065616.mid
Generating... (Target: 1024 tokens)


Gen:   0%|          | 0/1024 [00:00<?, ?it/s]

Saved: conformer_0.3_4_065625.mid
Generating... (Target: 1024 tokens)


Gen:   0%|          | 0/1024 [00:00<?, ?it/s]

Saved: conformer_0.3_5_065633.mid
