In [64]:
import argparse
import logging
import pathlib
import pprint
import subprocess
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.utils.data
import tqdm
from torch.xpu import device_of

import dataset
import music_x_transformers
import representation
import utils

import torch
import representation
from music_x_transformers import MusicXTransformer

device = torch.device("cuda")

In [65]:
import torch
import torch.nn as nn

In [66]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [67]:
# import another midi file and encode
import representation
import muspy
import matplotlib.pyplot as plt
import os

# Paths
encoding_path = "../data/sod/processed/notes/encoding.json"
midi_in_path = "../data/sod/SOD/Kunstderfuge/0/1.mid"

# Set the output base to your Desktop
desktop = os.path.expanduser("~/Desktop")
out_base = os.path.join(desktop, "1_out")

# Load encoding
encoding = representation.load_encoding(encoding_path)

In [68]:
# Read MIDI as MusPy object
music = muspy.read_midi(midi_in_path)

# Fix resolution mismatch if needed
expected_resolution = encoding["resolution"]
if music.resolution != expected_resolution:
    music.adjust_resolution(expected_resolution)


In [69]:
codes = representation.encode(music, encoding)  # shape: (seq_len, 6)

In [70]:
codes.shape

(14550, 6)

In [71]:
entire_len = codes.shape[0]

In [72]:
entire_len

14550

In [73]:
import numpy as np

instruments = codes[:, 5]
pitches = codes[:, 3]

# Find min/max pitch per instrument
pitch_ranges = {}
for inst in np.unique(instruments):
    inst_pitches = pitches[instruments == inst]
    pitch_ranges[inst] = (inst_pitches.min(), inst_pitches.max())

# For each event, sample a random pitch within the range of its instrument
rand_pitches = np.zeros_like(pitches)
for idx, inst in enumerate(instruments):
    lo, hi = pitch_ranges[inst]
    rand_pitches[idx] = np.random.randint(lo, hi + 1)  # inclusive

In [74]:
rand_pitches.min()

0

In [75]:
# Exclude zeros
nonzero_pitches = pitches[pitches != 0]
min_pitch = nonzero_pitches.min()
print(min_pitch)

21


In [76]:
orig_durations = codes[:, 4]
log_durs = np.log(orig_durations + 1e-8)  # avoid log(0)
mean_log_dur = log_durs.mean()
std_log_dur = log_durs.std()
nonzero_durations = orig_durations[orig_durations > 0]
min_dur = nonzero_durations.min()
max_dur = nonzero_durations.max()

mean_log_dur_tensor = torch.tensor(mean_log_dur)

min_dur_tensor = torch.tensor(min_dur, dtype=mean_log_dur_tensor.dtype)
max_dur_tensor = torch.tensor(max_dur, dtype=mean_log_dur_tensor.dtype)

min_dur_tensor = min_dur_tensor.to(device)
max_dur_tensor = max_dur_tensor.to(device)

max_seq_len = 256


rand_log_dur = np.random.normal(mean_log_dur, std_log_dur, size=orig_durations.shape)
rand_durations = np.exp(rand_log_dur)

In [77]:
rand_durations.max()

224.43119936902576

In [78]:
max_dur

32

In [83]:
# 2. Recreate model with training hyperparameters
model = MusicXTransformer(
    dim=64,
    encoding=encoding,
    depth=3,
    heads=4,
    max_seq_len=256,
    max_beat=64,
    rotary_pos_emb=False,
    use_abs_pos_emb=True,
    emb_dropout=0.2,
    attn_dropout=0.2,
    ff_dropout=0.2,
)

In [84]:
model = model.to(device)

In [85]:
model.load_state_dict(torch.load("../exp/sod/ape/checkpoints/model_200.pt", map_location=device))

<All keys matched successfully>

In [86]:
# class SuperModel(nn.Module):
#     def __init__(self, , max_seq_len=32):
#         super(SuperModel, self).__init__()
#
#         self.pit = nn.Parameter(torch.tensor(rand_pitches, dtype=torch.float32).reshape(1, -1, 1))  # shape (1, 11302, 1)
#         self.dur = nn.Parameter(torch.tensor(rand_durations, dtype=torch.float32).reshape(1, -1, 1))
#
#         self.lower_bound = torch.max(torch.exp(mean_log_dur)/8, min_dur_tensor).item()
#         self.upper_bound = torch.min(torch.exp(mean_log_dur)*128, torch.tensor(32).to(device)).item()
#
#     def forward(self, ix):  #pitches must just have 1 batch. This is to just work for 1 batch, to overfit is good.
#         quantized_pit = torch.round(self.pit).clamp(21, 100)
#         clamped_dur = self.dur.clamp(self.lower_bound, self.upper_bound)
#
#
#
#         x = model(seq)
#


