# Generating Music using VAEs

In [None]:
import os
import shutil
import glob
import numpy as np 
import pandas as pd
import pretty_midi
import pypianoroll
import tables
from music21 import converter, instrument, note, chord, stream
import music21
import librosa
import librosa.display
import matplotlib.pyplot as plt
from keras.utils import np_utils
import json
import IPython.display
from datetime import datetime
import random

import torch
import torch.nn as nn
from torch.nn import functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from tqdm.notebook import tqdm, trange

import random
import itertools
root_dir = 'drive/MyDrive/ProjectMusic'
data_dir = root_dir + '/Lakh Piano Dataset/LPD-5/lpd_5/lpd_5_cleansed'

In [None]:
!apt-get update -qq && apt-get install -qq libfluidsynth1 fluid-soundfont-gm build-essential libasound2-dev libjack-dev
!pip install -qU pyfluidsynth pretty_midi
!pip install music21
!pip install pypianoroll

In [None]:
from google.colab import drive
drive.mount('/content/drive')

**Getting MIDI and Song Metadata**

In [None]:
RESULTS_PATH = os.path.join(root_dir, 'Lakh Piano Dataset', 'Metadata')

# Utility functions for retrieving paths
def msd_id_to_dirs(msd_id):
    """Given an MSD ID, generate the path prefix.
    E.g. TRABCD12345678 -> A/B/C/TRABCD12345678"""
    return os.path.join(msd_id[2], msd_id[3], msd_id[4], msd_id)


def msd_id_to_h5(msd_id):
    """Given an MSD ID, return the path to the corresponding h5"""
    return os.path.join(RESULTS_PATH, 'lmd_matched_h5',
                        msd_id_to_dirs(msd_id) + '.h5')

# Load the midi npz file from the LMD cleansed folder
def get_midi_npz_path(msd_id, midi_md5):
    return os.path.join(data_dir,
                        msd_id_to_dirs(msd_id), midi_md5 + '.npz')

In [None]:
# Open the cleansed ids - cleansed file ids : msd ids
cleansed_ids = pd.read_csv(os.path.join(root_dir, 'Lakh Piano Dataset', 'cleansed_ids.txt'), delimiter = '    ', header = None)
lpd_to_msd_ids = {a:b for a, b in zip(cleansed_ids[0], cleansed_ids[1])}
msd_to_lpd_ids = {a:b for a, b in zip(cleansed_ids[1], cleansed_ids[0])}

In [None]:
# Reading the genre annotations
genre_file_dir = os.path.join(root_dir, 'Lakh Piano Dataset', 'Genre', 'msd_tagtraum_cd1.cls')
ids = []
genres = []
with open(genre_file_dir) as f:
    line = f.readline()
    while line:
        if line[0] != '#':
          split = line.strip().split("\t")
          if len(split) == 2:
            ids.append(split[0])
            genres.append(split[1])
          elif len(split) == 3:
            ids.append(split[0])
            ids.append(split[0])
            genres.append(split[1])
            genres.append(split[2])
        line = f.readline()
genre_df = pd.DataFrame(data={"TrackID": ids, "Genre": genres})

genre_dict = genre_df.groupby('TrackID')['Genre'].apply(lambda x: x.tolist()).to_dict()

**Objects that we need**

- cleansed_ids: dictionary of LPD file name : MSD file name
- lmd_metadata: list of dictionaries - each dict has a msd_id field to identify
- Get the lmd_file_name (actual path )

In [None]:
# Load the processed metadata
with open(os.path.join(root_dir, 'Lakh Piano Dataset', 'processed_metadata.json'), 'r') as outfile:
  lmd_metadata = json.load(outfile)

# Change this into a dictionary of MSD_ID: metadata
lmd_metadata = {e['msd_id']:e for e in lmd_metadata}

In [None]:
# Get all song MSD IDs in pop rock genre
metal_song_msd_ids = [k for k, v in lmd_metadata.items() if 'rock' in v['artist_terms']]

# Randomly choose 1000 songs out of these
train_ids = random.choices(metal_song_msd_ids, k = 2000)

In [None]:
combined_pianorolls = []
i = 0
for msd_file_name in train_ids:

  lpd_file_name = msd_to_lpd_ids[msd_file_name]
  # Get the NPZ path
  npz_path = get_midi_npz_path(msd_file_name, lpd_file_name)
  multitrack = pypianoroll.load(npz_path)
  multitrack.set_resolution(2).pad_to_same()

  # Piano, Guitar, Bass, Strings, Drums
  # Splitting into different parts

  parts = {'piano_part': None, 'guitar_part': None, 'bass_part': None, 'strings_part': None, 'drums_part': None}
  song_length = None
  empty_array = None
  has_empty_parts = False
  for track in multitrack.tracks:
    if track.name == 'Drums':
      parts['drums_part'] = track.pianoroll
    if track.name == 'Piano':
      parts['piano_part'] = track.pianoroll
    if track.name == 'Guitar':
      parts['guitar_part'] = track.pianoroll
    if track.name == 'Bass':
      parts['bass_part'] = track.pianoroll
    if track.name == 'Strings':
      parts['strings_part'] = track.pianoroll
    if track.pianoroll.shape[0] > 0:
      empty_array = np.zeros_like(track.pianoroll)

  for k,v in parts.items():
    if v.shape[0] == 0:
      parts[k] = empty_array.copy()
      has_empty_parts = True

  # Stack all together - Piano, Guitar, Bass, Strings, Drums
  combined_pianoroll = torch.tensor([parts['piano_part'], parts['guitar_part'], parts['bass_part'], parts['strings_part'], parts['drums_part']])

  # These contain velocity information - the force with which the notes are hit - which can be standardized to 0/1 if we want (to compress)
  if has_empty_parts == False:
    combined_pianorolls.append(combined_pianoroll)
    i+=1
    print(i)

  if i == 1000:
    break

