In [1]:
drive = False  # False for Local
if drive:
    !pip install pretty_midi
    !pip install -U sparsemax
# locally, also compile torch (with CUDA enabled if available):
# conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch

In [2]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pretty_midi
#this package is used to write it back into music.
from mido import Message, MidiFile, MidiTrack
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import random
import torch.distributions
import sparsemax

In [3]:
if drive:
    from google.colab import drive
    drive.mount('/content/drive')
    my_drive_path = '/content/drive/MyDrive/2022_Special_Studies_Hablutzel/ModelCopies/'
else: # local
    my_drive_path = "./"

In [4]:
# this is the model
class music_generator(nn.Module):
    def __init__(self, hidden_size, output_size, base_lstm=False):
        super().__init__()
        self.hidden_size = hidden_size  # 128
        # output_size is num expected features (128)
        self.lstm = nn.LSTM(output_size, hidden_size, num_layers=1, bidirectional=False)
        self.attention = nn.Linear(2, 1)
        self.softmax = sparsemax.Sparsemax(dim=1)
        self.sigmoid = nn.Sigmoid()
        self.hidden = None
        self.base_lstm = base_lstm  # true to use lstm without attention

    def init_hidden(self, batch_size):
        # set hidden state to zeros after each batch
        hidden = (torch.zeros(1, batch_size, self.hidden_size)).float() # .to("cuda:0")  # [layers, batch_size, hidden_size/features]
        self.hidden = (hidden, hidden) # hidden_state, cell_state
        return

    def set_random_hidden(self, batch_size):
        # create new random hidden layer
        hidden = (torch.randn(1, batch_size, self.hidden_size)).float() # .to("cuda:0")
        self.hidden = (hidden, hidden)
        return

    def forward(self, in_put, batch_size, prev_sequence, batched_ssm):
        # look at tensor things - view vs. reshape vs. permute, and unsqueeze and squeeze
        # try looking at the LSTM equations
        # .to('cpu')  # returns a copy of the tensor in CPU memory
        # .to('cuda:0')  # returns copy in CUDA memory, 0 indicates first GPU device
        # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.to

        # sequence length
        # size of input (10 or 1)
        sequence_length = in_put.size()[0]
        # print("in_put:", in_put.shape)

        # Run the LSTM
        # output - sequence of all the hidden states
        # hidden - most recent hidden state
        # input dimensions: [sequence_length, batch_size, 128]
        output, self.hidden = self.lstm(in_put.float(), self.hidden) # removed input.tocuda, so hidden doesn't need cuda either # make sure hidden in cuda memory! - #0].to("cuda:0").float(), hidden[1].to("cuda:0").float()))
        # output dimensions: [sequence_length, batch_size, 128]
        # outputs as many beats (sequence_length) as there were beats in the input
        # hidden: last hidden states from last beat

        #########################
        # attention starts here #
        #########################
        
        # output without attention
        new_output = output.view(sequence_length, batch_size, 1, 128)  # reshape
        avg_output = torch.sum(new_output, 2)
        
        # if we're using a starter sequence, cut output to last note
        avg_output = avg_output[-1,:,:].unsqueeze(1)  # [batch_size, 1, 128]
        
        # return early (w/o attention) for base lstm
        if self.base_lstm:
          return avg_output.transpose(0,1), self.hidden
        
        #this variable holds the output after the attention has been applied.
        seqs = []

        # slice the batched ssms to the right places
        beat_num = prev_sequence.shape[0]
        
        # find the row for this beat in each ssm
        # batched_ssm shape is (batch_size*beats, beats), bc all the pieces are stacked vertically atop each other
        inds_across_pieces = range(beat_num, batched_ssm.shape[0], batched_ssm.shape[1])  # eg 11, 2625, 105 - indices of this beat in each of the pieces in the batched_ssm
        
        # for the row for this beat in each ssm, slice the row up to (not including) this beat
        ssm_slice = batched_ssm[inds_across_pieces, :beat_num] # [batch_size, beat_num]
        
        # sparsemax makes entries in the vector add to 1
        weights = self.softmax(ssm_slice)  # weights are shape [batch_size, beat_num]

        # this is the sparsemaxed SSM multiplied by the entire previous sequence
        # to scale the previous timesteps for how much attention to pay to each
        # TODO: replace .T
        weighted = (prev_sequence.T*weights).T  # [batch_size, beat_num]

        # then it's summed to provide weights for each note.
        weight_vec = (torch.sum(weighted, axis=0)).unsqueeze(1)  # [batch_size, 1, 128]

        # This concatenates the weights for each note with the output for that note, which is then run through the linear layer to get the final output.
        # returns attentioned note
        pt2 = torch.hstack((weight_vec, avg_output)).transpose(1,2)
        attentioned = self.attention(pt2.float()).permute(2,0,1)  # before .permute() .to("cuda:0")).to('cpu')

        # delete vars to remove clutter in memory
        del pt2
        del weight_vec
        del weighted
        del weights
        del ssm_slice
        del inds_across_pieces
        del beat_num
        del new_output
        del avg_output

        # return attentioned note
        return attentioned.double(), self.hidden  # hidden = hidden_state, cell_state

