In [114]:
drive = False  # False for Local
if drive:
    !pip install pretty_midi
    !pip install -U sparsemax
# locally, also compile torch (with CUDA enabled if available):
# conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch

In [115]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pretty_midi
#this package is used to write it back into music.
from mido import Message, MidiFile, MidiTrack
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import random
import torch.distributions
import sparsemax

In [116]:
if drive:
    from google.colab import drive
    drive.mount('/content/drive')
    my_drive_path = '/content/drive/MyDrive/2022_Special_Studies_Hablutzel/ModelCopies/'
else: # local
    my_drive_path = "./"

In [117]:
#this is the model
class music_generator(nn.Module):
    def __init__(self, hidden_size, output_size, base_lstm=False):
        super().__init__()
        self.hidden_size = hidden_size  # 128
        # output_size is num expected features (128)
        self.lstm = nn.LSTM(output_size, hidden_size, num_layers=1, bidirectional=False)
        self.attention = nn.Linear(2, 1)
        self.softmax = sparsemax.Sparsemax(dim=1)
        self.sigmoid = nn.Sigmoid()
        self.hidden = None
        self.base_lstm = base_lstm  # true to use lstm without attention

    def init_hidden(self, batch_size):
        # set hidden state to zeros after each batch
        hidden = (torch.zeros(1, batch_size, self.hidden_size)).float() # .to("cuda:0")  # [layers, batch_size, hidden_size/features]
        self.hidden = (hidden, hidden) # hidden_state, cell_state
        return

    def set_random_hidden(self, batch_size):
        # create new random hidden layer
        hidden = (torch.randn(1, batch_size, self.hidden_size)).float() # .to("cuda:0")
        self.hidden = (hidden, hidden)
        return

    def forward(self, input, batch_size, prev_sequence, batched_ssm):
        # look at tensor things - view vs. reshape vs. permute, and unsqueeze and squeeze
        # try looking at the LSTM equations
        # .to('cpu')  # returns a copy of the tensor in CPU memory
        # .to('cuda:0')  # returns copy in CUDA memory, 0 indicates first GPU device
        # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.to

        # sequence length
        # size of input (10 or 1)
        sequence_length = input.size()[0]

        # Run the LSTM
        # output - sequence of all the hidden states
        # hidden - most recent hidden state
        # input dimensions: [sequence_length, batch_size, 128]
        output, self.hidden = self.lstm(input.float(), self.hidden) # removed input.tocuda, so hidden doesn't need cuda either # make sure hidden in cuda memory! - #0].to("cuda:0").float(), hidden[1].to("cuda:0").float()))
        # output dimensions: [sequence_length, batch_size, 128]
        # outputs as many beats (sequence_length) as there were beats in the input
        # hidden: last hidden states from last beat

        # Get copy of the output
        # take out this line after finish LSTM (other than reshaping) - unsqueeze to get right reshaping
        output_1 = output.to('cpu') # [-5:,:,:]  # don't just take last 5

        #########################
        # attention starts here #
        #########################
        
        # output without attention
        new_output = output_1.view(sequence_length, batch_size, 1, 128)  # reshape
        avg_output = torch.sum(new_output, 2)  # return this for base LSTM (w/o attention)

        # return early (w/o attention) for base lstm
        if self.base_lstm:
          return avg_output, self.hidden  # TODO: test return type
        
        #this variable holds the output after the attention has been applied.
        seqs = []

        # slice the batched ssms to the right places
        beat_num = prev_sequence.shape[0]
        # find the row for this beat in each ssm
        # batched_ssm shape is (batch_size*beats, beats), bc all the pieces are stacked vertically atop each other
        inds_across_pieces = range(beat_num, batched_ssm.shape[0], batched_ssm.shape[1])  # eg 11, 2625, 105 - indices of this beat in each of the pieces in the batched_ssm
        # for the row for this beat in each ssm, slice the row up to (not including) this beat
        ssm_slice = batched_ssm[inds_across_pieces, :beat_num]
        
        # sparsemax makes entries in the vector add to 1
        weights = self.softmax(ssm_slice)
        # print("weights shape:", weights.shape)

        # this is the sparsemaxed SSM multiplied by the entire previous sequence
        # to scale the previous timesteps for how much attention to pay to each
        weighted = (prev_sequence.T*weights).T # replace .T - see which dimensions we're switching

        # then it's summed to provide weights for each note.
        weight_vec = (torch.sum(weighted, axis=0))
        #This concatenates the weights for each note with the output for that note, which is then run through the linear layer to get the final output.
        pt2 = torch.hstack((weight_vec.unsqueeze(1), avg_output[0,:,:].unsqueeze(1))).transpose(1,2)
        attentioned = self.attention(pt2.float()).permute(2,0,1)  # before .permute() .to("cuda:0")).to('cpu')

        # final added new output to seqs.
        seqs.append(attentioned)

        # seqs is combined into one tensor and returned, along with the hidden and cell states.
        newoutput = torch.vstack(seqs)

        # delete vars to remove clutter in memory
        del seqs
        del pt2
        del attentioned
        del weight_vec
        del weighted
        del weights
        del ssm_slice
        del inds_across_pieces
        del beat_num
        del new_output
        del avg_output

        return newoutput.double(), self.hidden  # hidden = hidden_state, cell_state

