In [None]:
#import statements 
import glob
import os
import random
from typing import List
from collections import defaultdict

import numpy as np
from numpy.random import choice

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from miditok.pytorch_data import DatasetMIDI, DataCollator

from symusic import Score
from miditok import REMI, TokenizerConfig
from midi2audio import FluidSynth # Import library
from IPython.display import Audio, display

Task 1 <br>
This assignment focuses on symbolic music modeling. The goal is to train a model that learns a distribution \( p(x) \) over symbolic music data (e.g., MIDI)  specifically within the classicalgenre. In addition it is capable of sampling new sequences from this learned distribution unconditionally. We will be using the LSTM model for this task. <br>

In [None]:

train_files = glob.glob("./train/*.midi")
test_files = glob.glob("./test/*.midi")

LSTM Model<br>

In [None]:
class MusicRNN(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim=768, hidden_dim=1024, num_layers=4,
                 dropout=0.3, bidirectional=False, max_position_embeddings=1024):
        super(MusicRNN, self).__init__()
        
        # Larger embeddings 
        self.token_embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.position_embedding = torch.nn.Embedding(max_position_embeddings, embedding_dim)
        
        # Deeper LSTM with residual connections
        self.rnn_layers = torch.nn.ModuleList([
            torch.nn.LSTM(
                input_size=embedding_dim if i == 0 else hidden_dim,
                hidden_size=hidden_dim,
                num_layers=1,
                dropout=0,
                batch_first=True,
                bidirectional=bidirectional
            ) for i in range(num_layers)
        ])
        
        rnn_output_dim = hidden_dim * 2 if bidirectional else hidden_dim
        
        # Enhanced output processing
        self.layer_norms = torch.nn.ModuleList([
            torch.nn.LayerNorm(hidden_dim) for _ in range(num_layers)
        ])
        self.dropout = torch.nn.Dropout(dropout)
        
        # Multi-layer output head
        self.output_projection = torch.nn.Sequential(
            torch.nn.Linear(rnn_output_dim, rnn_output_dim // 2),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(rnn_output_dim // 2, vocab_size)
        )
        
        # Initialize weights properly
        self._init_weights()
    
    def _init_weights(self):
        #Xavier initialization for better training
        for module in self.modules():
            if isinstance(module, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, torch.nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0, std=0.02)
    
    def forward(self, x, position_ids, hidden_states=None):
        # Token and position embeddings
        tok_emb = self.token_embedding(x)
        pos_emb = self.position_embedding(position_ids)
        x = tok_emb + pos_emb
        x = self.dropout(x)
        
        # Pass through LSTM layers with residual connections
        new_hidden_states = []
        for i, (rnn_layer, layer_norm) in enumerate(zip(self.rnn_layers, self.layer_norms)):
            hidden = hidden_states[i] if hidden_states else None
            out, new_hidden = rnn_layer(x, hidden)
            out = layer_norm(out)
            
            # Residual connection (when dimensions match)
            if i > 0 and out.size(-1) == x.size(-1):
                out = out + x
            
            x = self.dropout(out)
            new_hidden_states.append(new_hidden)
        
        # Output projection
        output = self.output_projection(x)
        return output, new_hidden_states

In [None]:
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0.001, restore_best_weights=True):
        self.patience = patience  # Number of epochs to wait for improvement before stopping
        self.min_delta = min_delta # Min change in validation loss to qualify as an improvement
        self.restore_best_weights = restore_best_weights

        # Internal variables for tracking state
        self.best_loss = None
        self.counter = 0
        self.best_weights = None
        
    # check current loss improved beyond the min_delta threshold
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
        else:
            self.counter += 1
            
        # If no improvement then early stopping will be triggered
        if self.counter >= self.patience:
            if self.restore_best_weights:
                model.load_state_dict(self.best_weights)
            return True
        return False
    
    def save_checkpoint(self, model):
        self.best_weights = model.state_dict().copy()

Check if gpu is avaliable or else cpu will be used instead

In [None]:
def check_gpu_setup():
    """Check and force GPU usage"""
    print("=== GPU Setup Check ===")
    print(f"CUDA Available: {torch.cuda.is_available()}")
    
    if torch.cuda.is_available():
        print(f"CUDA Device Count: {torch.cuda.device_count()}")
        print(f"Current CUDA Device: {torch.cuda.current_device()}")
        print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")
        device = torch.device('cuda')
        print("Using GPU")
    else:
        print("CUDA not available!")
        device = torch.device('cpu')
        print("Using CPU")
    
    print("=" * 30)
    return device

