In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import csv
import shutil
from datetime import datetime

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from torch.optim import Adam
import torch.nn.functional as F

from dataset.e_piano import create_epiano_datasets, create_pop909_datasets

from model.music_transformer import MusicTransformer

from model.discriminator import MusicDiscriminator

from model.loss import SmoothCrossEntropyLoss

from utilities.constants import *
from utilities.WGAN_GP import WassersteinLoss
from utilities.device import get_device, use_cuda
from utilities.lr_scheduling import LrStepTracker, get_lr
from utilities.argument_funcs import parse_train_args, print_train_args, write_model_params
from utilities.run_model import train_epoch, eval_model

In [2]:
import argparse

In [3]:
parser = argparse.ArgumentParser()

parser.add_argument("-classic_input_dir", type=str, default="./dataset/e_piano", help="Folder of preprocessed and pickled midi files")
parser.add_argument("-pop_input_dir", type=str, default="./dataset/pop_pickle", help="Folder of preprocessed and pickled midi files")
parser.add_argument("-output_dir", type=str, default="./saved_models", help="Folder to save model weights. Saves one every epoch")
parser.add_argument("-weight_modulus", type=int, default=10, help="How often to save epoch weights (ex: value of 10 means save every 10 epochs)")
parser.add_argument("-print_modulus", type=int, default=50, help="How often to print train results for a batch (batch loss, learn rate, etc.)")

parser.add_argument("-n_workers", type=int, default=4, help="Number of threads for the dataloader")
parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available")
parser.add_argument("--no_tensorboard", action="store_true", help="Turns off tensorboard result reporting")

parser.add_argument("--gan", action="store_true", help="use generative adversarial training")
parser.add_argument("--creative", action="store_true", help="creative learning")

parser.add_argument("-continue_weights", type=str, default=None, help="Model weights to continue training based on")
parser.add_argument("-continue_epoch", type=int, default=None, help="Epoch the continue_weights model was at")

parser.add_argument("-lr", type=float, default=None, help="Constant learn rate. Leave as None for a custom scheduler.")
parser.add_argument("-ce_smoothing", type=float, default=0.1, help="Smoothing parameter for smoothed cross entropy loss (defaults to no smoothing)")
parser.add_argument("-batch_size", type=int, default=32, help="Batch size to use")
parser.add_argument("-epochs", type=int, default=100, help="Number of epochs to use")

parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations")
parser.add_argument("-max_sequence", type=int, default=1536, help="Maximum midi sequence to consider")
parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use")
parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention")
parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)")

parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer")

parser.add_argument("-dropout", type=float, default=0.1, help="Dropout rate")

args = parser.parse_args(args=[])

In [4]:
args.rpr = True

In [5]:
if(args.lr is None):
    if(args.continue_epoch is None):
        init_step = 0
    else:
        init_step = args.continue_epoch * len(train_loader)

    lr = LR_DEFAULT_START
    lr_stepper = LrStepTracker(args.d_model, SCHEDULER_WARMUP_STEPS, init_step)
else:
    lr = args.lr

In [6]:
train_dataset, val_dataset, test_dataset = create_epiano_datasets(args.classic_input_dir, args.max_sequence)


pop909_dataset = create_pop909_datasets('dataset/pop_pickle', args.max_sequence)

train_set, val_set, test_set = torch.utils.data.random_split(pop909_dataset, [int(len(pop909_dataset) * 0.8), int(len(pop909_dataset) * 0.1), len(pop909_dataset) - int(len(pop909_dataset) * 0.8) - int(len(pop909_dataset) * 0.1)])

train_dataset = torch.utils.data.ConcatDataset([train_dataset, train_set])
val_dataset = torch.utils.data.ConcatDataset([val_dataset, val_set])



train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.n_workers, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.n_workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.n_workers)

classifier = MusicDiscriminator(n_layers=args.n_layers // 2, num_heads=args.num_heads // 2,
                             d_model=args.d_model // 2, dim_feedforward=args.dim_feedforward // 2, dropout=args.dropout,
                             max_sequence=args.max_sequence, rpr=args.rpr).to(get_device())

In [7]:
import time

In [8]:
classifier_loss = nn.MSELoss()

