In [70]:
from music21 import converter, note, harmony, stream,meter, key
from glob import glob
from collections import defaultdict
import random
import torch
import torch.nn as nn
import math
from miditok import REMI, TokenizerConfig
from symusic import Score
from torch.utils.data import Dataset
import torch
import os
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.functional as F
from miditok import TokSequence
from torch.optim.lr_scheduler import CosineAnnealingLR

In [35]:
abc_files = glob('./nottingham-dataset-master/ABC_cleaned/*.abc')
len(abc_files)

14

In [114]:
def extract_separate_tokens(score):
    chords = []
    pitches = []
    durations = []
    current_chord = "N"  

    #M = str(score.recurse().getElementsByClass(meter.TimeSignature).first())
    ts = score.recurse().getElementsByClass(meter.TimeSignature).first()
    M = ts.ratioString 
    K = str(score.recurse().getElementsByClass(key.KeySignature).first())

    

    for el in score.flatten().notesAndRests:
        if isinstance(el, note.Note):
            pitches.append(el.nameWithOctave)
            durations.append(round(el.quarterLength, 2))
            chords.append(current_chord)
        elif isinstance(el, note.Rest):
            pitches.append("Rest")
            durations.append(round(el.quarterLength, 2))
            chords.append(current_chord)
        elif isinstance(el, harmony.ChordSymbol):
            current_chord = el.figure

    return chords, pitches, durations, M, K

### baseline(Markov)

In [115]:
def num_count(item,items_count):
    if item not in items_count:
        items_count[item] = 1
    else:
        items_count[item] += 1
    return items_count

In [116]:
def unigram_probability(items_count):
    unigramProbabilities = {}
    
    total = sum(items_count.values())
    for item, value in items_count.items():
        unigramProbabilities[item] = value / total
        
    
    return unigramProbabilities

In [117]:
def bigram(prios, posts,bi):
    for prio, post in zip(prios, posts):
        bi[(prio,post)] += 1

    return bi

In [118]:
def trigram(prios1,prios2, posts,tri):
    for prio1,prio2, post in zip(prios1,prios2, posts):
        tri[(prio1,prio2, post)] += 1

    return tri

In [119]:
def cal_bigram_Probability(bigrams):
    temp_counts = defaultdict(list)
    bigramTransitions = defaultdict(list)
    bigramTransitionProbabilities = defaultdict(list)
    for item, value in bigrams.items():
        bigramTransitions[item[0]].append(item[1])
        temp_counts[item[0]].append(value)
    
    for item, counts in temp_counts.items():
        bigramTransitionProbabilities[item] = [i/sum(counts) for i in counts]
    
    return bigramTransitions, bigramTransitionProbabilities
    

In [120]:
def cal_trigram_Probability(trigrams):
    trigramTransitions = defaultdict(list)
    trigramTransitionProbabilities = defaultdict(list)
    temp_counts = defaultdict(list)
    for item, value in trigrams.items():
        trigramTransitions[(item[0],item[1])].append(item[2])
        temp_counts[(item[0],item[1])].append(value)
    
    for item, counts in temp_counts.items():
        trigramTransitionProbabilities[item] = [i/sum(counts) for i in counts]
    
    return trigramTransitions, trigramTransitionProbabilities

In [121]:
def buildTransitionProbabilities(abc_files):
    chords_count = {}
    meters_count  = {}
    keys_count = {}
    chord_bigrams = defaultdict(int)
    chord_pitch_bigrams = defaultdict(int)
    chord_pitch_dur_trigrams = defaultdict(int)
    for abc_file in abc_files:
        opus = converter.parse(abc_file, format='abc')
        scores = opus.scores if isinstance(opus, stream.Opus) else [opus]
        for i, score in enumerate(scores):
            score = score.parts[0] if hasattr(score, "parts") else score
            chords, pitches, durations,M, K = extract_separate_tokens(score)
            meters_count = num_count(M,meters_count)
            keys_count = num_count(K,keys_count)
            chord_bigrams = bigram(chords[:-1],chords[1:],chord_bigrams)
            chord_pitch_bigrams = bigram(chords,pitches,chord_pitch_bigrams)
            chord_pitch_dur_trigrams = trigram(chords,pitches, durations, chord_pitch_dur_trigrams)
            for chord in chords:
                chords_count = num_count(chord,chords_count)
                
    unigramChordProbabilities = unigram_probability(chords_count)
    unigramKeyProbabilities = unigram_probability(keys_count)
    unigramMeterProbabilities = unigram_probability(meters_count) 
    bigramChordTransitions, bigramChordTransitionProbabilities = cal_bigram_Probability(chord_bigrams)
    bigramChordPitchTransitions, bigramChordPitchTransitionProbabilities = cal_bigram_Probability(chord_pitch_bigrams)
    trigramChordPitchDurTransitions, trigramChordPitchDurTransitionProbabilities = cal_trigram_Probability(chord_pitch_dur_trigrams)

    return (unigramChordProbabilities, unigramKeyProbabilities, unigramMeterProbabilities, bigramChordTransitions, bigramChordTransitionProbabilities,
    bigramChordPitchTransitions, bigramChordPitchTransitionProbabilities, trigramChordPitchDurTransitions, trigramChordPitchDurTransitionProbabilities)