In [5]:
class model_trainer():
  def __init__(self, generator, optimizer, data, hidden_size=128, batch_size=50):
    self.generator = generator
    self.optimizer = optimizer
    self.batch_size = batch_size  # play with this
    self.hidden_size = hidden_size  # 128
    self.data = data
    self.data_length = data[0][0].shape[0]  # as long as piece length doesn't vary

  def train_epochs(self, num_epochs=50, full_training=False, variable_size_batches=False, save_name="model"):
    #trains each epoch
    losslist = []
    #useful when you want to see the progression of the SSM over time
    piclist = []

    for iter in tqdm(range(0, num_epochs)):
      # start training the generator
      self.generator.train()

      if variable_size_batches:
        # use all data, and group batches by piece size
        batches = make_variable_size_batches(self.data)
      elif full_training and not variable_size_batches: # truncating data doesn't work w/ variable size batches currently
        # use all data
        batches = make_batches(self.data, self.batch_size, self.data_length)
      else:
        # use first 100 pieces
        # can we overfit on a small dataset? if so, can be a good thing b/c shows the model can learn
        batches = make_batches(self.data[:100], self.batch_size, self.data_length)

      cum_loss = 0
      for batch_num in tqdm(range(len(batches))):
        batch = batches[batch_num]
        if full_training:
          # train on full-length pieces
          loss = self.train(batch)
        else:
          # train on first 105 beats of each piece
          loss = self.train(batch[:,:105,:])  # [batch, beats, 128]
        cum_loss+=loss
        del batch
        del loss
      del batches
          
      # print loss for early stopping
      print(cum_loss)
    
      # save generator after each epoch
      curr_file = f"{my_drive_path}trained/{save_name}-epoch-{str(iter)}-loss-{cum_loss:.5f}.txt"
      # !touch curr_file
      torch.save(self.generator, curr_file)

      # generate example piece for piclist
      snap = self.generate_n_examples(n=1, length=95, starter_notes=10)

      losslist.append(cum_loss) 
      piclist.append(snap)
      del snap
        
      # early stopping:
      # after each epoch,
      # run w/ validation
      # if devset (validation) loss goes up for ~5 epochs in a row, early stopping
    return losslist, piclist

  # train for one batch
  def train(self, batch, starter_notes=10):
    # seed vectors for the beginning:
    batch_size = batch.shape[0]
    self_sim = batch_SSM(batch.transpose(0,1), batch_size)  # use variable batch size
    sequence = batch[:,0:starter_notes,:].transpose(0,1)  # start w/ some amount of the piece - 10 might be a bit much
    generated = batch[:,0:starter_notes,:].transpose(0,1)

    # reset hidden to zeros for each batch
    self.generator.init_hidden(batch_size)
        
    # zero the gradients before training for each batch
    self.optimizer.zero_grad()
    
    # for accumulating loss
    loss = 0

    # first .forward on sequence of num_starter_beats (~5 or 10 or so)
    # then loop from there to generate one more element
    next_element = sequence.to("cpu")  # make copy!

    # take
    for i in range(0,batch.shape[1]-starter_notes):  # for each beat
      # iterate through beats, generating for each piece in the batch as you go
      val = torch.rand(1)  # probability it uses original - teacher forcing

      # generate a beat for each piece in the batch
      # we need to do this even in cases of teacher forcing, so we can calculate loss
      output, _ = self.generator.forward(next_element, batch_size, sequence, self_sim)  # returns output, hidden - we don't need the latest copy of hidden
      # print("outside output:", output.shape)
        
      if (val > .8):
        # teacher forcing - 20% of the time,  use original from piece instead of output
        next_element = batch[:,i+1,:].unsqueeze(0)  # [1, 0/deleted, 128] to [1, 1, 128]
      else:
        # 80% of the time we keep the output
        # take last output for each batch
        next_element = topk_batch_sample(output, 5) # sample up to 5 most likely notes at this beat
      
      # add next_element (either generated or teacher) to sequence
      sequence = torch.vstack((sequence, next_element.to("cpu"))) # .unsqueeze(0)
      # append output (generated - not teacher forced) for loss
      generated = torch.vstack((generated, output))  # used for loss
    
    # run loss after training on whole length of the pieces in the batches
    single_loss = custom_loss(generated[starter_notes:,:,:], batch.transpose(0,1)[starter_notes:,:,:])
    single_loss.backward()

    # update the parameters of the LSTM after running on full batch
    self.optimizer.step()

    loss += single_loss.detach().to('cpu')
    del next_element
    del self_sim
    del sequence
    del generated
    del single_loss
    return (loss)

#   def generate_n_pieces_old(self, initial_vectors, n_pieces, length, batched_ssm):
#     # generates a batch of n new pieces of music

#     # freeze generator so it doesn't train anymore
#     self.generator.eval()  
#     # start generator on random hidden states and cell states
#     self.generator.set_random_hidden(n_pieces)
  
