In [1]:
!pip install -q miditok symusic tqdm

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
import os
import random
import math
import glob
import json
import numpy as np
from pathlib import Path
from miditok import REMI, TokenizerConfig
from tqdm.notebook import tqdm
import csv
from symusic import Score, Track, Note
import shutil

In [3]:
class Config:
    DATA_DIR = "BACH_MIDI"
    
    STRIDE = 128           
    SEQ_LENGTH = 1024  
    
    EMBED_DIM = 256
    N_HEADS = 8
    N_LAYERS = 6
    CNN_KERNEL = 31        
    DROPOUT = 0.3          
    
    BATCH_SIZE = 16
    ACCUMULATION_STEPS = 1 
    LEARNING_RATE = 5e-4
    EPOCHS = 50
    PATIENCE = 5
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config()

In [4]:
def load_bach_remi(path):
    path = Path(path)
    train_files = list((path / "train").glob("*.mid"))
    val_files = list((path / "valid").glob("*.mid"))
    
    print(f"ðŸ“‚ Found {len(train_files)} Train | {len(val_files)} Valid files")

    tokenizer_config = TokenizerConfig(
        num_velocities=1,      
        use_chords=True,      
        use_tempos=False,    
        use_velocities=False,
        use_durations=False,
        use_programs=False,
        beat_res={(0, 4): 4}   
    )
    tokenizer = REMI(tokenizer_config)
    
    # B. Tokenize
    def tokenize_batch(file_list, desc):
        tokens_list = []
        for f in tqdm(file_list, desc=desc):
            try:
                midi = Score(f)
                toks = tokenizer(midi)
                if len(toks) > 0: tokens_list.extend(toks[0].ids)
            except: pass
        return torch.tensor(tokens_list, dtype=torch.long)

    train_tensor = tokenize_batch(train_files, "Tokenizing Train")
    val_tensor = tokenize_batch(val_files, "Tokenizing Valid")
    
    return train_tensor, val_tensor, tokenizer

train_tensor, val_tensor, tokenizer = load_bach_remi(config.DATA_DIR)
vocab_size = tokenizer.vocab_size
print(f"Vocab Size: {vocab_size}")

ðŸ“‚ Found 229 Train | 76 Valid files


Tokenizing Train:   0%|          | 0/229 [00:00<?, ?it/s]

Tokenizing Valid:   0%|          | 0/76 [00:00<?, ?it/s]

Vocab Size: 202


In [5]:
class MidiDataset(Dataset):
    def __init__(self, data_tensor, seq_len, stride):
        self.seq_len = seq_len
        self.data = data_tensor
        self.indices = range(0, len(data_tensor) - seq_len - 1, stride)
    def __len__(self): return len(self.indices)
    def __getitem__(self, idx):
        start_idx = self.indices[idx]
        return (self.data[start_idx : start_idx + self.seq_len], 
                self.data[start_idx+1 : start_idx + self.seq_len + 1])

train_ds = MidiDataset(train_tensor, config.SEQ_LENGTH, config.STRIDE)
val_ds = MidiDataset(val_tensor, config.SEQ_LENGTH, config.STRIDE)

train_loader = DataLoader(train_ds, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=2)

In [6]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div)
        pe[:, 1::2] = torch.cos(position * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x): return x + self.pe[:, :x.size(1), :]

