In [262]:
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import Dataset, DataLoader
import numpy as np
from datetime import datetime
from tqdm import tqdm


In [264]:
import pandas as pd
import ast

# Read CSV normally
df = pd.read_csv("preprocessed_medical_abstracts.csv")

# ✅ Convert stringified list back to real Python list
df['tokens_no_stopwords'] = df['tokens_no_stopwords'].apply(ast.literal_eval)

# 🔁 Recreate token_string
df['token_string'] = df['tokens_no_stopwords'].apply(
    lambda sents: ' '.join([word for sent in sents for word in sent])
)


In [265]:
print(df['token_string'].iloc[0][:200])

tissue change loose prosthesis canine model investigate effect antiinflammatory agent aseptically loosen prosthesis provide means investigate vivo vitro activity cell associate loosening process seven


In [266]:
# Encode labels
le = LabelEncoder()
df['label_enc'] = le.fit_transform(df['condition_label'])

# Split data
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df['token_string'].tolist(), df['label_enc'].tolist(), test_size=0.2, random_state=42, stratify=df['label_enc'].tolist()
)

In [267]:
for i in range(3):
    print(f"Text: {train_texts[i]}")
    print(f"Label: {train_labels[i]}")
    print("---")

Text: comparison cardiac catheterization doppler echocardiography decision operate aortic mitral valve disease clinical decision utilize doppler echocardiographic cardiac catheterization datum compare adult patient isolated combine aortic mitral valve disease clinical decision operate operate remain uncertain experienced cardiologist doppler echocardiographic cardiac catheterization datum prospective evaluation perform consecutive patient mean age year valvular heart disease consider surgical treatment basis clinical information patient undergo cardiac catheterization detailed doppler echocardiographic examination set cardiologist decision maker know patient identity clinical information combination doppler echocardiographic cardiac catheterization datum combination doppler echocardiographic clinical datum consider inadequate clinical decision making patient aortic patient mitral valve disease combination cardiac catheterization clinical datum consider inadequate patient aortic patient

In [None]:
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
# tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=200)
val_encodings = tokenizer(val_texts, truncation=True, padding=True, max_length=200)


In [269]:
# View first example
print("Input IDs:", train_encodings['input_ids'][0])
print("Attention Mask:", train_encodings['attention_mask'][0])
print("Decoded back:", tokenizer.decode(train_encodings['input_ids'][0]))


Input IDs: [101, 7831, 15050, 4937, 27065, 11124, 9276, 2079, 9397, 3917, 9052, 11522, 26535, 3247, 5452, 20118, 28228, 2278, 10210, 7941, 10764, 4295, 6612, 3247, 16462, 2079, 9397, 3917, 9052, 11522, 3695, 14773, 15050, 4937, 27065, 11124, 9276, 23755, 2819, 12826, 4639, 5776, 7275, 11506, 20118, 28228, 2278, 10210, 7941, 10764, 4295, 6612, 3247, 5452, 5452, 3961, 9662, 5281, 4003, 20282, 22522, 2079, 9397, 3917, 9052, 11522, 3695, 14773, 15050, 4937, 27065, 11124, 9276, 23755, 2819, 17464, 9312, 4685, 5486, 5776, 2812, 2287, 2095, 11748, 19722, 8017, 2540, 4295, 5136, 11707, 3949, 3978, 6612, 2592, 5776, 13595, 15050, 4937, 27065, 11124, 9276, 6851, 2079, 9397, 3917, 9052, 11522, 3695, 14773, 7749, 2275, 4003, 20282, 22522, 3247, 9338, 2113, 5776, 4767, 6612, 2592, 5257, 2079, 9397, 3917, 9052, 11522, 3695, 14773, 15050, 4937, 27065, 11124, 9276, 23755, 2819, 5257, 2079, 9397, 3917, 9052, 11522, 3695, 14773, 6612, 23755, 2819, 5136, 14710, 6612, 3247, 2437, 5776, 20118, 28228, 2278,

In [270]:
class ClinicalDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return {
            key: torch.tensor(val[idx]) for key, val in self.encodings.items()
        } | {"labels": torch.tensor(self.labels[idx])}

In [271]:
train_dataset = ClinicalDataset(train_encodings, train_labels)
val_dataset = ClinicalDataset(val_encodings, val_labels)

In [272]:
print(train_dataset)

<__main__.ClinicalDataset object at 0x7011a3e4c0d0>


In [None]:
from transformers import AutoModelForSequenceClassification
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=len(le.classes_)
).to(device)



# ✅ Replace the classifier to increase dropout
model.classifier = nn.Sequential(
    nn.Dropout(p=0.5),  # 👈 increased from 0.1 to 0.3
    nn.Linear(model.config.hidden_size, len(le.classes_))
)

model.to(device)

# ✅ Freeze the first 6 layers
for name, param in model.bert.encoder.layer[:6].named_parameters():
    param.requires_grad = True



Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:

