In [None]:
from torch import nn
from torchvision import transforms, datasets
from torch.utils.data import random_split, DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence,pack_padded_sequence
import pandas as pd
import numpy as np
import torch
import tiktoken

In [None]:
import matplotlib.pyplot as plt
from tqdm import tqdm

import os
import csv

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads, context_length, qkv_bias=False,dropout=0.5):
        super().__init__()
        assert out_dim % num_heads == 0, "in_dim must be divisible by num_heads"

        self.out_dim = out_dim
        self.num_heads = num_heads
        self.head_dim = out_dim // num_heads
        self.W_query = nn.Linear(in_dim, out_dim, bias=qkv_bias)
        self.W_key = nn.Linear(in_dim, out_dim, bias=qkv_bias)
        self.W_value = nn.Linear(in_dim, out_dim, bias=qkv_bias)
        self.out_proj = nn.Linear(out_dim, out_dim)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        batch_size, num_tokens, d_in = x.shape

        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        keys = keys.view(batch_size, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        queries = queries.view(batch_size, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        values = values.view(batch_size, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)

        attn_scores = queries @ keys.transpose(2,3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1,2)
        
        context_vec = context_vec.contiguous().view(batch_size, num_tokens, self.out_dim)
        context_vec = self.out_proj(context_vec)

        return context_vec

In [None]:
batch = torch.randn(2, 4, 4)
batch_size, num_tokens, d_in = batch.shape
attn = MultiHeadAttention(in_dim=d_in, out_dim=8, num_heads=2, context_length=num_tokens)
out = attn(batch)

print(out.shape)
print(out)

In [None]:
def collate_fn_imdb(data, pad_value=50257):
    data.sort(key=lambda x: len(x[0]), reverse=True)
    sequences = [x[0] for x in data]
    scores = torch.tensor([x[1] for x in data], dtype=torch.float32)
    labels = torch.tensor([x[2] for x in data], dtype=torch.float32)

    original_seq_lengths = torch.tensor([len(s) for s in sequences], dtype=torch.long)
    padded_seqs_long = pad_sequence(sequences, batch_first=True, padding_value=pad_value)

    return padded_seqs_long, original_seq_lengths, scores, labels

def tokenize_text(text_list, tokenizer):
    tokenized_text = []
    for text in text_list:
        tokens = torch.tensor(tokenizer.encode(text))
        tokenized_text.append(tokens)

    return tokenized_text

def detokenize_text(token_ids, tokenizer):
    detokenized_text = []
    for tokens in token_ids:
        text = tokenizer.decode(tokens.tolist())
        detokenized_text.append(text)

    return detokenized_text

class IMDBDataset(Dataset):
    def __init__(self, comments_token_ids, sentiments, scores):
        self.comments_token_ids = comments_token_ids
        self.sentiments = sentiments
        self.scores = scores

    def __len__(self):
        return len(self.comments_token_ids)
    
    def __getitem__(self, idx):
        return self.comments_token_ids[idx], self.scores[idx], self.sentiments[idx]
    
def create_IMDB_dataloader(dataset, batch_size=32, shuffle=True, num_workers=0, pad_value=50257):
    collate_wrapper = lambda x: collate_fn_imdb(x, pad_value=pad_value)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_wrapper)
    return dataloader

In [None]:
IMDB_train = pd.read_csv("data/IMDB_train.csv")
IMDB_test = pd.read_csv("data/IMDB_test.csv")

train_comments = IMDB_train["preprocessed_comments"].to_list()
train_sentiments = IMDB_train["sentiment"].to_list()
train_scores = IMDB_train["score"].to_list()
test_comments = IMDB_test["preprocessed_comments"].to_list()
test_sentiments = IMDB_test["sentiment"].to_list()
test_scores = IMDB_test["score"].to_list()
tokenizer = tiktoken.get_encoding("gpt2")

tokenized_train_comments = tokenize_text(train_comments, tokenizer)
tokenized_test_comments = tokenize_text(test_comments, tokenizer)
sample = train_comments[0]
sample_sentiment = train_sentiments[0]
sample_score = train_scores[0]
token_ids = tokenized_train_comments[0]
reconstructed = tokenizer.decode(token_ids.tolist())

print(f"Sample: {sample}")
print(f"Sentiment: {sample_sentiment}")
print(f"Score: {sample_score}")
print(f"Token IDs: {token_ids}")
print(f"Reconstructed: {reconstructed}")

In [None]:
class GRUAttention(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads, context_length, qkv_bias=False, dropout=0.5, vocab_size=50257, padding=50257):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size + 1, in_dim, padding_idx=padding)
        self.pos_embedding = nn.Embedding(context_length, in_dim)
        self.attention = MultiHeadAttention(in_dim, out_dim, num_heads, context_length, qkv_bias, dropout)
        self.gru = nn.GRU(out_dim, out_dim//2, num_layers=3, batch_first=True)
        self.fc = nn.Linear(out_dim//2, 1)

    def forward(self, x):
        batch_size, num_tokens = x.shape

        token_embeddings = self.token_embedding(x)
        pos_embeddings = self.pos_embedding(torch.arange(num_tokens, device=x.device)).unsqueeze(0).expand(batch_size, -1, -1)
        x = token_embeddings + pos_embeddings

        x = self.attention(x)
        x, fn = self.gru(x)
        result = self.fc(x[:, -1, :])

        return result.squeeze(-1)
    
def train(model, train_dataset, val_dataset, lr=1e-3, epochs=10, batch_size=64):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    train_loader = create_IMDB_dataloader(train_dataset, batch_size=batch_size)
    val_loader = create_IMDB_dataloader(val_dataset, batch_size=batch_size)
    model.to(device)

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            inputs, seq_lengths, scores, labels = batch
            inputs = inputs.to(device)
            scores = scores.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                inputs, seq_lengths, scores, labels = batch
                inputs = inputs.to(device)
                scores = scores.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                loss = loss_fn(outputs, labels)
                val_loss += loss.item()
        avg_val_loss = val_loss / len(val_loader)
        print(f"Validation Loss: {avg_val_loss:.4f}")

In [None]:
train_dataset = IMDBDataset(tokenized_train_comments, train_sentiments, train_scores)
test_dataset = IMDBDataset(tokenized_test_comments, test_sentiments, test_scores)

train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

In [None]:
model = GRUAttention(in_dim=64, out_dim=256, num_heads=8, context_length=2048, qkv_bias=True, dropout=0.5)

train(model, train_dataset, val_dataset, lr=1e-3, epochs=5, batch_size=32)