class CausalMacaronBlock(nn.Module):
    def __init__(self, d_model, n_heads, cnn_k, dropout):
        super().__init__()
        self.cnn_k = cnn_k
        
        # FF 1
        self.ff1_norm = nn.LayerNorm(d_model)
        self.ff1 = nn.Sequential(
            nn.Linear(d_model, d_model * 4), nn.SiLU(), 
            nn.Linear(d_model * 4, d_model), nn.Dropout(dropout)
        )
        
        # Attention
        self.attn_norm = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        
        # Conv
        self.conv_norm = nn.LayerNorm(d_model)
        self.conv_pointwise1 = nn.Conv1d(d_model, d_model * 2, 1)
        self.glu = nn.GLU(dim=1)
        self.conv_depthwise = nn.Conv1d(d_model, d_model, cnn_k, padding=0, groups=d_model)
        self.conv_batchnorm = nn.BatchNorm1d(d_model)
        self.conv_act = nn.SiLU()
        self.conv_pointwise2 = nn.Conv1d(d_model, d_model, 1)
        self.conv_dropout = nn.Dropout(dropout)
        
        # FF 2
        self.ff2_norm = nn.LayerNorm(d_model)
        self.ff2 = nn.Sequential(
            nn.Linear(d_model, d_model * 4), nn.SiLU(),
            nn.Linear(d_model * 4, d_model), nn.Dropout(dropout)
        )
        self.final_norm = nn.LayerNorm(d_model)

    def forward(self, x, mask):
        # Macaron Sandwich: 1/2 FF -> Attn -> Conv -> 1/2 FF
        x = x + 0.5 * self.ff1(self.ff1_norm(x))
        
        res = x
        x_norm = self.attn_norm(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm, attn_mask=mask, is_causal=True)
        x = res + attn_out
        
        res = x
        x_cnn = self.conv_norm(x).permute(0, 2, 1) 
        x_cnn = self.conv_pointwise1(x_cnn)
        x_cnn = self.glu(x_cnn)
        x_cnn = F.pad(x_cnn, (self.cnn_k - 1, 0)) # Causal Padding
        x_cnn = self.conv_depthwise(x_cnn)
        x_cnn = self.conv_batchnorm(x_cnn)
        x_cnn = self.conv_act(x_cnn)
        x_cnn = self.conv_pointwise2(x_cnn)
        x_cnn = self.conv_dropout(x_cnn)
        x = res + x_cnn.permute(0, 2, 1)
        
        x = x + 0.5 * self.ff2(self.ff2_norm(x))
        return self.final_norm(x)

class MusicConformer(nn.Module):
    def __init__(self, vocab_size, cfg):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, cfg.EMBED_DIM)
        self.pos = PositionalEncoding(cfg.EMBED_DIM)
        self.layers = nn.ModuleList([
            CausalMacaronBlock(cfg.EMBED_DIM, cfg.N_HEADS, cfg.CNN_KERNEL, cfg.DROPOUT) 
            for _ in range(cfg.N_LAYERS)
        ])
        self.head = nn.Linear(cfg.EMBED_DIM, vocab_size)

    def forward(self, x):
        mask = torch.triu(torch.ones(x.size(1), x.size(1)) * float('-inf'), diagonal=1).to(x.device)
        x = self.pos(self.embed(x))
        for layer in self.layers: 
            x = layer(x, mask)
        return self.head(x)

model = MusicConformer(vocab_size, config).to(config.DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler() 

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=config.LEARNING_RATE,
    steps_per_epoch=len(train_loader), epochs=config.EPOCHS, pct_start=0.1
)

  scaler = GradScaler()


In [7]:
def train_epoch(model, loader):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    loop = tqdm(loader, desc="Train", leave=False)
    
    for i, (x, y) in enumerate(loop):
        x, y = x.to(config.DEVICE), y.to(config.DEVICE)
        
        with autocast():
            logits = model(x)
            loss = criterion(logits.view(-1, vocab_size), y.view(-1))
        
        scaler.scale(loss).backward()
        
        # Accuracy Calculation
        predictions = logits.argmax(dim=-1) 
        total_correct += (predictions == y).sum().item()
        total_samples += y.numel()
        
        if (i + 1) % config.ACCUMULATION_STEPS == 0:
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)
            
        total_loss += loss.item()
        loop.set_postfix(loss=f"{loss.item():.4f}")
        
    return total_loss / len(loader), total_correct / total_samples

def validate(model, loader):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(config.DEVICE), y.to(config.DEVICE)
            with autocast():
                logits = model(x)
                loss = criterion(logits.view(-1, vocab_size), y.view(-1))
            
            # Accuracy Calculation
            predictions = logits.argmax(dim=-1)
            total_correct += (predictions == y).sum().item()
            total_samples += y.numel()
            total_loss += loss.item()
            
    return total_loss / len(loader), total_correct / total_samples

In [8]:
best_loss = float('inf')
patience = 0

for epoch in range(config.EPOCHS):
    t_loss, t_acc = train_epoch(model, train_loader)
    v_loss, v_acc = validate(model, val_loader)
    
    try: ppl = math.exp(v_loss)
    except: ppl = float('inf')
    
    print(
        f"Epoch {epoch+1}/{config.EPOCHS} | "
        f"Train: {t_loss:.4f} ({t_acc:.1%}) | "
        f"Val: {v_loss:.4f} ({v_acc:.1%}) | "
        f"PPL: {ppl:.2f}"
    )
    
    if v_loss < best_loss:
        best_loss = v_loss
        patience = 0
        torch.save(model.state_dict(), "best_bach_conformer.pth")
        print("Saved New Best Model")
    else:
        patience += 1
        if patience >= config.PATIENCE:
            print("Early Stopping")
            break