Training<br>

In [None]:
def train_model(model, train_loader, val_loader, vocab_size, num_epochs=50, 
                        initial_lr=1e-3, device='cpu', save_path='best_model.pth'):
    #Advanced training with learning rate scheduling and early stopping"
    model = model.to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding tokens
    optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=0.01)
    
    #Learning rate scheduler 
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6
    )
    
    # Early stopping to prevent overfitting
    early_stopping = EarlyStopping(patience=10, min_delta=0.001)
    
   
    train_losses = []
    val_losses = []
    learning_rates = []
    
    print(f"Starting training with {len(train_loader)} batches per epoch")
    print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")
    print(f"Training on: {device}")
    
    # Training phase
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_train_loss = 0
        num_batches = 0
        
        for batch_idx, batch in enumerate(train_loader):
            batch = batch['input_ids'].to(device)
            
            if batch.size(1) < 2:
                continue
                
            input_ids = batch[:, :-1]
            target_ids = batch[:, 1:]
            
            position_ids = torch.arange(input_ids.size(1), device=device).unsqueeze(0).expand_as(input_ids)
            
            optimizer.zero_grad()
            outputs, _ = model(input_ids, position_ids)
            
            outputs = outputs.reshape(-1, vocab_size)
            targets = target_ids.reshape(-1)
            
            loss = criterion(outputs, targets)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            total_train_loss += loss.item()
            num_batches += 1
            
            if batch_idx % 50 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        
        avg_train_loss = total_train_loss / num_batches if num_batches > 0 else 0
        
        # Validation phase
        model.eval()
        total_val_loss = 0
        val_batches = 0
        
        with torch.no_grad():
            for batch in val_loader:
                batch = batch['input_ids'].to(device)
                
                if batch.size(1) < 2:
                    continue
                    
                input_ids = batch[:, :-1]
                target_ids = batch[:, 1:]
                position_ids = torch.arange(input_ids.size(1), device=device).unsqueeze(0).expand_as(input_ids)
                
                outputs, _ = model(input_ids, position_ids)
                outputs = outputs.reshape(-1, vocab_size)
                targets = target_ids.reshape(-1)
                
                loss = criterion(outputs, targets)
                total_val_loss += loss.item()
                val_batches += 1
        
        avg_val_loss = total_val_loss / val_batches if val_batches > 0 else 0
        current_lr = optimizer.param_groups[0]['lr']
        
        # Record metrics
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        learning_rates.append(current_lr)
        
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}")
        
        # GPU memory info 
        if torch.cuda.is_available():
            print(f"GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB allocated")
        print("-" * 60)
        
        # Learning rate scheduling
        scheduler.step(avg_val_loss)
        
        # Early stopping check
        if early_stopping(avg_val_loss, model):
            print(f"Early stopping triggered at epoch {epoch+1}")
            break
        
        # checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'vocab_size': vocab_size
            }, f'checkpoint_epoch_{epoch+1}.pth')
    
    # Save final model
    torch.save({
        'model_state_dict': model.state_dict(),
        'vocab_size': vocab_size,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'learning_rates': learning_rates
    }, save_path)
    
    return train_losses, val_losses, learning_rates

In [None]:
def setup_training():
    #Set up and run the improved training pipeline
    
    # Force GPU check first
    device = check_gpu_setup()
    
    # Create tokenizer
    config = TokenizerConfig(
        num_velocities=32,           
        use_chords=True,            
        use_programs=False,         
        use_time_signatures=True,   
        use_rests=True, 
        use_tempos=True            
    )
    tokenizer = REMI(config)
    tokenizer.train(vocab_size=1500, files_paths=train_files)
    tokenizer.save("tokenizer.json")
    print(f"Tokenizer vocabulary size: {tokenizer.vocab_size}")
    
    # Create datasets
    train_dataset = DatasetMIDI(
        files_paths=train_files,
        tokenizer=tokenizer,
        max_seq_len=1024,
        bos_token_id=tokenizer["BOS_None"],
        eos_token_id=tokenizer["EOS_None"],
    )
    
    test_dataset = DatasetMIDI(
        files_paths=test_files,
        tokenizer=tokenizer,
        max_seq_len=1024,
        bos_token_id=tokenizer["BOS_None"],
        eos_token_id=tokenizer["EOS_None"],
    )
    
    collator = DataCollator(tokenizer.pad_token_id)
    # Smaller batch size for GPU memory
    batch_size = 16 if device.type == 'cuda' else 8
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collator)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collator)
    
    model = MusicRNN(
        vocab_size=tokenizer.vocab_size,
        embedding_dim=768,
        hidden_dim=1024,
        num_layers=4,
        dropout=0.3,
        max_position_embeddings=1024
    )
    
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    train_losses, val_losses, learning_rates = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=test_loader,
        vocab_size=tokenizer.vocab_size,
        num_epochs=50,
        initial_lr=1e-3,
        device=device,  
        save_path='improved_music_model.pth'
    )
    
    return model, tokenizer, train_losses, val_losses, learning_rates

