In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm_notebook as tqdm
from sklearn.model_selection import train_test_split
import torchvision
from torchvision import transforms
import math
from PIL import Image
from torchsummary import summary 
from tqdm import trange 

import midi
import utils
import os
import numpy as np
import random

np.random.seed(0)
random.seed(0)

batch_size = 512
num_epochs = 4000
learning_rate = 0.001
hidden_size = 120
step_size = 800
gamma = 0.5
dataset_path = 'Music/'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#torch.backends.cudnn.benchmark = True

print(device)

cuda:0


In [2]:
train_samples = []
train_lengths = []

for file in tqdm(os.listdir(dataset_path)):
    try:
        samples = midi.midi_to_samples(dataset_path + file)
    except:
        print ("ERROR ", dataset_path + file)
        continue
                
    if(len(samples) >= 8):
        samples, lengths = utils.generate_add_centered_transpose(samples)
        train_samples.extend(samples)
        train_lengths.extend(lengths)
        
y_samples = np.array(train_samples)
y_lengths = np.array(train_lengths)

y_test_song = np.expand_dims(np.copy(y_samples[0 : 16]), axis = 0)

y_samples = y_samples[2 * y_lengths[0] : ]
y_lengths = y_lengths[2 : ]

print(y_samples.shape)
print(y_lengths.shape)

num_samples = y_samples.shape[0]
num_songs = y_lengths.shape[0]

HBox(children=(IntProgress(value=0, max=3286), HTML(value='')))

ERROR  Music/Pokemon Mystery Dungeon Explorers of Sky - Defend Globe.mid
ERROR  Music/Sonic Unleashed - Windmill Isle Day.mid
ERROR  Music/Double Dragon II The Revenge - Undersea Base.mid
ERROR  Music/Super Mario Galaxy - End Title.mid
ERROR  Music/Lufia  The Fortress of Doom - Ending.mid
ERROR  Music/Kirby Super Star - Floria.mid
ERROR  Music/Terranigma - Evergreen Forest.mid
ERROR  Music/Mario Kart Wii - Mushroom Gorge.mid
ERROR  Music/Golden Axe II - The Tower.mid
ERROR  Music/Pokemon Diamond Version  Pokemon Pearl Version - Route 205 Day.mid
ERROR  Music/Dance Dance Revolution Hottest Party - Love Shine.mid
ERROR  Music/The Legend of Zelda Breath of the Wild - Spirit Orb Obtained.mid

(328434, 96, 96)
(5352,)


In [5]:
x_shape = (y_lengths.shape[0], 1)
y_shape = (y_lengths.shape[0], 16) + y_samples.shape[1:]

x_orig = np.expand_dims(np.arange(x_shape[0]), axis=-1)
y_orig = np.zeros(y_shape, dtype=y_samples.dtype)

cur_ix = 0

for i in trange(num_songs):
    ix = i
    end_ix = cur_ix + y_lengths[i]
        
    for j in range(16):
        k = j % (end_ix - cur_ix)
        y_orig[ix,j] = y_samples[cur_ix + k]
            
    cur_ix = end_ix

print(end_ix)
print(num_samples)
assert(end_ix == num_samples)

x_train = np.copy(x_orig)
y_train = np.copy(y_orig)

np.save('samples.npy', y_orig)
y_orig_tensor = torch.from_numpy(y_orig).to(device)

print(x_train.shape)
print(y_train.shape)

100%|██████████| 5352/5352 [00:00<00:00, 14242.88it/s]


328434
328434
(5352, 1)
(5352, 16, 96, 96)