In [87]:
class SuperModel(nn.Module):
    def __init__(self, mean_log_dur_tensor, min_dur_tensor, device, max_seq_len=1024):
        super(SuperModel, self).__init__()
        self.pit = nn.Parameter(torch.tensor(rand_pitches, dtype=torch.float32, device=device).reshape(1, -1, 1))
        self.dur = nn.Parameter(torch.tensor(rand_durations, dtype=torch.float32, device=device).reshape(1, -1, 1))
        self.lower_bound = torch.max(torch.exp(mean_log_dur_tensor)/8, min_dur_tensor).item()
        self.upper_bound = torch.min(torch.exp(mean_log_dur_tensor)*128, torch.tensor(32, device=device)).item()
        self.max_seq_len = max_seq_len

    def forward(self, codes, ix):
        # codes: numpy array (N, 6)
        # ix: int (start index)
        seq = torch.tensor(codes[ix:ix+self.max_seq_len], dtype=torch.long, device=self.pit.device)

        quantized_pit = torch.round(self.pit).clamp(21, 100)  # (1, N, 1)
        clamped_dur = self.dur.clamp(self.lower_bound, self.upper_bound)  # (1, N, 1)

        # Extract pitch and duration chunk
        pitch_chunk = quantized_pit[0, ix:ix+self.max_seq_len, 0]
        dur_chunk = clamped_dur[0, ix:ix+self.max_seq_len, 0]

        # Round duration, set any positive value <1 to 1
        rounded_dur = torch.round(dur_chunk)
        rounded_dur = torch.where((rounded_dur < 1) & (dur_chunk > 0), torch.ones_like(rounded_dur), rounded_dur)

        # Replace pitch and duration fields
        seq[:, 3] = pitch_chunk.long()
        seq[:, 4] = rounded_dur.long()

         # Post-process the beat field
        orig_first = seq[0, 1].item()
        seq[:, 1] = seq[:, 1] - orig_first + 1
        seq[:, 1] = torch.clamp(seq[:, 1], min=0)

        seq = seq.unsqueeze(0)




        # Pass to your model
        x = model(seq)

        return x

In [88]:
seq = torch.tensor(codes[10000:10000+32], dtype=torch.long)


In [89]:
seq.shape

torch.Size([32, 6])

In [90]:
supermodel = SuperModel(


    mean_log_dur_tensor,
    min_dur_tensor,
    device,
    max_seq_len
)

In [91]:
 codes[0:32]