In [118]:
class model_trainer():
  def __init__(self, generator, optimizer, data, hidden_size=128, batch_size=50):
    self.generator = generator
    self.optimizer = optimizer
    self.batch_size = batch_size  # play with this
    self.hidden_size = hidden_size  # 128
    self.data = data
    self.data_length = data[0][0].shape[0]  # as long as piece length doesn't vary

  def train_epochs(self, num_epochs=50, full_training=False, variable_size_batches=False, save_name="model"):
    #trains each epoch
    losslist = []
    #useful when you want to see the progression of the SSM over time
    piclist = []

    for iter in tqdm(range(0, num_epochs)):
      # start training the generator
      generator.train()

      if full_training and variable_size_batches:
        # use all data, and group batches by piece size
        batches = make_variable_size_batches(self.data)
      elif not full_training and variable_size_batches:
        batches = make_variable_size_batches(self.data[:100])  # buggy
      elif full_training:
        # use all data
        batches = make_batches(self.data, self.batch_size, self.data_length)
      else:
        # use first 100 pieces
        # can we overfit on a small dataset? if so, can be a good thing b/c shows the model can learn
        batches = make_batches(self.data[:100], self.batch_size, self.data_length)

      cum_loss = 0
      for batch_num in tqdm(range(len(batches))):
        batch = batches[batch_num]
        if full_training:
          # train on full-length pieces
          loss = self.train(batch)
        else:
          # train on first 105 beats of each piece
          loss = self.train(batch[:,:105,:])  # [batch, beats, 128]
        cum_loss+=loss
        del batch
        del loss
      del batches
          
      # print loss for early stopping
      print(cum_loss)
    
      # save generator after each epoch
      curr_file = f"{my_drive_path}trained/{save_name}-epoch-{str(iter)}-loss-{cum_loss:.5f}.txt"
      !touch curr_file
      torch.save(generator, curr_file)

      # generate example piece for piclist
      snap = self.generate_n_examples(n=1, length=95, starter_notes=10)

      losslist.append(cum_loss) 
      piclist.append(snap)
      del snap
        
      # early stopping:
      # after each epoch,
      # run w/ validation
      # if devset (validation) loss goes up for ~5 epochs in a row, early stopping
    return losslist, piclist

  # train for one batch
  def train(self, batch, starter_notes=10):
    #seed vectors for the beginning:
    batch_size = batch.shape[0]
    self_sim = batch_SSM(batch.transpose(0,1), batch_size)  # use variable batch size
    sequence = batch[:,0:starter_notes,:].transpose(0,1)  # start w/ some amount of the piece - 10 might be a bit much
    generated = batch[:,0:starter_notes,:].transpose(0,1)

    # reset hidden to zeros for each batch
    generator.init_hidden(batch_size)
        
    # zero the gradients before training for each batch
    self.optimizer.zero_grad()
    
    # for accumulating loss
    loss = 0

    # first .forward on sequence of num_starter_beats (~5 or 10 or so)
    # then loop from there to generate one more element
    next_element = sequence

    # take
    for i in range(0,batch.shape[1]-starter_notes):  # for each beat
      # iterate through beats num_starter_beats-400, generating for each piece in the batch as you go
      val = torch.rand(1)  # probability it uses original - teacher forcing

      # generate a beat for each piece in the batch
      # we need to do this even in cases of teacher forcing, so we can calculate loss
      output, _ = self.generator.forward(next_element, batch_size, sequence, self_sim)  # returns output, hidden - we don't need the latest copy of hidden
      
      if (val > .8):
        # teacher forcing - 20% of the time,  use original from piece instead of output
        next_element = batch[:,i+1,:].unsqueeze(0)  # [1, 0/deleted, 128] to [1, 1, 128]
      else:
        # 80% of the time we keep the output
        # take last output for each batch
        next_element = topk_batch_sample(output, 1)
      
      # add next_element (either generated or teacher) to sequence
      sequence = torch.vstack((sequence, next_element.to("cpu"))) # .unsqueeze(0)
      # append output (generated - not teacher forced) for loss
      generated = torch.vstack((generated, output))  # used for loss
    
    # run loss after training on whole length of the pieces in the batches
    single_loss = custom_loss(generated[starter_notes:,:,:], batch.transpose(0,1)[starter_notes:,:,:])
    single_loss.backward()

    # update the parameters of the LSTM after running on full batch
    self.optimizer.step()

    loss += single_loss.detach().to('cpu')
    del next_element
    del self_sim
    del sequence
    del generated
    del single_loss
    return (loss)

  def generate_n_pieces(self, initial_vectors, n_pieces, length, batched_ssm):
    # generates a batch of n new pieces of music

    # freeze generator so it doesn't train anymore
    self.generator.eval()  
    # start generator on random hidden states and cell states
    self.generator.set_random_hidden(n_pieces)
  
    # initial vectors in format [batch_size, num_notes=10, 128]
    # change sequence to [10, batch_size, 128]
    sequence = initial_vectors.transpose(0,1)
    
    # generate [length] more beats for the piece
    for i in range(0, length):  # one at a time
      with torch.no_grad():
        # use n_pieces to generate as the batch size
        output, _ = self.generator.forward(sequence.float(), n_pieces, sequence, batched_ssm)
        next_element = topk_batch_sample(output, 1)
      # add element to sequence
      sequence = torch.vstack((sequence, next_element.to("cpu")))

    # return sequence of beats
    return sequence
  
  def generate_n_examples(self, n=1, length=390, starter_notes=10, source_piece=0):
    # get piece from the data
    piece = self.data[source_piece][0].unsqueeze(0)  # in format [1, 400, 128]

    # take first 10 notes in format [1, 10, 128]
    first_vec = piece[:,0:starter_notes,:]

    # create batched SSMs for each piece
    batched_ssms = batch_SSM(piece.transpose(0,1), n)

    # generate pieces
    new_gen = self.generate_n_pieces(first_vec, n, length, batched_ssms)

    # clean up variables
    del piece
    del first_vec
    del batched_ssms

    # return pieces
    return new_gen

