In [1]:
# set cuda visible device
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"  # or "1", "0,1", etc.

### Clean Path and Import

In [2]:
def switch_to_mamba_only():
    import sys

    desired_path = "/home/yihsin/simba-ldm-midi"
    unwanted_path = "/home/yihsin/mamba"

    # Remove unwanted path
    if unwanted_path in sys.path:
        sys.path.remove(unwanted_path)

    # Remove any cached modules from unwanted path
    for name, module in list(sys.modules.items()):
        if hasattr(module, '__file__') and module.__file__ and unwanted_path in module.__file__:
            del sys.modules[name]

    # Add desired path if not already
    if desired_path not in sys.path:
        sys.path.insert(0, desired_path)

switch_to_mamba_only()

In [3]:
# python 
import json
import random
from tqdm import tqdm

# pytorch
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# custom model

from pl_model import Text_Mmamba_pl

### Util Functions (pre-defined)

In [12]:
import os
import pickle

def write_pkl(obj, filename):
    with open(filename, 'wb') as f:
        pickle.dump(obj, f)

def read_pkl(filename):
    with open(filename, 'rb') as f:
        return pickle.load(f)

def get_dict(dict_path="/home/yihsin/mamba/vocab/skyline2midi_vocab.pkl"):
    return read_pkl(dict_path)

In [13]:
# define tokenizer
from miditok import REMI, TokenizerConfig  # here we choose to use REMI

# Our parameters <- use default tokens & remi+ for first trial
TOKENIZER_PARAMS = {
    "pitch_range": (21, 109),
    "beat_res": {(0, 4): 8, (4, 12): 4},
    "num_velocities": 32,
    "special_tokens": ["PAD", "BOS", "EOS", "MASK"],
    "use_chords": True,
    "use_rests": False,
    "use_tempos": True,
    "use_time_signatures": True, #remi+
    "use_programs": True,        #remi+
    "num_tempos": 32,  # number of tempo bins
    "tempo_range": (40, 250),  # (min, max)
    "one_token_stream_for_programs": True
}
config = TokenizerConfig(**TOKENIZER_PARAMS)

# Creates the tokenizer
tokenizer = REMI(config)
event2idx = tokenizer.vocab
# idx2event[i] = event corresponding to index i
idx2event = [event for event, idx in sorted(event2idx.items(), key=lambda x: x[1])]

  super().__init__(tokenizer_config, params)


### Util Functions (Custom)

In [14]:
# model set up
import torch
from tqdm import tqdm

def top_k_logits(logits, k):
    """Keep only top k logits."""
    v, ix = torch.topk(logits, k)
    mask = logits < v[..., -1, None]
    logits[mask] = -float('Inf')
    return logits

def top_p_logits(logits, p=0.9):
    """Nucleus (top-p) filtering"""
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_mask = cumulative_probs > p
    # Shift mask to include the first token above p
    sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
    sorted_mask[..., 0] = 0

    # Scatter back to original ordering
    mask = torch.zeros_like(logits, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_mask)
    logits[mask] = -float('Inf')
    return logits

@torch.no_grad()
def generate_autoregressively(
    model,
    seq_len=100,
    codec_layer=1,
    temperature=1.0,
    sample=True,
    prompt=None,
    top_k=None,
    top_p=None
):
    device = next(model.parameters()).device

    if prompt is None:
        # Default start: token 0, shape [1, codec_layer, 10]
        generated = torch.zeros((1, codec_layer, 10), dtype=torch.long, device=device)
    else:
        generated = torch.tensor(prompt, device=device).unsqueeze(0).unsqueeze(0)

    for _ in tqdm(range(seq_len), desc="Generating"):
        # output = model(generated, [], [])  # shape: [1, codec_layer, seq_len, vocab_size]
        output = model.forward(generated.squeeze(1)).unsqueeze(1)
        logits = output[:, :, -1:, :]  # shape: [1, codec_layer, 1, vocab_size]
        logits = logits / temperature

        if top_k is not None:
            logits = top_k_logits(logits, top_k)
        if top_p is not None:
            logits = top_p_logits(logits, top_p)

        probs = torch.softmax(logits, dim=-1)

        if sample:
            next_token = torch.multinomial(probs.view(-1, probs.shape[-1]), num_samples=1)
        else:
            next_token = torch.argmax(probs, dim=-1).view(-1, 1)

        next_token = next_token.unsqueeze(0)  # shape: [1, codec_layer, 1]
        generated = torch.cat([generated, next_token], dim=2)

    return generated

