# Assignment 4 - Vanilla RNN

## Libraries

In [107]:
import math

import halo
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from tqdm.notebook import trange

## Data

In [108]:
book_fname = 'data/goblet_book.txt'
with open(book_fname, 'r') as f:
    book_data = f.read()
    f.close()

book_chars = list(set(book_data))
K = len(book_chars)
print(f"Number of unique characters: {K}")
char2ind, ind2char = dict(), dict()
for i, c in enumerate(book_chars):
    char2ind[c] = i
    ind2char[i] = c

Number of unique characters: 80


## Model

In [139]:
class RNN:

    def __init__(self, m=100, seq_length=25, eta=.1, sig=.01):
        self.m = m
        self.seq_length = seq_length
        self.eta = eta

        self.U = torch.randn(m, K, requires_grad=True) * sig
        self.W = torch.randn(m, m, requires_grad=True) * sig
        self.V = torch.randn(K, m, requires_grad=True) * sig
        self.b = torch.zeros(m, 1, requires_grad=True)
        self.c = torch.zeros(K, 1, requires_grad=True)

        self.training_loss = []
        self.validation_loss = []

        self.grads = []
        self.params = [self.V, self.c, self.W, self.b, self.U]

    
    def convert_to_text(self):
        text = ''
        for i in range(self.Y.shape[1]):
            text += ind2char[torch.argmax(self.Y[:, i]).item()]
        return text

    def synthesize(self, h, x, n):
        h = torch.tanh(self.W @ h + self.U @ x + self.b)
        y = self.V @ h + self.c
        p = torch.softmax(y, dim=0)
        cp = torch.cumsum(p, dim=0)
        a = torch.rand(1)
        ixs = torch.where(cp - a > 0)[0]
        self.Y[ixs[0], n] = 1
        return h, p, ixs[0]

    def forward(self, h_prev, x):
        h = torch.tanh(self.W @ h_prev + self.U @ x + self.b)
        y = self.V @ h + self.c
        p = torch.softmax(y, dim=0)
        return h, p, h_prev
    
    def backward(self, h, p,h_prev, x, y):
        do = p - y
        dV = do @ h.T
        dh = self.V.T @ do
        dh = dh * (1 - h**2)
        dc = do
        dW = dh @ h_prev.T
        db = dh
        dU = dh @ x.T
        self.grads.append([dV, dc, dW, db, dU])
        return dV, dc, dW, db, dU
    
    def loss(self, p, y):
        return - torch.sum(y.T @ torch.log(p))
    
    def check_grads(self, book_data):
        X_chars = book_data[:self.seq_length]
        Y_chars = book_data[1:self.seq_length+1]
        X = torch.zeros((K, self.seq_length))
        Y = torch.zeros((K, self.seq_length))
        h = torch.zeros((self.m, 1))
        for i in range(self.seq_length):
            X[char2ind[X_chars[i]], i] = 1
            Y[char2ind[Y_chars[i]], i] = 1
        
        h, p, h_prev = self.forward(h, X[:, 0].reshape(K, 1))
        y = Y[:, 0].reshape(K, 1)
        self.backward(h, p, h_prev, X[:, 0].reshape(K, 1), y)

        loss = self.loss(p, y)

        for i in range(len(self.params)):
            self.params[i].retain_grad()
        loss.backward()

        print("Checking gradients")
        with torch.no_grad():
            for i in range(len(self.grads)):
                for j in range(len(self.params)):
                    diff = torch.norm(self.grads[i][j] - self.params[j].grad)
                    rel_err = diff / (torch.norm(self.grads[i][j]) + torch.norm(self.params[j].grad) + 1e-6)
                    print(f"Relative error iter {i}, param {j}: {rel_err}")


In [145]:
rnn = RNN()
rnn.check_grads(book_data)

Checking gradients
Relative error iter 0, param 0: 0.0
Relative error iter 0, param 1: 0.0
Relative error iter 0, param 2: 0.0
Relative error iter 0, param 3: 0.0
Relative error iter 0, param 4: 0.0
