In [None]:
# ============================
# BERT + BiLSTM for SST-5
# ============================

!pip install transformers nlpaug datasets torch tqdm --quiet

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import nlpaug.augmenter.word as naw
from tqdm import tqdm
import pandas as pd
import numpy as np

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# -----------------------------
# Load SST-5 data
# Expecting: label<TAB>sentence
# -----------------------------
def load_sst5(path):
    df = pd.read_csv(path, sep="\t", header=None, names=["label","text"])
    df["label"] = df["label"].astype(int)
    return df

train_df = load_sst5("train.tsv")
dev_df   = load_sst5("dev.tsv")
test_df  = load_sst5("test.tsv")

# -----------------------------
# Optional NLPAUG augmentation
# -----------------------------
AUG = True   # set to False to disable augmentation
if AUG:
    aug = naw.SynonymAug(aug_src="wordnet", aug_p=0.1)

def maybe_augment(t):
    if AUG:
        return aug.augment(t)
    return t

# -----------------------------
# Dataset wrapper
# -----------------------------
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

class SST5Dataset(Dataset):
    def __init__(self, df, augment=False):
        self.df = df
        self.augment = augment

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        sentence = self.df.iloc[idx].text
        label = self.df.iloc[idx].label

        if self.augment:
            sentence = maybe_augment(sentence)

        enc = tokenizer(
            sentence,
            truncation=True,
            padding="max_length",
            max_length=64,
            return_tensors="pt"
        )

        return {
            "input_ids": enc["input_ids"].squeeze(),
            "mask": enc["attention_mask"].squeeze(),
            "label": torch.tensor(label)
        }

train_ds = SST5Dataset(train_df, augment=True)
dev_ds   = SST5Dataset(dev_df, augment=False)
test_ds  = SST5Dataset(test_df, augment=False)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
dev_loader   = DataLoader(dev_ds, batch_size=32)
test_loader  = DataLoader(test_ds, batch_size=32)

# -----------------------------
# Model: BERT → BiLSTM → linear
# -----------------------------
class BERT_BiLSTM(nn.Module):
    def __init__(self, lstm_hidden=256, layers=1, num_classes=5, dropout=0.3):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.lstm = nn.LSTM(
            input_size=768,
            hidden_size=lstm_hidden,
            num_layers=layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if layers > 1 else 0
        )
        self.out = nn.Linear(lstm_hidden*2, num_classes)

    def forward(self, ids, mask):
        x = self.bert(input_ids=ids, attention_mask=mask).last_hidden_state
        lstm_out, (h_n, _) = self.lstm(x)
        hidden = torch.cat((h_n[-2], h_n[-1]), dim=1)
        return self.out(hidden)

model = BERT_BiLSTM().to(DEVICE)

# -----------------------------
# Training setup
# -----------------------------
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

# -----------------------------
# Eval helper
# -----------------------------
def evaluate(loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in loader:
            ids = batch["input_ids"].to(DEVICE)
            mask = batch["mask"].to(DEVICE)
            labels = batch["label"].to(DEVICE)

            logits = model(ids, mask)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += len(labels)
    return correct / total

# -----------------------------
# Training loop
# -----------------------------
EPOCHS = 3
for epoch in range(EPOCHS):
    model.train()
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        ids = batch["input_ids"].to(DEVICE)
        mask = batch["mask"].to(DEVICE)
        labels = batch["label"].to(DEVICE)

        optimizer.zero_grad()
        logits = model(ids, mask)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

    dev_acc = evaluate(dev_loader)
    print(f"Dev accuracy: {dev_acc:.4f}")

# -----------------------------
# Final test accuracy
# -----------------------------
test_acc = evaluate(test_loader)
print("Test accuracy:", test_acc)