### Start Generation

In [4]:
# path and dictionary setups
# event2idx, idx2event = get_dict()
dataset_path = "/home/yihsin/midicaps-mini-parsed"

In [5]:
# model setup
model_path = "/home/yihsin/simba-ldm-midi/midicaps-gen/r9ka35jh/checkpoints/epoch=83-step=11000.ckpt"
config_path = "/home/yihsin/simba-ldm-midi/0530-simple-trial/new_project/config.json"
generation_root = "./generation_0507_step11000_test"

In [6]:
with open(config_path) as f:
    config = json.load(f)

model = Text_Mmamba_pl.load_from_checkpoint(model_path, config)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
model.eval()
model.freeze()

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


[simba] not in-context, use cross attention


In [7]:
from dataloader import *
train_data = MIDICaps_Dataset(
    root_path = dataset_path, 
    trv = "valid",
    codec_layer=1, 
    is_incontext = False
)
loader = DataLoader(dataset=train_data, batch_size = 1, shuffle=True, num_workers=4, pin_memory=True)

[dataloader] valid set initialization done with 168 files.


In [8]:
import random

# Convert DataLoader to a list to access individual samples (careful with large datasets)
samples = list(loader)

# Randomly choose 3 samples
chosen_samples = random.sample(samples, 3)

# Iterate through them

In [9]:
L = 2588//3
prompt_seqs = []
with torch.autocast(device_type="cuda", dtype=torch.float32):
    with torch.no_grad():
        device = 'cuda'
        for i, sample in enumerate(chosen_samples):
            print(f"Sample {i + 1}:")
            # assuming sample is (input, label)
            x, mask, y, des = sample
            print(des["description"])
            prompt_seq = model(description=des["description"], length=L, g_scale=3)
            prompt_seqs.append(prompt_seq)

Sample 1:
['A melodic and relaxing jazz song featuring acoustic guitar, piano, acoustic bass, and drums, perfect for the Christmas season or as background music for a documentary. Set in the key of F major with a 4/4 time signature, it maintains an Allegro tempo throughout its duration.']


100%|██████████| 872/872 [01:33<00:00,  9.31it/s]


Sample 2:
['A melodic rock song with electronic elements, featuring a clean electric guitar playing the lead, accompanied by a string ensemble, fretless bass, and voice oohs. Set in the key of F minor with a 4/4 time signature, it moves at an Andante tempo, creating a relaxing and somewhat dark atmosphere. The chord progression of C#, Eb, and Fm is prominent throughout the piece, adding to its emotive quality.']


100%|██████████| 872/872 [01:24<00:00, 10.29it/s]


Sample 3:
["This lengthy pop composition exudes a joyful and melodic energy, driven by a dynamic blend of drums, piano, brass section, alto saxophone, and electric bass. Set in A minor and maintaining a moderate 4/4 tempo, it's an uplifting piece that's well-suited for corporate or background settings."]


100%|██████████| 872/872 [01:28<00:00,  9.83it/s]


In [19]:
tokens = prompt_seqs[1].reshape(-1)
tokens = tokens[tokens != 530]
generated_midi = tokenizer(tokens)  # MidiTok can handle PyTorch/Numpy/Tensorflow tensors
generated_midi.dump_midi("./decoded_midi.mid")

