In [1]:
import numpy as np 
import pandas as pd 

import torch
from torch import nn
from torch.functional import F

In [2]:
device = 'cuda'

In [54]:
class LSTM(nn.Module):
    def __init__(self, hidden_size, input_size):
        super().__init__()

        self.hidden_size = hidden_size
        self.input_size = input_size

        # forget gate
        self.W_fh = nn.Parameter(torch.rand((hidden_size, hidden_size)), requires_grad=True)
        self.W_fx = nn.Parameter(torch.rand((hidden_size, input_size)), requires_grad=True)
        self.b_f = nn.Parameter(torch.rand((hidden_size, 1)), requires_grad=True)
        
        # update gate (aka input gate)
        self.W_uh = nn.Parameter(torch.rand((hidden_size, hidden_size)), requires_grad=True)
        self.W_ux = nn.Parameter(torch.rand((hidden_size, input_size)), requires_grad=True)
        self.b_u = nn.Parameter(torch.rand((hidden_size, 1)), requires_grad=True)

        # output gate
        self.W_oh = nn.Parameter(torch.rand((hidden_size, hidden_size)), requires_grad=True)
        self.W_ox = nn.Parameter(torch.rand((hidden_size, input_size)), requires_grad=True)
        self.b_o = nn.Parameter(torch.rand((hidden_size, 1)), requires_grad=True)

        # candidate for new c
        self.W_ch = nn.Parameter(torch.rand((hidden_size, hidden_size)), requires_grad=True)
        self.W_cx = nn.Parameter(torch.rand((hidden_size, input_size)), requires_grad=True)
        self.b_c = nn.Parameter(torch.rand((hidden_size, 1)), requires_grad=True)


    # x should be of shape (input_size, 1)
    def forward(self, c_prev, h_prev, x):
        if x.shape != (self.input_size, 1):
            raise ValueError('x.shape != (input_size, 1)')


        f = torch.matmul(self.W_fh, h_prev) + torch.matmul(self.W_fx, x) + self.b_f
        f = F.sigmoid(f)
        # print(f'f shape: {f.shape}')

        u = torch.matmul(self.W_uh, h_prev) + torch.matmul(self.W_ux, x) + self.b_u
        u = F.sigmoid(u)
        # print(f'u shape: {u.shape}')

        o = torch.matmul(self.W_oh, h_prev) + torch.matmul(self.W_ox, x) + self.b_o 
        o = F.sigmoid(o)
        # print(f'o shape: {o.shape}')

        c_cand = torch.matmul(self.W_ch, h_prev) + torch.matmul(self.W_cx, x) + self.b_c
        c_cand = F.tanh(c_cand)
        # print(f'c_cand shape: {c_cand.shape}')

        c = (u * c_cand) + (f * c_prev)
        # print(f'c shape: {c.shape}')

        h = o * F.tanh(c)

        return c, h
        

In [55]:
lstm_model = LSTM(hidden_size=3, input_size=5).to(device)

In [56]:
c_prev = torch.zeros((3, 1), dtype=torch.float32).to(device)
h_prev = torch.zeros((3, 1), dtype=torch.float32).to(device)

x = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.float32).T.to(device)

c, h = lstm_model(c_prev=c_prev, h_prev=h_prev, x=x)

In [57]:
c, h

(tensor([[0.9972],
         [1.0000],
         [0.9994]], device='cuda:0', grad_fn=<AddBackward0>),
 tensor([[0.7604],
         [0.7611],
         [0.7614]], device='cuda:0', grad_fn=<MulBackward0>))

In [48]:
X = [[[1]], [[2]], [[3]], [[4]], [[5]]]
y = [[[6]]]