In [122]:
unigramChordProbabilities, unigramKeyProbabilities, unigramMeterProbabilities, bigramChordTransitions, bigramChordTransitionProbabilities,bigramChordPitchTransitions, bigramChordPitchTransitionProbabilities, trigramChordPitchDurTransitions, trigramChordPitchDurTransitionProbabilities = buildTransitionProbabilities(abc_files)

In [124]:
unigramMeterProbabilities

{'4/4': 0.5193423597678917,
 '2/4': 0.039651837524177946,
 '2/2': 0.006769825918762089,
 '3/4': 0.058027079303675046,
 '6/8': 0.3578336557059961,
 '9/8': 0.013539651837524178,
 '3/2': 0.0029013539651837525,
 '6/4': 0.0019342359767891683}

In [125]:
unigramKeyProbabilities

{'D major': 0.3452611218568665,
 'G major': 0.3491295938104449,
 'A major': 0.12088974854932302,
 'F major': 0.025145067698259187,
 'e minor': 0.02804642166344294,
 'C major': 0.05222437137330754,
 'a minor': 0.037717601547388784,
 'B major': 0.0009671179883945841,
 'd minor': 0.013539651837524178,
 'B- major': 0.01160541586073501,
 'E major': 0.0019342359767891683,
 'g minor': 0.010638297872340425,
 'b minor': 0.0029013539651837525}

In [126]:
def duration_to_abc(duration, L="1/8"):

    l_num, l_den = map(int, L.split("/"))
    base_unit = l_num / l_den  # L 对应的 quarterLength，例如 1/8 = 0.125

    ratio = round(duration / base_unit, 3)

    if ratio == 1.0:
        return ""
    elif ratio == 0.5:
        return "/"
    elif ratio == 0.25:
        return "//"
    elif ratio.is_integer():
        return str(int(ratio))
    else:
        return str(ratio).replace(".", "/")

def pitch_to_abc(pitch_name):

    step_map = {'C': 'C', 'D': 'D', 'E': 'E', 'F': 'F', 'G': 'G', 'A': 'A', 'B': 'B'}
    accidental_map = {'#': '^', '-': '_', 'b': '_'}
    match = re.match(r"([A-Ga-g])([#b♯♭]?)(\d)", pitch_name)
    if not match:
        return "C"  
    step, acc, octave = match.groups()
    octave = int(octave)
    acc_symbol = accidental_map.get(acc, "")

    base = step.upper()
    if octave > 4:
        return acc_symbol + base.lower() + "'" * (octave - 5)
    elif octave < 4:
        return acc_symbol + base + "," * (3 - octave)
    else:
        return acc_symbol + base.lower()

In [127]:
def format_key_for_abc(key_str):
    key_str = key_str.strip().lower()
    if "major" in key_str:
        return key_str.replace(" major", "").upper()
    elif "minor" in key_str:
        return key_str.replace(" minor", "").capitalize() + "m"
    else:
        return key_str