Train:   0%|          | 0/244 [00:00<?, ?it/s]

  with autocast():
  with autocast():


Epoch 1/50 | Train: 2.2366 (56.1%) | Val: 0.9625 (78.4%) | PPL: 2.62
Saved New Best Model


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 2/50 | Train: 0.5802 (86.7%) | Val: 0.3254 (90.8%) | PPL: 1.38
Saved New Best Model


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 3/50 | Train: 0.2890 (91.8%) | Val: 0.2566 (92.5%) | PPL: 1.29
Saved New Best Model


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 4/50 | Train: 0.2321 (93.0%) | Val: 0.2421 (92.8%) | PPL: 1.27
Saved New Best Model


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 5/50 | Train: 0.1978 (93.8%) | Val: 0.2372 (93.0%) | PPL: 1.27
Saved New Best Model


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 6/50 | Train: 0.1675 (94.7%) | Val: 0.2431 (93.1%) | PPL: 1.28


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 7/50 | Train: 0.1390 (95.5%) | Val: 0.2528 (93.0%) | PPL: 1.29


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 8/50 | Train: 0.1136 (96.3%) | Val: 0.2760 (92.9%) | PPL: 1.32


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 9/50 | Train: 0.0917 (97.0%) | Val: 0.3018 (92.9%) | PPL: 1.35


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 10/50 | Train: 0.0741 (97.5%) | Val: 0.3256 (92.8%) | PPL: 1.38


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 11/50 | Train: 0.0607 (98.0%) | Val: 0.3526 (92.8%) | PPL: 1.42


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 12/50 | Train: 0.0514 (98.3%) | Val: 0.3769 (92.7%) | PPL: 1.46


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 13/50 | Train: 0.0444 (98.5%) | Val: 0.3906 (92.7%) | PPL: 1.48


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 14/50 | Train: 0.0391 (98.7%) | Val: 0.4138 (92.7%) | PPL: 1.51


Train:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch 15/50 | Train: 0.0353 (98.8%) | Val: 0.4188 (92.8%) | PPL: 1.52
Early Stopping


In [12]:
from miditok import TokSequence
import torch.nn.functional as F
from datetime import datetime
import numpy as np
from tqdm.notebook import tqdm

GEN_CONF = {
    'num_examples': 2,
    'max_len': 1024,        
    'temperature': 0.9,   
    'top_k': 10            
}

def generate_sequence(model, tokenizer, length, temp=1.0, top_k=20):
    model.eval()
    
    # Seed with a random token
    start_token = torch.randint(0, tokenizer.vocab_size, (1, 1)).to(config.DEVICE)
    generated = start_token
    
    print(f"Generating Chorale... (Target: {length} tokens)")
    
    with torch.no_grad():
        for _ in tqdm(range(length), desc="Gen"):
            logits = model(generated)
            last_logits = logits[:, -1, :] / temp
            
            # Top-K Sampling
            v, _ = torch.topk(last_logits, top_k)
            last_logits[last_logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(last_logits, dim=-1)
            
            # Sample
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat((generated, next_token), dim=1)
            
    return generated[0].cpu().numpy()

for i in range(GEN_CONF['num_examples']):
    try:
        raw_tokens = generate_sequence(
            model, 
            tokenizer, 
            GEN_CONF['max_len'], 
            GEN_CONF['temperature'], 
            GEN_CONF['top_k']
        )
        
        # Decode
        token_list = raw_tokens.tolist()
        seq = TokSequence(ids=token_list)
        generated_score = tokenizer.decode([seq])
        
        timestamp = datetime.now().strftime("%H%M%S")
        filename = f"bach_conformer_1_{i+1}_{timestamp}.mid"
        generated_score.dump_midi(filename)
        print(f"Saved: {filename}")
        
    except Exception as e:
        print(f"Error: {e}")

Generating Chorale... (Target: 1024 tokens)


Gen:   0%|          | 0/1024 [00:00<?, ?it/s]

Saved: bach_conformer_1_1_162832.mid
Generating Chorale... (Target: 1024 tokens)


Gen:   0%|          | 0/1024 [00:00<?, ?it/s]

Saved: bach_conformer_1_2_163021.mid
