## Import Things so the Code Runs
(Adjust your drive path!)

In [1]:
drive = False  # False = Local
if drive:
    !pip install pretty_midi
    !pip install -U sparsemax



In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pretty_midi
#this package is used to write it back into music.
from mido import Message, MidiFile, MidiTrack
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import random
import torch.distributions
import sparsemax

In [3]:
if drive:
    from google.colab import drive
    drive.mount('/content/drive')
    my_drive_path = '/content/drive/MyDrive/2022_Special_Studies_Hablutzel/ModelCopies/'
else: # local
    my_drive_path = "./"

## The Model

In [4]:
#this is the model
class music_generator(nn.Module):
    def __init__(self, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size  # 128
        # output_size is num expected features (128)
        self.lstm = nn.LSTM(output_size, hidden_size, num_layers=1, bidirectional=False)
        self.attention = nn.Linear(2, 1)
        self.softmax = sparsemax.Sparsemax(dim=1)
        self.sigmoid = nn.Sigmoid()
        self.hidden = None

    def init_hidden(self, batch_size):
      # set hidden state to zeros after each batch #
      hidden = (torch.zeros(1, batch_size, self.hidden_size))  # [layers, batch_size, hidden_size/features]
      hidden = (hidden, hidden)  # hidden_state, cell_state
      return hidden

    def forward(self, input, batch_size, ssm):
        # look at tensor things - view vs. reshape vs. permute, and unsqueeze and squeeze
        # try looking at the LSTM equations
        # .to('cpu')  # returns a copy of the tensor in CPU memory
        # .to('cuda:0')  # returns copu in CUDA memory, 0 indicates first GPU device
        # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.to

        # sequence length
        # size of input (10 or 1)
        sequence_length = input.size()[0]

        # Run the LSTM

        #########
        ###  is this setting a new lstm each time - no
        ######## but check how to pass hidden states
        #########

        # output - sequence of all the hidden states
        # hidden - most recent hidden state
        # input dimensions: [sequence_length, batch_size, 128]
        output, self.hidden = self.lstm(input.to("cuda:0").float(), self.hidden) #0].to("cuda:0").float(), hidden[1].to("cuda:0").float()))
        # output dimensions: [sequence_length, batch_size, 128]
        # outputs as many beats (sequence_length) as there were beats in the input
        # hidden: last hidden states from last beat

        # Get the output
        output_1 = output.to('cpu') # [-5:,:,:]  # don't just take last 5
        # take out this line after finish LSTM (other than reshaping) - unsqueeze to get right reshaping


        #########################
        # attention starts here #
        #########################
        
        # output without attention
        new_output = output_1.view(sequence_length, batch_size, 1, 128)  # reshape
        avg_output = torch.sum(new_output, 2)  # return this for base LSTM (w/o attention)
        
        #this variable holds the output after the attention has been applied.
        seqs = []

        # generate one step at a time from the previous 5 steps, hidden state, and cell state
        # generate x_t from hidden state, cell state from x_(t-1), inputs {x_(t-1)...(t-5)} --> get hidden state, cell state, output for next

        ######
        # ADJUST to use all of the SSM up to this beat
        ######
        # Find the right place in the self-similarity matrix.
        index = input.shape[0]
        weights = self.softmax(ssm[range(index, ssm.shape[0], ssm.shape[1]), :index])  # check SSM shape???

        #this is the sparsemaxed SSM multiplied by the entire previous sequence
        # replace .T - see which dimensions we're switching
        weighted = (input.T*weights).T
        #then it's summed to provide weights for each note.
        weight_vec = (torch.sum(weighted, axis=0))
        #This concatenates the weights for each note with the output for that note, which is then run through the linear layer to get the final output.
        pt2 = torch.hstack((weight_vec.unsqueeze(1), avg_output[0,:,:].unsqueeze(1))).transpose(1,2)
        attentioned = self.attention(pt2.float().to("cuda:0")).to('cpu').permute(2,0,1)

        # final added new output to seqs.
        seqs.append(attentioned)

        # seqs is combined into one tensor and returned, along with the hidden and cell states.
        newoutput = torch.vstack(seqs)

        return newoutput.double(), self.hidden  # hidden = hidden_state, cell_state

## For Training the Model and Generating Sequences

In [5]:
class model_trainer():
  def __init__(self, generator, optimizer, data, hidden_size=128, batch_size=50):
    self.generator = generator
    self.optimizer = optimizer
    self.batch_size = batch_size  # play with this
    self.hidden_size = hidden_size  # 128
    self.data = data
    self.data_length = 800

  def train_epochs(self, num_epochs=50, full_training=False):
    #trains each epoch
    losslist = []
    #useful when you want to see the progression of the SSM over time
    piclist = []

    for iter in tqdm(range(0, num_epochs)):
      # start training the generator
      generator.train()
      # for each epoch, zero the gradients
      self.generator.zero_grad()  # was optimizer


      if full_training:
        # use all data
        batches = make_batches(self.batch_size, data)
      else:
        # use first 100 pieces
        # can we overfit on a small dataset? if so, cna be a good thing b/c shows the model can learn
        batches = make_batches(self.batch_size, data[:100])

      cum_loss = 0
      for batch in batches:
        if full_training:
          # train on full-length pieces
          loss = self.train(batch)
        else:
          # train on first 105 beats of each piece
          loss = self.train(batch[:,:105,:])  # [batch, beats, 128]
        cum_loss+=loss
        del loss
      del batches

      # generate example piece for piclist
      snap = self.generate_n_examples(n=1, length=95, starter_notes=10)

      losslist.append(cum_loss) 
      piclist.append(snap)
      del snap

      # after each epoch,
      # run w/ validation
      # if devset (validation) loss goes up for ~5 epochs in a row, early stopping
    return losslist, piclist

  #one round of training
  def train(self, batch):
    #seed vectors for the beginning:
    self_sim = batch_SSM(batch.transpose(0,1), self.batch_size)
    sequence = batch[:,0:10,:].transpose(0,1)  # start w/ some amount of the piece - 10 might be a bit much
    generated = batch[:,0:10,:].transpose(0,1)

    # reset hidden to zeros for each batch
    generator.hidden = generator.init_hidden(self.batch_size)

    # for accumulating loss
    loss = 0

    # first .forward on sequence of 10
    # then loop from there to generate one more element
    next_element = sequence
    print(sequence.shape)

    # take
    for i in range(0,batch.shape[1]-10):  # for each beat
      # iterate through beats 10-400, generating for each piece in the batch as you go
      val = torch.rand(1)  # probability it uses original - teacher forcing

      # teacher forcing - 20% of the time, don't generate, use original from piece instead
      if (val > .8):
        next_element = batch[:,i+1,:].unsqueeze(0)  # [1, 0/deleted, 128] to [1, 1, 128]
      else:  # 80% of the time it adds the output
        output, hidden = self.generator.forward(next_element, hidden, self.batch_size, self_sim)  # should take next_element
        # take last output for each batch
        next_element = topk_batch_sample(output, 1)
      
      sequence = torch.vstack((sequence, next_element.to("cpu"))) # .unsqueeze(0)
      generated = torch.vstack((generated, output))  # used for loss
      del output

    
    # run loss after training on whole length of the pieces in the batches
    single_loss = custom_loss(generated[10:,:,:], batch.transpose(0,1)[10:,:,:])
    single_loss.backward()

    # update the parameters of the LSTM after running on full batch
    self.optimizer.step()

    # maybe move zero_grad() here
    loss += single_loss.detach().to('cpu')
    del next_element
    del self_sim
    del sequence
    del generated
    del hidden
    return (loss)

  def generate_n_pieces(self, initial_vectors, n_pieces, length, batched_ssm):
    #generates a new piece of music
    # createnew hidden layer
    hidden = (torch.randn(1, 1, self.hidden_size)).float()
    hidden = (hidden, hidden)

    # freeze generator so it doesn't train anymore
    self.generator.eval()  
  
    # initial vectors in format [batch_size, num_notes=10, 128]
    # change sequence to [10, batch_size, 128]
    sequence = initial_vectors.transpose(0,1)
    
    # generate [length] more beats for the piece
    for i in range(0, length):  # one at a time
      with torch.no_grad():
        # use n_pieces to generate as the batch size
        output, hidden = self.generator.forward(sequence.float(), hidden, n_pieces, batched_ssm)
        next_element = topk_batch_sample(output, 1)
      # add element to sequence
      sequence = torch.vstack((sequence, next_element.to("cpu")))

    # return sequence of beats
    return sequence
  
  def generate_n_examples(self, n=1, length=390, starter_notes=10, source_piece=0):
    # get piece from the data
    piece = self.data[source_piece][0].unsqueeze(0)  # in format [1, 400, 128]

    # take first 10 notes in format [1, 10, 128]
    first_vec = piece[:,0:starter_notes,:]

    # create batched SSMs for each piece
    batched_ssms = batch_SSM(piece.transpose(0,1), n)

    # generate pieces
    new_gen = self.generate_n_pieces(first_vec, n, length, batched_ssms)

    # clean up variables
    del piece
    del first_vec
    del batched_ssms

    # return pieces
    return new_gen



## Helper Functions

In [6]:
#this function takes in the piece of music and returns the chroma vectors
def get_chroma(roll, length):
    chroma_matrix = torch.zeros((roll.size()[0],12))
    for note in range(0, 12):
        chroma_matrix[:, note] = torch.sum(roll[:, note::12], axis=1)
    return chroma_matrix

In [7]:
#this takes in the sequence and creates a self-similarity matrix (it calls chroma function inside)
def SSM(sequence):
  #tensor will be in form length, hidden_size (128)
  cos = nn.CosineSimilarity(dim=1)
  chrom = get_chroma(sequence, sequence.size()[0])
  len = chrom.size()[0]
  SSM=torch.zeros((len, len))
  for i in range(0, len):
    SSM[i] = cos(chrom[i].view(1, -1),chrom)
  return (SSM)

In [8]:
#this bundles the SSM function.
def batch_SSM(seq, batch_size):
  # takes sequence in format
  # [beats=400, batch_size, 128]
  # print("SSM\tsequence_shape", seq.shape)
  SSMs = []
  for i in range(0, batch_size):
    # print("SSM\tsequence", seq[:,i,:].shape)
    ssm = SSM(seq[:,i,:])  # [beats, batch, 128]
    # print("SSM\tssm", ssm.shape)
    SSMs.append(ssm)  
  return torch.vstack(SSMs)

In [9]:
#Takes in the batch size and data and returns batches of the batch size
def make_batches(batch_size, data):
  random.shuffle(data)
  batches = []
  num_batches = len(data)//batch_size
  for i in range(0, num_batches):
    batch = torch.cat(list(np.array(data)[i*batch_size: (i+1)*(batch_size)][:, 0])).view(batch_size, target_size, 128)
    batches.append(batch)
  return batches

In [10]:
#sampling function 
def topk_sample_one(sequence, k):
  #takes in size sequence length, batch size, values
  softmax = sparsemax.Sparsemax(dim=2)
  vals, indices = torch.topk(sequence[:, :, 20:108],k)
  indices+=20
  seq = torch.distributions.Categorical(softmax(vals.float()))
  samples = seq.sample()
  onehot = F.one_hot(torch.gather(indices, -1, samples.unsqueeze(-1)), num_classes = sequence.shape[2]).squeeze(dim=2)
  return(onehot)

In [11]:
#samples multiple times for the time-step
def topk_batch_sample(sequence, k):
  for i in range(0, 3):
    new= topk_sample_one(sequence, k)
    if i ==0:
      sum = new
    else:
      sum+=new
  return(torch.where(sum>0, 1, 0))


In [12]:
def custom_loss(output, target):
  #custom loss function
  criterion = nn.BCEWithLogitsLoss()
  weighted_mse = criterion(output.double(), target.double())
  batch_size = output.size()[1]
  ssm_err = 0
  for i in range(0, batch_size):
    SSM1 = SSM(output[:,i,:])
    SSM2 = SSM(target[:,i, :])
    ssm_err += (torch.sum((SSM1-SSM2)**2)/(SSM2.size(0)**2))


  return torch.sum(weighted_mse)+ssm_err

## Train Model Here

In [13]:
#load the data; in these small-scale tests I usually loaded the val data because it didn't take as long to load. When you are training, load the train data.
data = torch.load(my_drive_path + "usable_data/train_tempo_800.csv")  # make train for real training
# val = torch.load(my_drive_path + "usable_data/validation_tempo.csv")
torch.manual_seed(2022)

<torch._C.Generator at 0x7fa7d0f7cc90>

In [14]:
# create model and optimizer
generator = music_generator(128,128).to("cuda:0")
optimizer = torch.optim.Adam(generator.parameters(), lr=0.005)

# parameters
hidden_size = 128  # maybe don't touch?
batch_size = 25    # play with this
target_size = 800  # make this not global?

# model trainer
trainer = model_trainer(generator, optimizer, data, hidden_size, batch_size)

AssertionError: Torch not compiled with CUDA enabled

In [None]:
# train the model
losslist, piclist = trainer.train_epochs(num_epochs=3, full_training=False)

In [None]:
#this code can save a model if you need it
#!touch my_drive_path+"model.txt"
#torch.save(generator, my_drive_path + "model.txt" )

## View Results and Generate Example

In [None]:
# view snapshots of pieces at each epoch
plt.imshow(SSM(piclist[-1].squeeze()))

In [None]:
# show loss at each epoch
plt.scatter(range(0,len(losslist)), losslist)

In [None]:
# create example
index = 0  # which source piece
new_gen = trainer.generate_n_examples(n=1, length=390, starter_notes=10, source_piece=index)

In [None]:
# show SSM for first piece
plt.imshow(SSM(data[index][0]))

In [None]:
# show SSM for new piece
plt.imshow(SSM(new_gen.squeeze()))

In [None]:
# show new piece
plt.imshow(new_gen.squeeze())

## Save Example Sequence to Audio

In [None]:
#this code translates the generated sequence to audio
def stop_note(note, time):
    return Message('note_off', note = note,
                   velocity = 0, time = time)

def start_note(note, time):
    return Message('note_on', note = note,
                   velocity = 120, time = time)

def roll_to_track(roll, tempo):
    delta = 0

    
    # MIDI note for first column.
    midi_base = 0
    notes = [0] * len(roll[0])
    for row in roll:
        for i, col in enumerate(row):
            note = i
            if col>notes[i] and col!=0: 
                if notes[i]!=0:
                    yield stop_note(note, delta)
                    delta = 0
                yield start_note(i, delta)
                delta = 0
                notes[i] = note
            elif col == 0:
                if notes[i]!=0:
                    # Stop the ringing note
                    yield stop_note(note, delta)
                    delta = 0
                notes[i] = 0
        # ms per row
        delta += int(np.round((1/(tempo/60))*1000))



In [None]:
new_roll_final = np.vstack((new_gen.squeeze(), np.zeros(128)))
midi = MidiFile(type = 1)
midi.tracks.append(MidiTrack(roll_to_track(new_roll_final, data[index][1])))
midi.save(my_drive_path + 'audio_outputs/blank.midi')