## TODOS

1. How to handle multiple notes at same time-step? (e.g. chords, drumming)
2. How to handle note loudness?
3. How to handle silence?  
4. How to seed different predictions? 
5. Select `key` of song
6. How to handle sections (e.g. verse, chorus, etc.)?
	- Need to do `pattern analysis`

In [1]:
import platform; platform.mac_ver()
# Should be 12.3 or greater, and 'arm64'

('14.5', ('', '', ''), 'arm64')

In [2]:
# External Imports
import matplotlib.pyplot as plt
import torch as torch
import numpy as np
import pypianoroll as pr

In [3]:
# Internal Imports
import sys, os
sys.path.append(os.path.abspath('/src'))

from src.util.types import Song, PianoState, NoteSample, PianoStateSamples
from src.util.globals import resolution, beats_per_bar, num_pitches, DEVICE
from src.util.convert import (
	output_piannoroll_to_midi
)
from src.util.plot import plot_pianoroll, plot_piano_states, plot_note_sample_probs, plot_track
from src.models import MusicRNN, MusicRNN_Batched, MusicLSTM
from src.models.train import train, train_batched
from src.models.infer import sample_notes

from src.dataset.dataset import InstrumentDataset, get_dataloader
from src.dataset.load import (
    get_songs,
    load_multi_track,
    get_track_by_instrument,
    get_samples
)

In [4]:
# TEST: Plot a specific file/track
if False:
	desired_instrument = 'Guitar'
	multi_track = load_multi_track(f'A/A/A/TRAAAGR128F425B14B/b97c529ab9ef783a849b896816001748.npz')
	pr.plot_multitrack(multi_track, axs=None, mode='blended')

	track = get_track_by_instrument(multi_track, desired_instrument)

	if track:
		plot_track(track, desired_instrument, True, 4)
	else:
		print('No track found')

## Create the model

In [5]:
if False:
	basic_model = MusicRNN(
		hidden_size=128,
		num_pitches=129,  # 0-128 notes (including silence at 0)
		num_layers=2,
		dropout=0.1
	).to(DEVICE)

In [6]:
if False:
	batched_model = MusicRNN_Batched(
		hidden_size=128,
		num_pitches=129,  # 0-128 notes (including silence at 0)
		num_layers=2,
		dropout=0.1
	)

	batched_model = batched_model.to(DEVICE)

In [17]:
if True:
	lstm_model = MusicLSTM(
		hidden_size=256,
		num_pitches=num_pitches+1,  # 0-128 notes (including silence at 0)
		num_layers=2,
		dropout=0.1
	).to(DEVICE)

## Testing (toy data)

In [8]:
# Create some simple test sequences
test_seq_1 = torch.Tensor([
	[1, 10],
	[2, 10],
 	[3, 10],
  	[4, 10],
   	[5, 10],
    [4, 10],
    [3, 10],
    [2, 10],
    [1, 10],
	[2, 10],
 	[3, 10],
  	[4, 10],
   	[5, 10],
    [4, 10],
    [3, 10],
    [2, 10],
    [1, 10],
]).float()

test_seq_2 = torch.Tensor([
	[1, 10],
	[2, 10],
 	[3, 10],
  	[4, 10],
   	[5, 10]
]).float()

In [9]:
# TEST Overfit on 1 sequence
if False:
	model = lstm_model

	seq = test_seq_1
	start_notes = seq[0]
	max_len = 100

	if True:
		train_batched(model, [seq], num_epochs=5000, lr=0.0001)

		# Test sampling a sequence
		piano_state_samples = sample_notes(model, start_notes, max_len)

		plot_piano_states(seq, None, 'Real Sequence')
		plot_piano_states(piano_state_samples.piano_states, None, 'Generated Sequence')
		plot_note_sample_probs(piano_state_samples.note_samples)

In [10]:
# Test Packed Model
seqs = [test_seq_1, test_seq_2]
start_notes = seqs[0][0]
max_len = 100

if False:
	train_batched(
     	batched_model,
      	seqs,
		batch_size=2,
       	num_epochs=1000,
        lr=0.0001
    )

In [11]:
if False:
    # Test sampling a sequence
	piano_state_samples = sample_notes(
     	batched_model,
      	start_notes,
       	max_len,
		temperature=0.3
    )

	# plot_piano_states(seq, None, 'Real Sequence')
	plot_piano_states(piano_state_samples.piano_states, None, 'Generated Sequence')
	plot_note_sample_probs(piano_state_samples.note_samples)

## Testing (Real Data)

In [18]:
dataset = InstrumentDataset(
	instrument='Bass',
	max_samples=20,
)
trainloader = get_dataloader(dataset, 5)

Found 21425 total files
Got 20 total sequences for instrument "Bass"


In [19]:
# Choose your model
model = lstm_model
print(model)
print(f"# Parameters: {sum(p.numel() for p in model.parameters())}")

MusicLSTM(
  (rnn): LSTM(2, 256, num_layers=2, bias=False, batch_first=True, dropout=0.1)
  (note_head): Linear(in_features=256, out_features=130, bias=True)
  (duration_head): Linear(in_features=256, out_features=1, bias=True)
)
# Parameters: 822147


In [None]:
if True:
	train_batched(
		model,
		trainloader,
		num_epochs=200,
		lr=0.001
	)

Training on data set with n = 4


In [None]:
seq = dataset[0]
if True:
	predictions = sample_notes(
		model,
		start_event=torch.Tensor(seq[0]),
		length=beats_per_bar*1,
		temperature=0.3
	)

	plot_note_sample_probs(predictions.note_samples)
	output_piannoroll_to_midi(
     	predictions.piano_states,
		instrument='Guitar',
     	name='generated_guitar2'
    )