In [3]:
import torch
import numpy as np
import torch.nn as nn
from click.termui import hidden_prompt_func
from openpyxl.styles.builtins import output

In [4]:
class RNN(nn.Module):
    def __init__(self,input_size,hidden_layer,output_size):
        super(RNN,self).__init__()
        self.i2h = nn.Linear(input_size,hidden_layer)
        self.w2h = nn.Linear(hidden_layer,hidden_layer)
        self.h2o = nn.Linear(hidden_layer,output_size)
        self.hidden_size = hidden_layer
        #hidden = self.init_hidden()
        
    def forward(self,x,hidden=None):
        if(hidden == None):
            hidden = self.init_hidden()
        x = self.i2h(x)
        hidden = self.w2h(hidden)
        hidden = torch.tanh(x + hidden)
        output = self.h2o(hidden)
        return output,hidden
    
    def init_hidden(self):
        return torch.zeros(1,self.hidden_size)
    

In [5]:
#test
a = torch.zeros(128,10)

model = RNN(10,30,20)
#model.init_hidden()
output,hidden = model(a)
output.shape

#test using torch.rnn
model = nn.RNN(input_size=10,hidden_size=30)
model(a)[0].shape

torch.Size([128, 30])

In [6]:
#writing rnn model
class RNNModel(nn.Module):
    def __init__(self,input_size,hidden_layer,output_size):
        super(RNNModel,self).__init__()
        self.rnncell = RNN(input_size,hidden_layer,output_size)
        
    def forward(self,x):
        len,batch_size,__ = x.size() # suppose x dim (seq_len, batch, input_size)
        outputs = []
        for i in range(len):
            output, hidden = self.rnncell(x[i])
            outputs.append(output)
        outputs = torch.stack(outputs,dim=0)
        return outputs,hidden 
    

In [7]:
#test
b = torch.zeros(10,128,5)
model = RNNModel(5,20,15)
model(b)[0].shape

torch.Size([10, 128, 15])

In [21]:
class LSTM(nn.Module):
    def __init__(self,input_size,hidden_layer):
        super(LSTM,self).__init__()
        self.f = nn.Linear(input_size+hidden_layer,hidden_layer)
        self.i = nn.Linear(input_size+hidden_layer,hidden_layer)
        self.c = nn.Linear(input_size+hidden_layer,hidden_layer)
        self.o = nn.Linear(input_size+hidden_layer,hidden_layer)
        self.hidden_size = hidden_layer
        
    def forward(self,x,hidden=None,cell=None):
        if hidden == None:
            hidden = self.init_hidden(x)
            cell = self.init_hidden(x)
        joint = torch.cat((x,hidden),dim=-1)
        f = torch.sigmoid(self.f(joint))
        i = torch.sigmoid(self.i(joint))
        c = torch.tanh(self.c(joint))
        o = torch.sigmoid(self.o(joint))
        cell = f*cell + i*c
        hidden = o*torch.tanh(cell)
        return hidden,cell
    
    def init_hidden(self,x):
        return torch.zeros(x.shape[0],self.hidden_size)
    
class LSTMModel(nn.Module):
    def __init__(self,input_size,hidden_layer):
        super(LSTMModel,self).__init__()
        self.lstmcell = LSTM(input_size,hidden_layer)
    
    def forward(self,x):
        len,batch_size,__ = x.size()
        outputs = []
        for i in range(len):
            hidden,cell = self.lstmcell(x[i])
            outputs.append(hidden)
        outputs = torch.stack(outputs,dim=0)
        return outputs,(hidden,cell)

In [24]:
c = torch.zeros(10,128,5)
model = LSTMModel(5,20)
model(c)[1][0].shape

torch.Size([128, 20])