## Arousal style transfer code demo
This notebook provides demonstration on performing arousal style transfer
The pre-trained model used here is trained on single bar segments (3-6 seconds).

### Installing dependencies

In [None]:
# Install dependencies for new version of Magenta
!pip install torch
!pip install tensorflow==1.15.0
!pip install sklearn, numpy, matplotlib
!pip install pretty_midi
!pip install pypianoroll
!pip install music21
!pip install pygtrie
!pip install tensor2tensor
!pip install pyfluidsynth

In [None]:
# Clone fork of gudgud96 for Magenta
!git clone https://github.com/gudgud96/magenta.git

# Setup Magenta environment
!cd magenta
!pip install -e .
!cd ..

# Rearrange Magenta folder.
!mv magenta/magenta magenta_core
!rm -r magenta
!mv magenta_core magenta

In [None]:
# Download pre-processed melody chunks of VGMIDI dataset
!mkdir data
!mkdir data/filtered_songs_disambiguate
!wget https://github.com/gudgud96/music-fader-nets/releases/download/1.0.0/arousal_lst.npy
!wget https://github.com/gudgud96/music-fader-nets/releases/download/1.0.0/chroma_lst.npy
!wget https://github.com/gudgud96/music-fader-nets/releases/download/1.0.0/note_lst.npy
!wget https://github.com/gudgud96/music-fader-nets/releases/download/1.0.0/rhythm_lst.npy
!wget https://github.com/gudgud96/music-fader-nets/releases/download/1.0.0/song_tokens.npy
!wget https://github.com/gudgud96/music-fader-nets/releases/download/1.0.0/valence_lst.npy
!mv *.npy data/filtered_songs_disambiguate

### Importing libraries

In [1]:
import json
import torch
from gmm_model import *
import os
from sklearn.model_selection import train_test_split
from ptb_v2 import *
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pretty_midi
from IPython.display import Audio
from tqdm import tqdm
from polyphonic_event_based_v2 import *
from collections import Counter
import matplotlib.pyplot as plt
from polyphonic_event_based_v2 import parse_pretty_midi

### Load dataset and models

In [3]:
# Initialization
with open('gmm_model_config.json') as f:
    args = json.load(f)
    if torch.cuda.is_available():
        args['device'] = 'cuda' 
    else: 
        args['device'] = 'cpu'
    
if not os.path.isdir('log'):
    os.mkdir('log')
if not os.path.isdir('params'):
    os.mkdir('params')

    
from datetime import datetime
timestamp = str(datetime.now())
save_path_timing = 'params/{}.pt'.format(args['name'] + "_" + timestamp)


# Model dimensions
EVENT_DIMS = 342
RHYTHM_DIMS = 3
NOTE_DIMS = 16
CHROMA_DIMS = 24

# Load model
model = MusicAttrRegGMVAE(roll_dims=EVENT_DIMS, rhythm_dims=RHYTHM_DIMS, note_dims=NOTE_DIMS, 
                        chroma_dims=CHROMA_DIMS,
                        hidden_dims=args['hidden_dim'], z_dims=args['z_dim'], 
                        n_step=args['time_step'],
                        n_component=2,
                        device=args['device'])  
model.load_state_dict(torch.load("params/music_attr_vae_reg_gmm.pt"))
print("Loading params/music_attr_vae_reg_gmm.pt...")
    
if torch.cuda.is_available():
    print('Using: ', torch.cuda.get_device_name(torch.cuda.current_device()))
    model.cuda()
else:
    print('CPU mode')
    
step, pre_epoch = 0, 0
batch_size = args["batch_size"]
is_shuffle = False

# ================ In this example, we will load only the examples from VGMIDI dataset ========== #
# print("Loading Yamaha...")
# data_lst, rhythm_lst, note_density_lst, chroma_lst = get_classic_piano()
# tlen, vlen = int(0.8 * len(data_lst)), int(0.9 * len(data_lst))
# train_ds_dist = YamahaDataset(data_lst, rhythm_lst, note_density_lst, 
#                                 chroma_lst, mode="train")
# train_dl_dist = DataLoader(train_ds_dist, batch_size=batch_size, shuffle=is_shuffle, num_workers=0)
# val_ds_dist = YamahaDataset(data_lst, rhythm_lst, note_density_lst, 
#                                 chroma_lst, mode="val")
# val_dl_dist = DataLoader(val_ds_dist, batch_size=batch_size, shuffle=is_shuffle, num_workers=0)
# test_ds_dist = YamahaDataset(data_lst, rhythm_lst, note_density_lst, 
#                                 chroma_lst, mode="test")
# test_dl_dist = DataLoader(test_ds_dist, batch_size=batch_size, shuffle=is_shuffle, num_workers=0)
# dl = train_dl_dist
# print(len(train_ds_dist), len(val_ds_dist), len(test_ds_dist))