In [119]:
#this function takes in the piece of music and returns the chroma vectors
def get_chroma(roll, length):
    chroma_matrix = torch.zeros((roll.size()[0],12))
    for note in range(0, 12):
        chroma_matrix[:, note] = torch.sum(roll[:, note::12], axis=1)
    return chroma_matrix

In [120]:
#this takes in the sequence and creates a self-similarity matrix (it calls chroma function inside)
def SSM(sequence):
  #tensor will be in form length, hidden_size (128)
  cos = nn.CosineSimilarity(dim=1)
  chrom = get_chroma(sequence, sequence.size()[0])
  len = chrom.size()[0]
  SSM=torch.zeros((len, len))
  for i in range(0, len):
    SSM[i] = cos(chrom[i].view(1, -1),chrom)
  return (SSM)

In [121]:
#this bundles the SSM function.
def batch_SSM(seq, batch_size):
  # takes sequence in format
  # [beats=400, batch_size, 128]
  # print("SSM\tsequence_shape", seq.shape)
  SSMs = []
  for i in range(0, batch_size):
    # print("SSM\tsequence", seq[:,i,:].shape)
    ssm = SSM(seq[:,i,:])  # [beats, batch, 128]
    # print("SSM\tssm", ssm.shape)
    SSMs.append(ssm)  
  return torch.vstack(SSMs)

In [122]:
#sampling function 
def topk_sample_one(sequence, k):
  #takes in size sequence length, batch size, values
  softmax = sparsemax.Sparsemax(dim=2)
  vals, indices = torch.topk(sequence[:, :, 20:108],k)
  indices+=20
  seq = torch.distributions.Categorical(softmax(vals.float()))
  samples = seq.sample()
  onehot = F.one_hot(torch.gather(indices, -1, samples.unsqueeze(-1)), num_classes = sequence.shape[2]).squeeze(dim=2)
  return(onehot)

In [123]:
#samples multiple times for the time-step
def topk_batch_sample(sequence, k):
  for i in range(0, 3):
    new= topk_sample_one(sequence, k)
    if i ==0:
      sum = new
    else:
      sum+=new
  return(torch.where(sum>0, 1, 0))


In [124]:
def custom_loss(output, target):
  # custom loss function
  criterion = nn.BCEWithLogitsLoss()
  weighted_mse = criterion(output.double(), target.double())
  batch_size = output.size()[1]
  ssm_err = 0
  for i in range(0, batch_size):
    SSM1 = SSM(output[:,i,:])
    SSM2 = SSM(target[:,i,:])
    ssm_err += (torch.sum((SSM1-SSM2)**2)/(SSM2.size(0)**2))


  return torch.sum(weighted_mse)+ssm_err