array([[ 0,  0,  0,  0,  0,  0],
       [ 1,  0,  0,  0,  0,  1],
       [ 1,  0,  0,  0,  0,  6],
       [ 1,  0,  0,  0,  0, 25],
       [ 1,  0,  0,  0,  0, 26],
       [ 1,  0,  0,  0,  0, 27],
       [ 1,  0,  0,  0,  0, 28],
       [ 1,  0,  0,  0,  0, 29],
       [ 1,  0,  0,  0,  0, 31],
       [ 1,  0,  0,  0,  0, 35],
       [ 1,  0,  0,  0,  0, 38],
       [ 1,  0,  0,  0,  0, 45],
       [ 1,  0,  0,  0,  0, 46],
       [ 1,  0,  0,  0,  0, 47],
       [ 1,  0,  0,  0,  0, 48],
       [ 1,  0,  0,  0,  0, 50],
       [ 2,  0,  0,  0,  0,  0],
       [ 3,  7,  1, 50, 12, 31],
       [ 3,  7,  1, 50, 25, 47],
       [ 3,  7,  1, 57, 12, 31],
       [ 3,  7,  1, 57, 25, 47],
       [ 3,  7,  1, 58, 23, 48],
       [ 3,  7,  1, 62, 23, 48],
       [ 3,  7,  1, 69, 23, 46],
       [ 3,  8,  1, 31, 12, 31],
       [ 3,  8,  1, 43, 12, 31],
       [ 3,  8,  1, 62, 12, 31],
       [ 3,  8,  1, 69, 12, 31],
       [ 3,  9,  1, 74, 18, 31],
       [ 3,  9,  1, 81, 18, 31],
       [ 3

In [92]:
codes.shape

(14550, 6)

In [93]:
supermodel(codes,14000)

tensor(0.0116, device='cuda:0', grad_fn=<MseLossBackward0>)

In [95]:
import torch.optim as optim

optimizer = optim.Adam(supermodel.parameters(), lr=1e-3)

loss_total = 0.0

supermodel.train()  # Set model to training mode
for ix in range(1, 1400):  # inclusive of 14000
    loss = supermodel(codes, ix)
    # If your loss is a tensor, make sure to use .item() or .sum() as appropriate
    loss_total += loss if isinstance(loss, float) else loss.sum()

# If you want to backpropagate:
optimizer.zero_grad()
loss_total.backward()
optimizer.step()

In [96]:
num_steps = 2000

for step in range(num_steps):
    supermodel.train()
    loss_total = 0.0

    for ix in range(1, 1400):  # You probably meant 14000 for full range
        loss = supermodel(codes, ix)
        loss_total += loss if isinstance(loss, float) else loss.sum()

    optimizer.zero_grad()
    loss_total.backward()
    optimizer.step()

    print(f"Step {step+1}/{num_steps} - Loss: {loss_total.item() if hasattr(loss_total, 'item') else loss_total}")

Step 1/2000 - Loss: 27.806901931762695
Step 2/2000 - Loss: 27.38190269470215
Step 3/2000 - Loss: 27.271635055541992
Step 4/2000 - Loss: 28.607372283935547


KeyboardInterrupt: 

In [100]:
supermodel = SuperModel(


    mean_log_dur_tensor,
    min_dur_tensor,
    device,
    max_seq_len
)

In [98]:
num_steps = 500

for step in range(num_steps):
    supermodel.train()
    loss_total = 0.0

    for ix in range(1, 700):  # You probably meant 14000 for full range
        loss = supermodel(codes, ix)
        loss_total += loss if isinstance(loss, float) else loss.sum()

    optimizer.zero_grad()
    loss_total.backward()
    optimizer.step()

    print(f"Step {step+1}/{num_steps} - Loss: {loss_total.item() if hasattr(loss_total, 'item') else loss_total}")

Step 1/500 - Loss: 14.627222061157227
Step 2/500 - Loss: 14.74473762512207
Step 3/500 - Loss: 13.444860458374023
Step 4/500 - Loss: 13.211654663085938
Step 5/500 - Loss: 13.355987548828125
Step 6/500 - Loss: 12.792304992675781
Step 7/500 - Loss: 13.47061824798584
Step 8/500 - Loss: 12.651981353759766
Step 9/500 - Loss: 13.394171714782715
Step 10/500 - Loss: 13.922688484191895
Step 11/500 - Loss: 14.595026969909668
Step 12/500 - Loss: 14.697186470031738
Step 13/500 - Loss: 13.850632667541504
Step 14/500 - Loss: 13.1456298828125
Step 15/500 - Loss: 12.221887588500977
Step 16/500 - Loss: 13.107438087463379
Step 17/500 - Loss: 14.181408882141113
Step 18/500 - Loss: 12.432570457458496
Step 19/500 - Loss: 12.96461009979248
Step 20/500 - Loss: 14.853802680969238
Step 21/500 - Loss: 13.6891508102417
Step 22/500 - Loss: 14.524425506591797
Step 23/500 - Loss: 14.052000045776367
Step 24/500 - Loss: 12.678755760192871
Step 25/500 - Loss: 14.360322952270508
Step 26/500 - Loss: 13.136216163635254
St

KeyboardInterrupt: 

In [101]:
import copy

num_steps = 500
best_loss = float('inf')
best_model_state = None   # To hold the best model's state_dict

for step in range(num_steps):
    supermodel.train()
    loss_total = 0.0

    for ix in range(1, 1400):  # Or 14001 for full range
        loss = supermodel(codes, ix)
        loss_total += loss if isinstance(loss, float) else loss.sum()

    optimizer.zero_grad()
    loss_total.backward()
    optimizer.step()

    loss_scalar = loss_total.item() if hasattr(loss_total, 'item') else float(loss_total)
    print(f"Step {step+1}/{num_steps} - Loss: {loss_scalar}")

    # Checkpoint the best model
    if loss_scalar < best_loss:
        best_loss = loss_scalar
        best_model_state = copy.deepcopy(supermodel.state_dict())
        # Optionally, save to disk
        torch.save(best_model_state, "best_supermodel.pt")
        print(f"New best model found at step {step+1} with loss {best_loss}")

# After training, you can restore the best model:
# supermodel.load_state_dict(best_model_state)

Step 1/500 - Loss: 26.40045738220215
New best model found at step 1 with loss 26.40045738220215
Step 2/500 - Loss: 26.350906372070312
New best model found at step 2 with loss 26.350906372070312
Step 3/500 - Loss: 26.426298141479492
Step 4/500 - Loss: 26.94700813293457
Step 5/500 - Loss: 27.481698989868164
Step 6/500 - Loss: 26.522672653198242
Step 7/500 - Loss: 25.663589477539062
New best model found at step 7 with loss 25.663589477539062
Step 8/500 - Loss: 27.199718475341797
Step 9/500 - Loss: 29.32910919189453
Step 10/500 - Loss: 24.1258602142334
New best model found at step 10 with loss 24.1258602142334
Step 11/500 - Loss: 29.485454559326172
Step 12/500 - Loss: 26.24025535583496
Step 13/500 - Loss: 27.39120864868164
Step 14/500 - Loss: 26.175060272216797
Step 15/500 - Loss: 27.111841201782227
Step 16/500 - Loss: 25.47618865966797
Step 17/500 - Loss: 26.4117374420166
Step 18/500 - Loss: 27.417335510253906
Step 19/500 - Loss: 27.10292625427246
Step 20/500 - Loss: 28.680923461914062
St

KeyboardInterrupt: 

In [105]:


# If you saved best_model_state as a file
supermodel.load_state_dict(torch.load("best_supermodel.pt"))

# Now retrieve pit and dur (detach and move to cpu if needed)
pit = supermodel.pit.detach().cpu()
dur = supermodel.dur.detach().cpu()

print("pit shape:", pit.shape)
print("dur shape:", dur.shape)

pit shape: torch.Size([1, 14550, 1])
dur shape: torch.Size([1, 14550, 1])


In [106]:
quantized_pit = torch.round(pit).clamp(21, 100)  # (1, N, 1)
clamped_dur = dur.clamp(torch.max(torch.exp(mean_log_dur_tensor)/8, min_dur_tensor).item(),torch.min(torch.exp(mean_log_dur_tensor)*128, torch.tensor(32, device=device)).item())  # (1, N, 1)

In [108]:
quantized_pit.shape

torch.Size([1, 14550, 1])

In [109]:
codes.shape

(14550, 6)

In [110]:
quantized_pit_np = quantized_pit.squeeze().cpu().numpy()
clamped_dur_np = clamped_dur.squeeze().cpu().numpy()

# Replace in codes array
codes[:, 3] = quantized_pit_np
codes[:, 4] = clamped_dur_np


In [113]:
codes[:1400, :]

array([[  0,   0,   0,  21,   9,   0],
       [  1,   0,   0,  71,   2,   1],
       [  1,   0,   0,  21,   2,   6],
       ...,
       [  3, 110,   1,  27,  13,  29],
       [  3, 110,   1,  68,   2,  29],
       [  3, 110,   1,  90,  12,  29]])

In [112]:
import representation
import muspy
import matplotlib.pyplot as plt
import os

# Paths
encoding_path = "../data/sod/processed/notes/encoding.json"


# Set the output base to your Desktop
desktop = os.path.expanduser("~/Desktop")
out_base = os.path.join(desktop, "somethingnew")

# Load encoding
encoding = representation.load_encoding(encoding_path)

In [114]:
music_decoded = representation.decode(codes[:1400,:], encoding)

In [None]:
# Save as MIDI
music_decoded.write(f"{out_base}.mid")

# Save as MusPy JSON
music_decoded.save(f"{out_base}.json")

# Save as piano roll PNG
music_decoded.show_pianoroll(track_label="program")
plt.savefig(f"{out_base}.png")
plt.close()