In [1]:
LOAD_MODEL = False
MODEL_NAME = "V0_ZERO_SHOT"

EPOCHS = 15
NUM_LABELS = 5
MAX_LEN = 128
BATCH_SIZE = 16
LR = 1e-3
NUM_WORKERS = 4

In [2]:
import os
import json
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertModel
from datasets import load_dataset
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

Using device: cuda


In [4]:
BASE_PATH = "./models/"
version_dir = os.path.join(BASE_PATH, MODEL_NAME)

if LOAD_MODEL:
    if not os.path.exists(version_dir):
        raise RuntimeError(f"Model '{MODEL_NAME}' does not exist.")
else:
    if os.path.exists(version_dir):
        raise RuntimeError(f"Model '{MODEL_NAME}' already exists.")
    os.makedirs(version_dir)

In [5]:
combined_log_path = os.path.join(version_dir, "run_output.txt")
combined_log_file = open(combined_log_path, "w", encoding="utf-8")

def log(msg):
    print(msg)
    combined_log_file.write(msg + "\n")
    combined_log_file.flush()

In [6]:

dataset = load_dataset("SetFit/sst5")

Repo card metadata block was not found. Setting CardData to empty.


In [7]:
print("Labels:", set(dataset["train"]["label_text"]))

def print_dist(ds, name):
    counts = Counter(ds['label_text'])
    print(f"\n{name} distribution:")
    for k,v in counts.items():
        print(f"{k}: {v}")

print_dist(dataset["train"], "Train")
print_dist(dataset["validation"], "Val")
print_dist(dataset["test"], "Test")


Labels: {'neutral', 'very negative', 'negative', 'positive', 'very positive'}

Train distribution:
very positive: 1288
negative: 2218
neutral: 1624
positive: 2322
very negative: 1092

Val distribution:
neutral: 229
negative: 289
very negative: 139
positive: 279
very positive: 165

Test distribution:
negative: 633
very negative: 279
neutral: 389
very positive: 399
positive: 510


In [8]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def tokenize(batch):
    return tokenizer(
        batch["text"],
        padding="max_length",
        truncation=True,
        max_length=MAX_LEN
    )

datasetMap = dataset.map(tokenize, batched=True)
datasetMap = datasetMap.rename_column("label", "labels")
datasetMap.set_format("torch", columns=["input_ids","attention_mask","labels"])