# Run the improved training
if __name__ == "__main__":
    model, tokenizer, train_losses, val_losses, learning_rates = setup_training()

In [None]:
def load_trained_model_fixed():
    # model loading function
    try:
        model_path = 'improved_music_model.pth'
        print(f"Loading model from {model_path}")
        
        if not os.path.exists(model_path):
            print(f" Model file {model_path} not found!")
            return None, None
            
        checkpoint = torch.load(model_path, map_location='cpu')
        
        # Load tokenizer 
        tokenizer = None
        if os.path.exists("tokenizer.json"):
            print("Loading existing tokenizer...")
            try:
                # Correct way to load REMI tokenizer
                tokenizer = REMI.from_pretrained(".")  # Load from current directory
                print(" Loaded existing tokenizer")
            except Exception as e:
                print(f"Failed to load existing tokenizer: {e}")
                tokenizer = None
        
        if tokenizer is None:
            print("Creating new tokenizer...")
            # Get training files
            train_files = glob.glob("./train/*.midi")
            if not train_files:
                print("No training files found in ./train/*.midi")
                return None, None
                
            config = TokenizerConfig(
                num_velocities=32,           
                use_chords=True,            
                use_programs=False,         
                use_time_signatures=True,   
                use_rests=True, 
                use_tempos=True            
            )
            tokenizer = REMI(config)
            tokenizer.train(vocab_size=1500, files_paths=train_files)
            tokenizer.save("tokenizer.json")
            print(" Created and saved new tokenizer")
        
        # Create model with same architecture
        vocab_size = checkpoint.get('vocab_size', tokenizer.vocab_size)
        model = MusicRNN(
            vocab_size=vocab_size,
            embedding_dim=768,
            hidden_dim=1024,
            num_layers=4,
            dropout=0.3,
            max_position_embeddings=1024
        )
        
        # Load the trained weights
        model.load_state_dict(checkpoint['model_state_dict'])
        print(" Successfully loaded your trained model!")
        
        # Show training info if available
        if 'train_losses' in checkpoint:
            print(f"Model was trained for {len(checkpoint['train_losses'])} epochs")
            print(f"Final training loss: {checkpoint['train_losses'][-1]:.4f}")
        if 'val_losses' in checkpoint:
            print(f"Final validation loss: {checkpoint['val_losses'][-1]:.4f}")
        
        return model, tokenizer
        
    except Exception as e:
        print(f"Error loading model: {e}")
        return None, None
    
def plot_training_progress(train_losses, val_losses, learning_rates):
    #Plot training metrics to visualize progress"""
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss plot
    axes[0].plot(train_losses, label='Training Loss', color='blue')
    axes[0].plot(val_losses, label='Validation Loss', color='red')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training Progress')
    axes[0].legend()
    axes[0].grid(True)
    
    # Learning rate plot
    axes[1].plot(learning_rates, color='green')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Learning Rate')
    axes[1].set_title('Learning Rate Schedule')
    axes[1].set_yscale('log')
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.show()

    
    # Plot results if matplotlib is available
    try:
        import matplotlib.pyplot as plt
        plot_training_progress(train_losses, val_losses, learning_rates)
    except ImportError:
        print("Install matplotlib to see training plots: pip install matplotlib")

Functions to improve the musicality of the midi file 

In [None]:
def get_token_string(tokenizer, token_id):
    # Helper function to get token string
    try:
        # Access token vocabulary
        if hasattr(tokenizer, '_vocab_base'):
            for name, idx in tokenizer._vocab_base.items():
                if idx == token_id:
                    return name
        elif hasattr(tokenizer, 'vocab'):
            if token_id < len(tokenizer.vocab):
                return tokenizer.vocab[token_id]
        elif hasattr(tokenizer, '_vocab'):
            for name, idx in tokenizer._vocab.items():
                if idx == token_id:
                    return name
        return None
    except:
        return None