# vgmidi dataloaders
print("Loading VGMIDI...")
data_lst, rhythm_lst, note_density_lst, chroma_lst, arousal_lst, valence_lst = get_vgmidi()
vgm_train_ds_dist = VGMIDIDataset(data_lst, rhythm_lst, note_density_lst, 
                                chroma_lst, arousal_lst, valence_lst, mode="train")
vgm_train_dl_dist = DataLoader(vgm_train_ds_dist, batch_size=32, shuffle=is_shuffle, num_workers=0)
vgm_val_ds_dist = VGMIDIDataset(data_lst, rhythm_lst, note_density_lst, 
                                chroma_lst, arousal_lst, valence_lst, mode="val")
vgm_val_dl_dist = DataLoader(vgm_val_ds_dist, batch_size=32, shuffle=is_shuffle, num_workers=0)
vgm_test_ds_dist = VGMIDIDataset(data_lst, rhythm_lst, note_density_lst, 
                                chroma_lst, arousal_lst, valence_lst, mode="test")
vgm_test_dl_dist = DataLoader(vgm_test_ds_dist, batch_size=32, shuffle=is_shuffle, num_workers=0)

cpu
Loading params/music_attr_vae_reg_gmm.pt...
CPU mode
Loading VGMIDI...
Shapes for: Data, Rhythm Density, Note Density, Chroma
(1013,) (1013,) (1013,) (1013, 24)
Shapes for: Arousal, Valence
(1013,) (1013,)


In [4]:
def convert_to_one_hot(input, dims):
    if len(input.shape) > 1:
        input_oh = torch.zeros((input.shape[0], input.shape[1], dims), device=args['device'])
        input_oh = input_oh.scatter_(-1, input.unsqueeze(-1), 1.)
    else:
        input_oh = torch.zeros((input.shape[0], dims),device=args['device'])
        input_oh = input_oh.scatter_(-1, input.unsqueeze(-1), 1.)
    return input_oh

def clean_output(out):
    recon = np.trim_zeros(torch.argmax(out, dim=-1).cpu().detach().numpy().squeeze())
    if 1 in recon:
        last_idx = np.argwhere(recon == 1)[0][0]
        recon[recon == 1] = 0
        recon = recon[:last_idx]
    return recon

def repar(mu, stddev, sigma=1):
    eps = Normal(0, sigma).sample(sample_shape=stddev.size()).to(device=args['device'])
    z = mu + stddev * eps  # reparameterization trick
    return z

### Obtain "shifting vectors"

In [5]:
# shifting vectors are obtained by getting the pre-trained mean vector from low arousal and high arousal cluster
# low --> high: mean_high - mean_low, vice versa
# we need shifting vectors for both rhythm and note space

mu_r_lst = []
var_r_lst = []
mu_n_lst = []
var_n_lst = []
for k_i in torch.arange(0, 2):
    mu_k =  model.mu_r_lookup(k_i.to(device=args['device']))
    mu_r_lst.append(mu_k.cpu().detach())
    
    var_k = model.logvar_r_lookup(k_i.to(device=args['device'])).exp_()
    var_r_lst.append(var_k.cpu().detach())
    
    mu_k = model.mu_n_lookup(k_i.to(device=args['device']))
    mu_n_lst.append(mu_k.cpu().detach())
    
    var_k = model.logvar_n_lookup(k_i.to(device=args['device'])).exp_()
    var_n_lst.append(var_k.cpu().detach())

r_low_to_high = mu_r_lst[1] - mu_r_lst[0]
r_high_to_low = mu_r_lst[0] - mu_r_lst[1]
n_low_to_high = mu_n_lst[1] - mu_n_lst[0]
n_high_to_low = mu_n_lst[0] - mu_n_lst[1]

### Load base melody

In [6]:
# Here, we use a melody segment from the VGMIDI test set
# Choose any number between 0 - 51 for `idx` variable
# Alternatively, you can also encode your desired melody segment using `magenta_encode_midi` in `ptb_v2.py`
# and use the token sequence as `d` here
idx = 5
d, r, n, a, v, c, r_density, n_density = vgm_test_ds_dist[idx]
c = torch.Tensor(c).to(device=args['device']).unsqueeze(0)

