In [None]:
#sampling function
import torch
import torch.nn.functional as F
import random

def sample_with_structure(
    model, tokenizer, start_token, max_length=1024,
    temperature=1.0, top_k=10, top_p=0.9,
    device='cuda'
):
    model = model.to(device)
    model.eval()

    generated = [start_token]
    input_token = torch.tensor([[start_token]], device=device)
    input_pos = torch.tensor([[0]], device=device)
    hidden = None
    current_position = 0
    vocab_size = tokenizer.vocab_size

    for step in range(max_length):
        with torch.no_grad():
            output, hidden = model(input_token, input_pos, hidden)
            logits = output[:, -1, :] / temperature
            logits = logits[:, :vocab_size]

            # Top-k and Top-p (nucleus) sampling
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

            # Nucleus sampling: remove tokens with cumulative prob above top_p
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            for b in range(logits.shape[0]):
                logits[b, sorted_indices[b][sorted_indices_to_remove[b]]] = -float("Inf")

            # Top-k sampling: keep only top-k tokens
            if top_k > 0:
                top_k_values, _ = torch.topk(logits, top_k)
                min_top_k = top_k_values[:, -1].unsqueeze(-1)
                logits[logits < min_top_k] = -float("Inf")

            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()
            generated.append(next_token)

            # Token string (for position and bar logic)
            token_str = None
            for name, idx in tokenizer._vocab_base.items():
                if idx == next_token:
                    token_str = name
                    break

            # Update position — capped to max 95
            if token_str and token_str.startswith("Position_"):
                try:
                    current_position = int(token_str.split("_")[1])
                except:
                    current_position = (current_position + 1) % 96
            else:
                current_position = (current_position + 1) % 96

            # Reset position if Bar or EOS is predicted
            if token_str and token_str.startswith("Bar"):
                current_position = 0

            if next_token in [tokenizer["EOS_None"], tokenizer["PAD_None"]]:
                break

            input_token = torch.tensor([[next_token]], device=device)
            input_pos = torch.tensor([[current_position]], device=device)

    return generated

start_token = tokenizer["BOS_None"]
generated_sequence = sample_with_structure(
    model=model,
    tokenizer=tokenizer,
    start_token=start_token,
    max_length=1024,
    temperature=0.9,
    top_k=10,
    top_p=0.95,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)


ImportError: cannot import name 'nn' from partially initialized module 'torch' (most likely due to a circular import) (C:\Users\sammy\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\__init__.py)

In [None]:
# Add Chord Annotations to Your Training Data

import pretty_midi

def add_chord_markers_to_midi(midi_path, out_path, chords):
    pm = pretty_midi.PrettyMIDI(midi_path)
    for chord_time, chord_label in chords:
        pm.markers.append(pretty_midi.Marker(chord_label, chord_time))
    pm.write(out_path)


import pretty_midi

def extract_chords(pm, beat_times):
    chords = []
    for i in range(len(beat_times) - 1):
        start = beat_times[i]
        end = beat_times[i + 1]
        notes = [note.pitch for inst in pm.instruments for note in inst.notes if start <= note.start < end]
        if notes:
            root = pretty_midi.note_number_to_name(min(notes))  # crude root guess
            chord_label = f"Chord_{root}maj"  # placeholder, better use chord recognition lib
        else:
            chord_label = "Chord_None"
        chords.append((start, chord_label))
    return chords



ModuleNotFoundError: No module named 'pretty_midi'