In [None]:
!pip install music21



In [None]:
from google.colab import drive
import os
import sys


drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
from pathlib import Path
from tqdm import tqdm
import sys
from dataclasses import dataclass





PROJECT_ROOT = Path("/content/drive/MyDrive/MusicScalingProject")
DATA_DIR = PROJECT_ROOT / "data/v3"
CHECKPOINT_PATH = PROJECT_ROOT / "data/checkpoints/BEST_Model.pt"
TOKENIZER_PATH = DATA_DIR / "tokenizer_bpe_4096.json"
TEST_BIN_PATH = DATA_DIR / "test.bin"


@dataclass
class BestConfig:
    n_layer: int = 16       
    n_head: int = 16
    n_embd: int = 1024
    block_size: int = 1024
    vocab_size: int = 4096  
    dropout: float = 0.1




class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        B, T, C = x.size()
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(y)

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = F.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None: torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

    
    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx




def calculate_perplexity(model, data_path, device, block_size):
    if not data_path.exists():
        print(f"ERROR: Test file not found at {data_path}")
        return float('inf')

    print(f"Loading test data from {data_path}...")
    data = np.memmap(data_path, dtype=np.uint16, mode='r')
    total_loss = 0.0
    total_tokens = 0
    batch_size = 4  

    model.eval()
    criterion = torch.nn.CrossEntropyLoss()

    with torch.no_grad():
        
        iterations = range(0, len(data) - block_size, batch_size * block_size)

        for i in tqdm(iterations, desc="Calculating Perplexity"):
            chunk_len = min(batch_size * block_size, len(data) - i - 1)
            if chunk_len < block_size: break

            chunk = torch.from_numpy(data[i:i + chunk_len + 1].astype(np.int64)).to(device)
            eff_batch = chunk_len // block_size
            if eff_batch == 0: continue

            x = chunk[:eff_batch * block_size].view(eff_batch, block_size)
            y = chunk[1:eff_batch * block_size + 1].view(eff_batch, block_size)

            logits, _ = model(x)
            loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))

            total_loss += loss.item() * (eff_batch * block_size)
            total_tokens += (eff_batch * block_size)

    return math.exp(total_loss / total_tokens) if total_tokens > 0 else float('inf')

def evaluate_validity(model, tokenizer_path, device, num_samples=500):
    from tokenizers import Tokenizer
    import music21

    if not tokenizer_path.exists():
        print("ERROR: Tokenizer not found.")
        return 0, 0

    tokenizer = Tokenizer.from_file(str(tokenizer_path))
    print(f"\nGenerating {num_samples} samples for validity check...")
    model.eval()

    
    start_id = tokenizer.token_to_id("<|endoftext|>")
    if start_id is None: start_id = 0

    start_tokens = torch.full((num_samples, 1), start_id, dtype=torch.long, device=device)

    generated = model.generate(start_tokens, max_new_tokens=256)
    valid_syntax = 0
    valid_midi = 0

    for i in range(num_samples):
        text = tokenizer.decode(generated[i].tolist())
        
        if "X:" in text:
            text = text[text.find("X:"):]
            if "\n\n" in text: text = text.split("\n\n")[0]

        try:
            s = music21.converter.parse(text, format='abc')
            valid_syntax += 1
            _ = s.write('midi') 
            valid_midi += 1
        except:
            pass

    return (valid_syntax/num_samples)*100, (valid_midi/num_samples)*100




if __name__ == "__main__":
    
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "music21", "tokenizers"])

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    
    config = BestConfig()
    model = GPT(config)

    
    if CHECKPOINT_PATH.exists():
        print(f"Loading checkpoint: {CHECKPOINT_PATH}")
        checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
        state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint

        
        unwanted_prefix = '_orig_mod.'
        for k in list(state_dict.keys()):
            if k.startswith(unwanted_prefix):
                state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

        model.load_state_dict(state_dict)
    else:
        print(f"CRITICAL: Checkpoint not found at {CHECKPOINT_PATH}")

    model.to(device)

    
    syn, midi = evaluate_validity(model, TOKENIZER_PATH, device)
    ppx = calculate_perplexity(model, TEST_BIN_PATH, device, config.block_size)

    
    print("\n" + "="*40)
    print("FINAL QUANTITATIVE METRICS")
    print("="*40)
    print(f"Test Perplexity:      {ppx:.4f}")
    print(f"Syntactic Validity:   {syn:.2f}%")
    print(f"MIDI Conversion:      {midi:.2f}%")
    print("="*40)

Using device: cuda
Loading checkpoint: /content/drive/MyDrive/MusicScalingProject/data/checkpoints/BEST_Model.pt

Generating 500 samples for validity check...




Loading test data from /content/drive/MyDrive/MusicScalingProject/data/v3/test.bin...


Calculating Perplexity: 100%|██████████| 336/336 [00:48<00:00,  6.90it/s]


FINAL QUANTITATIVE METRICS
Test Perplexity:      1.9336
Syntactic Validity:   68.00%
MIDI Conversion:      58.80%