In [125]:
# returns batches where piece size is constant within the batch
# but piece size is different across batches
# and batches are in random order
def make_variable_size_batches(data, min_batch_size=10):
  # sort data by num beats (element at index 2 in each sublist)
  data.sort(key = lambda x: x[2], reverse=False)  # sort descending

  # split data into batches, where each batch contains pieces of the same size
  batches = []

  i = 0  # counter of pieces
  
  while i < len(data):
    this_batch = []
    pieces_this_batch = 0
    current_beats = data[i][2] # num beats in this batch

    # for all pieces with this # of beats
    while i < len(data) and data[i][2] == current_beats:
      # get tensor from row of data, and reshape 
      just_tensor = data[i][0].view(1, data[i][0].shape[0], 128)  
      this_batch.append(just_tensor)

      # increment counters
      i += 1
      pieces_this_batch += 1

    # print("this batch", this_batch)
    # print("shapes in batch")
    # for p in this_batch:
      # print("\t", p.shape)
        
    # only save large enough batches
    if pieces_this_batch >= min_batch_size:
        # reformat pieces in this batch into one tensor of size [batch size, beats, 128]
        batch = torch.cat(this_batch, dim=0)

        # store batch
        batches.append(batch)

    # clean up variables
    del this_batch
    del pieces_this_batch
    del current_beats

  # randomize batches order
  random.shuffle(batches)

  return batches

In [126]:
# Takes in the batch size and data and returns batches of the batch size
def make_batches(data, batch_size, piece_size=800, slice_data=False):
  # slice data to piece_size
  if slice_data:
    data = [(roll[0][:piece_size], roll[1], roll[2]) for roll in data if len(roll[0]) >= piece_size]
    print(f"sliced pieces to length {piece_size}; total pieces is now {len(data)}")
    
  random.shuffle(data)
  batches = []
  if batch_size > 1:  # make batches
    num_batches = len(data)//batch_size
    for i in range(0, num_batches):
      # take just the tensors, and vstack
      batch = torch.cat([roll[0] for roll in data[i*batch_size: (i+1)*(batch_size)]]).view(batch_size, piece_size, 128)
      batches.append(batch)
  else:  # each piece is its own batch - doesn't use passed-in piece_size
    for i in range(len(data)):
      # removes tempo info from data, but leaves 1 piece per batch
      piece_size = data[i][0].shape[0]
      batch = data[i][0].view(1, piece_size, 128)
      batches.append(batch)
      # print(batches[i])
  # print(batches)
  return batches

In [127]:
def get_models_in_directory(filepath):
    # get the names of all files, sorted by epoch
    files = os.listdir(my_drive_path + filepath)
    files_by_epoch = [(name, int(name.split("-")[2])) for name in files if 'model' in name]  # parse out epoch
    files_by_epoch.sort(key = lambda x: x[1])  # sort by epoch
    
    # load model at each epoch
    files_to_load = [my_drive_path + filepath + filename[0] for filename in files_by_epoch]  # paths to each file
    # print("loading:\n", files_to_load, sep="")
    models = [torch.load(filename) for filename in files_to_load]  # load models
    return models

In [128]:
models_att = get_models_in_directory("trained/attention_model_v2/")
models_lstm = get_models_in_directory("trained/lstm_v2/")

In [129]:
models_att