In [None]:
# Defining our autoencoder model

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        
        self.fc1 = nn.Linear(96 * 96, 2000)
        self.fc2 = nn.Linear(2000, 200)
        self.fc3 = nn.Linear(16 * 200, 1600)
        self.fc4 = nn.Linear(1600, hidden_size)
        self.bn1 = nn.BatchNorm1d(hidden_size)

    def forward(self, x):
        x = x.view(x.size(0), x.size(1), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc3(x))
        x = self.bn1(self.fc4(x))
        
        return x
        
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.fc1 = nn.Linear(hidden_size, 1600)
        self.fc2 = nn.Linear(1600, 16 * 200)
        self.fc3 = nn.Linear(200, 2000)
        self.fc4 = nn.Linear(2000, 96 * 96)
        self.bn1 = nn.BatchNorm1d(1600)
        self.bn2 = nn.BatchNorm1d(16)
        self.bn3 = nn.BatchNorm1d(16)
        

    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.fc2(x)
        x = x.view(-1, 16, 200)
        x = F.relu(self.bn2(x))
        x = F.relu(self.bn3(self.fc3(x)))
        x = F.sigmoid(self.fc4(x))
        
        return x.view(-1, 16, 96, 96)
    
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        
        self.enc = Encoder()
        self.dec = Decoder()
        
    def forward(self, x): 
        x = self.enc(x)
        x = self.dec(x)
        
        return x
    
    def encoder_forward(self, x):
        x = self.enc(x)
        
        return x
    
    def decoder_forward(self, x):
        x = self.dec(x)
        
        return x
    
model = Autoencoder().to(device)
summary(model, (16, 96, 96))

error = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

In [None]:
class MIDI_data(Dataset):
    def __init__(self, tX):
            self.X = tX
            self.y = tX
            
    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.X[idx]

In [None]:
ofs = 0

def make_rand_songs(write_dir, rand_vecs, thresh):
    for i in range(rand_vecs.shape[0]):
        x_rand = torch.from_numpy(rand_vecs[i : i + 1]).to(device).float()
        with torch.no_grad(): 
            y_song = model.decoder_forward(x_rand).cpu().numpy()
            midi.samples_to_midi(y_song[0], write_dir + 'rand' + str(i) + '.mid', 16, thresh)
            
def train(epoch):
    global ofs
    
    model.train()
    
    cur_ix = 0

    for i in range(num_songs):
        end_ix = cur_ix + y_lengths[i]
            
        for j in range(16):
            k = (j + ofs) % (end_ix - cur_ix)
            y_train[i,j] = y_samples[cur_ix + k]
                
        cur_ix = end_ix
    
    assert(end_ix == num_samples)
    ofs += 1
    
    
    train_dataset = MIDI_data(y_train)
    train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = False, pin_memory = True)
    
    running_loss = 0.0 

    for i, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()

        outputs = model(x.float())

        loss = error(outputs, y.float())

        loss.backward()

        optimizer.step()

        running_loss += loss.item()

        if (i + 1) % 5 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, (i + 1) * len(x), len(train_loader.dataset),
                    100. * (i + 1) / len(train_loader), running_loss / 5))
            
            running_loss = 0.0
            
def evaluate(thresh):
    model.eval()

    y_test_song_tensor = torch.from_numpy(y_test_song).to(device).float()


    with torch.no_grad(): 
        y_song = model(y_test_song_tensor).cpu().numpy()[0]

    midi.samples_to_midi(y_song, 'Output/' + 'test.mid', 16)        
    midi.samples_to_midi(y_test_song[0], 'Output/' + 'target.mid', 16)

    rand_vecs = np.random.normal(0.0, 1.0, (10, hidden_size))

    with torch.no_grad():        
        x_enc = np.squeeze(model.encoder_forward(y_orig_tensor.float()).cpu().numpy())

    x_mean = np.mean(x_enc, axis=0)
    x_stds = np.std(x_enc, axis=0)
    x_cov = np.cov((x_enc - x_mean).T)
    u, s, v = np.linalg.svd(x_cov)
    e = np.sqrt(s)

    x_vecs = x_mean + np.dot(rand_vecs * e, v)
    make_rand_songs('Output/', x_vecs, thresh)

In [None]:
for epoch in range(num_epochs):
    lr_scheduler.step()
    train(epoch)
    if (epoch + 1) % 10 == 0:
        evaluate(0.25)

In [None]:
evaluate(0.25)

In [None]:
torch.save(model.state_dict(), 'States/model_state_MSE_config')

def save_config():
    with open('MSE_config.txt', 'w') as fout:
        fout.write('BATCH_SIZE:  ' + str(batch_size) + '\n')
        fout.write('NUM_EPOCHS:  ' + str(num_epochs) + '\n')
        fout.write('LR:  ' + str(learning_rate) + '\n')
        fout.write('HIDDEN_SIZE:  ' + str(hidden_size) + '\n')
        fout.write('STEP_SIZE:  ' + str(step_size) + '\n')
        fout.write('GAMMA:  ' + str(gamma) + '\n')
        
save_config()