#     # initial vectors in format [batch_size, num_notes=10, 128]
#     # change sequence to [10, batch_size, 128]
#     sequence = initial_vectors.transpose(0,1)
    
#     # can't generate more notes than the ssm
#     max_notes = batched_ssm.shape[0]
    
#     # generate [length] more beats for the piece
#     for i in range(0, min(length, max_notes)):  # one at a time
#       print(i)
#       with torch.no_grad():
#         # use n_pieces to generate as the batch size
#         output, _ = self.generator.forward(sequence.float(), n_pieces, sequence, batched_ssm)
#         next_element = topk_batch_sample(output, 5)  # sample up to 5 most likely notes at this beat
#       # add element to sequence
#       sequence = torch.vstack((sequence, next_element.to("cpu")))

#     # return sequence of beats
#     return sequence

  def generate_n_pieces(self, initial_vectors, n_pieces, length, batched_ssm):
    # generates a batch of n new pieces of music

    # freeze generator so it doesn't train anymore
    self.generator.eval()  
    # start generator on random hidden states and cell states
    self.generator.set_random_hidden(n_pieces)
  
    # initial vectors in format [batch_size, num_notes=10, 128]
    # change sequence to [10, batch_size, 128]
    sequence = initial_vectors.transpose(0,1)
    next_element = sequence.to("cpu")

    # can't generate more notes than the ssm has entries
    max_notes = batched_ssm.shape[0]-sequence.shape[0]
    
    # generate [length] more beats for the piece
    # or as many beats as available in the ssm
    for i in range(min(length, max_notes)):  # one at a time
      with torch.no_grad():
        # use n_pieces to generate as the batch size
        output, _ = self.generator.forward(next_element.float(), n_pieces, sequence, batched_ssm)
        next_element = topk_batch_sample(output, 5)  # sample up to 5 most likely notes at this beat
      # add element to sequence
      sequence = torch.vstack((sequence, next_element.to("cpu")))

    # return sequence of beats
    return sequence
  
#   def generate_n_examples_old(self, n=1, length=390, starter_notes=10, source_piece=0):
#     # get piece from the data
#     piece = self.data[source_piece][0].unsqueeze(0)  # in format [1, 400, 128]

#     # take first 10 notes in format [1, 10, 128]
#     first_vec = piece[:,0:starter_notes,:]

#     # create batched SSMs for each piece
#     batched_ssms = batch_SSM(piece.transpose(0,1), n)

#     # generate pieces
#     new_gen = self.generate_n_pieces(first_vec, n, length, batched_ssms)

#     # clean up variables
#     del piece
#     del first_vec
#     del batched_ssms

#     # return pieces
#     return new_gen

  def generate_n_examples(self, n=1, length=390, starter_notes=10, piece_inds=[0], random_source_pieces=False):
    # get pieces from the data
    pieces = torch.vstack([self.data[i][0].unsqueeze(0) for i in piece_inds]) # get just the note for each piece, and stack pieces
    
    # print(pieces.shape)

    # take first 10 notes in format [1, 10, 128]
    first_vecs = pieces[:,0:starter_notes,:]

    # create batched SSMs for each piece
    batched_ssms = batch_SSM(pieces.transpose(0,1), n)

    # generate pieces
    new_gen = self.generate_n_pieces(first_vecs, n, length, batched_ssms)

    # clean up variables
    del pieces
    del first_vecs
    del batched_ssms

    # return pieces
    return new_gen


In [6]:
#this function takes in the piece of music and returns the chroma vectors
def get_chroma(roll, length):
    chroma_matrix = torch.zeros((roll.size()[0],12))
    for note in range(0, 12):
        chroma_matrix[:, note] = torch.sum(roll[:, note::12], axis=1)
    return chroma_matrix

In [7]:
#this takes in the sequence and creates a self-similarity matrix (it calls chroma function inside)
def SSM(sequence):
  #tensor will be in form length, hidden_size (128)
  cos = nn.CosineSimilarity(dim=1)
  chrom = get_chroma(sequence, sequence.size()[0])
  len = chrom.size()[0]
  SSM=torch.zeros((len, len))
  for i in range(0, len):
    SSM[i] = cos(chrom[i].view(1, -1),chrom)
  return (SSM)

In [8]:
#this bundles the SSM function.
def batch_SSM(seq, batch_size):
  # takes sequence in format
  # [beats=400, batch_size, 128]
  # print("SSM\tsequence_shape", seq.shape)
  SSMs = []
  for i in range(0, batch_size):
    # print("SSM\tsequence", seq[:,i,:].shape)
    ssm = SSM(seq[:,i,:])  # [beats, batch, 128]
    # print("SSM\tssm", ssm.shape)
    SSMs.append(ssm)  
  return torch.vstack(SSMs)

In [9]:
#sampling function 
def topk_sample_one(sequence, k):
  #takes in size sequence length, batch size, values
  softmax = sparsemax.Sparsemax(dim=2)
  vals, indices = torch.topk(sequence[:, :, 20:108],k)
  indices+=20
  seq = torch.distributions.Categorical(softmax(vals.float()))
  samples = seq.sample()
  onehot = F.one_hot(torch.gather(indices, -1, samples.unsqueeze(-1)), num_classes = sequence.shape[2]).squeeze(dim=2)
  return(onehot)