In [None]:
def get_musical_context(tokenizer, recent_tokens, max_lookback=20):
    
    #Analyze recent tokens to understand current musical context (Returns tempo, chord, and structural information)
    
    context = {
        'current_tempo': None,
        'current_chord': None,
        'recent_pitches': [],
        'current_velocity': None,
        'bar_position': 0,
        'time_signature': None
    }
    
    # Look at recent tokens for context
    lookback_tokens = recent_tokens[-max_lookback:] if len(recent_tokens) > max_lookback else recent_tokens
    
    for token in lookback_tokens:
        token_str = get_token_string(tokenizer, token)
        if not token_str:
            continue
            
        # Extract tempo 
        if token_str.startswith("Tempo_"):
            try:
                context['current_tempo'] = int(token_str.split("_")[1])
            except:
                pass
                
        # Extract chord 
        elif token_str.startswith("Chord_"):
            context['current_chord'] = token_str
            
        # Extract pitch info for harmony
        elif token_str.startswith("Pitch_"):
            try:
                pitch = int(token_str.split("_")[1])
                context['recent_pitches'].append(pitch)
                # Keep recent pitches
                context['recent_pitches'] = context['recent_pitches'][-8:]
            except:
                pass
                
        # Extract velocity for dynamics
        elif token_str.startswith("Velocity_"):
            try:
                context['current_velocity'] = int(token_str.split("_")[1])
            except:
                pass
                
        # Extract time signature
        elif token_str.startswith("TimeSig_"):
            context['time_signature'] = token_str
            
        # Track bar position
        elif token_str.startswith("Position_"):
            try:
                context['bar_position'] = int(token_str.split("_")[1])
            except:
                pass
    
    return context

In [None]:
def apply_musical_biasing(logits, tokenizer, context, step, bias_strength=2.0):
   
    #apply musical knowledge to bias the probability distribution

    vocab_size = logits.shape[-1]
    bias = torch.zeros_like(logits)
    
    # Define common chord progressions and musical patterns
    common_tempos = [60, 72, 80, 90, 100, 110, 120, 132, 144]  # Common classical tempos
    consonant_intervals = [0, 3, 4, 7, 12]  # Unison, minor 3rd, major 3rd, perfect 5th, octave
    
    for token_id in range(min(vocab_size, tokenizer.vocab_size)):
        token_str = get_token_string(tokenizer, token_id)
        if not token_str:
            continue
            
        # Tempo biasing - favor stable, musical tempos
        if token_str.startswith("Tempo_"):
            try:
                tempo = int(token_str.split("_")[1])
                if tempo in common_tempos:
                    bias[0, token_id] += bias_strength * 0.3
                elif 60 <= tempo <= 160:  # Reasonable tempo range
                    bias[0, token_id] += bias_strength * 0.1
                else:
                    bias[0, token_id] -= bias_strength * 0.2  # Discourage extreme tempos
            except:
                pass
                
        # Chord progression biasing
        elif token_str.startswith("Chord_"):
            # Favor common chord types
            if any(chord_type in token_str for chord_type in ["_M", "_m", "_dim", "_7"]):
                bias[0, token_id] += bias_strength * 0.2
                
        # Harmonic biasing for pitches
        elif token_str.startswith("Pitch_"):
            try:
                pitch = int(token_str.split("_")[1])
                
                # If we have recent pitches, favor consonant intervals
                if context['recent_pitches']:
                    for recent_pitch in context['recent_pitches'][-3:]:  # Check last 3 pitches
                        interval = abs(pitch - recent_pitch) % 12
                        if interval in consonant_intervals:
                            bias[0, token_id] += bias_strength * 0.15
                        elif interval in [1, 2, 10, 11]:  # Dissonant intervals
                            bias[0, token_id] -= bias_strength * 0.1
                            
                # Favor pitches in reasonable range (piano range roughly)
                if 21 <= pitch <= 108:  # A0 to C8
                    bias[0, token_id] += bias_strength * 0.05
                    
            except:
                pass
                
        # Velocity biasing for natural dynamics
        elif token_str.startswith("Velocity_"):
            try:
                velocity = int(token_str.split("_")[1])
                # Favor moderate velocities, avoid extremes
                if 40 <= velocity <= 100:
                    bias[0, token_id] += bias_strength * 0.1
                elif velocity < 20 or velocity > 120:
                    bias[0, token_id] -= bias_strength * 0.15
            except:
                pass
                
        # Structural biasing
        elif token_str.startswith("Position_"):
            try:
                position = int(token_str.split("_")[1])
                # Favor positions that align with musical structure (beats)
                if position % 24 == 0:  # Strong beats (assuming 24 ticks per quarter)
                    bias[0, token_id] += bias_strength * 0.1
                elif position % 12 == 0:  # Half beats
                    bias[0, token_id] += bias_strength * 0.05
            except:
                pass
    
    return bias

