In [1]:
from miditoolkit.midi import parser as mid_parser  
from miditoolkit.midi import containers as ct
from transformers import BertConfig
from sklearn.model_selection import train_test_split
import numpy as np
import torch
import pickle

from MidiBERT.model import MidiBertSeq2Seq
from MidiBERT.modelLM import MidiBertSeq2SeqComplete

Matplotlib created a temporary config/cache directory at /tmp/matplotlib-8jc7uidv because the default path (/uac/ascstd/wkwong/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [2]:
device = "cuda"
skyline_max_len = 90
hs = 768
seq_len = 512
token_len = 6

e2w, w2e = np.load('dict/CP_program.pkl', allow_pickle=True)
X = np.load('data/processed/String.npy', allow_pickle=True)
y = np.load('data/processed/String_ans.npy', allow_pickle=True)
X, y = torch.tensor(X), torch.tensor(y)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.15, random_state=42
)
print(X_test.shape, y_test.shape)

torch.Size([90, 512, 6]) torch.Size([90, 513, 6])


In [3]:
config_en = BertConfig(
    max_position_embeddings=seq_len,
    position_embedding_type="relative_key_query",
    hidden_size=hs,
)
config_de = BertConfig(
    max_position_embeddings=seq_len,
    position_embedding_type="relative_key_query",
    hidden_size=hs,
)
config_de.is_decoder = True
config_de.add_cross_attention = True
midibert = MidiBertSeq2Seq(config_en, config_de, '', e2w, w2e)

model = MidiBertSeq2SeqComplete(midibert).to(device)
model.eval()

checkpoint = torch.load('result/seq2seq/MidiBert/model_best.ckpt')
for key in list(checkpoint["state_dict"].keys()):
            # rename the states in checkpoint
            checkpoint["state_dict"][key.replace("module.", "")] = checkpoint[
                "state_dict"
            ].pop(key)
model.load_state_dict(checkpoint['state_dict'])

Some weights of BertLMHeadModel were not initialized from the model checkpoint at ./s2s_decoder_model/ and are newly initialized: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [4]:
BOS = np.array([midibert.e2w[etype]["%s <BOS>" % etype] for etype in midibert.e2w])
PAD = np.array([midibert.e2w[etype]["%s <PAD>" % etype] for etype in midibert.e2w])
EOS = np.array([midibert.e2w[etype]["%s <EOS>" % etype] for etype in midibert.e2w])
ABS = np.array([midibert.e2w[etype]["%s <ABS>" % etype] for etype in midibert.e2w])

In [5]:
@torch.no_grad()
def inference(token):
    token = token.reshape((1, seq_len, token_len))
    token = token.to(device)
    attn_mask_encoder = (
        (token[:, :, 0] != midibert.bar_pad_word)
        .float()
        .to(device)
    )  # (batch, seq_len)

    outputs = np.array([BOS])
    for i in range(seq_len):
        decoder_input_ids = np.array([np.vstack((outputs, np.tile(midibert.pad_word_np, (seq_len - 1 - i, 1))))])
        # assert decoder_input_ids.shape == (1, seq_len, token_len)
        decoder_input_ids = torch.from_numpy(decoder_input_ids).to(device)
        attn_mask_decoder = (
            (decoder_input_ids[:, :, 0] != midibert.bar_pad_word)
            .float()
            .to(device)
        )  # (batch, seq_len)

        # tuples of size 6, each element is a tensor with shape: (batch, seq_len, n_tokens)
        predicted_word = model(token, decoder_input_ids, attn_mask_encoder, attn_mask_decoder)

        # event to word
        temp = []
        for j, etype in enumerate(midibert.e2w):
            o = np.argmax(predicted_word[j].cpu().detach().numpy(), axis=-1)
            temp.append(o)
        temp = np.stack(temp, axis=-1)[0][i]
        
        # stop generating when EOS or PAD is generated
        is_end = (temp == EOS).all() or (temp == PAD).all()
        print(f'Generated {i} notes', end="\n" if is_end else "\r")
        if is_end:
            break
        outputs = np.vstack((outputs, temp))

    outputs = outputs[1:]
    last_pos = 999
    changed = 0
    for i, tk in enumerate(outputs):
        if tk[1] >= last_pos and tk[0] == 0:
            outputs[i][0] = 1
            changed += 1
        last_pos = tk[1]
    print(f"Changed {changed} tokens")
    return outputs

In [6]:
def token2mid(page, out_path):
    # meta data
    out = mid_parser.MidiFile()
    out.ticks_per_beat = 480

    # First Time Signature
    ts = int(page[0][5]+2)
    last_ts = ts
    current_beat = -ts*480
    out.time_signature_changes.append(ct.TimeSignature(ts, 4, 0))

    for idx, n in enumerate(page):
        # Stop if end or padding starts
        if (n == EOS).all() or (n == PAD).all():
            break

        # Time Signature for THIS note
        ts = int(page[idx][5]+2)

        # Bar moves forward
        if n[0] == 0 or (n[:-1] == ABS[:-1]).all():
            current_beat += last_ts*480

        # Update new Time Signature if any
        if ts != last_ts:
            last_ts = ts
            out.time_signature_changes.append(ct.TimeSignature(ts, 4, current_beat))

        # Add THIS note
        if (n[:-1] != ABS[:-1]).all():
            program = n[4]
            if program not in [i.program for i in out.instruments]:
                out.instruments.append(ct.Instrument(program=program, is_drum=False, name='reduction'))
                instrument = out.instruments[-1]
            else:
                index = [i.program for i in out.instruments].index(program)
                instrument = out.instruments[index]
            instrument.notes.append(
                ct.Note(
                    start=int(current_beat + n[1]*480/12),
                    end=int(current_beat + (n[1]+n[3]+1)*(480/12)),
                    pitch=n[2] + 22,
                    velocity=90
                )
            )

    out.dump(out_path)

In [7]:
# get one sample
X_test = X_test[0, :, :]
y_test = y_test[0, 1:, :]

token2mid(y_test.cpu().detach().numpy(), "./test_ans.mid")
token2mid(X_test.cpu().detach().numpy(), "./test_input.mid")

all_tokens = inference(X_test)
token2mid(all_tokens, "./test_gen.mid")

Generated 423 notes
Changed 1 tokens
