In [None]:
from pathlib import Path
import numpy as np
import torch.nn as nn
import torch
from dataclasses import dataclass
from pathlib import Path

from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DataCollator, DatasetMIDI
from torch.utils.data import DataLoader
from miditok import REMI, TokenizerConfig
from midi_player import MIDIPlayer
from model import ModelConfig, DatasetMIDI, PopTransformer


  from .autonotebook import tqdm as notebook_tqdm


In [45]:
config = ModelConfig()
tokenizer = REMI(TokenizerConfig())
config.vocab_size = tokenizer.vocab_size
midi_dir = Path("pop1k7/midi_analyzed")

dataset = DatasetMIDI(
    files_paths=list(midi_dir.glob("**/*.mid")),
    tokenizer=tokenizer,
    max_seq_len=config.max_seq_length,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
)

collator = DataCollator(tokenizer.pad_token_id)
dataloader = DataLoader(dataset, batch_size=config.batch_size, collate_fn=collator)

bos_token = tokenizer['BOS_None']
eos_token = tokenizer['EOS_None']
pad_token = tokenizer['PAD_None']

In [None]:
def sample(tokenizer, model, max_len=2000):
    sample_size = 1
    bos_token = tokenizer['BOS_None']
    eos_token = tokenizer['EOS_None']
    pad_token = tokenizer['PAD_None']
    tokens = torch.tensor([bos_token]).reshape(sample_size, 1) # (batch_n, seq_len)

    for _ in range(max_len):
        padding_mask = torch.tensor([pad_token] * tokens.shape[1]).reshape(sample_size, -1)
        logits = model(tokens, padding_mask) # (batch_n, seq_len, vocab_size)
        prob = torch.nn.functional.softmax(logits, dim=2)
        next_token = torch.multinomial(prob[:, -1], 1) # (batch_n, 1)
        tokens = torch.cat((tokens, next_token), dim=1)

        if next_token[0, -1] == eos_token:
            break

    return tokens

# tokens = sample(tokenizer, model)
# len(tokens)

In [7]:
model = PopTransformer(
    vocab_size=config.vocab_size,
    d_model=config.d_model,
    nhead=config.n_head,
    num_encoder_layers=config.num_encoder_layers,
    dim_feedforward=config.dim_feedforward,
    max_seq_length=config.max_seq_length
)

optim = torch.optim.Adam(model.parameters())
ctriterion = nn.CrossEntropyLoss(ignore_index=tokenizer["PAD_None"])
data = next(iter(dataloader))




In [274]:
for i in range(100):
    model.train()
    optim.zero_grad()

    tokens = data["input_ids"]
    mask = data["attention_mask"]

    logits = model(tokens, mask)
    n_batch, n_seq, vocab_size = logits.shape
    logits = logits.view(n_batch * n_seq, vocab_size)
    tokens = tokens.flatten()
    loss = ctriterion(logits, tokens)

    loss.backward()
    optim.step()

    if i % 20 == 0:
        print(loss.item())

4.30021333694458
1.7786098718643188
0.6422502994537354
0.24322421848773956
0.11996977031230927


In [34]:
tokens = sample(tokenizer, model)
print(tokens.shape)

score = tokenizer.decode(tokens)
score.dump_midi("sample.mid")
MIDIPlayer("/Users/stevenkao/workspace/music-hw-3/sample.mid", 100)  

torch.Size([1, 101])


In [None]:
data = next(iter(dataloader))
tokens = data['input_ids']

tokens = tokens[0:1, :]
score = tokenizer.decode(tokens)
score.dump_midi("sample.mid")

MIDIPlayer("/Users/stevenkao/workspace/music-hw-3/sample.mid", 400)  

In [23]:
from transformers import GPT2Config, GPT2LMHeadModel

bos_token = tokenizer['BOS_None']
eos_token = tokenizer['EOS_None']
pad_token = tokenizer['PAD_None']

gpt_config = GPT2Config(
    vocab_size=config.vocab_size,
    n_positions=config.max_seq_length,
    n_embd=config.d_model,
    n_layer=config.num_encoder_layers,
    n_head=config.n_head,
    bos_token_id=bos_token,
    eos_token_id=eos_token,
)

gpt2 = GPT2LMHeadModel(gpt_config)

In [None]:

data = next(iter(dataloader))
tokens = data['input_ids']
attention_mask = data['attention_mask']
tokens.shape

out = gpt2(tokens, attention_mask=attention_mask)
out.logits.shape

torch.Size([32, 511, 282])

In [36]:
optim = torch.optim.Adam(gpt2.parameters())
ctriterion = nn.CrossEntropyLoss(ignore_index=tokenizer["PAD_None"])
data = next(iter(dataloader))

for i in range(100):
    gpt2.train()
    optim.zero_grad()

    tokens = data["input_ids"]
    mask = data["attention_mask"]
    out = gpt2(tokens, attention_mask=mask)
    logits = out.logits

    n_batch, n_seq, vocab_size = logits.shape
    logits = logits.view(n_batch * n_seq, vocab_size)
    tokens = tokens.flatten()
    loss = ctriterion(logits, tokens)

    loss.backward()
    optim.step()

    if i % 20 == 0:
        print(loss.item())

4.6560187339782715
1.5040385723114014
0.3334512412548065


KeyboardInterrupt: 

In [44]:
sample_size = 1
bos_token = tokenizer['BOS_None']
eos_token = tokenizer['EOS_None']
pad_token = tokenizer['PAD_None']
tokens = torch.tensor([bos_token]).reshape(sample_size, 1) # (batch_n, seq_len)

sample = gpt2.generate(tokens, do_sample=True, max_length=100)
sample
    # for _ in range(max_len):
    #     padding_mask = torch.tensor([pad_token] * tokens.shape[1]).reshape(sample_size, -1)
    #     logits = model(tokens, padding_mask) # (batch_n, seq_len, vocab_size)
    #     prob = torch.nn.functional.softmax(logits, dim=2)
    #     next_token = torch.multinomial(prob[:, -1], 1) # (batch_n, 1)
    #     tokens = torch.cat((tokens, next_token), dim=1)

    #     if next_token[0, -1] == eos_token:
    #         break

    # return tokens

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


tensor([[  1,   1,   1,   1,   1,   1, 130, 130, 130, 130, 130, 130, 130,  46,
          46,  46,  46,  46,  46,  46,  46,  46, 160, 160,  43,  43,  43,  43,
          43,  43,  43,  43,  43, 197, 197, 197, 197, 105, 104,  38,  38,  38,
          38,  38,  38, 105, 105, 105, 105, 105, 105, 105, 105, 143, 143, 143,
         143, 143, 143, 143, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145,
         145,  54,  54,  54,  54,  54,  54,  54,  54,  54,  54,  54,  54,  54,
          54,  54,  54,  54,  54,  54,  54,  54,  54,  54,  54, 165, 165, 165,
         165, 165]])