In [None]:
pianoroll_lengths = [e.size()[1] for e in combined_pianorolls]
combined_pianorolls = torch.hstack(combined_pianorolls)

In [None]:
torch.save(combined_pianorolls, os.path.join(root_dir, 'Lakh Piano Dataset', 'metal_1000_pianorolls.pt'))
pianoroll_lengths = torch.tensor(pianoroll_lengths)
torch.save(pianoroll_lengths, os.path.join(root_dir, 'Lakh Piano Dataset', 'metal_1000_pianorolls_lengths.pt'))

**Loading Pianoroll Data**

In [None]:
# Loading
combined_pianorolls = torch.load(os.path.join(root_dir, 'Lakh Piano Dataset', 'rock_pianorolls_res2_part1.pt')) / 127.0
pianoroll_lengths = torch.load(os.path.join(root_dir, 'Lakh Piano Dataset', 'rock_pianorolls_lengths_res2_part1.pt'))
pianoroll_lengths = pianoroll_lengths.numpy()
pianoroll_cum_lengths = pianoroll_lengths.cumsum()

In [None]:
# Remake the list of pianorolls - ensuring all songs are multiple of 32
pianorolls_list = []
pianorolls_list.append(combined_pianorolls[:, :(pianoroll_cum_lengths[0] - pianoroll_cum_lengths[0] % 32), :])
for i in range(len(pianoroll_cum_lengths) - 1):
  length = pianoroll_cum_lengths[i+1] - pianoroll_cum_lengths[i]
  # Get the nearest multiple of 32
  length_multiple = length - (length % 32)
  pianoroll = combined_pianorolls[:, pianoroll_cum_lengths[i]:(pianoroll_cum_lengths[i] + length_multiple), :]
  pianorolls_list.append(pianoroll)

**Creating Music Dataset**

In [None]:
# Creating dataset and dataloader
from torch.utils.data import Dataset, DataLoader

In [None]:
# Dataset which only returns sequences which are multiples of 32
class CombinedDataset(Dataset):
  def __init__(self, pianorolls, instrument_id):
    self.data = pianorolls
    self.length = int(pianorolls.size(1) / 32)
    self.instrument_id = instrument_id

  def __getitem__(self, index):
    sequence = self.data[self.instrument_id, (index * 32):((index+1) * 32), :]
    return sequence
    
  def __len__(self):
    return self.length

In [None]:
# Melody-conditional dataset NEW - returns BOTH the previous harmony, and current melody, and current harmony
# only outputs samples with all tracks non-empty
class ConditionalDataset(Dataset):
  def __init__(self, list_of_sequences, dataset_length = 32 * 1000, seq_length = 50, instrument = 'guitar'):

    self.data = list_of_sequences
    self.n_songs = len(list_of_sequences)
    self.seq_length = seq_length
    self.length = dataset_length
    self.instrument = instrument

  def __getitem__(self, index):
    # Choose a random song id
    valid_sequence = False

    while valid_sequence == False:
      song_id = random.randint(0, self.n_songs - 1)
      song_length = self.data[song_id].size()[1]

      if song_length > (self.seq_length * 2 + 5):

        # Choose a random start window
        start_time = random.randint(0, song_length - self.seq_length * 2 - 2)
        start_time = start_time - (start_time % 32)

        # train_sequence: 1 (piano) x seq_length x 128
        piano_sequence = self.data[song_id][0, (start_time + self.seq_length):(start_time + self.seq_length * 2), :]

        if self.instrument == 'guitar':
          past_sequence = self.data[song_id][1, start_time:(start_time + self.seq_length), :]
          target_sequence = self.data[song_id][1, (start_time + self.seq_length):(start_time + self.seq_length * 2), :]
        elif self.instrument == 'bass':
          past_sequence = self.data[song_id][2, start_time:(start_time + self.seq_length), :]
          target_sequence = self.data[song_id][2, (start_time + self.seq_length):(start_time + self.seq_length * 2), :]
        elif self.instrument == 'strings':
          past_sequence = self.data[song_id][3, start_time:(start_time + self.seq_length), :]
          target_sequence = self.data[song_id][3, (start_time + self.seq_length):(start_time + self.seq_length * 2), :]
        elif self.instrument == 'drums':
          past_sequence = self.data[song_id][4, start_time:(start_time + self.seq_length), :]
          target_sequence = self.data[song_id][4, (start_time + self.seq_length):(start_time + self.seq_length * 2), :]
        else:
          past_sequence = None
          target_sequence = None

        if piano_sequence.sum() != 0 and past_sequence.sum() != 0 and target_sequence.sum() != 0:
          valid_sequence = True
        else:
          if random.random() < 0.1:
            valid_sequence = True
          
    return piano_sequence, past_sequence, target_sequence
  def __len__(self):
    return self.length