train_loader = DataLoader(datasetMap["train"], batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(datasetMap["validation"], batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True)
test_loader  = DataLoader(datasetMap["test"], batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True)


In [9]:
class CustomBertClassifier(nn.Module):
    def __init__(self, num_labels=NUM_LABELS):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")

        # Freeze BERT completely
        for p in self.bert.parameters():
            p.requires_grad = False

        hidden = self.bert.config.hidden_size
        self.classifier = nn.Linear(hidden, num_labels)

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            out = self.bert(
                input_ids=input_ids,
                attention_mask=attention_mask
            )

        cls = out.last_hidden_state[:, 0, :]
        return self.classifier(cls)

In [10]:
model = CustomBertClassifier(NUM_LABELS).to(DEVICE)

if LOAD_MODEL:
    ckpt = torch.load(os.path.join(version_dir, "model.pt"), map_location=DEVICE)
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()
else:
    model.train()

optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

In [11]:
def train_one_epoch(model, loader, opt, crit, device):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for batch in loader:
        opt.zero_grad()

        ids = batch["input_ids"].to(device)
        mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        logits = model(ids, mask)
        loss = crit(logits, labels)
        loss.backward()
        opt.step()

        total_loss += loss.item()
        preds = logits.argmax(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return total_loss / len(loader), correct / total

def validate(model, loader, crit, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0

    with torch.no_grad():
        for batch in loader:
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            logits = model(ids, mask)
            loss = crit(logits, labels)

            total_loss += loss.item()
            preds = logits.argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return total_loss / len(loader), correct / total

In [12]:
# Track Metrics
train_losses = []
val_losses = []
train_accs = []
val_accs = []

# Training loop
for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)
    val_loss, val_acc     = validate(model, val_loader, criterion, DEVICE)

    # <-- Add these lines
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)

    print(f"Epoch {epoch+1}: Train Acc={train_acc:.4f} | Val Acc={val_acc:.4f}")

Epoch 1: Train Acc=0.3978 | Val Acc=0.4296
Epoch 2: Train Acc=0.4566 | Val Acc=0.4269
Epoch 3: Train Acc=0.4696 | Val Acc=0.4360
Epoch 4: Train Acc=0.4760 | Val Acc=0.4559
Epoch 5: Train Acc=0.4775 | Val Acc=0.4687
Epoch 6: Train Acc=0.4875 | Val Acc=0.4578
Epoch 7: Train Acc=0.4774 | Val Acc=0.4450
Epoch 8: Train Acc=0.4886 | Val Acc=0.4687
Epoch 9: Train Acc=0.4918 | Val Acc=0.4578
Epoch 10: Train Acc=0.4886 | Val Acc=0.4614
Epoch 11: Train Acc=0.4943 | Val Acc=0.4668
Epoch 12: Train Acc=0.4905 | Val Acc=0.4777
Epoch 13: Train Acc=0.4930 | Val Acc=0.4587
Epoch 14: Train Acc=0.4938 | Val Acc=0.4659
Epoch 15: Train Acc=0.4901 | Val Acc=0.4460


In [None]:
def test(model, loader, crit, device):
    model.eval()
    preds, labels_list = [], []
    total_loss = 0
    with torch.no_grad():
        for batch in loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            logits = model(input_ids, attention_mask)
            loss = crit(logits, labels)
            total_loss += loss.item()

            preds.extend(logits.argmax(1).cpu().tolist())
            labels_list.extend(labels.cpu().tolist())

    return total_loss/len(loader), preds, labels_list

test_loss, preds, labels_list = test(model, test_loader, criterion, DEVICE)
test_acc = accuracy_score(labels_list, preds)
report = classification_report(labels_list, preds, digits=4)
cm = confusion_matrix(labels_list, preds)

# Build FINAL RESULTS block
final_results_text = (
    "========== FINAL RESULTS ==========\n"
    f"Model Version: {MODEL_NAME}\n\n"
    f"Final Train Accuracy: {train_acc:.4f}\n"
    f"Final Validation Accuracy: {val_acc:.4f}\n\n"
    f"Test Loss: {test_loss:.4f}\n"
    f"Test Accuracy: {test_acc:.4f}\n\n"
    "Classification Report:\n"
    f"{report}\n"
    "====================================\n\n"
)

# Path to your output file
out_path = os.path.join(version_dir, "run_output.txt")

# Read the old content
try:
    with open(out_path, "r") as f:
        old_content = f.read()
except FileNotFoundError:
    old_content = ""

# Write FINAL RESULTS at top, followed by original content
with open(out_path, "w") as f:
    f.write(final_results_text + old_content)

print("Test Accuracy:", test_acc)
print(report)


Test Accuracy: 0.47692307692307695
              precision    recall  f1-score   support

           0     0.4186    0.4516    0.4345       279
           1     0.5047    0.5972    0.5470       633
           2     0.3732    0.1362    0.1996       389
           3     0.4373    0.5196    0.4749       510
           4     0.5631    0.5815    0.5721       399

    accuracy                         0.4769      2210
   macro avg     0.4594    0.4572    0.4456      2210
weighted avg     0.4657    0.4769    0.4596      2210



In [14]:
# === PLOT & SAVE TRAINING CURVES ===

# epochs list
epochs = list(range(1, len(train_losses) + 1))

# --- create a combined figure ---
fig, ax = plt.subplots(1, 2, figsize=(14, 5))

# --- LOSS PLOT ---
ax[0].plot(epochs, train_losses, label="Train Loss")
ax[0].plot(epochs, val_losses, label="Validation Loss")
ax[0].set_title("Training and Validation Loss")
ax[0].set_xlabel("Epoch")
ax[0].set_ylabel("Loss")
ax[0].legend()

# --- ACCURACY PLOT ---
ax[1].plot(epochs, train_accs, label="Train Accuracy")
ax[1].plot(epochs, val_accs, label="Validation Accuracy")
ax[1].set_title("Training and Validation Accuracy")
ax[1].set_xlabel("Epoch")
ax[1].set_ylabel("Accuracy")
ax[1].legend()

plt.tight_layout()

# === SAVE to version folder ===
curve_path = os.path.join(version_dir, "training_curves.png")
fig.savefig(curve_path, dpi=150, bbox_inches="tight")

plt.close()


In [17]:
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.savefig(os.path.join(version_dir, "confusion_matrix.png"))
plt.close()


In [15]:
torch.save({"model_state_dict": model.state_dict()},
os.path.join(version_dir, "model.pt"))

with open(os.path.join(version_dir, "config.json"), "w") as f:
    json.dump({
    "MODEL_NAME": MODEL_NAME,
    "EPOCHS": EPOCHS,
    "LR": LR,
    "MAX_LEN": MAX_LEN,
    "BATCH_SIZE": BATCH_SIZE
    }, f, indent=4)