In [None]:
def encourage_musical_structure(logits, tokenizer, context, step, structure_strength=1.5):
   
    #encourage musical structure like proper phrase lengths and cadences into midi

    bias = torch.zeros_like(logits)
    vocab_size = logits.shape[-1]
    
    #encourage structural elements
    if step % 64 < 8:  
        # Encourage new chord/tempo 
        for token_id in range(min(vocab_size, tokenizer.vocab_size)):
            token_str = get_token_string(tokenizer, token_id)
            if token_str and (token_str.startswith("Chord_") or token_str.startswith("Tempo_")):
                bias[0, token_id] += structure_strength * 0.2
                
    elif step % 64 > 56:  # End of phrase
        # Encourage cadential patterns like chord changes & rests
        for token_id in range(min(vocab_size, tokenizer.vocab_size)):
            token_str = get_token_string(tokenizer, token_id)
            if token_str and (token_str.startswith("Chord_") or token_str.startswith("Rest_")):
                bias[0, token_id] += structure_strength * 0.3
    
    return bias

Sampling using funcitons from above<br>

In [None]:
def sample_with_structure(
    model, tokenizer, start_token, max_length=2048,
    temperature=1.0, top_k=10, top_p=0.9,
    device='cuda', musical_bias=True, bias_strength=2.0
):
   
    #sampling function with musical knowledge and structure awareness 
    
    if model is None or tokenizer is None:
        print("Model or tokenizer is None")
        return []
        
    model = model.to(device)
    model.eval()

    generated = [start_token]
    input_token = torch.tensor([[start_token]], device=device)
    input_pos = torch.tensor([[0]], device=device)
    hidden = None
    current_position = 0
    vocab_size = tokenizer.vocab_size

    print(f"Starting enhanced sampling with vocab_size: {vocab_size}")
    print(f"Musical biasing: {'enabled' if musical_bias else 'disabled'}")

    for step in range(max_length):
        with torch.no_grad():
            try:
                output, hidden = model(input_token, input_pos, hidden)
                logits = output[:, -1, :] / temperature
                logits = logits[:, :vocab_size]  # Constrain to valid vocab

                # Apply musical biasing if enabled
                if musical_bias:
                    context = get_musical_context(tokenizer, generated)
                    
                    # Apply musical knowledge biasing
                    musical_bias_tensor = apply_musical_biasing(
                        logits, tokenizer, context, step, bias_strength
                    )
                    
                    # Apply structural biasing
                    structure_bias = encourage_musical_structure(
                        logits, tokenizer, context, step, bias_strength * 0.75
                    )
                    
                    # Combine biases
                    logits = logits + musical_bias_tensor + structure_bias

                # Top-p (nucleus) sampling
                if top_p < 1.0:
                    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                    
                    # Create mask for tokens to remove
                    sorted_indices_to_remove = cumulative_probs > top_p
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = 0  # Never remove the top token
                    
                    # Apply mask
                    for b in range(logits.shape[0]):
                        logits[b, sorted_indices[b][sorted_indices_to_remove[b]]] = -float("Inf")

                # Top-k sampling
                if top_k > 0:
                    top_k_values, _ = torch.topk(logits, min(top_k, vocab_size))
                    min_top_k = top_k_values[:, -1].unsqueeze(-1)
                    logits[logits < min_top_k] = -float("Inf")

                # Sample from the filtered distribution
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1).item()
                
                # Safety check for out-of-vocabulary tokens
                if next_token >= vocab_size:
                    try:
                        next_token = tokenizer["EOS_None"]
                    except:
                        next_token = 0  # Fallback to pad token
                        
                generated.append(next_token)

                # Get token string for position and structure logic
                token_str = get_token_string(tokenizer, next_token)

                # Update position with musical structure awareness
                if token_str and token_str.startswith("Position_"):
                    try:
                        current_position = int(token_str.split("_")[1])
                        current_position = min(current_position, 95)  # Cap at 95
                    except:
                        current_position = (current_position + 1) % 96
                elif token_str and token_str.startswith("Bar"):
                    # Reset position at bar boundaries for musical structure
                    current_position = 0
                else:
                    # Default increment with wraparound
                    current_position = (current_position + 1) % 96

                # Check for end tokens
                try:
                    eos_token = tokenizer["EOS_None"]
                    pad_token = tokenizer.get("PAD_None", 0)
                    if next_token in [eos_token, pad_token]:
                        break
                except:
                    # If we can't access special tokens, continue
                    pass

                # Prepare for next iteration
                input_token = torch.tensor([[next_token]], device=device)
                input_pos = torch.tensor([[current_position]], device=device)
                
                # Print progress occasionally with musical context
                if step % 100 == 0:
                    context = get_musical_context(tokenizer, generated)
                    print(f"Step {step}: Generated {len(generated)} tokens, "
                          f"Tempo: {context['current_tempo']}, "
                          f"Chord: {context['current_chord']}")

            except Exception as e:
                print(f"Error during sampling at step {step}: {e}")
                break

    print(f"Generated {len(generated)} tokens total")
    return generated