[music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoi

In [130]:
models_lstm

[music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoid()
 ),
 music_generator(
   (lstm): LSTM(128, 128)
   (attention): Linear(in_features=2, out_features=1, bias=True)
   (softmax): Sparsemax(dim=1)
   (sigmoid): Sigmoi

In [132]:
def generate_n_pieces_revised(trained_model, initial_vectors, n_pieces, length, batched_ssm):
    # generates a batch of n new pieces of music
    print(f"\tgenerating {n_pieces} piece(s) of length {length}")

    # freeze generator so it doesn't train anymore
    trained_model.eval()  
    # start generator on random hidden states and cell states
    trained_model.set_random_hidden(n_pieces)
  
    # initial vectors in format [batch_size, num_notes=10, 128]
    # change sequence to [10, batch_size, 128]
    sequence = initial_vectors.transpose(0,1)
    # print("sequence", sequence.shape)
    # print("batched_ssm", batched_ssm.shape)
    max_notes = batched_ssm.shape[0]  # can't generate more notes than the ssm
    
    # generate [length] more beats for the piece
    for i in range(0, min(length, max_notes)):  # one at a time
      with torch.no_grad():
        # use n_pieces to generate as the batch size
        output, _ = trained_model.forward(sequence.float(), n_pieces, sequence, batched_ssm)
        next_element = topk_batch_sample(output, 5)
      # add element to sequence
      sequence = torch.vstack((sequence, next_element.to("cpu")))

    # return sequence of beats
    return sequence

In [133]:
def generate_val_pieces(trained_model, val_batch, starter_notes=10):
    # how many pieces in batch
    n = val_batch.shape[0]
    
    # length of the pieces in this batch
    length = val_batch.shape[1]
    
    # take first 10 notes in format [1, 10, 128]
    first_vecs = val_batch[:,0:starter_notes,:]

    # create batched SSMs for each piece
    batched_ssms = batch_SSM(val_batch.transpose(0,1), n)

    # generate pieces
    generated = generate_n_pieces_revised(trained_model, first_vecs, n, length - starter_notes, batched_ssms)
    
    # get loss
    single_loss = custom_loss(generated[starter_notes:,:,:], val_batch.transpose(0,1)[starter_notes:,:,:])
    #single_loss.backward()
    loss = single_loss.detach()

    # clean up variables
    del n
    del length
    del first_vecs
    del batched_ssms
    del single_loss

    # return pieces and loss
    return generated, float(loss)

In [134]:
def get_loss_on_data(trained_model, data, val_size=450, variable_batches=False):
    """Get loss when generating on given data, generating with a given model"""
    # make batches from val_data
    # so we can generate for all pieces of the same length at once
    if variable_batches:
        batches = make_variable_size_batches(data, min_batch_size=0)
    else:
        # batch size 50 for speeeed (we hopeee)
        batches = make_batches(data, 50, piece_size=val_size, slice_data=True)
    
    # accumulate loss for each batch
    cum_loss = 0
    for batch_num in range(len(batches)): # tqdm(range(len(batches))):
        val_batch = batches[batch_num]
        # create pieces to get loss
        generated, loss = generate_val_pieces(trained_model, val_batch)
        print("\tloss:", loss)
        cum_loss += loss
        del val_batch
        del loss
        del generated
        
    return cum_loss

In [135]:
# for each model, get validation loss
def val_loss_for_all_models(models, val_data, val_size=450, variable_batches=False):
    loss_list = []
    for i in tqdm(range(len(models))):
        model = models[i]
        print(f"calculating loss for epoch {i}:")
        loss_list.append(get_loss_on_data(model, val_data, val_size, variable_batches))
        print(loss_list)
        del model
        
    return loss_list

In [131]:
# load validation data
# val_data = torch.load(my_drive_path + "usable_data/validation_tempo_all_w_beats_30.csv") 
val_data = torch.load(my_drive_path + "usable_data/validation_tempo_round_down_30.csv")
torch.manual_seed(2022)

<torch._C.Generator at 0x7fbab0985c70>

In [136]:
loss_list_att = val_loss_for_all_models(models_att, val_data, val_size=450, variable_batches=False)
print(loss_list_att)

# first run: [58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162, 60.76163521445571, 65.72370909473803, 65.29740455011577, 53.31530709361719, 43.10948994384802, 42.501780608520875, 41.81713659329537, 34.23370181352132, 37.12999854703072, 34.86240460698944, 35.05370937217055, 34.59301298739383, 30.543632000399867, 29.509237587058692, 27.812329880537916]

  0%|          | 0/20 [00:00<?, ?it/s]

calculating loss for epoch 0:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 29.409865432434263
	generating 50 piece(s) of length 440


  5%|▌         | 1/20 [01:29<28:24, 89.71s/it]

	loss: 29.387948848072583
[58.79781428050684]
calculating loss for epoch 1:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 31.032432347772954
	generating 50 piece(s) of length 440


 10%|█         | 2/20 [03:01<27:11, 90.65s/it]

	loss: 29.96214896986546
[58.79781428050684, 60.99458131763841]
calculating loss for epoch 2:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 29.349624417086698
	generating 50 piece(s) of length 440


 15%|█▌        | 3/20 [04:45<27:31, 97.14s/it]

	loss: 31.6864883516903
[58.79781428050684, 60.99458131763841, 61.036112768777]
calculating loss for epoch 3:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 30.8112998751278
	generating 50 piece(s) of length 440


 20%|██        | 4/20 [06:55<29:22, 110.14s/it]

	loss: 30.244667883828086
[58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589]
calculating loss for epoch 4:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 31.800373418329595
	generating 50 piece(s) of length 440


 25%|██▌       | 5/20 [08:34<26:31, 106.11s/it]

	loss: 29.17691014762203
[58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162]
calculating loss for epoch 5:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 30.10678407673247
	generating 50 piece(s) of length 440


 30%|███       | 6/20 [10:04<23:28, 100.58s/it]

	loss: 30.654851137723238
[58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162, 60.76163521445571]
calculating loss for epoch 6:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 32.06947941281424
	generating 50 piece(s) of length 440


 35%|███▌      | 7/20 [11:35<21:07, 97.47s/it] 

	loss: 33.65422968192379
[58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162, 60.76163521445571, 65.72370909473803]
calculating loss for epoch 7:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 31.85992269034664
	generating 50 piece(s) of length 440


 40%|████      | 8/20 [13:08<19:12, 96.02s/it]

	loss: 33.437481859769136
[58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162, 60.76163521445571, 65.72370909473803, 65.29740455011577]
calculating loss for epoch 8:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 26.159607185752357
	generating 50 piece(s) of length 440


 45%|████▌     | 9/20 [14:40<17:23, 94.84s/it]

	loss: 27.15569990786484
[58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162, 60.76163521445571, 65.72370909473803, 65.29740455011577, 53.31530709361719]
calculating loss for epoch 9:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 22.090793969976435
	generating 50 piece(s) of length 440


 50%|█████     | 10/20 [16:13<15:40, 94.07s/it]

	loss: 21.018695973871587
[58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162, 60.76163521445571, 65.72370909473803, 65.29740455011577, 53.31530709361719, 43.10948994384802]
calculating loss for epoch 10:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 20.459994113963916
	generating 50 piece(s) of length 440


 55%|█████▌    | 11/20 [17:46<14:03, 93.68s/it]

	loss: 22.041786494556955
[58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162, 60.76163521445571, 65.72370909473803, 65.29740455011577, 53.31530709361719, 43.10948994384802, 42.501780608520875]
calculating loss for epoch 11:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 20.918785212818676
	generating 50 piece(s) of length 440


 60%|██████    | 12/20 [19:19<12:27, 93.44s/it]

	loss: 20.8983513804767
[58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162, 60.76163521445571, 65.72370909473803, 65.29740455011577, 53.31530709361719, 43.10948994384802, 42.501780608520875, 41.81713659329537]
calculating loss for epoch 12:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 17.051150351479453
	generating 50 piece(s) of length 440


 65%|██████▌   | 13/20 [20:51<10:51, 93.14s/it]

	loss: 17.182551462041864
[58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162, 60.76163521445571, 65.72370909473803, 65.29740455011577, 53.31530709361719, 43.10948994384802, 42.501780608520875, 41.81713659329537, 34.23370181352132]
calculating loss for epoch 13:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 17.93105380079808
	generating 50 piece(s) of length 440


 70%|███████   | 14/20 [22:23<09:16, 92.79s/it]

	loss: 19.198944746232648
[58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162, 60.76163521445571, 65.72370909473803, 65.29740455011577, 53.31530709361719, 43.10948994384802, 42.501780608520875, 41.81713659329537, 34.23370181352132, 37.12999854703072]
calculating loss for epoch 14:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 17.76330489388224
	generating 50 piece(s) of length 440


 75%|███████▌  | 15/20 [23:57<07:45, 93.17s/it]

	loss: 17.099099713107204
[58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162, 60.76163521445571, 65.72370909473803, 65.29740455011577, 53.31530709361719, 43.10948994384802, 42.501780608520875, 41.81713659329537, 34.23370181352132, 37.12999854703072, 34.86240460698944]
calculating loss for epoch 15:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 17.98697329403982
	generating 50 piece(s) of length 440


 80%|████████  | 16/20 [25:28<06:10, 92.53s/it]

	loss: 17.06673607813073
[58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162, 60.76163521445571, 65.72370909473803, 65.29740455011577, 53.31530709361719, 43.10948994384802, 42.501780608520875, 41.81713659329537, 34.23370181352132, 37.12999854703072, 34.86240460698944, 35.05370937217055]
calculating loss for epoch 16:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 17.42579578022368
	generating 50 piece(s) of length 440


 85%|████████▌ | 17/20 [27:00<04:36, 92.27s/it]

	loss: 17.167217207170147
[58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162, 60.76163521445571, 65.72370909473803, 65.29740455011577, 53.31530709361719, 43.10948994384802, 42.501780608520875, 41.81713659329537, 34.23370181352132, 37.12999854703072, 34.86240460698944, 35.05370937217055, 34.59301298739383]
calculating loss for epoch 17:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 15.182947635692432
	generating 50 piece(s) of length 440


 90%|█████████ | 18/20 [28:32<03:04, 92.20s/it]

	loss: 15.360684364707435
[58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162, 60.76163521445571, 65.72370909473803, 65.29740455011577, 53.31530709361719, 43.10948994384802, 42.501780608520875, 41.81713659329537, 34.23370181352132, 37.12999854703072, 34.86240460698944, 35.05370937217055, 34.59301298739383, 30.543632000399867]
calculating loss for epoch 18:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 14.659039483198868
	generating 50 piece(s) of length 440


 95%|█████████▌| 19/20 [30:05<01:32, 92.42s/it]

	loss: 14.850198103859825
[58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162, 60.76163521445571, 65.72370909473803, 65.29740455011577, 53.31530709361719, 43.10948994384802, 42.501780608520875, 41.81713659329537, 34.23370181352132, 37.12999854703072, 34.86240460698944, 35.05370937217055, 34.59301298739383, 30.543632000399867, 29.509237587058692]
calculating loss for epoch 19:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 14.048383687494633
	generating 50 piece(s) of length 440


100%|██████████| 20/20 [31:37<00:00, 94.86s/it]

	loss: 13.763946193043283
[58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162, 60.76163521445571, 65.72370909473803, 65.29740455011577, 53.31530709361719, 43.10948994384802, 42.501780608520875, 41.81713659329537, 34.23370181352132, 37.12999854703072, 34.86240460698944, 35.05370937217055, 34.59301298739383, 30.543632000399867, 29.509237587058692, 27.812329880537916]
[58.79781428050684, 60.99458131763841, 61.036112768777, 61.05596775895589, 60.97728356595162, 60.76163521445571, 65.72370909473803, 65.29740455011577, 53.31530709361719, 43.10948994384802, 42.501780608520875, 41.81713659329537, 34.23370181352132, 37.12999854703072, 34.86240460698944, 35.05370937217055, 34.59301298739383, 30.543632000399867, 29.509237587058692, 27.812329880537916]





In [137]:
loss_list_lstm = val_loss_for_all_models(models_lstm, val_data, val_size=450, variable_batches=False)
print(loss_list_lstm)

# first run: [27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386, 26.299691983393252, 26.071323804765548, 23.759693231146052, 22.978566764915136, 23.125428292791558, 24.846454579523623, 24.912110084617286, 24.860141947743173, 24.374138179515857, 24.19659813282396, 24.10343893510335, 24.259611301158923, 23.4498316783007, 23.99073953810295, 23.319251953208592]

  0%|          | 0/20 [00:00<?, ?it/s]

calculating loss for epoch 0:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 13.78654850461111
	generating 50 piece(s) of length 440


  5%|▌         | 1/20 [01:33<29:40, 93.69s/it]

	loss: 13.517749698680714
[27.304298203291822]
calculating loss for epoch 1:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 13.480033661970841
	generating 50 piece(s) of length 440


 10%|█         | 2/20 [03:07<28:09, 93.88s/it]

	loss: 13.325959674183421
[27.304298203291822, 26.805993336154263]
calculating loss for epoch 2:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 13.307864302590293
	generating 50 piece(s) of length 440


 15%|█▌        | 3/20 [04:43<26:53, 94.93s/it]

	loss: 13.32207339100249
[27.304298203291822, 26.805993336154263, 26.629937693592783]
calculating loss for epoch 3:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 12.929111645480251
	generating 50 piece(s) of length 440


 20%|██        | 4/20 [06:14<24:48, 93.04s/it]

	loss: 12.828831533647199
[27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745]
calculating loss for epoch 4:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 12.829754607849303
	generating 50 piece(s) of length 440


 25%|██▌       | 5/20 [13:14<52:45, 211.05s/it]

	loss: 13.097188076061084
[27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386]
calculating loss for epoch 5:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 13.132048445656698
	generating 50 piece(s) of length 440


 30%|███       | 6/20 [14:54<40:25, 173.26s/it]

	loss: 13.167643537736554
[27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386, 26.299691983393252]
calculating loss for epoch 6:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 13.090982450006841
	generating 50 piece(s) of length 440


 35%|███▌      | 7/20 [16:31<32:10, 148.47s/it]

	loss: 12.980341354758705
[27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386, 26.299691983393252, 26.071323804765548]
calculating loss for epoch 7:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 12.112324711581326
	generating 50 piece(s) of length 440


 40%|████      | 8/20 [18:06<26:15, 131.29s/it]

	loss: 11.647368519564724
[27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386, 26.299691983393252, 26.071323804765548, 23.759693231146052]
calculating loss for epoch 8:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 10.851440169983093
	generating 50 piece(s) of length 440


 45%|████▌     | 9/20 [19:40<21:56, 119.67s/it]

	loss: 12.127126594932045
[27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386, 26.299691983393252, 26.071323804765548, 23.759693231146052, 22.978566764915136]
calculating loss for epoch 9:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 11.61172355551998
	generating 50 piece(s) of length 440


 50%|█████     | 10/20 [21:17<18:48, 112.83s/it]

	loss: 11.513704737271578
[27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386, 26.299691983393252, 26.071323804765548, 23.759693231146052, 22.978566764915136, 23.125428292791558]
calculating loss for epoch 10:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 12.44284809359482
	generating 50 piece(s) of length 440


 55%|█████▌    | 11/20 [22:56<16:17, 108.59s/it]

	loss: 12.403606485928805
[27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386, 26.299691983393252, 26.071323804765548, 23.759693231146052, 22.978566764915136, 23.125428292791558, 24.846454579523623]
calculating loss for epoch 11:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 12.415569334245344
	generating 50 piece(s) of length 440


 60%|██████    | 12/20 [24:34<14:03, 105.44s/it]

	loss: 12.496540750371942
[27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386, 26.299691983393252, 26.071323804765548, 23.759693231146052, 22.978566764915136, 23.125428292791558, 24.846454579523623, 24.912110084617286]
calculating loss for epoch 12:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 12.647692336037558
	generating 50 piece(s) of length 440


 65%|██████▌   | 13/20 [26:14<12:06, 103.75s/it]

	loss: 12.212449611705615
[27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386, 26.299691983393252, 26.071323804765548, 23.759693231146052, 22.978566764915136, 23.125428292791558, 24.846454579523623, 24.912110084617286, 24.860141947743173]
calculating loss for epoch 13:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 12.38969069363699
	generating 50 piece(s) of length 440


 70%|███████   | 14/20 [28:02<10:30, 105.02s/it]

	loss: 11.984447485878867
[27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386, 26.299691983393252, 26.071323804765548, 23.759693231146052, 22.978566764915136, 23.125428292791558, 24.846454579523623, 24.912110084617286, 24.860141947743173, 24.374138179515857]
calculating loss for epoch 14:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 11.981786048584166
	generating 50 piece(s) of length 440


 75%|███████▌  | 15/20 [31:01<10:36, 127.26s/it]

	loss: 12.214812084239796
[27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386, 26.299691983393252, 26.071323804765548, 23.759693231146052, 22.978566764915136, 23.125428292791558, 24.846454579523623, 24.912110084617286, 24.860141947743173, 24.374138179515857, 24.19659813282396]
calculating loss for epoch 15:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 12.076395664256886
	generating 50 piece(s) of length 440


 80%|████████  | 16/20 [33:56<09:26, 141.52s/it]

	loss: 12.027043270846463
[27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386, 26.299691983393252, 26.071323804765548, 23.759693231146052, 22.978566764915136, 23.125428292791558, 24.846454579523623, 24.912110084617286, 24.860141947743173, 24.374138179515857, 24.19659813282396, 24.10343893510335]
calculating loss for epoch 16:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 12.21629368976958
	generating 50 piece(s) of length 440


 85%|████████▌ | 17/20 [37:02<07:44, 154.90s/it]

	loss: 12.043317611389343
[27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386, 26.299691983393252, 26.071323804765548, 23.759693231146052, 22.978566764915136, 23.125428292791558, 24.846454579523623, 24.912110084617286, 24.860141947743173, 24.374138179515857, 24.19659813282396, 24.10343893510335, 24.259611301158923]
calculating loss for epoch 17:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 11.812045615151327
	generating 50 piece(s) of length 440


 90%|█████████ | 18/20 [40:51<05:54, 177.22s/it]

	loss: 11.637786063149374
[27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386, 26.299691983393252, 26.071323804765548, 23.759693231146052, 22.978566764915136, 23.125428292791558, 24.846454579523623, 24.912110084617286, 24.860141947743173, 24.374138179515857, 24.19659813282396, 24.10343893510335, 24.259611301158923, 23.4498316783007]
calculating loss for epoch 18:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 12.257182109874561
	generating 50 piece(s) of length 440


 95%|█████████▌| 19/20 [42:31<02:34, 154.09s/it]

	loss: 11.733557428228387
[27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386, 26.299691983393252, 26.071323804765548, 23.759693231146052, 22.978566764915136, 23.125428292791558, 24.846454579523623, 24.912110084617286, 24.860141947743173, 24.374138179515857, 24.19659813282396, 24.10343893510335, 24.259611301158923, 23.4498316783007, 23.99073953810295]
calculating loss for epoch 19:
sliced pieces to length 450; total pieces is now 127
	generating 50 piece(s) of length 440
	loss: 11.45040432986191
	generating 50 piece(s) of length 440


100%|██████████| 20/20 [44:00<00:00, 132.04s/it]

	loss: 11.868847623346685
[27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386, 26.299691983393252, 26.071323804765548, 23.759693231146052, 22.978566764915136, 23.125428292791558, 24.846454579523623, 24.912110084617286, 24.860141947743173, 24.374138179515857, 24.19659813282396, 24.10343893510335, 24.259611301158923, 23.4498316783007, 23.99073953810295, 23.319251953208592]
[27.304298203291822, 26.805993336154263, 26.629937693592783, 25.75794317912745, 25.926942683910386, 26.299691983393252, 26.071323804765548, 23.759693231146052, 22.978566764915136, 23.125428292791558, 24.846454579523623, 24.912110084617286, 24.860141947743173, 24.374138179515857, 24.19659813282396, 24.10343893510335, 24.259611301158923, 23.4498316783007, 23.99073953810295, 23.319251953208592]