In [10]:
#samples multiple times for the time-step
def topk_batch_sample(sequence, k):
  for i in range(0, 3):
    new= topk_sample_one(sequence, k)
    if i ==0:
      sum = new
    else:
      sum+=new
  return(torch.where(sum>0, 1, 0))


In [11]:
def custom_loss(output, target):
  # custom loss function
  criterion = nn.BCEWithLogitsLoss()
  weighted_mse = criterion(output.double(), target.double())
  batch_size = output.size()[1]
  ssm_err = 0
  for i in range(0, batch_size):
    SSM1 = SSM(output[:,i,:])
    SSM2 = SSM(target[:,i,:])
    ssm_err += (torch.sum((SSM1-SSM2)**2)/(SSM2.size(0)**2))


  return torch.sum(weighted_mse)+ssm_err

In [12]:
# returns batches where piece size is constant within the batch
# but piece size is different across batches
# and batches are in random order
def make_variable_size_batches(data, min_batch_size=10):
  # sort data by num beats (element at index 2 in each sublist)
  data.sort(key = lambda x: x[2], reverse=False)  # sort descending

  # split data into batches, where each batch contains pieces of the same size
  batches = []

  i = 0  # counter of pieces
  
  while i < len(data):
    this_batch = []
    pieces_this_batch = 0
    current_beats = data[i][2] # num beats in this batch

    # for all pieces with this # of beats
    while i < len(data) and data[i][2] == current_beats:
      # get tensor from row of data, and reshape 
      just_tensor = data[i][0].view(1, data[i][0].shape[0], 128)  
      this_batch.append(just_tensor)

      # increment counters
      i += 1
      pieces_this_batch += 1

    # print("this batch", this_batch)
    # print("shapes in batch")
    # for p in this_batch:
      # print("\t", p.shape)
        
    # only save large enough batches
    if pieces_this_batch >= min_batch_size:
        # reformat pieces in this batch into one tensor of size [batch size, beats, 128]
        batch = torch.cat(this_batch, dim=0)

        # store batch
        batches.append(batch)

    # clean up variables
    del this_batch
    del pieces_this_batch
    del current_beats

  # randomize batches order
  random.shuffle(batches)

  return batches

In [13]:
# Takes in the batch size and data and returns batches of the batch size
def make_batches(data, batch_size, piece_size=800, slice_data=False):
  # slice data to piece_size
  if slice_data:
    data = [(roll[0][:piece_size], roll[1], roll[2]) for roll in data if len(roll[0]) >= piece_size]
    # print(f"sliced pieces to length {piece_size}; total pieces is now {len(data)}")
    
  random.shuffle(data)
  batches = []
  if batch_size > 1:  # make batches
    num_batches = len(data)//batch_size
    for i in range(0, num_batches):
      # take just the tensors, and vstack
      batch = torch.cat([roll[0] for roll in data[i*batch_size: (i+1)*(batch_size)]]).view(batch_size, piece_size, 128)
      batches.append(batch)
  else:  # each piece is its own batch - doesn't use passed-in piece_size
    for i in range(len(data)):
      # removes tempo info from data, but leaves 1 piece per batch
      piece_size = data[i][0].shape[0]
      batch = data[i][0].view(1, piece_size, 128)
      batches.append(batch)
      # print(batches[i])
  # print(batches)
  return batches

In [14]:
def get_models_in_directory(filepath):
    # get the names of all files, sorted by epoch
    files = os.listdir(my_drive_path + filepath)
    files_by_epoch = [(name, int(name.split("-")[2])) for name in files if 'model' in name]  # parse out epoch
    files_by_epoch.sort(key = lambda x: x[1])  # sort by epoch
    
    # load model at each epoch
    files_to_load = [my_drive_path + filepath + filename[0] for filename in files_by_epoch]  # paths to each file
    # print("loading:\n", files_to_load, sep="")
    models = [torch.load(filename) for filename in files_to_load]  # load models
    return models

In [15]:
def generate_val_pieces(trainer, val_batch, starter_notes=10):
    # how many pieces in batch
    n = val_batch.shape[0]
    
    # length of the pieces in this batch
    length = val_batch.shape[1]
    
    # take first 10 notes in format [1, 10, 128]
    first_vecs = val_batch[:,0:starter_notes,:]

    # create batched SSMs for each piece
    batched_ssms = batch_SSM(val_batch.transpose(0,1), n)

    # generate pieces
    generated = trainer.generate_n_pieces(first_vecs, n, length - starter_notes, batched_ssms)
    
    # get loss
    single_loss = custom_loss(generated[starter_notes:,:,:], val_batch.transpose(0,1)[starter_notes:,:,:])
    #single_loss.backward()
    loss = single_loss.detach()

    # clean up variables
    del n
    del length
    del first_vecs
    del batched_ssms
    del single_loss

    # return pieces and loss
    return generated, float(loss)

