In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from torch.utils.data import DataLoader
from dataset import dataset_hparams, MaestroDataset
from model import model_hparams, Model

### Init Dataset

In [2]:
maestro_dataset = MaestroDataset(dataset_hparams)
print(maestro_dataset)

train_loader = DataLoader(maestro_dataset, batch_size=16)
print(train_loader)

<dataset.MaestroDataset object at 0x7fc274aadd90>
<torch.utils.data.dataloader.DataLoader object at 0x7fc2741fbfd0>


### Init Model

In [3]:
model = Model(model_hparams).cuda()
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
print('done')

Model(
  (embedding): Embedding(390, 512)
  (rnn): LSTM(512, 1024, num_layers=3, batch_first=True, dropout=0.1)
  (out_layer): Linear(in_features=1024, out_features=390, bias=True)
)
done


In [4]:
!mkdir checkpoints
!ls checkpoints

mkdir: cannot create directory ‘checkpoints’: File exists
checkpoint_1000   checkpoint_14000  checkpoint_19000  checkpoint_5000
checkpoint_10000  checkpoint_15000  checkpoint_2000   checkpoint_6000
checkpoint_11000  checkpoint_16000  checkpoint_20000  checkpoint_7000
checkpoint_12000  checkpoint_17000  checkpoint_3000   checkpoint_8000
checkpoint_13000  checkpoint_18000  checkpoint_4000   checkpoint_9000


In [6]:
# checkpoint = torch.load('checkpoints/checkpoint_20000', map_location=torch.device('cpu'))    
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# print('done')

done


### Train Loop

In [None]:
import librosa.display
from IPython import display
import matplotlib.pyplot as plt

while True:
    for i, batch in enumerate(train_loader):
        model.step[0] += 1
        step = model.step.item()
        
        model.train()
        x = batch.cuda()
        model.zero_grad()
        
        y = model(x[:, :-1])
        loss = nn.CrossEntropyLoss()(y.reshape(-1, model_hparams.n_tokens), x[:, 1:].reshape(-1))
        loss.backward()
        optimizer.step()
        
        if step % 10 == 0:
            print(step, loss.item())
        
        if step % 100 == 0:
            display.clear_output()
            
        if step % 1000 == 0:
            save_path = 'checkpoints/checkpoint_' + str(step)
            torch.save({'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict()}, save_path)