In [None]:
# test different configurations
test_tempatures = [0.9,1.2,1.5]
test_p = [0.9]
test_k = [5]

In [None]:
if not os.path.exists(generation_root):
    os.makedirs(generation_root)

for sample in samples:
    sample_idx = sample.split(".")[0]
    print(f"--- sample {sample_idx} ---")
    sample_save_dir = os.path.join(generation_root,sample_idx)
    if not os.path.exists(sample_save_dir):
        os.makedirs(sample_save_dir)

    # save prompt for reference
    test_sample = os.path.join(dataset_path,sample)
    idxs = read_pkl(test_sample)[:150]
    events = [idx2event[i] for i in idxs if i<366]
    event_to_midi(events,'full',output_midi_path=os.path.join(sample_save_dir,f"{sample_idx}_orig.mid"))
    prompt = idxs

    # test generation of different configurations
    for tt in test_tempatures:
        for tp in test_p:
            for tk in test_k:
                print(f"start generation of {sample_idx}_t={tt}_k={tk}_p={tp}.mid")
                gen = generate_autoregressively(
                    model.music_model, 
                    seq_len=2000, 
                    codec_layer=1, 
                    temperature=tt, 
                    top_k = tk,
                    top_p = tp,
                    sample=True, 
                    prompt=prompt
                )
                flat_list = gen.view(-1).tolist()  # or tensor.squeeze().tolist()
                events = [idx2event[i] for i in flat_list if i<366]
                event_to_midi(events,'full',output_midi_path=os.path.join(sample_save_dir,f"{sample_idx}_t={tt}_k={tk}_p={tp}.mid"))

    # test generation arg_max
    gen = generate_autoregressively(model.music_model, seq_len=2000, sample=False, prompt=prompt)
    flat_list = gen.view(-1).tolist()  
    events = [idx2event[i] for i in flat_list if i<366]
    event_to_midi(events,'full',output_midi_path=os.path.join(sample_save_dir,f"{sample_idx}_argmax.mid"))


--- sample 004_740 ---
[Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 143), Event(name: Chord | value: None_None), Event(name: Beat | value: 8), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Beat | value: 12), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 146), Event(name: Chord | value: 11_sus2), Event(name: Note_Pitch | value: 71), Event(name: Note_Duration | value: 1080), Event(name: Note_Velocity | value: 67), Event(name: Beat | value: 4), Event(name: Tempo | value: 143), Event(name: Chord | value: Conti_Conti)]
# tempo changes: 19 | # notes: 21
start generation of 004_740_t=0.9_k=5_p=0.9.mid


Generating:   0%|          | 0/2000 [00:00<?, ?it/s]

Generating: 100%|██████████| 2000/2000 [01:06<00:00, 29.88it/s]


[Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 143), Event(name: Chord | value: None_None), Event(name: Beat | value: 8), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Beat | value: 12), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 146), Event(name: Chord | value: 11_sus2), Event(name: Note_Pitch | value: 71), Event(name: Note_Duration | value: 1080), Event(name: Note_Velocity | value: 67), Event(name: Beat | value: 4), Event(name: Tempo | value: 143), Event(name: Chord | value: Conti_Conti)]
# tempo changes: 53 | # notes: 377
start generation of 004_740_t=1.2_k=5_p=0.9.mid


Generating: 100%|██████████| 2000/2000 [01:05<00:00, 30.35it/s]


[Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 143), Event(name: Chord | value: None_None), Event(name: Beat | value: 8), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Beat | value: 12), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 146), Event(name: Chord | value: 11_sus2), Event(name: Note_Pitch | value: 71), Event(name: Note_Duration | value: 1080), Event(name: Note_Velocity | value: 67), Event(name: Beat | value: 4), Event(name: Tempo | value: 143), Event(name: Chord | value: Conti_Conti)]
# tempo changes: 248 | # notes: 331
start generation of 004_740_t=1.5_k=5_p=0.9.mid