In [None]:
# Melody prediction dataset - predict the next melody given the current melody
class MelodyDataset(Dataset):
  def __init__(self, list_of_sequences, dataset_length = 32 * 10000, seq_length = 50):

    # Don't normalize anymore since it was done earlier
    self.data = list_of_sequences
    self.n_songs = len(list_of_sequences)
    self.seq_length = seq_length
    self.length = dataset_length

  def __getitem__(self, index):
    # Choose a random song id
    valid_sequence  = False
    while valid_sequence == False:
      song_id = random.randint(0, self.n_songs - 1)
      song_length = self.data[song_id].size()[1]

      if song_length > (self.seq_length * 2 + 5):
        # Choose a random start window
        start_time = random.randint(0, song_length - self.seq_length * 2 - 2)
        start_time = start_time - (start_time % 32)
        # train_sequence: 1 (piano) x seq_length x 128
        train_sequence = self.data[song_id][0, start_time:(start_time + self.seq_length), :]
        # target_sequence: 1 (piano) x seq_length x 128
        target_sequence = self.data[song_id][0, (start_time + self.seq_length):(start_time + self.seq_length * 2), :]

        valid_sequence = True
    
    return train_sequence, target_sequence

  def __len__(self):
    return self.length

### VAE Execution

In [None]:
# @title VAE Helper Functions

def kl_q_p(zs, phi):
    """Given [b,n,k] samples of z drawn from q, compute estimate of KL(q||p).
    phi must be size [b,k+1]

    This uses mu_p = 0 and sigma_p = 1, which simplifies the log(p(zs)) term to
    just -1/2*(zs**2)
    """
    b, n, k = zs.size()
    mu_q, log_sig_q = phi[:,:-1], phi[:,-1]
    log_p = -0.5*(zs**2)
    log_q = -0.5*(zs - mu_q.view(b,1,k))**2 / log_sig_q.exp().view(b,1,1)**2 - log_sig_q.view(b,1,-1)
    # Size of log_q and log_p is [b,n,k]. Sum along [k] but mean along [b,n]
    return (log_q - log_p).sum(dim=2).mean(dim=(0,1))

def log_p_x(x, mu_xs, sig_x):
    """Given [batch, ...] input x and [batch, n, ...] reconstructions, compute
    pixel-wise log Gaussian probability

    Sum over pixel dimensions, but mean over batch and samples.
    """
    b, n = mu_xs.size()[:2]
    # Flatten out pixels and add a singleton dimension [1] so that x will be
    # implicitly expanded when combined with mu_xs
    x = x.reshape(b, 1, -1)
    _, _, p = x.size()
    squared_error = (x - mu_xs.view(b, n, -1))**2 / (2*sig_x**2)

    # Size of squared_error is [b,n,p]. log prob is by definition sum over [p].
    # Expected value requires mean over [n]. Handling different size batches
    # requires mean over [b].
    return -(squared_error + torch.log(sig_x)).sum(dim=2).mean(dim=(0,1))

def rsample(phi, n_samples):
    """Sample z ~ q(z;phi)
    Ouput z is size [b,n_samples,K] given phi with shape [b,K+1]. The first K
    entries of each row of phi are the mean of q, and phi[:,-1] is the log
    standard deviation
    """
    b, kplus1 = phi.size()
    k = kplus1-1
    mu, sig = phi[:, :-1], phi[:,-1].exp()
    eps = torch.randn(b, n_samples, k, device=phi.device)
    return eps*sig.view(b,1,1) + mu.view(b,1,k)

In [None]:
class ConvVAE(nn.Module):
    def __init__(self, K, num_filters=32, filter_size=5):
        super(ConvVAE, self).__init__()

        # Define the recognition model (encoder or q) part
        # Input size: num_channels (1) x seq_length (32) x n_pitches (128)
        self.q_conv_1 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = (4, 4), stride = (4, 4))
        self.q_conv_2 = nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = (4, 4), stride = (4, 4))
        self.q_conv_3 = nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = (2, 8), stride = (2, 8))
        self.q_fc_phi = nn.Linear(256, K+1)

        # Define the generative model (decoder or p) part
        self.p_fc_upsample = nn.Linear(K, 256)
        self.p_deconv_1 = nn.ConvTranspose2d(in_channels = 256, out_channels = 128, kernel_size = (2, 8), stride = (2, 8))
        self.p_deconv_2 = nn.ConvTranspose2d(in_channels = 128, out_channels = 64, kernel_size = (4, 4), stride = (4, 4))
        self.p_deconv_3 = nn.ConvTranspose2d(in_channels = 64, out_channels = 1, kernel_size = (4, 4), stride = (4, 4))

        # Define a special extra parameter to learn scalar sig_x for all pixels
        self.log_sig_x = nn.Parameter(torch.zeros(()))
    
    def infer(self, x):
        """Map (batch of) x to (batch of) phi which can then be passed to
        rsample to get z
        """
        x = x.unsqueeze(1)
        s = F.relu(self.q_conv_1(x))
        s = F.relu(self.q_conv_2(s))
        s = F.relu(self.q_conv_3(s))
        # Flatten s
        flat_s = s.view(s.size()[0], -1)
        phi = self.q_fc_phi(flat_s)
        return phi

    def generate(self, zs):
        """Map [b,n,k] sized samples of z to [b,n,p] sized images
        """
        # Note that for the purposes of passing through the generator, we need
        # to reshape zs to be size [b*n,k]
        b, n, k = zs.size()
        s = zs.view(b*n, -1)
        # Unflatten
        s = F.relu(self.p_fc_upsample(s)).unsqueeze(2).unsqueeze(3)
        s = F.relu(self.p_deconv_1(s))
        s = F.relu(self.p_deconv_2(s))
        s = self.p_deconv_3(s)
        mu_xs = s.view(b, n, -1)
        return mu_xs
    
    def forward(self, x):
        # VAE.forward() is not used for training, but we'll treat it like a
        # classic autoencoder by taking a single sample of z ~ q
        phi = self.infer(x)
        zs = rsample(phi, 1)
        return self.generate(zs).view(x.size())

    def elbo(self, x, n=1):
        """Run input end to end through the VAE and compute the ELBO using n
        samples of z
        """
        phi = self.infer(x)
        zs = rsample(phi, n)
        mu_xs = self.generate(zs)
        return log_p_x(x, mu_xs, self.log_sig_x.exp()) - kl_q_p(zs, phi)

