<a href="https://colab.research.google.com/github/yaniv92648/OpenU_DL_Mamans/blob/main/Maman_14_DL_OpenU_Yaniv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installations

In [None]:
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# Imports

In [None]:
import datasets as ds
import torch
from torch import nn
from torchtext.vocab import GloVe
from torchtext.data.utils import get_tokenizer
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D
import pandas as pd

# GPU

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Question 1

## Prepare data

In [None]:
dataset = ds.load_dataset("glue", "sst2")
sentences = dataset["train"]["sentence"]
labels = torch.tensor(dataset["train"]["label"], dtype=torch.float).to(device)
glove_dim = 50
glove_embedder = GloVe(name='6B', dim=glove_dim)
tokenizer = get_tokenizer("basic_english")
sentences_tokens = list(map(tokenizer, sentences))

Reusing dataset glue (/root/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

## RNN Classifier

In [None]:
class RNNClassifier(nn.Module):

    def __init__(self, input_dim, hidden_dim, t, vocab):
        super().__init__()
        self.t = t
        self.rnn = nn.RNN(input_dim, hidden_dim, batch_first=True)
        self.hidden_dim = hidden_dim
        self.linear = nn.Linear(hidden_dim, 2)

    def forward(self, sentence_embeddings):
        T = len(sentence_embeddings)
        if self.t < T:
            with torch.no_grad(): # Calculate gradient (BP) only for the last t tokens
                _, last_hidden_state = self.rnn(sentence_embeddings[:-self.t])
            hidden_state_history, _ = self.rnn(sentence_embeddings[-self.t:], last_hidden_state)
        else:
            hidden_state_history, _= self.rnn(sentence_embeddings)
        feature_extractor_output = hidden_state_history[-1, :]
        class_scores = self.linear(feature_extractor_output)
        return class_scores

## Training Parameters

In [None]:
model = RNNClassifier(input_dim=glove_dim, hidden_dim=2, t=5, vocab=glove_embedder.stoi).to(device)
n_epochs = 1
index_to_plot = 2
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

## Training

In [None]:
def train_sentence(i):
    sentence_embeddings = glove_embedder.get_vecs_by_tokens(sentences_tokens[i]).to(device)
    output, hidden = model(sentence_embeddings)
    loss = criterion(output, labels[i])
    loss.backward() # Backpropagation: calculate gradients
    if i == index_to_plot:
        pass
        # plot_grad_flow()
    optimizer.step() # Update weights accordingly
    return loss

In [None]:
def train_RNN():
    loss = torch.tensor(0)
    for epoch in range(1, n_epochs+1):
        optimizer.zero_grad() # Reset weights
        for i in range(len(sentences)):
            loss = train_sentence(i)
        print(f'Epoch: {epoch}/{n_epochs}.............', end=' ')
        print('Loss: {:.2f}'.format(loss.item()))

In [None]:
# def plot_grad_flow():
#     sentence_tokens = sentences_tokens[index_to_plot]
#     grads_per_token = [0.0] * len(sentence_tokens)
#     final_grad = None
#     for n, p in model.named_parameters():
#         if(p.requires_grad) and ("bias" not in n):
#             print(f'p.grad = {p.grad}')
#             final_grad = p.grad
#     final_grad = final_grad.abs().mean(dim=1)
#     for grad_i in range(len(final_grad)):
#         grads_per_token[-grad_i-1] = final_grad[-grad_i-1].item()
#     plt.plot(grads_per_token)
#     plt.xlabel("Token location in sentence")
#     plt.ylabel("Gradient")
#     plt.show()

In [None]:
# train_RNN()

In [None]:
def backward_hook(module, grad_input, grad_output):
    grads = [grad for grad in grad_output if grad is not None]

In [None]:
grads = []
idx = 2
handle = model.rnn.register_full_backward_hook(backward_hook)
train_sentence(idx)
handle.remove()
relevant_grads = grads[::2] + grads[-1::]

In [None]:
for i, grad in enumerate(relevant_grads):
    print(f"Time : {len(sentences_tokens[idx])-1-i}\n{grad}")