Generating: 100%|██████████| 2000/2000 [01:04<00:00, 30.89it/s]


[Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 143), Event(name: Chord | value: None_None), Event(name: Beat | value: 8), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Beat | value: 12), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 146), Event(name: Chord | value: 11_sus2), Event(name: Note_Pitch | value: 71), Event(name: Note_Duration | value: 1080), Event(name: Note_Velocity | value: 67), Event(name: Beat | value: 4), Event(name: Tempo | value: 143), Event(name: Chord | value: Conti_Conti)]
# tempo changes: 227 | # notes: 334


Generating: 100%|██████████| 2000/2000 [01:04<00:00, 31.06it/s]


[Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 143), Event(name: Chord | value: None_None), Event(name: Beat | value: 8), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Beat | value: 12), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 146), Event(name: Chord | value: 11_sus2), Event(name: Note_Pitch | value: 71), Event(name: Note_Duration | value: 1080), Event(name: Note_Velocity | value: 67), Event(name: Beat | value: 4), Event(name: Tempo | value: 143), Event(name: Chord | value: Conti_Conti)]
# tempo changes: 142 | # notes: 342
--- sample 004_878 ---
[Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 89), Event(name: Chord | value: None_None), Event(name: Beat | value: 8), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti)

Generating: 100%|██████████| 2000/2000 [01:05<00:00, 30.53it/s]


[Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 89), Event(name: Chord | value: None_None), Event(name: Beat | value: 8), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Beat | value: 12), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 92), Event(name: Chord | value: 0_M), Event(name: Note_Pitch | value: 60), Event(name: Note_Duration | value: 600), Event(name: Note_Velocity | value: 52), Event(name: Note_Pitch | value: 55), Event(name: Note_Duration | value: 720), Event(name: Note_Velocity | value: 46)]
# tempo changes: 17 | # notes: 596
start generation of 004_878_t=1.2_k=5_p=0.9.mid


Generating: 100%|██████████| 2000/2000 [01:05<00:00, 30.38it/s]


[Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 89), Event(name: Chord | value: None_None), Event(name: Beat | value: 8), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Beat | value: 12), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 92), Event(name: Chord | value: 0_M), Event(name: Note_Pitch | value: 60), Event(name: Note_Duration | value: 600), Event(name: Note_Velocity | value: 52), Event(name: Note_Pitch | value: 55), Event(name: Note_Duration | value: 720), Event(name: Note_Velocity | value: 46)]
# tempo changes: 87 | # notes: 543
start generation of 004_878_t=1.5_k=5_p=0.9.mid


Generating: 100%|██████████| 2000/2000 [01:05<00:00, 30.59it/s]


[Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 89), Event(name: Chord | value: None_None), Event(name: Beat | value: 8), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Beat | value: 12), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 92), Event(name: Chord | value: 0_M), Event(name: Note_Pitch | value: 60), Event(name: Note_Duration | value: 600), Event(name: Note_Velocity | value: 52), Event(name: Note_Pitch | value: 55), Event(name: Note_Duration | value: 720), Event(name: Note_Velocity | value: 46)]
# tempo changes: 52 | # notes: 620


Generating: 100%|██████████| 2000/2000 [01:04<00:00, 31.17it/s]


[Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 89), Event(name: Chord | value: None_None), Event(name: Beat | value: 8), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Beat | value: 12), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 92), Event(name: Chord | value: 0_M), Event(name: Note_Pitch | value: 60), Event(name: Note_Duration | value: 600), Event(name: Note_Velocity | value: 52), Event(name: Note_Pitch | value: 55), Event(name: Note_Duration | value: 720), Event(name: Note_Velocity | value: 46)]
# tempo changes: 23 | # notes: 682
--- sample 002_253 ---
[Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 92), Event(name: Chord | value: None_None), Event(name: Beat | value: 8), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_

Generating: 100%|██████████| 2000/2000 [01:05<00:00, 30.31it/s]


[Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 92), Event(name: Chord | value: None_None), Event(name: Beat | value: 8), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Beat | value: 12), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 89), Event(name: Chord | value: 10_m7), Event(name: Note_Pitch | value: 58), Event(name: Note_Duration | value: 240), Event(name: Note_Velocity | value: 43), Event(name: Note_Pitch | value: 53), Event(name: Note_Duration | value: 360), Event(name: Note_Velocity | value: 37)]
# tempo changes: 6 | # notes: 507
start generation of 002_253_t=1.2_k=5_p=0.9.mid


Generating: 100%|██████████| 2000/2000 [01:02<00:00, 32.05it/s]


[Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 92), Event(name: Chord | value: None_None), Event(name: Beat | value: 8), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Beat | value: 12), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 89), Event(name: Chord | value: 10_m7), Event(name: Note_Pitch | value: 58), Event(name: Note_Duration | value: 240), Event(name: Note_Velocity | value: 43), Event(name: Note_Pitch | value: 53), Event(name: Note_Duration | value: 360), Event(name: Note_Velocity | value: 37)]
# tempo changes: 97 | # notes: 519
start generation of 002_253_t=1.5_k=5_p=0.9.mid


Generating: 100%|██████████| 2000/2000 [01:05<00:00, 30.67it/s]


[Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 92), Event(name: Chord | value: None_None), Event(name: Beat | value: 8), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Beat | value: 12), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 89), Event(name: Chord | value: 10_m7), Event(name: Note_Pitch | value: 58), Event(name: Note_Duration | value: 240), Event(name: Note_Velocity | value: 43), Event(name: Note_Pitch | value: 53), Event(name: Note_Duration | value: 360), Event(name: Note_Velocity | value: 37)]
# tempo changes: 105 | # notes: 509


Generating: 100%|██████████| 2000/2000 [01:04<00:00, 31.17it/s]


[Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 92), Event(name: Chord | value: None_None), Event(name: Beat | value: 8), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Beat | value: 12), Event(name: Tempo | value: Conti), Event(name: Chord | value: Conti_Conti), Event(name: Bar | value: None), Event(name: Beat | value: 0), Event(name: Tempo | value: 89), Event(name: Chord | value: 10_m7), Event(name: Note_Pitch | value: 58), Event(name: Note_Duration | value: 240), Event(name: Note_Velocity | value: 43), Event(name: Note_Pitch | value: 53), Event(name: Note_Duration | value: 360), Event(name: Note_Velocity | value: 37)]
# tempo changes: 6 | # notes: 698


In [None]:
test_sample = dataset_path + "002_049.pkl"
idxs = read_pkl(prompt_path)[:300]

In [None]:
gen = generate_autoregressively(model.music_model, seq_len=100, codec_layer=1, temperature=1.5, sample=False, prompt=prompt)
flat_list = gen.view(-1).tolist()  # or tensor.squeeze().tolist()
events = [idx2event[i] for i in idxs if i<366]
event_to_midi(events,'full',output_midi_path="test_orig_033.mid")

In [None]:
directory = "/home/yihsin/mamba/notebooks/generation_0511_ep400_sa"
subdirs = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]
print(subdirs)

['004_740', '004_878', '002_253', '004_602', '004_223', '003_261', '004_768', '004_167', '003_170', '002_361', '004_103', '004_540', '004_447', '003_251', '003_068', '004_880', '004_146', '003_035', '004_893', '002_171', '004_412', '004_939', '002_153']


In [None]:
import shutil
for subdir in subdirs:
    path = os.path.join(directory, subdir)
    shutil.move(os.path.join(path, f"{subdir}_orig.mid"), "/home/yihsin/mamba/notebooks/test_samples_0511/original")
    shutil.move(os.path.join(path, f"{subdir}_t=1.2_k=5_p=0.9.mid"), "/home/yihsin/mamba/notebooks/test_samples_0511/generated")