In [None]:
def train_vae(vae, dataloader, epochs=10):
    opt = torch.optim.Adam(vae.parameters(), lr=0.001, weight_decay=1e-6)
    elbo_vals = []
    vae.to(device)
    vae.train()
    for epoch in trange(epochs, desc='Epochs'):
        for sequence in tqdm(dataloader, total=len(dataset)//100, desc='Batches', leave=False):
            sequence = sequence.to(device)
            opt.zero_grad()
            loss = -vae.elbo(sequence)
            loss.backward()
            opt.step()

            elbo_vals.append(-loss.item())
    vae.to('cpu')
    vae.eval()
    return elbo_vals

In [None]:
vae = ConvVAE(K = 64)
elbo_vals = train_vae(vae, piano_loader, epochs = 20)

In [None]:
# Code to generate VAEs
for K in [8, 16, 32]:
  instruments = ['piano', 'guitar', 'bass', 'strings', 'drums']
  for i in range(5):
    print(K, instrument)
    dataset = CombinedDataset(combined_pianorolls, instrument_id = i)
    piano_loader = DataLoader(dataset, batch_size = 32, drop_last=True)

    vae = ConvVAE(K = K)
    elbo_vals = train_vae(vae, piano_loader, epochs = 25)
    model_name = 'VAE_{}_{}'.format(instruments[i], K)
    save_path = os.path.join(root_dir, 'Saved Models', 'VAE', 'Large Data', model_name)
    torch.save(vae.state_dict(), save_path)

    print(f'Learned sigma_x is {torch.exp(vae.log_sig_x)}')
    plt.figure()
    plt.plot(elbo_vals)
    plt.xlabel('Batch #')
    plt.ylabel('ELBO')
    plt.show()

### Training Melody and Conditional NNs

In [None]:
# @title Load trained VAEs (Run me)
# Specify dimensionality of VAEs you want (K = 8, 16, 32, 64)
K = 16

model_name = 'VAE_piano_{}'.format(K)
save_path = os.path.join(root_dir, 'Saved Models', 'VAE', 'Large Data', model_name)
piano_vae = ConvVAE(K=K).to(device)
piano_vae.load_state_dict(torch.load(save_path))
piano_vae.eval()

model_name = 'VAE_guitar_{}'.format(K)
save_path = os.path.join(root_dir, 'Saved Models', 'VAE', 'Large Data', model_name)
guitar_vae = ConvVAE(K=K).to(device)
guitar_vae.load_state_dict(torch.load(save_path))
guitar_vae.eval()

model_name = 'VAE_bass_{}'.format(K)
save_path = os.path.join(root_dir, 'Saved Models', 'VAE', 'Large Data', model_name)
bass_vae = ConvVAE(K=K).to(device)
bass_vae.load_state_dict(torch.load(save_path))
bass_vae.eval()

model_name = 'VAE_strings_{}'.format(K)
save_path = os.path.join(root_dir, 'Saved Models', 'VAE', 'Large Data', model_name)
strings_vae = ConvVAE(K=K).to(device)
strings_vae.load_state_dict(torch.load(save_path))
strings_vae.eval()

model_name = 'VAE_drums_{}'.format(K)
save_path = os.path.join(root_dir, 'Saved Models', 'VAE', 'Large Data', model_name)
drums_vae = ConvVAE(K=K).to(device)
drums_vae.load_state_dict(torch.load(save_path))
drums_vae.eval()

In [None]:
# @title Conditional and Melody NN Architectures

# Conditional NN - uses current melody and previous harmony's LATENT vectors to predict next harmony's LATENT vectors
class ConditionalNN(nn.Module):
    def __init__(self, K):
        super(ConditionalNN, self).__init__()

        self.fc1 = nn.Linear(2*K, 128)
        self.fc2 = nn.Linear(128, K)

    def forward(self, prev_harmony, melody):

      x = torch.cat((prev_harmony, melody), axis = 1)
      x = F.relu(self.fc1(x))
      out = self.fc2(x)
      return out

# Melody NN - uses previous melody's LATENT vectors to predict next melody's LATENT VECTORS
class MelodyNN(nn.Module):
    def __init__(self, K):
        super(MelodyNN, self).__init__()
        self.fc1 = nn.Linear(K, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, K)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
      x = F.relu(self.fc1(x))
      x = self.dropout(x)
      x = F.relu(self.fc2(x))
      out = self.fc3(x)
      return out

In [None]:
# @title Melody NN Training Functions

def run_epoch_MelodyNN(dataloader, piano_vae, melody_nn, optimizer, criterion, is_train = True):
    running_loss = 0
    n_obs = 0

    for train_seq, target_seq in dataloader:
      train_seq = train_seq.to(device)
      target_seq = target_seq.to(device)

      with torch.no_grad():
        train_latent = piano_vae.infer(train_seq.to(device))[:, :-1]
        target_latent = piano_vae.infer(target_seq.to(device))[:, :-1]

      output_latent = melody_nn(train_latent.detach())
      loss = criterion(output_latent, target_latent)

      if is_train == True:
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

      running_loss += loss.item()
      n_obs += train_seq.size()[0]
    
    return running_loss / n_obs * 100

def training_loop_MelodyNN(piano_vae, melody_nn, optimizer, scheduler, criterion, train_dataloader, test_dataloader, n_epochs = 50):
  train_losses = []
  test_losses = []

  for epoch in range(1, n_epochs + 1):
    melody_nn.train()
    print(scheduler.get_last_lr())
    train_epoch_loss = run_epoch_MelodyNN(train_dataloader, piano_vae, melody_nn, optimizer, criterion, is_train = True)
    train_losses.append(train_epoch_loss)
    scheduler.step()
    melody_nn.eval()
    test_epoch_loss = run_epoch_MelodyNN(test_dataloader, piano_vae, melody_nn, optimizer, criterion, is_train = False)
    test_losses.append(test_epoch_loss)
    print('Epoch {}, Train Loss: {}, Test Loss: {}, Time: {}'.format(epoch, train_epoch_loss, test_epoch_loss, datetime.now()))
  return train_losses, test_losses

In [None]:
# @title Melody NN Training Loop

n_epochs = 50
lr = 0.001
lr_lambda = 0.95
K = 16

melody_train_dataset = MelodyDataset(pianorolls_list, dataset_length = 32 * 8000, seq_length = 32)
melody_train_loader = DataLoader(melody_train_dataset, batch_size = 32, drop_last=True)
melody_test_dataset = MelodyDataset(pianorolls_list[4000:4000], dataset_length = 32 * 1000, seq_length = 32)
melody_test_loader = DataLoader(melody_test_dataset, batch_size = 32, drop_last=True)

# Load the model
model_name = 'VAE_piano_{}'.format(K)
save_path = os.path.join(root_dir, 'Saved Models', 'VAE', model_name)
piano_vae = ConvVAE(K=K).to(device)
piano_vae.load_state_dict(torch.load(save_path))
piano_vae.eval()

melody_nn = MelodyNN(K = K).to(device)
optimizer = torch.optim.Adam(melody_nn.parameters(), lr = lr)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda epoch: lr_lambda ** epoch)
criterion = nn.MSELoss()
train_losses, test_losses = training_loop_MelodyNN(piano_vae, melody_nn, optimizer, scheduler, criterion, melody_train_loader, melody_test_loader, n_epochs = n_epochs)

