# __Neural network models__

In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torchvision import models
from collections import OrderedDict

## Recurent Neural Network

In [35]:
class LSTM_regular(nn.Module):
    def __init__(self, embedding_dims, hidden_size, drop = 0):
        super().__init__()
        torch.set_default_dtype(torch.float32)

        # Network parameters
        self.W = nn.Parameter(torch.Tensor(embedding_dims, hidden_size * 4))
        self.U = nn.Parameter(torch.Tensor(hidden_size, hidden_size * 4))
        self.bias = nn.Parameter(torch.Tensor(4 * hidden_size))

        # Add dropout layer
        self.dropout = nn.Dropout(drop)

        # Initialize weights
        self.init_weights(hidden_size)


    def init_weights(self, hidden_size):
        stdv = 1.0 / math.sqrt(hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)


    def forward(self, input, state):
        seq_length = input.shape(1)
        output = []

        # Forward pass for each word 
        for i in range(seq_length):
            x = input[:,i,:]

            x, state = self.forward_cell(x, state)

            # Save results 
            output.append(x)
        
        # Join results and reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        output = torch.cat(output, dim=0)
        # TODO check
        output = output.transpose(0, 1).contiguous()

        return output, state
        

    # Computes forward for one timestep (one word of sequence)
    def forward_cell(self, x_t, cell_states):
        # load current state of cell
        h_t, c_t = cell_states

        # Squeeze dims if they equal 1 
        h_t = h_t.squeeze(dim=0)
        c_t = c_t.squeeze(dim=0)

        # Forward pass
        gates = torch.matmul(x_t, self.W) + torch.matmul(h_t, self.U) + self.bias

        gates = gates.squeeze()

        # Devide vector into gates
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        # Compute state of each gate
        ingate = F.sigmoid(ingate)
        forgetgate = F.sigmoid(forgetgate)
        cellgate = F.tanh(cellgate)
        outgate = F.sigmoid(outgate)
        
        # Compute new cell state
        c_t = torch.mul(forgetgate, c_t) +  torch.mul(ingate, cellgate)        
        h_t = torch.mul(outgate, F.tanh(c_t))
        
        c_t = c_t.unsqueeze(0)
        h_t = h_t.unsqueeze(0)
        
        return self.dropout(h_t), (h_t, c_t)

In [36]:
class RNN(nn.Module):
    def __init__(self, rnn_type, output_size, vocab_size, embed_dims, params: dict):
        super().__init__()
        torch.set_default_dtype(torch.float32)
        
        self.rnn_type = rnn_type
        assert self.rnn_type in ['test', 'lstm', 'lstm_M', 'lstm_A'], f"RNN type '{self.rnn_type} 'is NOT supported."

        # Config parameters
        try:
            drop = params['embedding_dropout']
            layer_count = params['lstm_layers']
            dir2 = params['bidirectional']
            pad_idx = params['padding_index']
            lstm_hidden = params['lstm_features']
            drop_lstm = params['lstm_dropout'] 
        except KeyError as e:
            raise Exception(f'Parameter "{e.args[0]}" NOT found!')
        
        # Encoder layer = encodes indices of words to embedding vectors
        self.encoder = nn.Embedding(
            num_embeddings = vocab_size
            , embedding_dim = embed_dims
            , padding_idx = pad_idx
            )

        # Dropout layer = drops embedding features
        self.dropout = nn.Dropout(drop)

        # Initlize LSTM layers
        self.rnns = nn.ModuleList()

        match self.rnn_type:
            # LSTM - pytorch
            case "test":
                l = nn.LSTM(
                    input_size = embed_dims,
                    hidden_size = lstm_hidden, 
                    num_layers=layer_count, 
                    bidirectional=dir2, 
                    dropout=drop_lstm, 
                    batch_first= True 
                )
                for p in l.parameters():
                    # TODO init weights
                    pass

                self.rnns.append(l)

            # LSTM - regular
            case "lstm":
                for i in range(lstm_hidden):
                    if (i == lstm_hidden - 1):
                        self.rnns.append(LSTM_regular(embed_dims, lstm_hidden))
                    else:
                        self.rnns.append(LSTM_regular(embed_dims, lstm_hidden, drop_lstm))
                    
                
            # LSTM - momentum
            case "lstm_M":
                pass
            
            # LSTM - momentum ADAM
            case "lstm_A":
                pass
        
        # Decoder layer = output layer for network
        self.decoder = nn.Linear(lstm_hidden, output_size)


    def forward(self, train: bool, input, lstm_state = None):
        # Embeddings
        input = self.encoder(input)

        # Dropout 
        if train:
            input = self.dropout(input)

        # Initialize first state
        if lstm_state is None:
            hx = 0
            cx = 0
            lstm_state = (hx, cx)

       # Compute LSTM forward for each layer
        for lstm in self.rnns:
            input, lstm_state = lstm.forward(input, lstm_state)

        # TODO check
        input = self.decoder(input.view(input.size(0)*input.size(1), input.size(2)))
        return input.view(input.size(0), input.size(1), input.size(1)), lstm_state

In [37]:
from net_config import *

vocab_size = 100
output_size = 10

# Input tensor
batch_size = 3
seq_length = 10
embed_features = 5

test_input = torch.randn(batch_size, seq_length, embed_features)
print(test_input.shape)

net = RNN("test", output_size, vocab_size, embed_features, config_to_dict(config_NN))

torch.Size([3, 10, 5])


In [38]:
A = torch.randn(1, 2)
B = torch.randn(2, 3)

# works
C = torch.matmul(A, B)
print(C.shape)
print(C)


torch.Size([1, 3])
tensor([[-0.1391, -1.2906,  1.2185]])
