In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import time
import pprint
from torchsummary import summary


class PerformanceLSTM(nn.Module):
    def __init__(self,
                 # embeddings: nn.Module,
                 input_size: int,
                 hidden_size: int,
                 num_layers: int,
                 device:str):
        super(PerformanceLSTM, self).__init__()

        self.device = device

        self.input_size = input_size
        self.hidden_size = hidden_size

        self.num_layers = num_layers

        # self.embeddings = embeddings
#         print(self.embeddings)
        self.model = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True).to(self.device)
        self.fc = nn.Linear(hidden_size*2, input_size, bias=False).to(self.device) # Fully connected layer

        self.relu = nn.ReLU()
        self.softplus = nn.Softplus(beta=500, threshold=0)
        self.sigmoid = nn.Sigmoid()

        # summary(self.model, (1, 16, 88))

        print(f'= MODELO = ')
        print(f'\n{self.model}\n')
        print(f'\n{self.fc}\n')
        print('='*42)
        # summary(self.lstm, (88, 88))

        # self.lr = 3e-4 # Learning Rate
        # self.adam_optimizer = optim.Adam(self.parameters(), lr=lr)
        # self.sgd_optimizer = optim.SGD(self.parameters(), lr=lr)
        # self.loss_fn = nn.BCELoss() # Binary Cross Entropy


    def forward(self, input_seq, hidden, cell):

        # print('FORWARD')

#         print(input_seq)

#         ix_to_word = np.load('Embeddings/ix_to_word.npy', allow_pickle=True)
#         word_to_ix = np.load('Embeddings/word_to_ix.npy', allow_pickle=True)
#         print(ix_to_word)
#         print(word_to_ix)

        # (1, 16, 88) -> floats
        input_seq = input_seq.to(self.device)

#         for batx in input_seq:
#             for frame in batx:
#                 frame = self.embeddings.weight[word_to_ix[tuple(frame)]]

        # 64 = EMBEDDING_DIM
        # (1, 16, 64) -> floats

        # print(input_seq)
        # print('Input shape: ', len(input_seq))

        # print(input_seq.shape)

        # Passing in the input and hidden state into the model and obtaining outputs
        output_seq, (hidden, cell) = self.model(input_seq, (hidden, cell))
        # print('Output shape: ', output_seq.shape, output_seq)

        # expanded_output_seq -> (1, 16, 176)

        # Reshaping the outputs such that it can be fit into the fully connected layer
        output_seq = output_seq.contiguous().view(-1, self.hidden_size*2)
        # print('Output shape B: ',output_seq.shape)
        output_seq = output_seq.squeeze(-1)
        # print('Output shape C: ',output_seq.shape, output_seq)

        # expanded_output_seq -> (1, 16, 88)?
        output_seq = torch.abs(output_seq) # !
        # print('Output shape C: ',output_seq.shape, output_seq)
        output_seq = self.fc(output_seq)
        # print('Output shape D: ',output_seq.shape, output_seq)

#         for batx in output_seq:
#             for frame in batx:
#                 frame = self.embedding.weight[ix_to_word[tuple(frame)]]

        output_seq = self.sigmoid(output_seq)
        # print('Output shape E: ',output_seq.shape, output_seq)

        # output_seq = expanded_output_seq.squeeze(-1)

        return output_seq, (hidden, cell)


    def init_hidden(self, batch_size):
        # This method generates the first hidden state of zeros which we'll use in the forward pass
        # We'll send the tensor holding the hidden state to the device we specified earlier as well
        hidden = torch.zeros(self.num_layers*2, batch_size, self.hidden_size).float().to(self.device)
        cell = torch.zeros(self.num_layers*2, batch_size, self.hidden_size).float().to(self.device)

        return (hidden, cell)



    def plot_loss_update(self, i, n, mb, train_loss):
        '''
            Dynamically print the loss plot during the training/validation loop.
            Expects epoch to start from 1.
        '''

        mb.names = ['Loss']
        x = range(1, i+1)
        y = train_loss
        graphs = [[x,train_loss]]
        x_margin = 0.2
        y_margin = 0.05
        x_bounds = [1-x_margin, n+x_margin]
        y_bounds = [np.min(y)-y_margin, np.max(y)+y_margin]

        mb.update_graph(graphs, x_bounds, y_bounds)


    @torch.no_grad()
    def generate(self, context, resolution, predict_amount=1, temperature=0.5, batch_size=1):
        '''
            Context: MHE matrix with dims (BEAT_AMOUNT, KEYBOARD_SIZE)
        '''

        hidden, cell = self.init_hidden(batch_size)

        # Send context data to GPU and cast to float
        if isinstance(context, torch.Tensor):
          context = context.to(self.device)
          context = context.float()
        else:
          context = torch.from_numpy(context.astype(float)).float().to(self.device)

        # Amount of beats in the context
        context_beat_amount = len(context) - resolution

        # Feed context to model
        for i in range(context_beat_amount):
            # Get frames for iteration beat
            context_beat_i = context[i:i + resolution]

            # We DONT CARE about the output here.
            # We are just feeding the model with the
            # context we received as input
            _, (hidden, cell) = self.mmodel(context_beat_i,
                                      (hidden, cell))

        # We CARE about the output from the last context beat
        previous_beat = context[context_beat_amount:]

        output_beats = []
        for _ in range(predict_amount):
            '''
            Here, output_seq has it's first
            (resolution-1) frames equal to the previous
            beat and the last one different,
            the generated frame.

            So now we will repeat this line of code
            until all its frames are new.
            '''
            output_seq, (hidden, cell) = self(previous_beat, hidden, cell)

            # Compute the other remaining frames
            for _ in range(resolution - 1):
                '''
                As the returned tensor holds Float64s (probabilities)
                and the input type of the model is a boolean int (0 or 1).
                So we call the torch.where func to make valeus greater
                or equal to (1-temperature) become 1 and 0 otherwise.
                '''
                output_seq = torch.where(output_seq >= (1-temperature), 1, 0).float()
                output_seq, (hidden, cell) = self.model(output_seq, hidden, cell)

                # update
                previous_beat = output_seq


            # Make 1s and 0s become booleans
            output_seq = np.where(output_seq.cpu() >= (1-temperature), True, False)
            output_beats.append(output_seq)

        # print(out, out.shape)
        return np.array(output_beats, dtype=bool)