model_name = 'VAE_NN_piano_{}'.format(K)
save_path = os.path.join(root_dir, 'Saved Models', 'VAE', model_name)
torch.save(melody_nn.state_dict(), save_path)

plt.figure()
plt.plot(train_losses, label = 'Train Loss')
plt.plot(test_losses, label = 'Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# @title Conditional NN Training Functions

def run_epoch_VAENN(dataloader, piano_vae, harmony_vae, conditional_nn, optimizer, criterion, is_train = True):
    running_loss = 0
    n_obs = 0

    for piano_seq, past_seq, target_seq in dataloader:
      piano_seq = piano_seq.to(device)
      past_seq = past_seq.to(device)
      target_seq = target_seq.to(device)

      with torch.no_grad():
        piano_latent = piano_vae.infer(piano_seq.to(device))[:, :-1]
        past_latent = harmony_vae.infer(past_seq.to(device))[:, :-1]
        target_latent = harmony_vae.infer(target_seq.to(device))[:, :-1]

      output_latent = conditional_nn(past_latent.detach(), piano_latent.detach())
      loss = criterion(output_latent, target_latent)

      if is_train == True:
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

      running_loss += loss.item()
      n_obs += piano_seq.size()[0]
    
    return running_loss / n_obs * 100

def training_loop_VAENN(piano_vae, harmony_vae, conditional_nn, optimizer, scheduler, criterion, train_dataloader, test_dataloader, n_epochs = 50):
  train_losses = []
  test_losses = []

  for epoch in range(1, n_epochs + 1):
    conditional_nn.train()
    print(scheduler.get_last_lr())
    train_epoch_loss = run_epoch_VAENN(train_dataloader, piano_vae, harmony_vae, conditional_nn, optimizer, criterion, is_train = True)
    train_losses.append(train_epoch_loss)
    scheduler.step()
    conditional_nn.eval()
    test_epoch_loss = run_epoch_VAENN(test_dataloader, piano_vae, harmony_vae, conditional_nn, optimizer, criterion, is_train = False)
    test_losses.append(test_epoch_loss)
    print('Epoch {}, Train Loss: {}, Test Loss: {}, Time: {}'.format(epoch, train_epoch_loss, test_epoch_loss, datetime.now()))
  return train_losses, test_losses

In [None]:
# @title Conditional NN Training Loop

n_epochs = 30
lr = 0.001
lr_lambda = 0.95
K = 8

for instrument in ['guitar', 'bass', 'strings', 'drums']:
  print(instrument)

  cond_train_dataset = ConditionalDataset(pianorolls_list[0:1000], dataset_length = 32 * 4000, seq_length = 32, instrument = instrument)
  cond_train_loader = DataLoader(cond_train_dataset, batch_size = 32, drop_last=True)
  cond_test_dataset = ConditionalDataset(pianorolls_list[900:1000], dataset_length = 32 * 1000, seq_length = 32, instrument = instrument)
  cond_test_loader = DataLoader(cond_test_dataset, batch_size = 32, drop_last=True)

  # Load the model
  model_name = 'VAE_{}_{}'.format(instrument, K)
  save_path = os.path.join(root_dir, 'Saved Models', 'VAE', model_name)
  harmony_vae = ConvVAE(K=K).to(device)
  harmony_vae.load_state_dict(torch.load(save_path))
  harmony_vae.eval()

  conditional_nn = ConditionalNN(K = K).to(device)
  optimizer = torch.optim.Adam(conditional_nn.parameters(), lr = lr)
  scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda epoch: lr_lambda ** epoch)
  criterion = nn.MSELoss()
  train_losses, test_losses = training_loop_VAENN(piano_vae, harmony_vae, conditional_nn, optimizer, scheduler, criterion, cond_train_loader, cond_test_loader, n_epochs = n_epochs)

  model_name = 'VAE_NN_{}_{}'.format(instrument, K)
  save_path = os.path.join(root_dir, 'Saved Models', 'VAE', model_name)
  torch.save(conditional_nn.state_dict(), save_path)

  plt.figure()
  plt.plot(train_losses, label = 'Train Loss')
  plt.plot(test_losses, label = 'Test Loss')
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  plt.legend()
  plt.show()

