In [None]:
from pathlib import Path
import numpy as np
import torch.nn as nn
import torch
from dataclasses import dataclass

In [161]:
@dataclass
class ModelConfig:
    vocab_size: int = -1
    d_model: int = 128
    n_head: int = 2
    num_encoder_layers: int = 2
    dim_feedforward: int = 128
    max_seq_length: int = 512

config = ModelConfig()

In [2]:
midi_dir = Path("pop1k7/midi_analyzed")

In [3]:
from midi_player import MIDIPlayer
from midi_player.stylers import basic, cifka_advanced

# midi_path = "pop1k7/midi_transcribed/src_001/3.midi"
midi_path = "tutorial/000.mid"
MIDIPlayer(midi_path, 400)  


In [166]:
from miditok import REMI, TokenizerConfig
tokenizer = REMI(TokenizerConfig())

config.vocab_size = tokenizer.vocab_size

In [167]:
len(tokenizer.vocab), tokenizer["BOS_None"], tokenizer["EOS_None"], tokenizer['PAD_None']

(282, 1, 2, 0)

In [168]:
sample_path = "pop1k7/midi_transcribed/src_001/0.midi"
tokens = tokenizer(Path(sample_path))
len(tokens[0][0:10])
tokens[0][0:10]

TokSequence(tokens=['Bar_None', 'Position_24', 'Pitch_45', 'Velocity_59', 'Duration_6.1.4', 'Bar_None', 'Position_2', 'Pitch_52', 'Velocity_63', 'Duration_2.6.8'], ids=[4, 213, 29, 107, 165, 4, 191, 36, 108, 146], bytes='', events=[Event(type=Bar, value=None, time=0, desc=0), Event(type=Position, value=24, time=24, desc=24), Event(type=Pitch, value=45, time=24, desc=74), Event(type=Velocity, value=59, time=24, desc=59), Event(type=Duration, value=6.1.4, time=24, desc=50 ticks), Event(type=Bar, value=None, time=32, desc=0), Event(type=Position, value=2, time=34, desc=34), Event(type=Pitch, value=52, time=34, desc=56), Event(type=Velocity, value=63, time=34, desc=63), Event(type=Duration, value=2.6.8, time=34, desc=22 ticks)], are_ids_encoded=False, _ticks_bars=[0, 32, 64, 96, 128, 160, 192, 224, 256, 288, 320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024, 1056, 1088, 1120, 1152, 1184, 1216, 1248, 1280, 1312, 1344, 1376, 14

In [169]:
from pathlib import Path

from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DataCollator, DatasetMIDI
from torch.utils.data import DataLoader

midi_dir = Path("pop1k7/midi_analyzed")
dataset = DatasetMIDI(
    files_paths=list(midi_dir.glob("**/*.mid")),
    tokenizer=tokenizer,
    max_seq_len=config.max_seq_length,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
)

collator = DataCollator(tokenizer.pad_token_id)
dataloader = DataLoader(dataset, batch_size=3, collate_fn=collator)

u = next(iter(dataloader))
tokens = u['input_ids'][1]
tokens

tensor([  1,   4,   4, 189,  38, 107, 141,  41, 107, 140,  45, 108, 140,  50,
        109, 140, 205,  41, 107, 133,  45, 107, 132,  48, 107, 133,  53, 108,
        132, 213,  45, 108, 133,  33, 107, 133,  36, 106, 133,  40, 107, 132,
          4, 189,  34, 107, 141,  38, 108, 141,  41, 106, 141,  46, 107, 141,
        205,  33, 106, 132,  40, 107, 132,  45, 107, 132,  36, 105, 132, 213,
         33, 108, 133,  36, 107, 132,  40, 108, 133,  45, 109, 132,   4, 189,
         38, 108, 141,  41, 107, 140,  45, 109, 140,  50, 108, 140, 205,  41,
        107, 133,  45, 107, 132,  48, 107, 134,  53, 107, 132, 213,  33, 107,
        133,  40, 107, 132,  45, 107, 133,  36, 107, 133,   4, 189,  34, 107,
        141,  38, 108, 141,  41, 107, 142,  46, 107, 141, 205,  33, 107, 132,
         40, 107, 132,  45, 107, 132,  36, 105, 132, 213,  40, 108, 133,  45,
        109, 132,  33, 108, 156,  36, 107, 133,  48, 107, 125,   4, 189,  38,
        108, 141,  41, 107, 140,  45, 108, 140,  50, 108, 141, 2

In [None]:
from transformers.models.gpt2 import GPT2Config, GPT2Model

config = GPT2Config(
    vocab_size=tokenizer.vocab_size,
    n_positions=1024,
    n_embd=768,
    n_layer=6,        # Number of decoder layers
    n_head=6          # Number of attention heads
)

model = GPT2Model(config)


In [None]:
y = next(iter(dataloader))
ids, casual_mask = y['input_ids'], y['attention_mask']
y = model(ids, attention_mask=casual_mask)

In [176]:


class DummyDecoder(nn.Module):
    def __init__(self):
        super(DummyDecoder, self).__init__()

    def forward(self, tgt, memory, **kwarg):
        return memory


class PositionalEmbedding(nn.Module):
    def __init__(self, max_seq_length, d_model):
        super(PositionalEmbedding, self).__init__()
        self.position_embedding = nn.Embedding(max_seq_length, d_model)
        
    def forward(self, x):
        seq_length = x.size(1)
        position_ids = torch.arange(seq_length, device=x.device).unsqueeze(0)
        positional_embeddings = self.position_embedding(position_ids)

        return positional_embeddings



class PopTransformer(nn.Module):
    def __init__(
        self,
        vocab_size,
        max_seq_length,
        d_model,
        nhead,
        num_encoder_layers,
        dim_feedforward,
    ):
        super().__init__()

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=0,
            dim_feedforward=dim_feedforward,
            custom_decoder=DummyDecoder(),
        )

        self.linear = nn.Linear(d_model, vocab_size)
        nn.init.zeros_(self.linear.bias)

        self.position_embedding = PositionalEmbedding(max_seq_length, d_model)
        self.token_embedding = nn.Embedding(vocab_size, d_model)


    def forward(self, tokens):
        ids, attn_mask = tokens["input_ids"], tokens["attention_mask"]
        n_batch, n_seq = ids.shape
        casual_mask = self.transformer.generate_square_subsequent_mask(n_seq)

        # n_batch, n_seq, n_embed
        y = self.token_embedding(ids) + self.position_embedding(ids)
        y = y.permute(1, 0, 2)  # n_seq, n_batch, n_embed
        y = self.transformer(
            src=y,
            tgt=y,
            src_mask=casual_mask,
            src_key_padding_mask=attn_mask.float(),
            src_is_causal=True,
            memory_is_causal=True,
        )
        y = y.permute(1, 0, 2)  # n_batch, n_seq, d_model
        y = self.linear(y)

        return y

        # logits = y.view(n_batch * n_seq, vocab_size)
        # target = ids.flatten()
        # loss = nn.CrossEntropyLoss(ignore_index=tokenizer["PAD_None"])
        # loss = loss(logits, target)
        # loss


model = PopTransformer(
    vocab_size=config.vocab_size,
    d_model=config.d_model,
    nhead=config.n_head,
    num_encoder_layers=config.num_encoder_layers,
    dim_feedforward=config.dim_feedforward,
    max_seq_length=config.max_seq_length
)

u = next(iter(dataloader))
model(u).shape



torch.Size([3, 511, 282])

In [178]:
model = PopTransformer(
    vocab_size=config.vocab_size,
    d_model=config.d_model,
    nhead=config.n_head,
    num_encoder_layers=config.num_encoder_layers,
    dim_feedforward=config.dim_feedforward,
    max_seq_length=config.max_seq_length
)
optim = torch.optim.Adam(model.parameters())
ctriterion = nn.CrossEntropyLoss(ignore_index=tokenizer["PAD_None"])
data = next(iter(dataloader))


In [180]:
for _ in range(1000):
    model.train()
    optim.zero_grad()

    x = data["input_ids"]
    y_pred = model(data)

    n_batch, n_seq, vocab_size = y_pred.shape
    y_pred = y_pred.view(n_batch * n_seq, vocab_size)
    x = x.flatten()
    loss = ctriterion(y_pred, x)

    loss.backward()
    optim.step()

    print(loss.item())

3.512178659439087
3.3456530570983887
3.1785552501678467
3.0197675228118896
2.8517987728118896
2.7026479244232178
2.5597283840179443
2.4318881034851074
2.2887723445892334
2.158266305923462
2.048886775970459
1.925934910774231
1.815224289894104
1.7157306671142578
1.6132076978683472
1.5271167755126953
1.4288290739059448
1.345755696296692
1.2672566175460815
1.182913899421692
1.1164321899414062
1.0528889894485474
0.9845254421234131
0.9337214827537537
0.8644852042198181
0.8142988085746765
0.7659730315208435
0.7150060534477234
0.6733663082122803
0.626871645450592
0.5946667194366455
0.5508633255958557
0.517102062702179
0.4909314513206482
0.4616050720214844
0.43234673142433167
0.3973710834980011
0.3805658221244812
0.35750722885131836
0.33327537775039673
0.3133848011493683
0.2891451418399811
0.2765183448791504
0.25594601035118103
0.2440302073955536
0.229014590382576
0.21349020302295685
0.20331361889839172
0.19222809374332428
0.182834655046463
0.17245422303676605
0.16269183158874512
0.155989497900

KeyboardInterrupt: 