In [142]:
def music_generate(length,abc_filename):
    M = random.choices(
        population=list(unigramMeterProbabilities.keys()),
        weights=list(unigramMeterProbabilities.values()),
        k=1
    )[0]

    
    
    
    raw_key = random.choices(
        population=list(unigramKeyProbabilities.keys()),
        weights=list(unigramKeyProbabilities.values()),
        k=1
    )[0]

    K = format_key_for_abc(raw_key)
    sampled_chords = []
    sampled_pitches = []
    sampled_durations = []
    last_chord = None
    
    for i in range(length):
        if not last_chord:
            chord = random.choices(
                population=list(unigramChordProbabilities.keys()),
                weights=list(unigramChordProbabilities.values()),
                k=1
                )[0]
        else:
            chord = random.choices(
                population=bigramChordTransitions[last_chord],
                weights=bigramChordTransitionProbabilities[last_chord],
                k=1
            )[0]
        pitch = random.choices(
                population=bigramChordPitchTransitions[chord],
                weights=bigramChordPitchTransitionProbabilities[chord],
                k=1
            )[0]
        duration = random.choices(
                population=trigramChordPitchDurTransitions[(chord,pitch)],
                weights=trigramChordPitchDurTransitionProbabilities[(chord,pitch)],
                k=1
            )[0]
        sampled_chords.append(chord)
        sampled_pitches.append(pitch)
        sampled_durations.append(duration)
        last_chord = chord

    # save the generated music as a abc file
    abc_lines = [
        "X:1",
        "T:Generated Tune",
        f"M:{M}",
        "L:1/8",  
        f"K:{K}"
    ]

    print(M,K)
    body_lines = []
    measure = []
    current_len = 0

    for ch, pitch, dur in zip(sampled_chords, sampled_pitches, sampled_durations):
        dur_str = duration_to_abc(dur,"1/8")
        if pitch == "Rest":
            pitch_token = "z" + dur_str
        else:
            pitch_token = pitch_to_abc(pitch) + dur_str

        measure.append(f'"{ch}"{pitch_token}')
        current_len += dur

        if current_len >= 1.0:  
            body_lines.append(" ".join(measure) + " |")
            measure = []
            current_len = 0


    if measure:
        body_lines.append(" ".join(measure))

    abc_lines.append("\n".join(body_lines))

    with open(abc_filename, "w", encoding="utf-8") as f:
        f.write("\n".join(abc_lines))

    print(f" ABC file written to: {abc_filename}")
   

In [144]:
music_generate(20,"output.abc")

4/4 G
 ABC file written to: output.abc


In [145]:
score = converter.parse('output.abc', format='abc')
if isinstance(score, stream.Opus):
    score = score.scores[0]


ts_found = False
ks_found = False

for el in score.recurse():
    if isinstance(el, meter.TimeSignature):
        if not ts_found:
            ts_found = True
        else:
            container = el.activeSite
            if container is not None:
                container.remove(el)

    elif isinstance(el, key.KeySignature):
        if not ks_found:
            ks_found = True
        else:
            container = el.activeSite
            if container is not None:
                container.remove(el)

score.write('midi', fp='output.mid')

✅ Successfully wrote MIDI to output.mid


### Llama

In [6]:
config = TokenizerConfig(
    num_velocities=32,
    use_chords=True,
    use_programs=True,
    use_rests=True,
    use_tempos=True,
    use_controls=False
)
tokenizer = REMI(config)

In [21]:
def tokenize_all_midi(midi_folder):
    token_seqs = []
    for file in os.listdir(midi_folder):
        if file.endswith(".mid") or file.endswith(".midi"):
            score = Score(os.path.join(midi_folder, file))
            token_list = tokenizer(score) 
            token_seqs.append(token_list.ids)
            # for tokseq in token_list:
            #     token_seqs.append(tokseq)
    
    return token_seqs

In [22]:
token_seqs = tokenize_all_midi("./nottingham-dataset-master/MIDI")

In [32]:
Lengths = [len(i) for i in token_seqs]
sorted_lengths = sorted(Lengths, reverse=True)
for p in [90, 95, 98, 99]:
    val = int(np.percentile(Lengths, p))
    print(f"{p} percentile: max_len = {val}")

90 percentile: max_len = 2427
95 percentile: max_len = 3239
98 percentile: max_len = 4453
99 percentile: max_len = 5954


In [33]:
class MIDITokenDataset(Dataset):
    def __init__(self, token_seqs, seq_len=1024):
        self.samples = []
        for tokens in token_seqs:
            for i in range(0, len(tokens) - seq_len, seq_len):
                chunk = tokens[i:i+seq_len]
                self.samples.append(torch.tensor(chunk, dtype=torch.long))
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x = self.samples[idx][:-1]   
        y = self.samples[idx][1:]    
        return x, y

In [34]:
vocab_size = tokenizer.vocab_size