In [None]:
# @title Conditional NN Training Loop #2

n_epochs = 50
lr = 0.001
lr_lambda = 0.95
K = 16

melody_train_dataset = MelodyDataset(pianorolls_list, dataset_length = 32 * 8000, seq_length = 32)
melody_train_loader = DataLoader(melody_train_dataset, batch_size = 32, drop_last=True)
melody_test_dataset = MelodyDataset(pianorolls_list[0:500], dataset_length = 32 * 1000, seq_length = 32)
melody_test_loader = DataLoader(melody_test_dataset, batch_size = 32, drop_last=True)

# Load the model
model_name = 'VAE_piano_{}'.format(K)
save_path = os.path.join(root_dir, 'Saved Models', 'VAE', 'Large Data', model_name)
piano_vae = ConvVAE(K=K).to(device)
piano_vae.load_state_dict(torch.load(save_path))
piano_vae.eval()

melody_nn = MelodyNN(K = K).to(device)
optimizer = torch.optim.Adam(melody_nn.parameters(), lr = lr)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda epoch: lr_lambda ** epoch)
criterion = nn.MSELoss()
train_losses, test_losses = training_loop_MelodyNN(piano_vae, melody_nn, optimizer, scheduler, criterion, melody_train_loader, melody_test_loader, n_epochs = n_epochs)

model_name = 'VAE_NN_piano_{}'.format(K)
save_path = os.path.join(root_dir, 'Saved Models', 'VAE', 'Large Data', model_name)
torch.save(melody_nn.state_dict(), save_path)

plt.figure()
plt.plot(train_losses, label = 'Train Loss')
plt.plot(test_losses, label = 'Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

n_epochs = 50
lr = 0.001
lr_lambda = 0.95
K = 16

for instrument in ['guitar', 'bass', 'strings', 'drums']:
  print(instrument)

  cond_train_dataset = ConditionalDataset(pianorolls_list, dataset_length = 32 * 8000, seq_length = 32, instrument = instrument)
  cond_train_loader = DataLoader(cond_train_dataset, batch_size = 32, drop_last=True)
  cond_test_dataset = ConditionalDataset(pianorolls_list[0:500], dataset_length = 32 * 1000, seq_length = 32, instrument = instrument)
  cond_test_loader = DataLoader(cond_test_dataset, batch_size = 32, drop_last=True)

  # Load the model
  model_name = 'VAE_{}_{}'.format(instrument, K)
  save_path = os.path.join(root_dir, 'Saved Models', 'VAE', 'Large Data', model_name)
  harmony_vae = ConvVAE(K=K).to(device)
  harmony_vae.load_state_dict(torch.load(save_path))
  harmony_vae.eval()

  conditional_nn = ConditionalNN(K = K).to(device)
  optimizer = torch.optim.Adam(conditional_nn.parameters(), lr = lr)
  scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda epoch: lr_lambda ** epoch)
  criterion = nn.MSELoss()
  train_losses, test_losses = training_loop_VAENN(piano_vae, harmony_vae, conditional_nn, optimizer, scheduler, criterion, cond_train_loader, cond_test_loader, n_epochs = n_epochs)

  model_name = 'VAE_NN_{}_{}'.format(instrument, K)
  save_path = os.path.join(root_dir, 'Saved Models', 'VAE', 'Large Data', model_name)
  torch.save(conditional_nn.state_dict(), save_path)

  plt.figure()
  plt.plot(train_losses, label = 'Train Loss')
  plt.plot(test_losses, label = 'Test Loss')
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  plt.legend()
  plt.show()

### Using Trained VAEs and NNs to Generate Music

In [None]:
# @title Load trained VAEs and NNs (Run me)

# Specify dimensionality of VAEs you want (K = 8, 16, 32, 64)
K = 8

# Load VAEs
model_name = 'VAE_piano_{}'.format(K)
save_path = os.path.join(root_dir, 'Saved Models', 'VAE', model_name)
piano_vae = ConvVAE(K=K).to(device)
piano_vae.load_state_dict(torch.load(save_path))
piano_vae.eval()

