### Learning GRU layer

### Math formulas:
- **update gate z_t:**                                  
z_t​ = sigmoid(W_z​ @ x_t ​+ U_z @ ​h_prev​)  

- **reset gate r_t:**                                   
r_t = sigmoid(W_r @ x_t + U_r @ h_prev)  

- **candidate hidden state h_t_candidate:**             
tanh(W_h @ x_t + U_h @ (r_t * h_prev))  

- **final hidden state h_t:**                           
h_t = (1 - z_t) * h_prev + z_t * h_t_candidate

Begin by understanding what kinds of inputs go into a GRU layer:  

The input is x_t which is a vector that contains d_x number of features. It is supposed to represent a specific *time* step x_t.  

If all the time steps of x_t are taken together, we get a matrix with shape T x n_features, where T is the whole sequence length (little t denotes one particular time step; T denotes the whole sequence). 

Normally, it is processed in batches, so we'd get an input shape of (B, T, n_features). This is what is going into the GRU layer.  


The GRU cell handles each x_t independently. At each timestep t, the GRU cell gets the whole batch and all features, so the GRU layer needs the amount T GRU cells to process the whole sequence. Because of this, the sample input is actually of shape (T, B, n_features) because each x_t of T is processed by each cell independently. However, in pytorch, I have still noticed that B is commonly used first.

Also, going forward, n_features is usually denoted as d_x to represent the *dimension of the input vector*.

Furthermore, there is a hidden dimension associated with the GRU cells. This is a hyperparameter that determines the model memory/context. It basically is a compressed summary of all the info seen up to x_t. It is what carries the information forward to use as a basis for the next prediction.

In [None]:
import torch
import math

torch.manual_seed(44)

# example input x_t that I will be working with
# ensuring that easy numbers will be used

B, T = 1, 3          # batch size, timesteps
d_x, d_h = 4, 3      # input features per step, hidden size

# manual sigmoid
# NOTE: can't use on tensors
def sigmoid(x):
    return 1 / (1 + math.exp(-x))

# manual tanh
# NOTE: can't use on tensors
def tanh(x):
    return (math.exp(x) - math.exp(-x)) / (math.exp(x) + math.exp(-x))


# random sample of data
X = torch.randn(B, T, d_x)

t = 0 # pick the initial t just for example
x_t = X[:, t, :] # an input at a single time step 

# init h_prev
h_prev = torch.zeros(B, d_h)

# a weight matrix of size d_h x d_x for W's because x_t is d_x for the last shape
# when doing x @ W, W is expected to be d_x by d_h but nn.Linear does it this way
# so have to remember to do x @ W.T
# the W matrices map the input onto the hidden state d_h; they project it onto d_h
# takes the input info and initializes it into the hidden state
W_z = torch.randn(d_h, d_x) 

# U_z operates on h_prev, which is of shape B x d_h
# so U_z must have d_h because the hidden state is in R**d_h
# this conceptually makes sense because if d_x is taken again then it would be treating the hidden state as another input
# the U matrices take the old info and project it into the current state
# 'reprocesses' the memory of the past before combining with the new input
U_z = torch.randn(d_h, d_h)

W_r = torch.randn(d_h, d_x)
U_r = torch.randn(d_h, d_h)

W_h = torch.randn(d_h, d_x)
U_h = torch.randn(d_h, d_h)


# need four formulas coded up
# (1) update gate
z_t = torch.sigmoid(x_t @ W_z.T + h_prev @ U_z.T)

# (2) reset gate
r_t = torch.sigmoid(x_t @ W_r.T + h_prev @ U_r.T)

# (3) candidate hidden state
h_t_tilda = torch.tanh(x_t @ W_h.T + (r_t * h_prev) @ U_h.T)

# (4) final hidden state
h_t = (1 - z_t) * h_prev + z_t * h_t_tilda

h_prev = h_t # set h_prev to the new h_t for the next iteration




In [None]:
import torch.nn.functional as F

# code it up as pytorch does it (no bias)
class GRUCell(torch.nn.Module):
    def __init__(self, d_h, d_x):
        super().__init__()

        self.d_h, self.d_x = d_h, d_x

        # stacked weights for input-to-hidden
        self.weights_ih = torch.nn.Parameter(torch.randn(3 * d_h, d_x)) # -> (3 * d_h, d_x) -> (9, 4)

        # stacked weights for hidden-to-hidden
        self.weights_hh = torch.nn.Parameter(torch.randn(3 * d_h, d_h)) # (9, 3)
    
        

    def forward(self, x_t, h_prev):

        # torch uses linear function to perform x @ W.T + b (if bias True)

        # gates from input
        gi = F.linear(x_t, self.weights_ih) # (B, d_x) @ (d_x, 3 * d_h) -> (B, 3 * d_h)

        # gates from hidden
        # NOTE: for h_t_tilda, h_prev @ U_h.T is calc before h_prev is multiplied with r_t; gives same results
        gh = F.linear(h_prev, self.weights_hh) # (B, d_h) @ (d_h, 3 * d_h) -> (B, 3 * dh)

        i_r, i_z, i_n = gi.chunk(3, dim=1) # 3 * (B, d_h)
        h_r, h_z, h_n = gh.chunk(3, dim=1) # 3 * (B, d_h)

        z_t = torch.sigmoid(i_z + h_z) # (B, d_h) + (B, d_h) 
        r_t = torch.sigmoid(i_r + h_r)
        n_t = torch.tanh(i_n + r_t * h_n)

        h_t = (1 - z_t) * h_prev + z_t * n_t # original formula
        # NOTE: pytorch does this differently, i.e., h_t = (1 - z_t) * n_t + z_t * h_prev
        # from what I've gathered, they want more 'past' info to carry over
        # e.g., if z_t is big, we get more of the h_prev info and less of the new candidate info
        # the original formula does the opposite, i.e., if z_t is big, then more of the new candidate info
        # HOWEVER, if the model is being trained, it doesn't matter because the weights will update accordingly
        # it is ultimately done because of efficiency reasons due to cuDNN, which implements GRU via this new formula
        # so it keeps consistency with CUDA kernels for optimization

        return h_t





        

In [None]:
B, T, d_h, d_x = 1, 3, 3, 4
X = torch.randn(B, T, d_x)
cell = GRUCell(d_h, d_x)
h_prev = torch.zeros(B, d_h)

hs = []
for t in range(T):
    x_t = X[:, t, :]
    h_prev = cell(x_t, h_prev)
    hs.append(h_prev)


H = torch.stack(hs, dim=1) # (B, T, d_h)
# stack creates new dim; concat adds along existing dim

    

In [27]:
H[0, 0, 0]

tensor(-0.6569, grad_fn=<SelectBackward0>)