In [39]:
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm = x.norm(2, dim=-1, keepdim=True)
        return self.weight * x / (norm + self.eps)

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class SwiGLUFeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim)
        self.w2 = nn.Linear(dim, hidden_dim)
        self.activation = Swish()
        self.proj = nn.Linear(hidden_dim, dim)

    def forward(self, x):
        return self.proj(self.w1(x) * self.activation(self.w2(x)))

class Attention(nn.Module):
    def __init__(self, dim, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        self.qkv = nn.Linear(dim, dim * 3)
        self.out = nn.Linear(dim, dim)

    def forward(self, x, rope_embed):
        B, T, C = x.shape
        qkv = self.qkv(x).view(B, T, 3, self.n_heads, self.head_dim).transpose(1, 3)
        q, k, v = qkv.unbind(dim=2)  # (B, n_heads, T, head_dim)

        q, k = apply_rope(q, rope_embed), apply_rope(k, rope_embed)

        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        att = att.masked_fill(torch.triu(torch.ones(T, T, device=x.device), 1) == 1, float('-inf'))
        att = torch.softmax(att, dim=-1)
        out = att @ v  # (B, n_heads, T, head_dim)
        out = out.transpose(1, 2).reshape(B, T, C)
        return self.out(out)

class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def get_embed(self, seq_len, device):
        pos = torch.arange(seq_len, device=device).type_as(self.inv_freq)
        freqs = torch.einsum("i,j->ij", pos, self.inv_freq)
        emb = torch.cat([freqs.sin(), freqs.cos()], dim=-1)
        return emb

def apply_rope(x, rope_embed):
    # x: (B, n_heads, T, head_dim)
    sin, cos = rope_embed[..., ::2], rope_embed[..., 1::2]
    x1, x2 = x[..., ::2], x[..., 1::2]
    return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)

class DecoderBlock(nn.Module):
    def __init__(self, dim, n_heads):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.attn = Attention(dim, n_heads)
        self.norm2 = RMSNorm(dim)
        self.ff = SwiGLUFeedForward(dim, 4 * dim)

    def forward(self, x, rope_embed):
        x = x + self.attn(self.norm1(x), rope_embed)
        x = x + self.ff(self.norm2(x))
        return x

In [40]:
class MiniLLaMA(nn.Module):
    def __init__(self, vocab_size, dim=256, n_heads=4, n_layers=4, max_seq_len=1024):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, dim)
        self.pos_emb = RotaryEmbedding(dim // n_heads)
        self.blocks = nn.ModuleList([DecoderBlock(dim, n_heads) for _ in range(n_layers)])
        self.norm = RMSNorm(dim)
        self.output = nn.Linear(dim, vocab_size)

    def forward(self, idx):
        B, T = idx.shape
        x = self.token_emb(idx)
        rope_embed = self.pos_emb.get_embed(T, idx.device)
        for block in self.blocks:
            x = block(x, rope_embed)
        x = self.norm(x)
        return self.output(x)

In [76]:
class Pipeline:
    def __init__(self, model, dataset, vocab_size,epochs, batch_size=32, lr=1e-3, device=None):
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = model.to(self.device)
        self.dataset = dataset
        self.vocab_size = vocab_size
    
        self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        self.total_step = len(self.dataloader) * epochs
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay = 1e-5)
        self.schedule = CosineAnnealingLR(self.optimizer, T_max=self.total_step)
        self.criterion = nn.CrossEntropyLoss()

    def train(self, epochs):
        self.model.train()
        best_loss = float('inf')
        patience = 5
        count = 0
        for epoch in range(epochs):
            total_loss = 0
            pbar = tqdm(self.dataloader, desc=f"Epoch {epoch+1}/{epochs}")
            for x, y in pbar:
                x = x.to(self.device)
                y = y.to(self.device)

                logits = self.model(x) 
                loss = self.criterion(logits.view(-1, self.vocab_size), y.view(-1))

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.schedule.step()

                total_loss += loss.item()
                pbar.set_postfix(loss=loss.item())

            avg_loss = total_loss / len(self.dataloader)
            print(f"[Epoch {epoch+1}] Average Loss: {avg_loss:.4f}")
            if avg_loss < best_loss:
                best_loss = avg_loss
                torch.save(self.model.state_dict(), 'bestmusicLLaMA.pt')
                count = 0
            else:
                count += 1
            if count >= patience:
                print("Early stopping!")
                break
                
                


        

    

In [77]:
dataset = MIDITokenDataset(token_seqs)
model = MiniLLaMA(vocab_size)
trainer = Pipeline(model, dataset, vocab_size,epochs=100)
trainer.train(epochs=100)