In [16]:
def get_loss_on_data(trainer, data, val_size=450, variable_batches=False):
    """Get loss when generating on given data, generating with a given model"""
    # make batches from val_data
    # so we can generate for all pieces of the same length at once
    if variable_batches:
        batches = make_variable_size_batches(data, min_batch_size=0)
    else:
        # batch size 50 for speeeed (we hopeee)
        batches = make_batches(data, 50, piece_size=val_size, slice_data=True)
    
    # accumulate loss for each batch
    cum_loss = 0
    for batch_num in range(len(batches)): # tqdm(range(len(batches))):
        val_batch = batches[batch_num]
        # create pieces to get loss
        generated, loss = generate_val_pieces(trainer, val_batch)
        # print("\tloss:", loss)
        cum_loss += loss
        del val_batch
        del loss
        del generated
        
    return cum_loss

In [17]:
# for each model, get validation loss
def val_loss_for_all_models(trainers, val_data, val_size=450, variable_batches=False):
    loss_list = []
    for i in tqdm(range(len(trainers))):
        trainer = trainers[i]
        # print(f"calculating loss for epoch {i}:")
        loss_list.append(get_loss_on_data(trainer, val_data, val_size, variable_batches))
        print(loss_list)
        del trainer
        
    return loss_list

In [18]:
models_att = get_models_in_directory("trained/attention_model_v3/")
models_lstm = get_models_in_directory("trained/lstm_v3/")

In [19]:
models_att

