In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pretty_midi
from mido import Message, MidiFile, MidiTrack
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import random
import torch.distributions
import sparsemax


In [None]:
data = torch.load("test_tempo.csv")
torch.manual_seed(2022)

In [None]:
print(len(data))

In [None]:
def get_chroma(roll, length):
    chroma_matrix = torch.zeros((roll.size()[0],12))
    for note in range(0, 12):
        chroma_matrix[:, note] = torch.sum(roll[:, note::12], axis=1)
    return chroma_matrix

In [None]:
def batch_SSM(seq, batch_size):
  SSMs = []
  for i in range(0, batch_size):
    SSMs.append(SSM(seq[:,i, :]))
  return torch.vstack(SSMs)

In [None]:
def SSM(sequence):
  #tensor will be in form length, hidden_size (128)
  cos = nn.CosineSimilarity(dim=1)
  chrom = get_chroma(sequence, sequence.size()[0])
  len = chrom.size()[0]
  SSM=torch.zeros((len, len))
  for i in range(0, len):
    SSM[i] = cos(chrom[i].view(1, -1),chrom)
  return (SSM)

In [None]:
def make_batches(batch_size, data):
  random.shuffle(data)
  batches = []
  num_batches = len(data)//batch_size
  for i in range(0, num_batches):
    batch = torch.cat(list(np.array(data)[i*batch_size: (i+1)*(batch_size)][:, 0])).view(batch_size, 400, 128)
    batches.append(batch)
  return batches

In [None]:
class music_generator(nn.Module):
    def __init__(self, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(output_size, hidden_size, num_layers=1, bidirectional=True)
        self.attention = nn.Linear(2, 1)
        self.softmax = sparsemax.Sparsemax(dim=1)
        self.sigmoid = nn.Sigmoid()
        

    def forward(self, input, hidden, batch, sequence_length, ssm):
        output, hidden = self.lstm(input.to("cuda:0").float(), (hidden[0].to("cuda:0").float(), hidden[1].to("cuda:0").float()))
        output_1 = output.to('cpu')[-5:, :,:]
        new_output = output_1.view(5, batch, 2, 128)
        avg_output = torch.sum(new_output, 2)
        hidden_1_0 = hidden[0].to('cpu')
        hidden_1_1= hidden[1].to('cpu')

        

        seqs = []
        for l in range(0,5):
          index = input.shape[0]+l
          weights = self.softmax(ssm[range(index, ssm.shape[0], ssm.shape[1]), :index])
          if l!=0:
            input_and_gen = torch.vstack((input[:,:,:], torch.vstack((seqs))))
          else:
            input_and_gen = input[:,:,:]
          weighted = (input_and_gen.T*weights).T
          weight_vec = (torch.sum(weighted, axis=0))
          pt2 = torch.hstack((weight_vec.unsqueeze(1), avg_output[l, : ,:].unsqueeze(1))).transpose(1,2)
          attentioned = self.attention(pt2.float().to("cuda:0")).to('cpu').permute(2,0,1)
          seqs.append(attentioned)
        newoutput = torch.vstack(seqs)

        del output
        del output_1
        del ssm
        del weights
        del seqs
        del hidden
        del pt2
        del attentioned
        del input


        return newoutput.double(), (hidden_1_0,hidden_1_1)

In [None]:
att_mod = torch.load("trained/best_model/attention_model.txt")


In [None]:
def topk_sample_one(sequence, k):
  #takes in size sequence length, batch size, values
  softmax = sparsemax.Sparsemax(dim=2)
  vals, indices = torch.topk(sequence[:, :, 20:108],k)
  indices+=20
  seq = torch.distributions.Categorical(softmax(vals.float()))
  samples = seq.sample()
  onehot = F.one_hot(torch.gather(indices, -1, samples.unsqueeze(-1)), num_classes = sequence.shape[2]).squeeze(dim=2)
  return(onehot)

  


In [None]:
def topk_batch_sample(sequence, k):
  for i in range(0, 3):
    new= topk_sample_one(sequence, k)
    if i ==0:
      sum = new
    else:
      sum+=new
  return(torch.where(sum>0, 1, 0))


In [None]:
def generate(generator, initial_vectors, batch_size, length , hidden_shape, batched_ssm):
  hidden = (torch.randn(2, 1, hidden_shape)).float()
  generator.eval()
  sequence = initial_vectors.transpose(0,1)
  hidden = (hidden, hidden)
  for i in range(0, length,5):
    seq_length = sequence.size()[0]
    with torch.no_grad():
      output, hidden = generator.forward(sequence.float(), hidden, batch_size, seq_length, batched_ssm)
      next_element = topk_batch_sample(output[-5:, :,:], 5)
    sequence = torch.vstack((sequence, next_element.to("cpu")))
  return sequence

In [None]:
index = 45
first_vec = data[index][0].unsqueeze(0)[:,0:10,:]
print(first_vec.shape)
new_gen_att = generate(att_mod, first_vec, 1, 390, 128, SSM(data[index][0]))
plt.imshow(SSM(new_gen_att.squeeze()))


In [None]:
plt.imshow(SSM(data[index][0]))

In [None]:
plt.imshow(new_gen_att.squeeze())

In [None]:
def stop_note(note, time):
    return Message('note_off', note = note,
                   velocity = 0, time = time)

def start_note(note, time):
    return Message('note_on', note = note,
                   velocity = 120, time = time)

def roll_to_track(roll, tempo):
    delta = 0

    
    # MIDI note for first column.
    midi_base = 0
    notes = [0] * len(roll[0])
    for row in roll:
        for i, col in enumerate(row):
            note = i
            if col>notes[i] and col!=0: 
                if notes[i]!=0:
                    yield stop_note(note, delta)
                    delta = 0
                yield start_note(i, delta)
                delta = 0
                notes[i] = note
            elif col == 0:
                if notes[i]!=0:
                    # Stop the ringing note
                    yield stop_note(note, delta)
                    delta = 0
                notes[i] = 0
        # ms per row
        delta += int(np.round((1/(tempo/60))*1000))



In [None]:
new_roll_final = np.vstack((new_gen_att.squeeze(), np.zeros(128)))
midi = MidiFile(type = 1)
midi.tracks.append(MidiTrack(roll_to_track(new_roll_final, data[index][1])))
midi.save('blank.midi')