Epoch 1/100: 100%|██████████| 36/36 [00:09<00:00,  3.66it/s, loss=5.15]


[Epoch 1] Average Loss: 5.6877


Epoch 2/100: 100%|██████████| 36/36 [00:09<00:00,  3.69it/s, loss=4.17]


[Epoch 2] Average Loss: 4.6257


Epoch 3/100: 100%|██████████| 36/36 [00:09<00:00,  3.69it/s, loss=3.4] 


[Epoch 3] Average Loss: 3.7535


Epoch 4/100: 100%|██████████| 36/36 [00:09<00:00,  3.69it/s, loss=2.81]


[Epoch 4] Average Loss: 3.0768


Epoch 5/100: 100%|██████████| 36/36 [00:09<00:00,  3.68it/s, loss=2.35]


[Epoch 5] Average Loss: 2.5633


Epoch 6/100: 100%|██████████| 36/36 [00:09<00:00,  3.67it/s, loss=2.02]


[Epoch 6] Average Loss: 2.1987


Epoch 7/100: 100%|██████████| 36/36 [00:09<00:00,  3.65it/s, loss=1.85]


[Epoch 7] Average Loss: 1.9545


Epoch 8/100: 100%|██████████| 36/36 [00:09<00:00,  3.64it/s, loss=1.7] 


[Epoch 8] Average Loss: 1.7944


Epoch 9/100: 100%|██████████| 36/36 [00:09<00:00,  3.63it/s, loss=1.65]


[Epoch 9] Average Loss: 1.6791


Epoch 10/100: 100%|██████████| 36/36 [00:09<00:00,  3.61it/s, loss=1.47]


[Epoch 10] Average Loss: 1.5631


Epoch 11/100: 100%|██████████| 36/36 [00:10<00:00,  3.55it/s, loss=1.39]


[Epoch 11] Average Loss: 1.4427


Epoch 12/100: 100%|██████████| 36/36 [00:10<00:00,  3.54it/s, loss=1.28]


[Epoch 12] Average Loss: 1.3358


Epoch 13/100: 100%|██████████| 36/36 [00:10<00:00,  3.54it/s, loss=1.23]


[Epoch 13] Average Loss: 1.2495


Epoch 14/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=1.07]


[Epoch 14] Average Loss: 1.1693


Epoch 15/100: 100%|██████████| 36/36 [00:10<00:00,  3.51it/s, loss=0.991]


[Epoch 15] Average Loss: 1.0994


Epoch 16/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.98] 


[Epoch 16] Average Loss: 1.0332


Epoch 17/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.995]


[Epoch 17] Average Loss: 0.9700


Epoch 18/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.888]


[Epoch 18] Average Loss: 0.9059


Epoch 19/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.715]


[Epoch 19] Average Loss: 0.8350


Epoch 20/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.809]


[Epoch 20] Average Loss: 0.7668


Epoch 21/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.711]


[Epoch 21] Average Loss: 0.7110


Epoch 22/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.659]


[Epoch 22] Average Loss: 0.6741


Epoch 23/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.576]


[Epoch 23] Average Loss: 0.6448


Epoch 24/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.6]  


[Epoch 24] Average Loss: 0.6201


Epoch 25/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.658]


[Epoch 25] Average Loss: 0.6015


Epoch 26/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.573]


[Epoch 26] Average Loss: 0.5812


Epoch 27/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.504]


[Epoch 27] Average Loss: 0.5647


Epoch 28/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.494]


[Epoch 28] Average Loss: 0.5502


Epoch 29/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.525]


[Epoch 29] Average Loss: 0.5386


Epoch 30/100: 100%|██████████| 36/36 [00:10<00:00,  3.51it/s, loss=0.457]


[Epoch 30] Average Loss: 0.5251


Epoch 31/100: 100%|██████████| 36/36 [00:10<00:00,  3.50it/s, loss=0.518]


[Epoch 31] Average Loss: 0.5139


Epoch 32/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.552]


[Epoch 32] Average Loss: 0.5015


Epoch 33/100: 100%|██████████| 36/36 [00:10<00:00,  3.51it/s, loss=0.452]


[Epoch 33] Average Loss: 0.4901


Epoch 34/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.547]


[Epoch 34] Average Loss: 0.4815