Generation output of midi files<br>

In [None]:
def generate_midi(tokenizer, generated_sequence, output_filename="rnn.mid"):
    #MIDI generation function
    try:
        if not generated_sequence:
            print("empty generated sequence")
            return None
            
        vocab_size = tokenizer.vocab_size
        valid_sequence = [token for token in generated_sequence if 0 <= token < vocab_size]
        
        print(f"Original sequence length: {len(generated_sequence)}")
        print(f"Valid sequence length: {len(valid_sequence)}")
        
        if len(valid_sequence) < 2:
            print("sequence too short or no valid tokens found")
            return None
            
        # Decode tokens to MIDI
        try:
            output_scores = tokenizer.decode([valid_sequence])
            
            if isinstance(output_scores, list):
                output_score = output_scores[0]
            else:
                output_score = output_scores
            
            if len(output_score.tracks) == 0:
                print("MIDI generated has no tracks")
                return None
                
            output_score.dump_midi(output_filename)
            print(f"successfully generated {output_filename}")
            return output_score
            
        except Exception as decode_error:
            print(f"Error during token decoding: {decode_error}")
            return None
        
    except Exception as e:
        print(f"Error during MIDI generation: {e}")
        print(f"Sequence sample: {generated_sequence[:20] if generated_sequence else 'Empty'}...")
        return None

In [None]:
def enhanced_generation():
    
    #function to enhanced the generation with existing model
    
    print("=== Testing Enhanced Generation ===")
    
    # Load existing trained model
    model, tokenizer = load_trained_model_fixed()
    
    if model is None or tokenizer is None:
        print(" model and tokenizer load failed")
        return
    
    print(f" Model loaded successfully")
    print(f"Tokenizer loaded with vocab size: {tokenizer.vocab_size}")
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    try:
        start_token = tokenizer["BOS_None"]
    except:
        print("Using fallback start token")
        start_token = 1
    
    print("\n=== Generating with Musical Enhancements ===")
    
    # Generate with enhanced musical biasing
    generated_sequence = sample_with_structure(
        model=model,
        tokenizer=tokenizer,
        start_token=start_token,
        max_length=1024,       
        temperature=0.825,    
        top_k=20,           
        top_p=0.95,         
        device=device,
        musical_bias=True,    
        bias_strength=2.0     
    )
    
    if not generated_sequence:
        print(" Generation failed")
        return
    
    print(f"Generated sequence of length {len(generated_sequence)}")
    
    # Generate MIDI
    print("\n=== Creating Enhanced MIDI ===")
    output_score = generate_midi(tokenizer, generated_sequence, "enhanced_music2.mid")
    
    if output_score:
        print("midi generated")
        print("File saved as: enhanced_music.mid")
        
        # Show musical analysis
        context = get_musical_context(tokenizer, generated_sequence, max_lookback=len(generated_sequence))
        print(f"\n=== Musical Analysis ===")
        print(f"Final tempo: {context['current_tempo']}")
        print(f"Final chord: {context['current_chord']}")
        print(f"Unique pitches used: {len(set(context['recent_pitches']))}")
        print(f"Final velocity: {context['current_velocity']}")
    else:
        print("MIDI generation failed")

In [None]:
enhanced_generation()