In [4]:
# === Imports ===
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, get_scheduler
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import ParameterGrid
from torch.optim import AdamW
from tqdm import tqdm
import pickle
import csv
import os

# === Load Data ===
train_df = pd.read_csv("splits/train.csv")
val_df = pd.read_csv("splits/val.csv")

# === Load Label Encoder ===
with open("splits/label_encoder.pkl", "rb") as f:
    label_encoder = pickle.load(f)

num_classes = len(label_encoder.classes_)

# === Tokenizer ===
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.1")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# === Dataset ===
class IntentDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=192):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        encodings = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_len, return_tensors="pt")
        return {
            "input_ids": encodings["input_ids"].squeeze(0),
            "attention_mask": encodings["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.long)
        }

# === Model ===
class BERTClassifier(nn.Module):
    def __init__(self, num_classes, dropout=0.1):
        super(BERTClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained("dmis-lab/biobert-base-cased-v1.1", trust_remote_code=True, use_safetensors=True)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        return self.fc(self.dropout(pooled_output))

# === Train / Eval Functions ===
def train_epoch(model, loader, loss_fn, optimizer, scheduler, device):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for batch in tqdm(loader, desc="Training"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        outputs = model(input_ids, attention_mask)
        loss = loss_fn(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        preds = torch.argmax(outputs, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return total_loss / total, correct / total

def eval_epoch(model, loader, loss_fn, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = model(input_ids, attention_mask)
            loss = loss_fn(outputs, labels)

            total_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return total_loss / total, correct / total


# === Grid Search Setup ===
param_grid = {
    "learning_rate": [3e-5, 1e-5, 2e-5],
    "batch_size": [16, 32],
    "dropout": [0.15, 0.1, 0.2],
    "weight_decay": [0.01, 0.05],
    "warmup_proportion": [0.05, 0.1, 0.15],
    "boost_factor": [1.0, 1.3, 1.5, 2.0]
}

all_params = list(ParameterGrid(param_grid))

csv_file = "gridsearch_results.csv"
with open(csv_file, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(list(param_grid.keys()) + ["val_accuracy"])

best_overall_acc = 0
best_overall_params = None

for params in all_params:
    print(f"\n🚀 Trying config: {params}")

    # Clear GPU
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

    # Dataset
    train_dataset = IntentDataset(train_df["text"], train_df["label"], tokenizer)
    val_dataset = IntentDataset(val_df["text"], val_df["label"], tokenizer)
    train_loader = DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=params["batch_size"])

    # Class weights
    class_weights = compute_class_weight('balanced', classes=np.unique(train_df['label']), y=train_df['label'])
    class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)

    # Boost underperforming classes
    boost_labels = ["treatment", "treatment method", "symptoms", "disease manifestations"]
    boost_indices = label_encoder.transform(boost_labels)
    for idx in boost_indices:
        class_weights_tensor[idx] *= params["boost_factor"]

    model = BERTClassifier(num_classes, dropout=params["dropout"]).to(device)
    loss_fn = nn.CrossEntropyLoss(weight=class_weights_tensor)
    optimizer = AdamW(model.parameters(), lr=params["learning_rate"], weight_decay=params["weight_decay"])

    num_epochs = 10
    total_steps = len(train_loader) * num_epochs
    warmup_steps = int(params["warmup_proportion"] * total_steps)
    scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

    best_val_acc = 0
    patience = 2
    patience_counter = 0

    for epoch in range(num_epochs):
        train_loss, train_acc = train_epoch(model, train_loader, loss_fn, optimizer, scheduler, device)
        val_loss, val_acc = eval_epoch(model, val_loader, loss_fn, device)
        print(f"Epoch {epoch+1}: train_acc={train_acc:.4f}, val_acc={val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("⛔ Early stopping")
                break

    with open(csv_file, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([params[k] for k in param_grid.keys()] + [round(best_val_acc, 4)])

    if best_val_acc > best_overall_acc:
        best_overall_acc = best_val_acc
        best_overall_params = params

print("\n✅ Grid search complete. Results saved to:", csv_file)
print(f"\n🎉 Best overall val_accuracy: {best_overall_acc:.4f}")
print("Best hyperparameters:")
for k, v in best_overall_params.items():
    print(f"  {k}: {v}")


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations



🚀 Trying config: {'batch_size': 16, 'boost_factor': 1.0, 'dropout': 0.15, 'learning_rate': 3e-05, 'warmup_proportion': 0.05, 'weight_decay': 0.01}


Training: 100%|██████████| 754/754 [03:26<00:00,  3.66it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.04it/s]


Epoch 1: train_acc=0.3888, val_acc=0.5236


Training: 100%|██████████| 754/754 [03:37<00:00,  3.47it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.02it/s]


Epoch 2: train_acc=0.6197, val_acc=0.6284


Training: 100%|██████████| 754/754 [03:40<00:00,  3.42it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00,  9.95it/s]


Epoch 3: train_acc=0.7521, val_acc=0.6835


Training: 100%|██████████| 754/754 [03:33<00:00,  3.54it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.11it/s]


Epoch 4: train_acc=0.8526, val_acc=0.7213


Training: 100%|██████████| 754/754 [03:31<00:00,  3.56it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.10it/s]


Epoch 5: train_acc=0.9177, val_acc=0.7691


Training: 100%|██████████| 754/754 [03:31<00:00,  3.56it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.09it/s]


Epoch 6: train_acc=0.9485, val_acc=0.7870


Training: 100%|██████████| 754/754 [03:31<00:00,  3.56it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.10it/s]


Epoch 7: train_acc=0.9655, val_acc=0.7936


Training: 100%|██████████| 754/754 [03:32<00:00,  3.56it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.06it/s]


Epoch 8: train_acc=0.9771, val_acc=0.7996


Training: 100%|██████████| 754/754 [03:32<00:00,  3.55it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.07it/s]


Epoch 9: train_acc=0.9834, val_acc=0.8056


Training: 100%|██████████| 754/754 [03:31<00:00,  3.57it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.05it/s]


Epoch 10: train_acc=0.9858, val_acc=0.8042

🚀 Trying config: {'batch_size': 16, 'boost_factor': 1.0, 'dropout': 0.15, 'learning_rate': 3e-05, 'warmup_proportion': 0.05, 'weight_decay': 0.05}


Training: 100%|██████████| 754/754 [03:32<00:00,  3.54it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.11it/s]


Epoch 1: train_acc=0.3880, val_acc=0.5421


Training: 100%|██████████| 754/754 [03:32<00:00,  3.54it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.08it/s]


Epoch 2: train_acc=0.6216, val_acc=0.6131


Training: 100%|██████████| 754/754 [03:32<00:00,  3.55it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.09it/s]


Epoch 3: train_acc=0.7596, val_acc=0.7014


Training: 100%|██████████| 754/754 [03:32<00:00,  3.55it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.10it/s]


Epoch 4: train_acc=0.8596, val_acc=0.7419


Training: 100%|██████████| 754/754 [03:31<00:00,  3.56it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.08it/s]


Epoch 5: train_acc=0.9162, val_acc=0.7525


Training: 100%|██████████| 754/754 [03:32<00:00,  3.55it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.10it/s]


Epoch 6: train_acc=0.9465, val_acc=0.7790


Training: 100%|██████████| 754/754 [03:32<00:00,  3.55it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.10it/s]


Epoch 7: train_acc=0.9677, val_acc=0.7877


Training: 100%|██████████| 754/754 [03:32<00:00,  3.55it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.05it/s]


Epoch 8: train_acc=0.9749, val_acc=0.7923


Training: 100%|██████████| 754/754 [03:32<00:00,  3.55it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.09it/s]


Epoch 9: train_acc=0.9813, val_acc=0.8076


Training: 100%|██████████| 754/754 [03:32<00:00,  3.56it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.11it/s]


Epoch 10: train_acc=0.9868, val_acc=0.8049

🚀 Trying config: {'batch_size': 16, 'boost_factor': 1.0, 'dropout': 0.15, 'learning_rate': 3e-05, 'warmup_proportion': 0.1, 'weight_decay': 0.01}


Training: 100%|██████████| 754/754 [03:38<00:00,  3.46it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.16it/s]


Epoch 1: train_acc=0.3516, val_acc=0.5149


Training: 100%|██████████| 754/754 [03:33<00:00,  3.53it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.21it/s]


Epoch 2: train_acc=0.5991, val_acc=0.6238


Training: 100%|██████████| 754/754 [03:33<00:00,  3.53it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.09it/s]


Epoch 3: train_acc=0.7430, val_acc=0.6423


Training: 100%|██████████| 754/754 [03:31<00:00,  3.56it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.12it/s]


Epoch 4: train_acc=0.8429, val_acc=0.7419


Training: 100%|██████████| 754/754 [03:36<00:00,  3.48it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00,  9.80it/s]


Epoch 5: train_acc=0.9105, val_acc=0.7711


Training: 100%|██████████| 754/754 [03:46<00:00,  3.33it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00,  9.89it/s]


Epoch 6: train_acc=0.9465, val_acc=0.7817


Training: 100%|██████████| 754/754 [03:46<00:00,  3.33it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00,  9.86it/s]


Epoch 7: train_acc=0.9643, val_acc=0.7936


Training: 100%|██████████| 754/754 [03:49<00:00,  3.28it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00,  9.84it/s]


Epoch 8: train_acc=0.9762, val_acc=0.8009


Training: 100%|██████████| 754/754 [03:49<00:00,  3.28it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00,  9.83it/s]


Epoch 9: train_acc=0.9812, val_acc=0.8009


Training: 100%|██████████| 754/754 [03:46<00:00,  3.32it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00,  9.84it/s]


Epoch 10: train_acc=0.9865, val_acc=0.8062

🚀 Trying config: {'batch_size': 16, 'boost_factor': 1.0, 'dropout': 0.15, 'learning_rate': 3e-05, 'warmup_proportion': 0.1, 'weight_decay': 0.05}


Training: 100%|██████████| 754/754 [03:50<00:00,  3.27it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00,  9.72it/s]


Epoch 1: train_acc=0.3494, val_acc=0.5156


Training: 100%|██████████| 754/754 [03:49<00:00,  3.28it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00,  9.73it/s]


Epoch 2: train_acc=0.5777, val_acc=0.6025


Training: 100%|██████████| 754/754 [03:52<00:00,  3.24it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00,  9.75it/s]


Epoch 3: train_acc=0.7234, val_acc=0.6821


Training: 100%|██████████| 754/754 [03:51<00:00,  3.25it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00,  9.75it/s]


Epoch 4: train_acc=0.8324, val_acc=0.7253


Training: 100%|██████████| 754/754 [03:50<00:00,  3.28it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00,  9.73it/s]


Epoch 5: train_acc=0.9019, val_acc=0.7631


Training: 100%|██████████| 754/754 [03:52<00:00,  3.24it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00,  9.73it/s]


Epoch 6: train_acc=0.9422, val_acc=0.7764


Training: 100%|██████████| 754/754 [03:52<00:00,  3.24it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00,  9.70it/s]


Epoch 7: train_acc=0.9648, val_acc=0.7810


Training: 100%|██████████| 754/754 [03:52<00:00,  3.24it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00,  9.70it/s]


Epoch 8: train_acc=0.9727, val_acc=0.7910


Training: 100%|██████████| 754/754 [03:51<00:00,  3.26it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00,  9.81it/s]


Epoch 9: train_acc=0.9829, val_acc=0.8049


Training: 100%|██████████| 754/754 [03:53<00:00,  3.22it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00,  9.68it/s]


Epoch 10: train_acc=0.9873, val_acc=0.8009

🚀 Trying config: {'batch_size': 16, 'boost_factor': 1.0, 'dropout': 0.15, 'learning_rate': 3e-05, 'warmup_proportion': 0.15, 'weight_decay': 0.01}


Training: 100%|██████████| 754/754 [03:53<00:00,  3.22it/s]
Evaluating: 100%|██████████| 95/95 [00:09<00:00,  9.70it/s]


Epoch 1: train_acc=0.3096, val_acc=0.4891


Training:  88%|████████▊ | 662/754 [03:25<00:28,  3.25it/s]

: 