Epoch 35/100: 100%|██████████| 36/36 [00:10<00:00,  3.51it/s, loss=0.523]


[Epoch 35] Average Loss: 0.4756


Epoch 36/100: 100%|██████████| 36/36 [00:10<00:00,  3.55it/s, loss=0.449]


[Epoch 36] Average Loss: 0.4679


Epoch 37/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.46] 


[Epoch 37] Average Loss: 0.4616


Epoch 38/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.444]


[Epoch 38] Average Loss: 0.4560


Epoch 39/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.397]


[Epoch 39] Average Loss: 0.4495


Epoch 40/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.46] 


[Epoch 40] Average Loss: 0.4464


Epoch 41/100: 100%|██████████| 36/36 [00:10<00:00,  3.54it/s, loss=0.533]


[Epoch 41] Average Loss: 0.4444


Epoch 42/100: 100%|██████████| 36/36 [00:10<00:00,  3.54it/s, loss=0.456]


[Epoch 42] Average Loss: 0.4381


Epoch 43/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.451]


[Epoch 43] Average Loss: 0.4339


Epoch 44/100: 100%|██████████| 36/36 [00:10<00:00,  3.54it/s, loss=0.394]


[Epoch 44] Average Loss: 0.4292


Epoch 45/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.426]


[Epoch 45] Average Loss: 0.4271


Epoch 46/100: 100%|██████████| 36/36 [00:10<00:00,  3.51it/s, loss=0.463]


[Epoch 46] Average Loss: 0.4239


Epoch 47/100: 100%|██████████| 36/36 [00:10<00:00,  3.50it/s, loss=0.412]


[Epoch 47] Average Loss: 0.4210


Epoch 48/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.449]


[Epoch 48] Average Loss: 0.4179


Epoch 49/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.395]


[Epoch 49] Average Loss: 0.4143


Epoch 50/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.382]


[Epoch 50] Average Loss: 0.4117


Epoch 51/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.455]


[Epoch 51] Average Loss: 0.4127


Epoch 52/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.358]


[Epoch 52] Average Loss: 0.4063


Epoch 53/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.388]


[Epoch 53] Average Loss: 0.4042


Epoch 54/100: 100%|██████████| 36/36 [00:10<00:00,  3.51it/s, loss=0.39] 


[Epoch 54] Average Loss: 0.4021


Epoch 55/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.39] 


[Epoch 55] Average Loss: 0.4001


Epoch 56/100: 100%|██████████| 36/36 [00:10<00:00,  3.51it/s, loss=0.455]


[Epoch 56] Average Loss: 0.3995


Epoch 57/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.427]


[Epoch 57] Average Loss: 0.3975


Epoch 58/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.391]


[Epoch 58] Average Loss: 0.3945


Epoch 59/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.412]


[Epoch 59] Average Loss: 0.3936


Epoch 60/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.407]


[Epoch 60] Average Loss: 0.3913


Epoch 61/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.376]


[Epoch 61] Average Loss: 0.3896


Epoch 62/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.393]


[Epoch 62] Average Loss: 0.3885


Epoch 63/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.359]


[Epoch 63] Average Loss: 0.3867


Epoch 64/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.401]


[Epoch 64] Average Loss: 0.3860


Epoch 65/100: 100%|██████████| 36/36 [00:10<00:00,  3.51it/s, loss=0.454]


[Epoch 65] Average Loss: 0.3856


Epoch 66/100: 100%|██████████| 36/36 [00:10<00:00,  3.51it/s, loss=0.395]


[Epoch 66] Average Loss: 0.3832


Epoch 67/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.376]


[Epoch 67] Average Loss: 0.3815


Epoch 68/100: 100%|██████████| 36/36 [00:10<00:00,  3.51it/s, loss=0.378]


[Epoch 68] Average Loss: 0.3805


Epoch 69/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.328]


[Epoch 69] Average Loss: 0.3787


Epoch 70/100: 100%|██████████| 36/36 [00:10<00:00,  3.54it/s, loss=0.352]


[Epoch 70] Average Loss: 0.3781


Epoch 71/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.371]


[Epoch 71] Average Loss: 0.3774


Epoch 72/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.346]


[Epoch 72] Average Loss: 0.3760


Epoch 73/100: 100%|██████████| 36/36 [00:10<00:00,  3.50it/s, loss=0.392]


[Epoch 73] Average Loss: 0.3760


Epoch 74/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.378]


[Epoch 74] Average Loss: 0.3751


