In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
from easydict import EasyDict
dims = EasyDict(interval = 100,
                velocity = 32,
                note_on = 128,
                note_off = 128,
                pedal_on = 1,
                pedal_off = 1)

offsets = EasyDict(interval = 100,
                   velocity = dims.interval,
                   note_on = dims.interval + dims.velocity,
                   note_off = dims.interval + dims.velocity + dims.note_on,
                   pedal_on = dims.interval + dims.velocity + dims.note_on + dims.note_off,
                   pedal_off = dims.interval + dims.velocity + dims.note_on + dims.note_off + dims.pedal_on)

dataset_hparams = EasyDict(root_dir = 'dataset/',
                           max_note_duration = 2, # seconds)
                           token_length = 1024,
                           dims = dims,
                           offsets = offsets
                          )
model_hparams = EasyDict(n_tokens = offsets.pedal_off,
                         embedding_dim = 512,
                         hidden_dim = 1024
                        )      

In [16]:
class Model(nn.Module):
    def __init__(self, model_hparams):
        super().__init__()
        self.hp = model_hparams
        self.step = nn.Parameter(torch.zeros(1).long(), requires_grad=False)
        self.embedding = nn.Embedding(self.hp.n_tokens, self.hp.embedding_dim)
        self.rnn = nn.LSTM(input_size=self.hp.embedding_dim, hidden_size=self.hp.hidden_dim,
                        num_layers=3, batch_first=True, dropout=0.1)
        self.out_layer = nn.Linear(self.hp.hidden_dim, self.hp.n_tokens)
        
    def forward(self, x):
        # x : (batch, length)
        
        # (batch, length, model_dim)
        x = self.embedding(x)
        # (batch, length, hidden_dim)
        x, _ = self.rnn(x)
        # (batch, length, n_tokens)
        x = self.out_layer(x)
        return x
        

In [17]:
model = Model(model_hparams)
print(model)

Model(
  (embedding): Embedding(389, 512)
  (rnn): LSTM(512, 1024, num_layers=3, batch_first=True, dropout=0.1)
  (out_layer): Linear(in_features=1024, out_features=389, bias=True)
)


In [18]:
t = torch.randint(high=200, size=(2, 100))
t = model(t)
print(t.shape)

torch.Size([2, 100, 389])


In [25]:
model.step[0] = 1

In [None]:
tprc