In [None]:
!pip install -U transformers accelerate datasets --quiet

import transformers
print(transformers.__version__)


In [None]:
!pip install evaluate

In [None]:
# ==== Install deps (safe to run multiple times) ====
!pip install -U transformers datasets accelerate evaluate scikit-learn matplotlib --quiet

# ==== Imports ====
import os
import pandas as pd
import numpy as np
import evaluate
import matplotlib.pyplot as plt
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)

# ==== Disable external logging (W&B etc.) ====
os.environ["WANDB_DISABLED"] = "true"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"

# ==== Paths (chunked BlueScrubs for BERT) ====
TRAIN = "./bluescrubs_train_chunked_bert.csv"
VAL   = "./bluescrubs_val_chunked_bert.csv"
TEST  = "./bluescrubs_test_chunked_bert.csv"

# ==== Load CSVs ‚Üí Hugging Face Datasets ====
def to_hfds(path):
    df = pd.read_csv(path)
    df["labels"] = df["label"].astype(int)   # Trainer expects 'labels'
    return Dataset.from_pandas(df[["text", "labels"]])

ds = DatasetDict({
    "train": to_hfds(TRAIN),
    "validation": to_hfds(VAL),
    "test": to_hfds(TEST)
})

# ==== Tokenizer & Preprocessing ====
MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

def preprocess(batch):
    return tokenizer(batch["text"], truncation=True, max_length=512)

tokenized = ds.map(preprocess, batched=True, remove_columns=["text"])

# ==== Model ====
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

# ==== Metrics ====
accuracy = evaluate.load("accuracy")
precision = evaluate.load("precision")
recall = evaluate.load("recall")
f1 = evaluate.load("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {
        "accuracy":  accuracy.compute(predictions=preds, references=labels)["accuracy"],
        "precision": precision.compute(predictions=preds, references=labels, average="binary")["precision"],
        "recall":    recall.compute(predictions=preds, references=labels, average="binary")["recall"],
        "f1":        f1.compute(predictions=preds, references=labels, average="binary")["f1"],
    }

# ==== Fine-tuning Training Arguments (no evaluation_strategy) ====
args = TrainingArguments(
    output_dir="./bioclinicalbert_bluescrubs_finetuned",

    # üîß Fine-tuning changes vs baseline
    learning_rate=1e-5,             # smaller LR than baseline (2e-5)
    num_train_epochs=4,             # more epochs than baseline (2)
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=2,  # effective batch size 16
    weight_decay=0.01,
    fp16=True,

    # üìù Logging & checkpoints (simple, version-safe)
    logging_dir="./logs",
    logging_steps=500,
    save_steps=1000,
    save_total_limit=2,
    report_to=[],                   # no W&B / TB / MLflow
)

# ==== Trainer ====
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

# ==== Train (fine-tuning) ====
trainer.train()

# ==== Evaluate on Validation & Test Sets ====
print("\n===== Validation Results (Fine-tuned BioClinicalBERT) =====")
val_results = trainer.evaluate(tokenized["validation"])
for k, v in val_results.items():
    try:
        print(f"{k}: {v:.4f}")
    except TypeError:
        print(f"{k}: {v}")

print("\n===== Test Results (Fine-tuned BioClinicalBERT) =====")
test_results = trainer.evaluate(tokenized["test"], metric_key_prefix="test")
for k, v in test_results.items():
    try:
        print(f"{k}: {v:.4f}")
    except TypeError:
        print(f"{k}: {v}")

# ==== Training Logs & Plots (if available) ====
logs = pd.DataFrame(trainer.state.log_history)
print("\n===== Log History (tail) =====")
print(logs.tail())

# Plot Loss if possible
if "loss" in logs.columns:
    plt.figure(figsize=(8, 5))
    x = logs["step"] if "step" in logs.columns else range(len(logs["loss"]))
    plt.plot(x, logs["loss"], label="Training Loss", marker="o")
    if "eval_loss" in logs.columns:
        plt.plot(x, logs["eval_loss"], label="Validation Loss", marker="o")
    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.title("Training vs Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.show()

# Plot Accuracy and F1 if available
if "eval_accuracy" in logs.columns and "eval_f1" in logs.columns:
    plt.figure(figsize=(8, 5))
    x = logs["step"] if "step" in logs.columns else range(len(logs["eval_accuracy"]))
    plt.plot(x, logs["eval_accuracy"], label="Validation Accuracy", marker="o")
    plt.plot(x, logs["eval_f1"], label="Validation F1", marker="o")
    plt.xlabel("Step")
    plt.ylabel("Score")
    plt.title("Validation Accuracy and F1 Score")
    plt.legend()
    plt.grid(True)
    plt.show()
