In [4]:
import nltk
nltk.download("wordnet")
nltk.download("omw-1.4")


[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\EG\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\EG\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


True

In [5]:
!pip install --upgrade pip --quiet

# Core libraries
!pip install torch --quiet
!pip install transformers --quiet
!pip install datasets --quiet
!pip install tqdm --quiet

# Augmentation dependency
!pip install nltk --quiet

# Download WordNet resources
import nltk
nltk.download("wordnet")
nltk.download("omw-1.4")


[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\EG\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\EG\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


True

In [6]:
!pip install transformers nltk datasets torch tqdm --quiet

In [8]:
# ======================================
# BERT + BiLSTM + Synonym Augmentation
# Correct, full, working version
# ======================================

import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
from nltk.corpus import wordnet as wn
import pandas as pd
import random

from datasets import load_dataset


# ----------------------------------------
# Settings
# ----------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LEN = 64
BATCH = 16
EPOCHS = 3
AUG = True


# ----------------------------------------
# Synonym augmentation
# ----------------------------------------
def synonym_augment(text, p=0.1):
    words = text.split()
    new_words = []

    for w in words:
        if random.random() < p:
            syns = wn.synsets(w)
            if syns:
                lemmas = syns[0].lemma_names()
                cand = [l.replace("_", " ") for l in lemmas if l.lower() != w.lower()]
                if cand:
                    new_words.append(random.choice(cand))
                    continue
        new_words.append(w)
    return " ".join(new_words)


# ----------------------------------------
# Load SST-5 from HuggingFace
# ----------------------------------------
dataset = load_dataset("SetFit/sst5")

train_df = pd.DataFrame(dataset["train"])[["text", "label"]]
dev_df   = pd.DataFrame(dataset["validation"])[["text", "label"]]
test_df  = pd.DataFrame(dataset["test"])[["text", "label"]]


# ----------------------------------------
# Dataset class
# ----------------------------------------
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

class SST5Dataset(Dataset):
    def __init__(self, df, augment=False):
        self.df = df
        self.augment = augment

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        text = self.df.iloc[idx].text
        label = self.df.iloc[idx].label

        # Best practice: ONLY augment some samples, not all
        if self.augment and random.random() < 0.3:
            text = synonym_augment(text, p=0.1)

        enc = tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=MAX_LEN,
            return_tensors="pt"
        )

        return {
            "input_ids": enc["input_ids"].squeeze(),
            "mask": enc["attention_mask"].squeeze(),
            "label": torch.tensor(label)
        }


train_loader = DataLoader(SST5Dataset(train_df, augment=True), batch_size=BATCH, shuffle=True)
dev_loader   = DataLoader(SST5Dataset(dev_df, augment=False), batch_size=32)
test_loader  = DataLoader(SST5Dataset(test_df, augment=False), batch_size=32)


# ----------------------------------------
# Model: BERT → BiLSTM → FC
# ----------------------------------------
class BERT_BiLSTM(nn.Module):
    def __init__(self, hidden=256, layers=1, classes=5, dropout=0.3):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.lstm = nn.LSTM(
            input_size=768,
            hidden_size=hidden,
            num_layers=layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if layers > 1 else 0
        )
        self.fc = nn.Linear(hidden * 2, classes)

    def forward(self, ids, mask):
        x = self.bert(ids, attention_mask=mask).last_hidden_state
        lstm_out, (h_n, _) = self.lstm(x)
        final = torch.cat((h_n[-2], h_n[-1]), dim=1)
        return self.fc(final)


model = BERT_BiLSTM().to(DEVICE)

# UNFREEZE BERT explicitly
for p in model.bert.parameters():
    p.requires_grad = True


# ----------------------------------------
# Optimizer (very important!)
# Use SEPARATE learning rates
# ----------------------------------------
optimizer = torch.optim.AdamW([
    {"params": model.bert.parameters(), "lr": 1e-5},
    {"params": model.lstm.parameters(), "lr": 1e-4},
    {"params": model.fc.parameters(),   "lr": 1e-4},
])

criterion = nn.CrossEntropyLoss()


# ----------------------------------------
# Evaluation
# ----------------------------------------
def evaluate(loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in loader:
            ids = batch["input_ids"].to(DEVICE)
            mask = batch["mask"].to(DEVICE)
            labels = batch["label"].to(DEVICE)

            logits = model(ids, mask)
            preds = logits.argmax(1)

            correct += (preds == labels).sum().item()
            total += len(labels)

    return correct / total


# ----------------------------------------
# Training loop
# ----------------------------------------
for epoch in range(EPOCHS):
    model.train()
    model.bert.train()

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        ids = batch["input_ids"].to(DEVICE)
        mask = batch["mask"].to(DEVICE)
        labels = batch["label"].to(DEVICE)

        optimizer.zero_grad()
        logits = model(ids, mask)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

    print("Dev accuracy:", evaluate(dev_loader))

print("TEST accuracy:", evaluate(test_loader))


[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\EG\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\EG\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
Repo card metadata block was not found. Setting CardData to empty.
Epoch 1: 100%|██████████| 534/534 [01:02<00:00,  8.56it/s]


Dev accuracy: 0.4904632152588556


Epoch 2: 100%|██████████| 534/534 [01:03<00:00,  8.37it/s]


Dev accuracy: 0.5095367847411444


Epoch 3: 100%|██████████| 534/534 [01:02<00:00,  8.59it/s]


Dev accuracy: 0.5340599455040872
TEST accuracy: 0.5466063348416289