[music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoi

In [20]:
models_lstm

[music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoi

In [21]:
# load validation data
# val_data = torch.load(my_drive_path + "usable_data/validation_tempo_all_w_beats_30.csv") 
val_data = torch.load(my_drive_path + "usable_data/validation_tempo_round_down_30.csv")
torch.manual_seed(2022)

<torch._C.Generator at 0x7fcf39086c70>

In [22]:
optimizers_att = [torch.optim.Adam(generator.parameters(), lr=0.005) for generator in models_att]
optimizers_lstm = [torch.optim.Adam(generator.parameters(), lr=0.005) for generator in models_lstm]
hidden_size = 128

In [23]:
trainers_att = [model_trainer(generator, optimizer, val_data, hidden_size) for generator, optimizer in zip(models_att, optimizers_att)]
trainers_lstm = [model_trainer(generator, optimizer, val_data, hidden_size) for generator, optimizer in zip(models_lstm, optimizers_lstm)]

In [24]:
trainers_att

[<__main__.model_trainer at 0x7fcf3b742640>,
 <__main__.model_trainer at 0x7fcf3b742bb0>,
 <__main__.model_trainer at 0x7fcf3b742a90>,
 <__main__.model_trainer at 0x7fcf3b742a60>,
 <__main__.model_trainer at 0x7fcf3b742280>,
 <__main__.model_trainer at 0x7fcf3b7429a0>,
 <__main__.model_trainer at 0x7fcf3b742490>,
 <__main__.model_trainer at 0x7fcf3b742eb0>,
 <__main__.model_trainer at 0x7fcf3d703f10>,
 <__main__.model_trainer at 0x7fcf3d703970>,
 <__main__.model_trainer at 0x7fcf3d703550>,
 <__main__.model_trainer at 0x7fcf3d7036d0>,
 <__main__.model_trainer at 0x7fcf3d703310>,
 <__main__.model_trainer at 0x7fcf3d703760>,
 <__main__.model_trainer at 0x7fcf3d7030a0>,
 <__main__.model_trainer at 0x7fcf3d703670>,
 <__main__.model_trainer at 0x7fcf3d72fc70>,
 <__main__.model_trainer at 0x7fcf3d72f790>,
 <__main__.model_trainer at 0x7fcf3574d130>,
 <__main__.model_trainer at 0x7fcf3574d370>,
 <__main__.model_trainer at 0x7fcf3574d250>,
 <__main__.model_trainer at 0x7fcf3574d670>,
 <__main__

In [25]:
trainers_lstm

[<__main__.model_trainer at 0x7fcf3b7424c0>,
 <__main__.model_trainer at 0x7fcf3574d5b0>,
 <__main__.model_trainer at 0x7fcf3574d640>,
 <__main__.model_trainer at 0x7fcf3574d3d0>,
 <__main__.model_trainer at 0x7fcf3574d310>,
 <__main__.model_trainer at 0x7fcf3574d1c0>,
 <__main__.model_trainer at 0x7fcf3574d1f0>,
 <__main__.model_trainer at 0x7fcf3574d070>,
 <__main__.model_trainer at 0x7fcf3574d280>,
 <__main__.model_trainer at 0x7fcf3c9f7ee0>,
 <__main__.model_trainer at 0x7fcf3c9f7670>,
 <__main__.model_trainer at 0x7fcf3c9f7f40>,
 <__main__.model_trainer at 0x7fcf3c9f7400>,
 <__main__.model_trainer at 0x7fcf3c9f78e0>,
 <__main__.model_trainer at 0x7fcf3b734790>,
 <__main__.model_trainer at 0x7fcf3b734bb0>,
 <__main__.model_trainer at 0x7fcf3b7345e0>,
 <__main__.model_trainer at 0x7fcf3b734940>,
 <__main__.model_trainer at 0x7fcf3b734be0>,
 <__main__.model_trainer at 0x7fcf3b734880>,
 <__main__.model_trainer at 0x7fcf3b7349d0>,
 <__main__.model_trainer at 0x7fcf3b734430>,
 <__main__

In [26]:
loss_list_att = val_loss_for_all_models(trainers_att, val_data, val_size=450, variable_batches=True)
print(loss_list_att)

# first run: [58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162, 60.76163521445571, 65.72370909473803, 65.29740455011577, 53.31530709361719, 43.10948994384802, 42.501780608520875, 41.81713659329537, 34.23370181352132, 37.12999854703072, 34.86240460698944, 35.05370937217055, 34.59301298739383, 30.543632000399867, 29.509237587058692, 27.812329880537916]

  weighted = (prev_sequence.T*weights).T  # [batch_size, beat_num]
  3%|▎         | 1/30 [10:03<4:51:51, 603.86s/it]

[92.74938249124502]


  7%|▋         | 2/30 [20:26<4:46:54, 614.79s/it]

[92.74938249124502, 93.23906300503248]


 10%|█         | 3/30 [30:50<4:38:33, 619.03s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157]


 13%|█▎        | 4/30 [41:18<4:29:45, 622.54s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461]


 17%|█▋        | 5/30 [51:44<4:19:57, 623.92s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503]


 20%|██        | 6/30 [1:02:14<4:10:20, 625.84s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278]


 23%|██▎       | 7/30 [1:12:41<4:00:01, 626.16s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524]


 27%|██▋       | 8/30 [1:23:04<3:49:16, 625.29s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109]


 30%|███       | 9/30 [1:33:34<3:39:24, 626.88s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175]


 33%|███▎      | 10/30 [1:44:02<3:29:03, 627.18s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559]


 37%|███▋      | 11/30 [1:54:30<3:18:38, 627.28s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304]


 40%|████      | 12/30 [2:05:05<3:08:55, 629.74s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747]


 43%|████▎     | 13/30 [2:15:33<2:58:15, 629.12s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926]


 47%|████▋     | 14/30 [2:26:07<2:48:08, 630.55s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032]


 50%|█████     | 15/30 [2:36:41<2:37:57, 631.85s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032, 79.27676476112968]


 53%|█████▎    | 16/30 [2:47:14<2:27:27, 631.99s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032, 79.27676476112968, 78.66303585904595]


 57%|█████▋    | 17/30 [2:57:49<2:17:09, 633.08s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032, 79.27676476112968, 78.66303585904595, 77.27290864457066]


 60%|██████    | 18/30 [3:08:17<2:06:17, 631.47s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032, 79.27676476112968, 78.66303585904595, 77.27290864457066, 76.6681886348328]


 63%|██████▎   | 19/30 [3:18:48<1:55:44, 631.35s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032, 79.27676476112968, 78.66303585904595, 77.27290864457066, 76.6681886348328, 145.89837107659523]


 67%|██████▋   | 20/30 [3:29:25<1:45:28, 632.85s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032, 79.27676476112968, 78.66303585904595, 77.27290864457066, 76.6681886348328, 145.89837107659523, 109.87096343681237]


 70%|███████   | 21/30 [3:40:02<1:35:09, 634.34s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032, 79.27676476112968, 78.66303585904595, 77.27290864457066, 76.6681886348328, 145.89837107659523, 109.87096343681237, 95.10892224077863]


 73%|███████▎  | 22/30 [3:50:38<1:24:37, 634.72s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032, 79.27676476112968, 78.66303585904595, 77.27290864457066, 76.6681886348328, 145.89837107659523, 109.87096343681237, 95.10892224077863, 115.45783016736293]


 77%|███████▋  | 23/30 [4:01:11<1:13:59, 634.23s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032, 79.27676476112968, 78.66303585904595, 77.27290864457066, 76.6681886348328, 145.89837107659523, 109.87096343681237, 95.10892224077863, 115.45783016736293, 99.17165258961307]


 80%|████████  | 24/30 [4:11:50<1:03:34, 635.68s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032, 79.27676476112968, 78.66303585904595, 77.27290864457066, 76.6681886348328, 145.89837107659523, 109.87096343681237, 95.10892224077863, 115.45783016736293, 99.17165258961307, 122.0534783240037]


 83%|████████▎ | 25/30 [4:22:34<53:10, 638.10s/it]  

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032, 79.27676476112968, 78.66303585904595, 77.27290864457066, 76.6681886348328, 145.89837107659523, 109.87096343681237, 95.10892224077863, 115.45783016736293, 99.17165258961307, 122.0534783240037, 124.3212703035463]


 87%|████████▋ | 26/30 [4:33:16<42:37, 639.36s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032, 79.27676476112968, 78.66303585904595, 77.27290864457066, 76.6681886348328, 145.89837107659523, 109.87096343681237, 95.10892224077863, 115.45783016736293, 99.17165258961307, 122.0534783240037, 124.3212703035463, 125.018022455069]


 90%|█████████ | 27/30 [4:43:50<31:53, 637.79s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032, 79.27676476112968, 78.66303585904595, 77.27290864457066, 76.6681886348328, 145.89837107659523, 109.87096343681237, 95.10892224077863, 115.45783016736293, 99.17165258961307, 122.0534783240037, 124.3212703035463, 125.018022455069, 125.28704306668445]


 93%|█████████▎| 28/30 [4:54:31<21:17, 638.56s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032, 79.27676476112968, 78.66303585904595, 77.27290864457066, 76.6681886348328, 145.89837107659523, 109.87096343681237, 95.10892224077863, 115.45783016736293, 99.17165258961307, 122.0534783240037, 124.3212703035463, 125.018022455069, 125.28704306668445, 103.47229886477332]


 97%|█████████▋| 29/30 [5:05:12<10:39, 639.51s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032, 79.27676476112968, 78.66303585904595, 77.27290864457066, 76.6681886348328, 145.89837107659523, 109.87096343681237, 95.10892224077863, 115.45783016736293, 99.17165258961307, 122.0534783240037, 124.3212703035463, 125.018022455069, 125.28704306668445, 103.47229886477332, 103.14463104576137]


100%|██████████| 30/30 [5:15:55<00:00, 631.85s/it]

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032, 79.27676476112968, 78.66303585904595, 77.27290864457066, 76.6681886348328, 145.89837107659523, 109.87096343681237, 95.10892224077863, 115.45783016736293, 99.17165258961307, 122.0534783240037, 124.3212703035463, 125.018022455069, 125.28704306668445, 103.47229886477332, 103.14463104576137, 81.2268667119662]
[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032, 79.27676476112968, 78.66303585904595, 77.27290864457066, 76.6681886348328, 145.89837107659523, 109.87096343681237, 95.10892224077863, 115.45783016736293, 99.17165258




In [27]:
loss_list_lstm = val_loss_for_all_models(trainers_lstm, val_data, val_size=450, variable_batches=True)
print(loss_list_lstm)

# first run: [27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386, 26.299691983393252, 26.071323804765548, 23.759693231146052, 22.978566764915136, 23.125428292791558, 24.846454579523623, 24.912110084617286, 24.860141947743173, 24.374138179515857, 24.19659813282396, 24.10343893510335, 24.259611301158923, 23.4498316783007, 23.99073953810295, 23.319251953208592]

  3%|▎         | 1/30 [09:05<4:23:49, 545.84s/it]

[79.67979485297654]


  7%|▋         | 2/30 [18:14<4:15:31, 547.55s/it]

[79.67979485297654, 79.52978831340694]


 10%|█         | 3/30 [27:21<4:06:13, 547.18s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497]


 13%|█▎        | 4/30 [36:25<3:56:34, 545.94s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467]


 17%|█▋        | 5/30 [45:33<3:47:48, 546.74s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666]


 20%|██        | 6/30 [54:41<3:38:47, 547.00s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793]


 23%|██▎       | 7/30 [1:03:44<3:29:14, 545.84s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244]


 27%|██▋       | 8/30 [1:12:55<3:20:46, 547.55s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703]


 30%|███       | 9/30 [1:22:00<3:11:21, 546.74s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892]


 33%|███▎      | 10/30 [1:31:09<3:02:26, 547.34s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792]


 37%|███▋      | 11/30 [1:40:20<2:53:42, 548.57s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401]


 40%|████      | 12/30 [1:49:27<2:44:22, 547.91s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549]


 43%|████▎     | 13/30 [1:58:32<2:35:02, 547.21s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932]


 47%|████▋     | 14/30 [2:07:41<2:26:01, 547.60s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814]


 50%|█████     | 15/30 [2:16:47<2:16:49, 547.28s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814, 84.20387256269817]


 53%|█████▎    | 16/30 [2:25:54<2:07:40, 547.21s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814, 84.20387256269817, 80.50794088880728]


 57%|█████▋    | 17/30 [2:35:01<1:58:30, 546.97s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814, 84.20387256269817, 80.50794088880728, 81.51241459117321]


 60%|██████    | 18/30 [2:44:04<1:49:09, 545.77s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814, 84.20387256269817, 80.50794088880728, 81.51241459117321, 78.87793919587338]


 63%|██████▎   | 19/30 [2:53:12<1:40:10, 546.41s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814, 84.20387256269817, 80.50794088880728, 81.51241459117321, 78.87793919587338, 81.19087807337999]


 67%|██████▋   | 20/30 [3:02:22<1:31:14, 547.47s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814, 84.20387256269817, 80.50794088880728, 81.51241459117321, 78.87793919587338, 81.19087807337999, 78.72395049273821]


 70%|███████   | 21/30 [3:11:24<1:21:53, 545.94s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814, 84.20387256269817, 80.50794088880728, 81.51241459117321, 78.87793919587338, 81.19087807337999, 78.72395049273821, 80.22350105492471]


 73%|███████▎  | 22/30 [3:20:36<1:13:02, 547.77s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814, 84.20387256269817, 80.50794088880728, 81.51241459117321, 78.87793919587338, 81.19087807337999, 78.72395049273821, 80.22350105492471, 80.33601728868184]


 77%|███████▋  | 23/30 [3:29:45<1:03:57, 548.27s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814, 84.20387256269817, 80.50794088880728, 81.51241459117321, 78.87793919587338, 81.19087807337999, 78.72395049273821, 80.22350105492471, 80.33601728868184, 78.01704524961455]


 80%|████████  | 24/30 [3:38:59<54:58, 549.74s/it]  

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814, 84.20387256269817, 80.50794088880728, 81.51241459117321, 78.87793919587338, 81.19087807337999, 78.72395049273821, 80.22350105492471, 80.33601728868184, 78.01704524961455, 81.48363650063115]


 83%|████████▎ | 25/30 [3:48:13<45:56, 551.25s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814, 84.20387256269817, 80.50794088880728, 81.51241459117321, 78.87793919587338, 81.19087807337999, 78.72395049273821, 80.22350105492471, 80.33601728868184, 78.01704524961455, 81.48363650063115, 79.03450320838881]


 87%|████████▋ | 26/30 [3:57:25<36:45, 551.44s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814, 84.20387256269817, 80.50794088880728, 81.51241459117321, 78.87793919587338, 81.19087807337999, 78.72395049273821, 80.22350105492471, 80.33601728868184, 78.01704524961455, 81.48363650063115, 79.03450320838881, 81.15257212344942]


 90%|█████████ | 27/30 [4:06:39<27:36, 552.25s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814, 84.20387256269817, 80.50794088880728, 81.51241459117321, 78.87793919587338, 81.19087807337999, 78.72395049273821, 80.22350105492471, 80.33601728868184, 78.01704524961455, 81.48363650063115, 79.03450320838881, 81.15257212344942, 79.17349770738058]


 93%|█████████▎| 28/30 [4:15:52<18:24, 552.42s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814, 84.20387256269817, 80.50794088880728, 81.51241459117321, 78.87793919587338, 81.19087807337999, 78.72395049273821, 80.22350105492471, 80.33601728868184, 78.01704524961455, 81.48363650063115, 79.03450320838881, 81.15257212344942, 79.17349770738058, 78.16059535447968]


 97%|█████████▋| 29/30 [4:25:03<09:12, 552.07s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814, 84.20387256269817, 80.50794088880728, 81.51241459117321, 78.87793919587338, 81.19087807337999, 78.72395049273821, 80.22350105492471, 80.33601728868184, 78.01704524961455, 81.48363650063115, 79.03450320838881, 81.15257212344942, 79.17349770738058, 78.16059535447968, 77.83317769923954]


100%|██████████| 30/30 [4:34:19<00:00, 548.66s/it]

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814, 84.20387256269817, 80.50794088880728, 81.51241459117321, 78.87793919587338, 81.19087807337999, 78.72395049273821, 80.22350105492471, 80.33601728868184, 78.01704524961455, 81.48363650063115, 79.03450320838881, 81.15257212344942, 79.17349770738058, 78.16059535447968, 77.83317769923954, 77.71568402086015]
[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814, 84.20387256269817, 80.50794088880728, 81.51241459117321, 78.87793919587338, 81.19087807337999, 78.72395049273821, 80.22350105492471, 80.33601728868184, 78.01704524961455,




In [28]:
print(loss_list_att)
np.argmin(loss_list_att)  # min at 17

[92.74938249124502, 93.23906300503248, 92.19413086440157, 92.49665917698461, 92.28206422503503, 92.0982585402278, 92.35477148240524, 92.6356850499109, 92.83110828213175, 82.2993480917559, 82.29158622038304, 82.25085880798747, 82.24439529033926, 81.94511380516032, 79.27676476112968, 78.66303585904595, 77.27290864457066, 76.6681886348328, 145.89837107659523, 109.87096343681237, 95.10892224077863, 115.45783016736293, 99.17165258961307, 122.0534783240037, 124.3212703035463, 125.018022455069, 125.28704306668445, 103.47229886477332, 103.14463104576137, 81.2268667119662]


17

In [29]:
print(loss_list_lstm)
np.argmin(loss_list_lstm)  # min at 28/29

[79.67979485297654, 79.52978831340694, 79.59269432046497, 79.5220790124467, 78.45980846308666, 78.53818885547793, 79.78754508523244, 80.09970942381703, 79.91776205714892, 80.18426216631792, 80.446114885401, 79.74755281186549, 79.73755187844932, 80.6935416543814, 84.20387256269817, 80.50794088880728, 81.51241459117321, 78.87793919587338, 81.19087807337999, 78.72395049273821, 80.22350105492471, 80.33601728868184, 78.01704524961455, 81.48363650063115, 79.03450320838881, 81.15257212344942, 79.17349770738058, 78.16059535447968, 77.83317769923954, 77.71568402086015]


29