In [9]:
classifier_opt = Adam(classifier.parameters(), lr=lr, betas=(ADAM_BETA_1, ADAM_BETA_2), eps=ADAM_EPSILON)

if(args.lr is None):
    classifier_lr_scheduler = LambdaLR(classifier_opt, lr_stepper.step)
else:
    lr_scheduler = None

In [10]:
args.print_modulus = 200

In [None]:
best_acc = 0

for epoch in range(300):
    
    train_acc_cla_loss = 0
    train_acc_class_accuracy = 0
    
    classifier.train()

    for batch_num, batch in enumerate(train_loader):
        time_before = time.time()

        x   = batch[0].to(get_device())
        tgt = batch[1].to(get_device())
        label = batch[2].to(get_device())
        
        tgt = F.one_hot(tgt, num_classes = VOCAB_SIZE).float()

        classifier_pred = classifier(tgt)

        class_loss = classifier_loss(classifier_pred, label.float())

        classifier_opt.zero_grad()
        class_loss.backward()
        classifier_opt.step()

        train_acc_cla_loss += float(class_loss)

        train_acc_class_accuracy += ((classifier_pred > 0.5).float() == label).float().mean()

        if classifier_lr_scheduler is not None:
            classifier_lr_scheduler.step()
        
    val_acc_cla_loss = 0
    val_acc_class_accuracy = 0
    
    classifier.eval()
        
    for batch_num, batch in enumerate(val_loader):
        time_before = time.time()

        x   = batch[0].to(get_device())
        tgt = batch[1].to(get_device())
        label = batch[2].to(get_device())
        
        tgt = F.one_hot(tgt, num_classes = VOCAB_SIZE).float()

        classifier_pred = classifier(tgt)

        class_loss = classifier_loss(classifier_pred, label.float())

        val_acc_cla_loss += float(class_loss)

        val_acc_class_accuracy += ((classifier_pred > 0.5).float() == label).float().mean()

        #if classifier_lr_scheduler is not None:
        #    classifier_lr_scheduler.step()
        
        
    if float(val_acc_class_accuracy) / len(val_loader) > best_acc:
        best_acc = float(val_acc_class_accuracy) / len(val_loader)
        torch.save(classifier.state_dict(), f'best_classifier_acc_{best_acc:.4f}.pickle')
        
        
    print(SEPERATOR)
    print(f"Epoch {epoch}")
    print(
        f"Classifier LR: {get_lr(classifier_opt)}")
    print(f"Classifier Train Loss: {train_acc_cla_loss / len(train_loader):.5f}, Val Loss: {val_acc_cla_loss / len(val_loader):.5f}")
    print(f"Classifier Train Accuracy: {float(train_acc_class_accuracy) / len(train_loader):.5f}, Val Accuracy: {float(val_acc_class_accuracy) / len(val_loader):.5f}")
    print(SEPERATOR)
    print("")

Epoch 0
Classifier LR: 9.258718969335066e-06
Classifier Train Loss: 0.25494, Val Loss: 0.26370
Classifier Train Accuracy: 0.51973, Val Accuracy: 0.53516

Epoch 1
Classifier LR: 1.8517437938670133e-05
Classifier Train Loss: 0.25079, Val Loss: 0.25167
Classifier Train Accuracy: 0.54603, Val Accuracy: 0.53516

Epoch 2
Classifier LR: 2.77761569080052e-05
Classifier Train Loss: 0.25048, Val Loss: 0.25861
Classifier Train Accuracy: 0.54784, Val Accuracy: 0.53516

Epoch 3
Classifier LR: 3.7034875877340266e-05
Classifier Train Loss: 0.24870, Val Loss: 0.24893
Classifier Train Accuracy: 0.56427, Val Accuracy: 0.53516

Epoch 4
Classifier LR: 4.6293594846675334e-05
Classifier Train Loss: 0.25024, Val Loss: 0.24907
Classifier Train Accuracy: 0.56899, Val Accuracy: 0.53516

Epoch 5
Classifier LR: 5.55523138160104e-05
Classifier Train Loss: 0.24669, Val Loss: 0.24763
Classifier Train Accuracy: 0.56144, Val Accuracy: 0.53516

Epoch 6
Classifier LR: 6.481103278534547e-05
Classifier Train Loss: 0.24844