model_name = 'VAE_guitar_{}'.format(K)
save_path = os.path.join(root_dir, 'Saved Models', 'VAE', model_name)
guitar_vae = ConvVAE(K=K).to(device)
guitar_vae.load_state_dict(torch.load(save_path))
guitar_vae.eval()

model_name = 'VAE_bass_{}'.format(K)
save_path = os.path.join(root_dir, 'Saved Models', 'VAE', model_name)
bass_vae = ConvVAE(K=K).to(device)
bass_vae.load_state_dict(torch.load(save_path))
bass_vae.eval()

model_name = 'VAE_strings_{}'.format(K)
save_path = os.path.join(root_dir, 'Saved Models', 'VAE', model_name)
strings_vae = ConvVAE(K=K).to(device)
strings_vae.load_state_dict(torch.load(save_path))
strings_vae.eval()

# Load Melody NN
model_name = 'VAE_NN_piano_{}'.format(K)
save_path = os.path.join(root_dir, 'Saved Models', 'VAE', model_name)
melody_nn = MelodyNN(K = K).to(device)
melody_nn.load_state_dict(torch.load(save_path))
melody_nn.eval()

# Load Conditional NNs
model_name = 'VAE_NN_guitar_{}'.format(K)
save_path = os.path.join(root_dir, 'Saved Models', 'VAE', model_name)
guitar_nn = ConditionalNN(K = K).to(device)
guitar_nn.load_state_dict(torch.load(save_path))
guitar_nn.eval()

model_name = 'VAE_NN_bass_{}'.format(K)
save_path = os.path.join(root_dir, 'Saved Models', 'VAE', model_name)
bass_nn = ConditionalNN(K = K).to(device)
bass_nn.load_state_dict(torch.load(save_path))
guitar_nn.eval()

model_name = 'VAE_NN_strings_{}'.format(K)
save_path = os.path.join(root_dir, 'Saved Models', 'VAE', model_name)
strings_nn = ConditionalNN(K = K).to(device)
strings_nn.load_state_dict(torch.load(save_path))
strings_nn.eval()

model_name = 'VAE_NN_drums_{}'.format(K)
save_path = os.path.join(root_dir, 'Saved Models', 'VAE', model_name)
drums_nn = ConditionalNN(K = K).to(device)
drums_nn.load_state_dict(torch.load(save_path))
drums_nn.eval()

In [None]:
# @title Generation Dataset

# all fields cannot be blank
class GenerationDataset(Dataset):
  def __init__(self, list_of_sequences, dataset_length = 32 * 10000, seq_length = 50):

    # Don't normalize anymore since it was done earlier
    self.data = list_of_sequences
    self.n_songs = len(list_of_sequences)
    self.seq_length = seq_length
    self.length = dataset_length

  def __getitem__(self, index):

    valid_sequence = False
    while valid_sequence == False:
      # Choose a random song id
      song_id = random.randint(0, self.n_songs - 1)
      song_length = self.data[song_id].size()[1]
      # Choose a random start window
      start_time = random.randint(0, song_length - self.seq_length * 2 - 2)
      start_time = start_time - (start_time % 32)
      # Check that every track is not empty
      piano_sequence = self.data[song_id][0, start_time:(start_time + self.seq_length), :]
      guitar_sequence = self.data[song_id][1, start_time:(start_time + self.seq_length), :]
      bass_sequence = self.data[song_id][2, start_time:(start_time + self.seq_length), :]
      strings_sequence = self.data[song_id][3, start_time:(start_time + self.seq_length), :]
      drums_sequence = self.data[song_id][4, start_time:(start_time + self.seq_length), :]

      if piano_sequence.sum() != 0 and guitar_sequence.sum() != 0 and bass_sequence.sum() != 0 \
      and strings_sequence.sum() != 0 and drums_sequence.sum() != 0:
        valid_sequence = True
      else:
        if random.random() < 0.1:
          valid_sequence = True

    train_sequence = self.data[song_id][:, start_time:(start_time + self.seq_length), :]
    target_sequence = self.data[song_id][:, (start_time + self.seq_length):(start_time + self.seq_length * 2), :]
    return train_sequence, target_sequence

  def __len__(self):
    return self.length

In [None]:
# Create dataset and dataloader that gives entire multitrack chunks of music
gen_dataset = GenerationDataset(pianorolls_list, dataset_length = 32 * 8000, seq_length = 32)
gen_loader = DataLoader(gen_dataset, batch_size = 5, drop_last=True)

In [None]:
train, test = next(iter(gen_loader))
sample = train[0, :, :, :]
# sample is a 5 x 32 x 128

