# Assignment 4 - Vanilla RNN

## Libraries

In [1]:
import math

import halo
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from tqdm.notebook import trange

## Data

In [2]:
book_fname = 'data/goblet_book.txt'
with open(book_fname, 'r') as f:
    book_data = f.read()
    f.close()

book_chars = list(set(book_data))
K = len(book_chars)
print(f"Number of unique characters: {K}")
char2ind, ind2char = dict(), dict()
for i, c in enumerate(book_chars):
    char2ind[c] = i
    ind2char[i] = c

Number of unique characters: 80


## Model

In [None]:
class RNN:

    def __init__(self, m=100, seq_length=25, eta=.1, sig=.01):
        self.m = m
        self.seq_length = seq_length
        self.eta = eta

        self.U = torch.randn(m, K) * sig
        self.W = torch.randn(m, m) * sig
        self.V = torch.randn(K, m) * sig
        self.b = torch.zeros(m, 1)
        self.c = torch.zeros(K, 1)

        self.training_loss = []
        self.validation_loss = []

        self.Y = torch.zeros((K, seq_length))
    
    def convert_to_text(self):
        text = ''
        for i in range(self.Y.shape[1]):
            text += ind2char[torch.argmax(self.Y[:, i]).item()]
        return text

    def synthesize(self, h, x, n):
        h = torch.tanh(self.W @ h + self.U @ x + self.b)
        y = self.V @ h + self.c
        p = torch.softmax(y, dim=0)
        cp = torch.cumsum(p, dim=0)
        a = torch.rand(1)
        ixs = torch.where(cp - a > 0)[0]
        self.Y[ixs[0], n] = 1
        return h, p, ixs[0]

    def forward(self, h, x):
        h = torch.tanh(self.W @ h + self.U @ x + self.b)
        y = self.V @ h + self.c
        p = torch.softmax(y, dim=0)
        return h, p
    
    def backward(self, h, p, x, y):
        dy = p - y
        dV = dy @ h.T
        dc = dy
        dh = self.V.T @ dy
        dh = dh * (1 - h * h)
        dU = dh @ x.T
        db = dh
        dW = dh @ h.T
        return dV, dc, dU, db, dW
    
    def loss(self, p, y):
        return - torch.sum(y.T @ torch.log(p))