optimizer = AdamW([
    {'params': model.bert.parameters(), 'lr': 5e-6},  # 👈 gentler LR for BERT base
    {'params': model.classifier.parameters(), 'lr': 2e-5}  # 👈 higher LR for new classifier head
])

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)


class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)


## early stopping

In [278]:
from sklearn.metrics import f1_score
from datetime import datetime
from tqdm import tqdm

# EarlyStopping class (if you haven't defined it yet)
class EarlyStopping:
    def __init__(self, patience=3):
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_score, model):
        if self.best_score is None or val_score > self.best_score:
            self.best_score = val_score
            self.counter = 0
            torch.save(model.state_dict(), "checkpoint.pt")  # ✅ Save best model
            print("✅ F1 improved. Saving model...")
        else:
            self.counter += 1
            print(f"❗ EarlyStopping counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True

# --------------------- Training Loop ---------------------

early_stopping = EarlyStopping(patience=3)
model.train()
epochs = 20
print("Starting training...")
start = datetime.now()

for epoch in range(epochs):
    total_loss = 0
    model.train()

    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask)
        loss = loss_fn(outputs.logits, labels)

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f}")

    # ------------------ Validation ------------------
    model.eval()
    val_preds, val_true = [], []
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs.logits, dim=1)

            val_preds.extend(preds.cpu().numpy())
            val_true.extend(labels.cpu().numpy())

    macro_f1 = f1_score(val_true, val_preds, average='macro')
    print(f"Macro F1: {macro_f1:.4f}")

    early_stopping(macro_f1, model)
    if early_stopping.early_stop:
        print("🛑 Early stopping triggered!")
        break

print("Training time:", datetime.now() - start)


Starting training...


100%|██████████| 289/289 [01:09<00:00,  4.18it/s]



Epoch 1 | Train Loss: 1.4513
Macro F1: 0.5205
✅ F1 improved. Saving model...


100%|██████████| 289/289 [01:09<00:00,  4.19it/s]



Epoch 2 | Train Loss: 0.9857
Macro F1: 0.5703
✅ F1 improved. Saving model...


100%|██████████| 289/289 [01:09<00:00,  4.18it/s]



Epoch 3 | Train Loss: 0.8628
Macro F1: 0.5878
✅ F1 improved. Saving model...


100%|██████████| 289/289 [01:09<00:00,  4.18it/s]



Epoch 4 | Train Loss: 0.8048
Macro F1: 0.5945
✅ F1 improved. Saving model...


100%|██████████| 289/289 [01:09<00:00,  4.18it/s]



Epoch 5 | Train Loss: 0.7533
Macro F1: 0.6017
✅ F1 improved. Saving model...


100%|██████████| 289/289 [01:09<00:00,  4.18it/s]



Epoch 6 | Train Loss: 0.7075
Macro F1: 0.6098
✅ F1 improved. Saving model...


100%|██████████| 289/289 [01:09<00:00,  4.18it/s]



Epoch 7 | Train Loss: 0.6701
Macro F1: 0.5968
❗ EarlyStopping counter: 1/3


100%|██████████| 289/289 [01:09<00:00,  4.18it/s]



Epoch 8 | Train Loss: 0.6391
Macro F1: 0.6126
✅ F1 improved. Saving model...


100%|██████████| 289/289 [01:09<00:00,  4.18it/s]



Epoch 9 | Train Loss: 0.6049
Macro F1: 0.6028
❗ EarlyStopping counter: 1/3


100%|██████████| 289/289 [01:09<00:00,  4.18it/s]



Epoch 10 | Train Loss: 0.5747
Macro F1: 0.5947
❗ EarlyStopping counter: 2/3


100%|██████████| 289/289 [01:09<00:00,  4.18it/s]



Epoch 11 | Train Loss: 0.5544
Macro F1: 0.5946
❗ EarlyStopping counter: 3/3
🛑 Early stopping triggered!
Training time: 0:13:44.205721


In [279]:
model.load_state_dict(torch.load('checkpoint.pt'))
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [280]:
from sklearn.metrics import accuracy_score, classification_report

val_preds, val_true = [], []

with torch.no_grad():
    for batch in val_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask)
        preds = torch.argmax(outputs.logits, dim=1)

        val_preds.extend(preds.cpu().numpy())
        val_true.extend(labels.cpu().numpy())

# Compute final accuracy
accuracy = accuracy_score(val_true, val_preds)
print(f"✅ Validation Accuracy (best model): {accuracy:.4f}")

# Optional: F1, Precision, Recall
print(classification_report(val_true, val_preds))


✅ Validation Accuracy (best model): 0.6134
              precision    recall  f1-score   support

           0       0.70      0.79      0.74       506
           1       0.49      0.74      0.59       239
           2       0.49      0.71      0.58       308
           3       0.67      0.76      0.71       488
           4       0.66      0.33      0.44       769

    accuracy                           0.61      2310
   macro avg       0.60      0.66      0.61      2310
weighted avg       0.63      0.61      0.60      2310