In [None]:
# Function that takes in previous length-32 5-instrument sequence, and generates a new length-32 5-instrument sequence
# Input_sequence: 5 x 32 x 128, vae_models: 5-tuple of the trained VAEs, nn_models: 5-tuple of trained NNs 
# threshold as a % of the maximum output that we will keep the notes there (those under threshold will be moved to 0)
# if binarize = True, will set all remaining non-zero to max intensity
def generate_music_vae(sample, vae_models, nn_models, noise_sd = 0, threshold = 0.3, binarize = True):

  piano_vae, guitar_vae, bass_vae, strings_vae, drums_vae = vae_models
  melody_nn, guitar_nn, bass_nn, strings_nn, drums_nn = nn_models

  piano, guitar, bass, strings, drums = sample[0, :, :], sample[1, :, :], sample[2, :, :], sample[3, :, :], sample[4, :, :]

  # Convert all part from image space to latent space - {instr}_latent: batch_size x K
  piano_latent = piano_vae.infer(piano.unsqueeze(0).to(device))[:, :-1]
  guitar_latent = guitar_vae.infer(guitar.unsqueeze(0).to(device))[:, :-1]
  bass_latent = bass_vae.infer(bass.unsqueeze(0).to(device))[:, :-1]
  strings_latent = strings_vae.infer(strings.unsqueeze(0).to(device))[:, :-1]
  drums_latent = drums_vae.infer(drums.unsqueeze(0).to(device))[:, :-1]

  # Use melody NN to convert past piano latent to next piano latent - piano_next_latent: batch_size x K
  piano_next_latent = melody_nn(piano_latent)
  # Add some noise
  random_noise = torch.randn_like(piano_next_latent) * noise_sd
  piano_next_latent = piano_next_latent + random_noise

  # Use conditional NNs to convert piano latent to instrument latent, and add noise - {istr})_next_latent: batch_size x K
  guitar_next_latent = guitar_nn(guitar_latent, piano_next_latent) + torch.randn_like(piano_next_latent) * noise_sd
  bass_next_latent = bass_nn(bass_latent, piano_next_latent) + torch.randn_like(piano_next_latent) * noise_sd
  strings_next_latent = strings_nn(strings_latent, piano_next_latent) + torch.randn_like(piano_next_latent) * noise_sd
  drums_next_latent = drums_nn(drums_latent, piano_next_latent) + torch.randn_like(piano_next_latent) * noise_sd

  # Generate new samples given new latent
  piano_next = piano_vae.generate(piano_next_latent.unsqueeze(0)).view(1, 32, 128)
  guitar_next = guitar_vae.generate(guitar_next_latent.unsqueeze(0)).view(1, 32, 128)
  bass_next = bass_vae.generate(bass_next_latent.unsqueeze(0)).view(1, 32, 128)
  strings_next = strings_vae.generate(strings_next_latent.unsqueeze(0)).view(1, 32, 128)
  drums_next = drums_vae.generate(drums_next_latent.unsqueeze(0)).view(1, 32, 128)

  creation = torch.cat((piano_next, guitar_next, bass_next, strings_next, drums_next), dim = 0)
  creation[creation < threshold] = 0

  if binarize == True:
    creation[creation > 0] = 0.8

    # Quieten the strings
    creation[3, :, :] = creation[3, :, :] * 0.75

  return creation

In [None]:
vae_models = (piano_vae, guitar_vae, bass_vae, strings_vae, drums_vae)
nn_models = (melody_nn, guitar_nn, bass_nn, strings_nn, drums_nn)

# Code to essentially recurrently generate music
prediction_steps = 10
generated_track = torch.zeros((5, 32 * (prediction_steps + 1), 128)).to(device)
generated_track[:, :32, :] = sample

for i in range(1, prediction_steps + 1):
  sample = generate_music_vae(sample, vae_models, nn_models, noise_sd = 1, threshold = 0.3, binarize = True)
  generated_track[:,32*i:32*(i+1) , :] = sample

In [None]:
# Only un-normalize here
generated_track_out = generated_track * 127
# Convert predictions into the multitrack pianoroll
piano_track = pypianoroll.StandardTrack(name = 'Piano', program = 0, is_drum = False, pianoroll = generated_track_out[0, :, :].detach().cpu().numpy())
guitar_track = pypianoroll.StandardTrack(name = 'Guitar', program = 24, is_drum = False, pianoroll = generated_track_out[1, :, :].detach().cpu().numpy())
bass_track = pypianoroll.StandardTrack(name = 'Bass', program = 32, is_drum = False, pianoroll = generated_track_out[2, :, :].cpu().detach().numpy())
strings_track = pypianoroll.StandardTrack(name = 'Strings', program = 48, is_drum = False, pianoroll = generated_track_out[3, :, :].cpu().detach().numpy())
drums_track = pypianoroll.StandardTrack(name = 'Drums', is_drum = True, pianoroll = generated_track_out[4, :, :].cpu().detach().numpy())
generated_multitrack = pypianoroll.Multitrack(name = 'Generated', resolution = 2, tracks = [piano_track, guitar_track, bass_track, strings_track, drums_track])
generated_multitrack.plot()

In [None]:
generated_pm = pypianoroll.to_pretty_midi(generated_multitrack)
generated_midi_audio = generated_pm.fluidsynth()
IPython.display.Audio(generated_midi_audio, rate = 44100)

In [None]:
generated_path = os.path.join(root_dir, 'Generated MIDIs', 'vae_28apr_good2.mid')
pypianoroll.write(generated_path, generated_multitrack)

In [None]:
# Generate VAE (old function)

def generate_samples_vae(samples, instrument_vae, noise_sd):
  # Expects samples: batch_size x 32 x 128
  batch_size = samples.size(0)
  # Convert samples array into latent vectors
  latent = instrument_vae.infer(samples.to(device))[:, :-1]
  # Add some random noise
  random_noise = torch.randn_like(latent) * noise_sd
  latent_after_noise = latent + random_noise
  # Convert latent samples into array
  generated_samples = instrument_vae.generate(latent_after_noise.unsqueeze(0)).squeeze(0).view(batch_size, 32, 128)
  return generated_samples