Epoch 75/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.375]


[Epoch 75] Average Loss: 0.3741


Epoch 76/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.363]


[Epoch 76] Average Loss: 0.3732


Epoch 77/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.405]


[Epoch 77] Average Loss: 0.3736


Epoch 78/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.381]


[Epoch 78] Average Loss: 0.3725


Epoch 79/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.376]


[Epoch 79] Average Loss: 0.3715


Epoch 80/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.367]


[Epoch 80] Average Loss: 0.3710


Epoch 81/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.381]


[Epoch 81] Average Loss: 0.3707


Epoch 82/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.379]


[Epoch 82] Average Loss: 0.3702


Epoch 83/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.34] 


[Epoch 83] Average Loss: 0.3691


Epoch 84/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.333]


[Epoch 84] Average Loss: 0.3685


Epoch 85/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.357]


[Epoch 85] Average Loss: 0.3687


Epoch 86/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.378]


[Epoch 86] Average Loss: 0.3687


Epoch 87/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.418]


[Epoch 87] Average Loss: 0.3693


Epoch 88/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.324]


[Epoch 88] Average Loss: 0.3672


Epoch 89/100: 100%|██████████| 36/36 [00:10<00:00,  3.54it/s, loss=0.381]


[Epoch 89] Average Loss: 0.3681


Epoch 90/100: 100%|██████████| 36/36 [00:10<00:00,  3.54it/s, loss=0.398]


[Epoch 90] Average Loss: 0.3682


Epoch 91/100: 100%|██████████| 36/36 [00:10<00:00,  3.53it/s, loss=0.384]


[Epoch 91] Average Loss: 0.3677


Epoch 92/100: 100%|██████████| 36/36 [00:10<00:00,  3.52it/s, loss=0.376]


[Epoch 92] Average Loss: 0.3674


Epoch 93/100: 100%|██████████| 36/36 [00:10<00:00,  3.54it/s, loss=0.398]

[Epoch 93] Average Loss: 0.3677
Early stopping!





In [78]:
model = MiniLLaMA(vocab_size)
model.load_state_dict(torch.load("bestmusicLLaMA.pt", map_location='cpu'))
model.eval()

MiniLLaMA(
  (token_emb): Embedding(489, 256)
  (pos_emb): RotaryEmbedding()
  (blocks): ModuleList(
    (0-3): 4 x DecoderBlock(
      (norm1): RMSNorm()
      (attn): Attention(
        (qkv): Linear(in_features=256, out_features=768, bias=True)
        (out): Linear(in_features=256, out_features=256, bias=True)
      )
      (norm2): RMSNorm()
      (ff): SwiGLUFeedForward(
        (w1): Linear(in_features=256, out_features=1024, bias=True)
        (w2): Linear(in_features=256, out_features=1024, bias=True)
        (activation): Swish()
        (proj): Linear(in_features=1024, out_features=256, bias=True)
      )
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=256, out_features=489, bias=True)
)

In [52]:
def generate_with_temperature(model, start_token, max_length=100, temperature=1.0, eos_token=None):
    model.eval()
    tokens = [start_token]
    input_ids = torch.tensor([tokens]).long() 
    
    for _ in range(max_length):
        with torch.no_grad():
            logits = model(input_ids)  
            next_token_logits = logits[0, -1] / temperature
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()

        tokens.append(next_token)
        input_ids = torch.tensor([tokens]).long()
        
        if eos_token is not None and next_token == eos_token:
            break

    return tokens

In [60]:
start_token = tokenizer.vocab["Bar_None"]
tokens = generate_with_temperature(
    model,
    start_token=start_token,  
    max_length=1024,
    temperature=1.0,
    eos_token=None  
)

In [68]:
temperatures = [0.75, 1, 1.25]
for temperature in temperatures:
    tokens = generate_with_temperature(
    model,
    start_token=start_token,  
    max_length=512,
    temperature=temperature,
    eos_token=None  
    )
    tok_seq = TokSequence(ids=tokens)

    score = tokenizer.decode(tok_seq)
    score.dump_midi(f"generated_remi_{temperature:.2f}.mid")

In [80]:
tokens = generate_with_temperature(
    model,
    start_token=start_token,  
    max_length=512,
    temperature=1.0,
    eos_token=None  
    )
tok_seq = TokSequence(ids=tokens)

score = tokenizer.decode(tok_seq)
score.dump_midi(f"generated.mid")