In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from torch.utils.data import DataLoader
from dataset import MaestroDataset
from hparams import dataset_hparams, model_hparams
from model import Model

In [None]:
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

### Init Dataset

In [None]:
maestro_dataset = MaestroDataset(dataset_hparams)
print(maestro_dataset)

train_loader = DataLoader(maestro_dataset, batch_size=16)
print(train_loader)

### Init Model

In [None]:
model = Model(model_hparams).cuda()
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print('done')

### Create directory for checkpoints

In [None]:
!mkdir -p checkpoints
!ls checkpoints

### Load checkpoint if need

In [None]:
# checkpoint = torch.load('checkpoints1e-4/checkpoint_20000', map_location=torch.device('cpu'))    
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# print('done')

### Train Loop

In [None]:
import librosa.display
from IPython import display
import matplotlib.pyplot as plt

while True:
    for i, batch in enumerate(train_loader):
        
        # Train step을 1 증가 시킵니다.
        model.step[0] += 1
        step = model.step.item()
        
        # Model을 train mode로 설정합니다.
        model.train()
        
        # batch 데이터를 GPU로 올립니다.
        x = batch.cuda()
        
        # Model의 gradient를 모두 zero로 init합니다.
        model.zero_grad()
        
        # Model에 input data를 입력합니다.
        # 출력은 다음 step의 token에 대한 probability distribution이 됩니다.
        # y : (batch, length, n_tokens)
        y = model(x[:, :-1])
        
        # Target과 비교하여 loss값을 구합니다.
        loss = nn.CrossEntropyLoss()(y.reshape(-1, model_hparams.n_tokens), x[:, 1:].reshape(-1))
        
        # Loss값을 시작으로 backpropagation을 진행하여 gradient를 업데이트합니다.
        loss.backward()
        
        # Gradient를 weight에 적용합니다.
        optimizer.step()
    
        if step % 10 == 0:
            print('step :', step, 'loss : %0.4f' % loss.item())
        
        if step % 100 == 0:
            display.clear_output()
            
        if step % 1000 == 0:
            save_path = 'checkpoints/checkpoint_' + str(step)
            torch.save({'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict()}, save_path)
