# Peephole LSTM

Given an implementation of an LSTM module:
\begin{align}
i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
g_t = tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{go}) \\
o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
h_t = o_t \odot tanh(c_t)
\end{align}


Your task is to modify the implementaiton to add [peephole connections](https://en.wikipedia.org/wiki/Long_short-term_memory#Peephole_LSTM) according to:

\begin{align}
i_t = \sigma(W_{ii} x_t + b_{ii} + W_{ci} c_{t-1} + b_{ci}) \\
f_t = \sigma(W_{if} x_t + b_{if} + W_{cf} c_{t-1} + b_{cf}) \\
o_t = \sigma(W_{io} x_t + b_{io} + W_{co} c_{t-1} + b_{co}) \\
c_t = f_t \odot c_{t-1} + i_t \odot tanh(W_{ic} x_t + b_{ic}) \\
h_t = o_t \odot c_t
\end{align}

In [None]:
import typing
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

np.random.seed(0)
torch.manual_seed(0)
random.seed(0)

In [None]:
class LSTM(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, batch_first: bool):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.batch_first = batch_first

        #input gate
        self.W_ii = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_hi = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_ii = nn.Parameter(torch.Tensor(hidden_size))
        self.b_hi = nn.Parameter(torch.Tensor(hidden_size))

        #forget gate
        self.W_if = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_hf = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_if = nn.Parameter(torch.Tensor(hidden_size))
        self.b_hf = nn.Parameter(torch.Tensor(hidden_size))

        #output gate c_t
        self.W_ig = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_hg = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_ig = nn.Parameter(torch.Tensor(hidden_size))
        self.b_hg = nn.Parameter(torch.Tensor(hidden_size))

        #output gate h_t
        self.W_io = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_ho = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_io = nn.Parameter(torch.Tensor(hidden_size))
        self.b_ho = nn.Parameter(torch.Tensor(hidden_size))

        self._init_parameters()

    def _init_parameters(self):
        for param in self.parameters():
            torch.nn.init.normal_(param)

    def forward(self, x: torch.Tensor, hx: typing.Optional[typing.Tuple[torch.Tensor, torch.Tensor]] = None) -> typing.Tuple[torch.Tensor, typing.Tuple[torch.Tensor, torch.Tensor]]:

        if not self.batch_first:
            x = x.permute(1,0,2).contiguous()

        batch_size = x.size(0)
        sequence_length = x.size(1)

        if hx is None:
            h_t, c_t = (
                torch.zeros(batch_size, self.hidden_size).to(x.device),
                torch.zeros(batch_size, self.hidden_size).to(x.device),
            )
        else:
            h_t, c_t = hx

        output = []

        for t in range(sequence_length):
            x_t = x[:, t, :]
            # input gate
            i_t = torch.sigmoid(x_t @ self.W_ii + self.b_ii + h_t @ self.W_hi + self.b_hi)
            # forget gate
            f_t = torch.sigmoid(x_t @ self.W_if + self.b_if + h_t @ self.W_hf + self.b_hf)
            # output gate
            g_t = torch.tanh(x_t @ self.W_ig + self.b_ig + h_t @ self.W_hg + self.b_hg)
            o_t = torch.sigmoid(x_t @ self.W_io + self.b_io + h_t @ self.W_ho + self.b_ho)

            # output
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)

            output.append(h_t.unsqueeze(0))

        output = torch.cat(output, dim=0)

        if not self.batch_first:
            output = output.permute(1,0,2).contiguous()

        return output, (h_t, c_t)


In [None]:
torch.manual_seed(0)
a = torch.randn((5,10, 3))
lstm = LSTM(3, 7, True)
print(lstm(a)[0].size(), lstm(a)[1][0].size(), lstm(a)[1][1].size())
print(lstm(a))

In [None]:
class LSTMPiphole(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, batch_first: bool):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.batch_first = batch_first

        #input gate
        self.W_ii = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_ci = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_ii = nn.Parameter(torch.Tensor(hidden_size))
        self.b_ci = nn.Parameter(torch.Tensor(hidden_size))

        #forget gate
        self.W_if = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_cf = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_if = nn.Parameter(torch.Tensor(hidden_size))
        self.b_cf = nn.Parameter(torch.Tensor(hidden_size))

        #output gate c_t
        self.W_ic = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.b_ic = nn.Parameter(torch.Tensor(hidden_size))


        #output gate h_t
        self.W_io = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_co = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_io = nn.Parameter(torch.Tensor(hidden_size))
        self.b_co = nn.Parameter(torch.Tensor(hidden_size))

        self._init_parameters()

    def _init_parameters(self):
        for param in self.parameters():
            torch.nn.init.normal_(param)

    def forward(self, x: torch.Tensor, hx: typing.Optional[typing.Tuple[torch.Tensor, torch.Tensor]] = None) -> typing.Tuple[torch.Tensor, typing.Tuple[torch.Tensor, torch.Tensor]]:
        #################################
        # TODO: Implement forward pass  #
        #################################
        pass

In [None]:
torch.manual_seed(0)
a = torch.randn((5,10, 3))
lstm = LSTMPiphole(3, 7, True)
print(lstm(a)[0].size(), lstm(a)[1][0].size(), lstm(a)[1][1].size())
print(lstm(a))