# Import libraries

In [1]:
import os
import sys

# Get the path of the current notebook
notebook_path = os.path.abspath(os.path.dirname(os.getcwd()))

# Add the project root directory to the Python path
project_root = os.path.abspath(os.path.join(notebook_path, '..'))
sys.path.append(project_root)

from trainer import Trainer

import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


# Models

In [2]:
from gan import TemporalNetwork, BarGenerator, MuseCritic, MuseGenerator

# TempNetwork

In [3]:
tempnet = TemporalNetwork()
x = torch.rand(10, 32)
tempnet(x).shape

torch.Size([10, 32, 2])

In [4]:
print('Number of parameters:', sum(p.numel() for p in tempnet.parameters()))

Number of parameters: 101472


# BarGenerator

In [5]:
bargenerator = BarGenerator()
a = torch.rand(10, 128)
bargenerator(a).shape

torch.Size([10, 1, 1, 16, 84])

In [6]:
print('Number of parameters:', sum(p.numel() for p in bargenerator.parameters()))

Number of parameters: 1517313


# MuseGenerator

In [7]:
generator = MuseGenerator()

cords = torch.rand(10, 32)
style = torch.rand(10, 32)
melody = torch.rand(10, 4, 32)
groove = torch.rand(10, 4, 32)

generator(cords, style, melody, groove).shape

torch.Size([10, 4, 2, 16, 84])

In [8]:
print('Number of parameters:', sum(p.numel() for p in generator.parameters()))

Number of parameters: 6576612


# MuseCritic

In [9]:
critic = MuseCritic()
a = torch.rand(10, 4, 2, 16, 84)
critic(a).shape

torch.Size([10, 1])

In [10]:
print('Number of parameters:', sum(p.numel() for p in critic.parameters()))

Number of parameters: 1446401


# Dataset Loader

In [11]:
from torch.utils.data import DataLoader
from data.utils import MidiDataset

dataset = MidiDataset(path='data/chorales/Jsb16thSeparated.npz')
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)

# Define Models

### Generator

In [12]:
generator = MuseGenerator(z_dimension=32, hid_channels=1024, hid_features=1024, out_channels=1)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001, betas=(0.5, 0.9))

### Critic

In [13]:
critic = MuseCritic(hid_channels=128,
                    hid_features=1024,
                    out_features=1)
c_optimizer = torch.optim.Adam(critic.parameters(), lr=0.001, betas=(0.5, 0.9))

In [14]:
from gan import initialize_weights

generator = generator.apply(initialize_weights)
critic = critic.apply(initialize_weights)

# Training

In [15]:

from trainer import Trainer

In [16]:
trainer = Trainer(generator, critic, g_optimizer, c_optimizer)

In [None]:
trainer.train(dataloader, epochs=1000)

  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 0/1000 | Generator loss: -891147.205 | Critic loss: -296101.683
(fake: -135871.495, real: -202826.390, penalty: 42596.218)
Epoch 10/1000 | Generator loss: -38551791616.000 | Critic loss: 8955142963.200
(fake: -39086785331.200, real: -39089809681.067, penalty: 87131734589.333)
Epoch 20/1000 | Generator loss: -99637873322.667 | Critic loss: -79694769902.933
(fake: -101680841796.267, real: -101678271692.800, penalty: 123664343637.333)
Epoch 30/1000 | Generator loss: -172607485269.333 | Critic loss: -124251720908.800
(fake: -166231367133.867, real: -166224548113.067, penalty: 208204196682.667)
Epoch 40/1000 | Generator loss: -242862019925.333 | Critic loss: 60053957290.667
(fake: -248698150365.867, real: -248682942737.067, penalty: 557435065765.333)
Epoch 50/1000 | Generator loss: -380094797141.333 | Critic loss: -383333892096.000
(fake: -377828042342.400, real: -377830239982.933, penalty: 372324397440.000)
Epoch 60/1000 | Generator loss: -482621139626.667 | Critic loss: 129627912055

In [None]:
losses = trainer.data.copy()

### Save losses

In [None]:
import pandas as pd

df = pd.DataFrame.from_dict(losses)
#df.to_csv('results.csv', index=False)

### Save models

In [None]:
# generator = generator.eval().cpu()
# critic = critic.eval().cpu()
#torch.save(generator, 'generator_e1000.pt')
#torch.save(critic, 'critic_e1000.pt')

### Plot losses

In [None]:
plt.figure(figsize=(12, 8))
plt.plot(losses['gloss'][:500], 'orange', label='generator')
plt.plot(losses['cfloss'][:500], 'r', label='critic fake')
plt.plot(losses['crloss'][:500], 'g', label='critic real')
plt.plot(losses['cploss'][:500], 'b', label='critic penalty')
plt.plot(losses['closs'][:500], 'm', label='critic')
plt.xlabel('epoch', fontsize=12)
plt.ylabel('loss', fontsize=12)
plt.grid()
plt.legend()
plt.show()
#plt.savefig('losses.png')

In [None]:
generator = generator.eval().cpu()
#generator = torch.load('generator_e1000.pt')

## Make prediction

In [None]:
chords = torch.rand(1, 32)
style = torch.rand(1, 32)
melody = torch.rand(1, 4, 32)
groove = torch.rand(1, 4, 32)

preds = generator(chords, style, melody, groove).detach()

### Get music data

In [None]:
from data.utils import postProcess

preds = preds.numpy()
music_data = postProcess(preds)

### Save file

In [None]:
filename = 'myexample.midi'
music_data.write('midi', fp=filename)