# Print the encoded event tokens
eos_index = np.where(d==1)[0][0]
print("Input tokens:", d.int().numpy()[:eos_index])

# Decode it into MIDI and listen the segment
# Note: you need to pre-install fluidsynth (using apt-get on linux) and pyfluidsynth (using pip)
dim =  torch.Tensor(d).to(device=args['device']).long()
d_oh = convert_to_one_hot(dim, EVENT_DIMS)
pm = magenta_decode_midi(d.int().numpy()[:eos_index])
a_1 = pm.fluidsynth()
Audio(a_1, rate=44100)

Input tokens: [325  34  50 201 122 138 178  33  46 309  50 201 121 134 138 178 325  33
 201 121 178  30  46  50 201 118 134 138 178  30 201 118 178  29  46  50
 201 117 134 138 178  29 202  46  50 201 117 134 138 178  29 189 117 178
  30 188 118 178  29  46  50 189 117 178  30 188 134 138 118 178  29 201
 117 178  30  46  50 201 118 134 138 178  29 202  46  50 201 134 138 202
 117 178  30  46  50 201 118 134]


### Low arousal --> high arousal

In [13]:
model.eval()
dis_r, dis_n = model.encode(d_oh.unsqueeze(0))
z_r = dis_r.rsample()
z_n = dis_n.rsample()

# lmbda is a parameter for you to control `how much` is the extent of transfer
# if you think the transferred arousal of output is not high enough, increase lmbda (and vice versa)
lmbda = 5
r_lowhi = torch.Tensor(r_low_to_high).to(device=args['device'])
n_lowhi = torch.Tensor(n_low_to_high).to(device=args['device'])
z_r_new = z_r + lmbda*r_lowhi
z_n_new = z_n + lmbda*n_lowhi

z = torch.cat([z_r_new, z_n_new, c], dim=1)        
out = model.global_decoder(z, steps=300)
print("Tokens:", clean_output(out))

# Listen to the transferred output
pm = magenta_decode_midi(clean_output(out))
a_1 = pm.fluidsynth()
Audio(a_1, rate=44100)

Tokens: [180 180 178 309  42 192 130 178 303  48 188 136 179 319  31 179 119 178
 299  54 188 142 178  50 179 138 178 315  38 188 126 178 315  55 188 139
 178  55 189 143 178 317  27 190 115 178 315  31 192 119 178 315  34 192
 122 178 303  55 189 125 178 305  31 178 143 180 119 178 315  55 189 143
 178 317  55 178 143 180 309  50 178 315  58 188 138 178 303  48 180 146
 178  50 180 136 178  46 180 134 178 341  31 178 315  55 188 119 178 337
  34 317  50 189 122 178 303  31 180 138 178 119 180 143 178 315  31 303
  46 188 134 178 119 178 317  31 305  46 188 134 178 119 178  47 179 135
 178 315  46 192 134 178 303  41 192 129 178  31 317  50 191 119 178 138
 178 317  48 178 315  46 192 134 178 303  46 192 134 178 136 178  41 315
  50 192 129 178 303  46 180 134 178 138 178 317  50 192 138 178 315  46
 190 134 178 303  50 192 138 178  46 315  50 189 134 178 138 178 315  46
 303  50 192 134 178 303  46 180 138 178  50 188 134 178 138 180 309  48
 190 136 178 315  46  50 189 134 138 178 31

### Example on high arousal -> low arousal

In [17]:
model.eval()
dis_r, dis_n = model.encode(d_oh.unsqueeze(0))
z_r = dis_r.rsample()
z_n = dis_n.rsample()

# lmbda is a parameter for you to control `how much` is the extent of transfer
# if you think the transferred arousal of output is not low enough, increase lmbda (and vice versa)
lmbda = 5
r_hilow = torch.Tensor(r_high_to_low).to(device=args['device'])
n_hilow = torch.Tensor(n_high_to_low).to(device=args['device'])
z_r_new = z_r + lmbda*r_hilow
z_n_new = z_n + lmbda*n_hilow

z = torch.cat([z_r_new, z_n_new, c], dim=1)        
out = model.global_decoder(z, steps=300)
print("Tokens:", clean_output(out))

# Listen to the transferred output
pm = magenta_decode_midi(clean_output(out))
a_1 = pm.fluidsynth()
Audio(a_1, rate=44100)

Tokens: [277 252 332  22  50  52  55 277 143 208 131 139 136 233  35 208]
