In [2]:
#Imports
import math
import random
from collections import Counter
from typing import List, Dict
import ast
import json

import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import pandas as pd

import torch
import torch.nn as nn
from datasets import Dataset
import evaluate
from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback,
    set_seed,
)

In [3]:
#Set seeds
set_seed(42)
random.seed(42)

In [5]:
#Load data
data = pd.read_csv("data/synthetic.csv")

In [6]:
#Data formatting function

def process_row(row):
    text = row["note"]

    disorders = ast.literal_eval(row["extracted_disorders"])
    dates = ast.literal_eval(row["formatted_dates"])
    gold = ast.literal_eval(row["relationship_gold"])

    # Build lookup for gold relations: (disorder_pos, date_pos) -> relation_type
    gold_map = {}
    for g in gold:
        date_pos = g["date_position"]
        for diag in g.get("diagnoses", []):
            gold_map[(diag["position"], date_pos)] = "diagnosis_date"  # <-- adjust relation type if multiple types

    samples = []
    for d in disorders:
        d_start, d_end = d["start"], d["end"]
        disorder_text = text[d_start:d_end]

        for dt in dates:
            dt_start = dt.get("start", None)
            if dt_start is None:
                dt_start = text.find(dt["original"])
            dt_end = dt_start + len(dt["original"])
            date_text = text[dt_start:dt_end]

            # Label: check if pair is in gold_map
            key = (d_start, dt_start)
            label = gold_map.get(key, "no_relation")

            # Insert entity markers (insert later span first)
            marked = text
            for span, token1, token2, ent_text, span_end in sorted(
                [(d_start, "[E1]", "[/E1]", disorder_text, d_end),
                 (dt_start, "[E2]", "[/E2]", date_text, dt_end)],
                reverse=True
            ):
                marked = marked[:span] + f"{token1} {ent_text} {token2}" + marked[span_end:]

            samples.append({
                "text": text,
                "marked_text": marked,
                "ent1_start": d_start, "ent1_end": d_end,
                "ent2_start": dt_start, "ent2_end": dt_end,
                "label": label
            })

    return samples

In [7]:
# explode dataset into pairs
all_samples = []
for _, row in data.iterrows():
    all_samples.extend(process_row(row))

processed_df = pd.DataFrame(all_samples)
processed_df.head(5)

Unnamed: 0,text,marked_text,ent1_start,ent1_end,ent2_start,ent2_end,label
0,Ultrasound (30nd Jun 2024): no significant fin...,Ultrasound (30nd Jun 2024): no significant fin...,57,63,311,326,diagnosis_date
1,Ultrasound (30nd Jun 2024): no significant fin...,Ultrasound (30nd Jun 2024): no significant fin...,57,63,587,602,no_relation
2,Ultrasound (30nd Jun 2024): no significant fin...,Ultrasound (30nd Jun 2024): no significant fin...,410,427,311,326,no_relation
3,Ultrasound (30nd Jun 2024): no significant fin...,Ultrasound (30nd Jun 2024): no significant fin...,410,427,587,602,no_relation
4,Ultrasound (30nd Jun 2024): no significant fin...,Ultrasound (30nd Jun 2024): no significant fin...,491,511,311,326,no_relation


In [8]:
#Define labels
label_list = ["no_relation", "diagnosis_date"]
label2id = {lbl: i for i, lbl in enumerate(label_list)}
id2label = {i: lbl for lbl, i in label2id.items()}
num_labels = len(label_list)

In [9]:
#Apply labels
processed_df["label_id"] = processed_df["label"].map(label2id)
processed_df.head(5)

Unnamed: 0,text,marked_text,ent1_start,ent1_end,ent2_start,ent2_end,label,label_id
0,Ultrasound (30nd Jun 2024): no significant fin...,Ultrasound (30nd Jun 2024): no significant fin...,57,63,311,326,diagnosis_date,1
1,Ultrasound (30nd Jun 2024): no significant fin...,Ultrasound (30nd Jun 2024): no significant fin...,57,63,587,602,no_relation,0
2,Ultrasound (30nd Jun 2024): no significant fin...,Ultrasound (30nd Jun 2024): no significant fin...,410,427,311,326,no_relation,0
3,Ultrasound (30nd Jun 2024): no significant fin...,Ultrasound (30nd Jun 2024): no significant fin...,410,427,587,602,no_relation,0
4,Ultrasound (30nd Jun 2024): no significant fin...,Ultrasound (30nd Jun 2024): no significant fin...,491,511,311,326,no_relation,0


In [10]:
#Create dataset
dataset = Dataset.from_pandas(processed_df[["marked_text", "label_id"]])
dataset

Dataset({
    features: ['marked_text', 'label_id'],
    num_rows: 1242
})

In [11]:
#Train/test split
dataset = dataset.train_test_split(test_size=0.2, seed=42)

In [12]:
#Tokenizer with special tokens (+ model name)
model_name = "emilyalsentzer/Bio_ClinicalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_name)
special_tokens = {"additional_special_tokens": ["[E1]", "[/E1]", "[E2]", "[/E2]"]}
tokenizer.add_special_tokens(special_tokens)

4

In [13]:
#Tokenization
def tokenize_fn(batch):
    return tokenizer(batch["marked_text"], truncation=True, padding="max_length", max_length=256)

tokenized = dataset.map(tokenize_fn, batched=True)

tokenized = tokenized.rename_column("label_id", "labels")
tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

Map:   0%|          | 0/993 [00:00<?, ? examples/s]

Map:   0%|          | 0/249 [00:00<?, ? examples/s]

In [14]:
#Look at class distribution
unique, counts = np.unique(tokenized["test"]["labels"], return_counts=True)
print(dict(zip(unique, counts)))

{0: 214, 1: 35}


In [15]:
#Load model
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id
)

# Resize embeddings so new tokens are usable
model.resize_token_embeddings(len(tokenizer))

ValueError: Due to a serious vulnerability issue in `torch.load`, even with `weights_only=True`, we now require users to upgrade torch to at least v2.6 in order to use the function. This version restriction does not apply when loading files with safetensors.
See the vulnerability report here https://nvd.nist.gov/vuln/detail/CVE-2025-32434

In [None]:
#Run training
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(-1)
    return {
        "accuracy": accuracy.compute(predictions=preds, references=labels)["accuracy"],
        "f1": f1.compute(predictions=preds, references=labels, average="macro")["f1"]
    }

training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
# Evaluate on test set and print F1
metrics = trainer.evaluate(tokenized["test"])
print("Test F1 (macro):", metrics.get("f1", metrics.get("f1_macro")))
print("All metrics:", metrics)

Additional

In [None]:
train_dataset = tokenized["train"]
eval_dataset = tokenized["test"]

In [None]:
#Function for computing class weights for weighted CE loss (to handle imbalance)
def compute_class_weights(ds: Dataset, n_labels: int):
    if len(ds) == 0:
        return torch.ones(n_labels)  # neutral
    counts = Counter([int(x) for x in ds["labels"]])
    total = sum(counts.values())
    weights = []
    for i in range(n_labels):
        # Inverse frequency (scaled): total / (n_labels * count_i)
        # Clamp to avoid inf if a class is missing in train set.
        c = max(1, counts.get(i, 0))
        w = total / (n_labels * c)
        weights.append(w)
    # Normalize so mean weight ~= 1
    mean_w = sum(weights) / len(weights)
    weights = [w / mean_w for w in weights]
    return torch.tensor(weights, dtype=torch.float)

In [None]:
#Compute class weights
class_weights = compute_class_weights(train_dataset, num_labels)

In [None]:
#Custom model: Bio_ClinicalBERT + span pooling between markers
class BertRC(nn.Module):
    """
    Forward expected inputs (from Trainer):
        input_ids:      [B, L]   torch.long
        attention_mask: [B, L]   torch.long
        labels:         [B]      torch.long

    Internals:
        last_hidden_state: [B, L, H]
        span pooling: mask tokens strictly between [E1]...[/E1] and [E2]...[/E2]
                       -> e1_emb, e2_emb: [B, H]
        concat: [B, 2H] -> classifier -> logits: [B, num_labels]
    """
    def __init__(self, model_name: str, tokenizer, num_labels: int, class_weights: torch.Tensor = None):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(model_name)
        self.backbone.resize_token_embeddings(len(tokenizer))

        self.hidden_size = self.backbone.config.hidden_size  # e.g., 768
        self.dropout = nn.Dropout(self.backbone.config.hidden_dropout_prob)
        self.classifier = nn.Linear(2 * self.hidden_size, num_labels)

        # Cache token IDs for markers
        self.e1_open_id = tokenizer.convert_tokens_to_ids("[E1]")
        self.e1_close_id = tokenizer.convert_tokens_to_ids("[/E1]")
        self.e2_open_id = tokenizer.convert_tokens_to_ids("[E2]")
        self.e2_close_id = tokenizer.convert_tokens_to_ids("[/E2]")

        # Class weights for imbalance
        if class_weights is not None:
            self.register_buffer("class_weights", class_weights)
        else:
            self.class_weights = None

    @staticmethod
    def _first_index(mask: torch.Tensor) -> torch.Tensor:
        """
        mask: [B, L] bool
        returns: [B] first True index (0 if none)
        """
        # Convert to float and argmax: if no True, argmax returns 0 (handled later)
        return mask.float().argmax(dim=1)

    def _span_mean(
        self,
        hidden: torch.Tensor,      # [B, L, H]
        input_ids: torch.Tensor,   # [B, L]
        open_id: int,
        close_id: int,
    ) -> torch.Tensor:
        """
        Mean-pool tokens strictly between open and close markers.
        Fallback: if span is empty or markers missing (e.g., truncation), use the open-marker embedding.

        returns: [B, H]
        """
        B, L, H = hidden.shape
        pos = torch.arange(L, device=hidden.device).unsqueeze(0).expand(B, L)  # [B, L]

        open_mask = (input_ids == open_id)    # [B, L]
        close_mask = (input_ids == close_id)  # [B, L]

        open_idx = self._first_index(open_mask)   # [B]
        close_idx = self._first_index(close_mask) # [B]

        # span_mask[b, t] = True iff open_idx[b] < t < close_idx[b]
        span_mask = (pos > open_idx.unsqueeze(1)) & (pos < close_idx.unsqueeze(1))  # [B, L]

        # Pool
        denom = span_mask.sum(dim=1, keepdim=True).clamp_min(1)  # [B, 1]
        span_sum = (hidden * span_mask.unsqueeze(-1)).sum(dim=1)  # [B, H]
        span_mean = span_sum / denom  # [B, H]

        # Fallback to open marker embedding if span empty or markers missing
        has_tokens = span_mask.any(dim=1, keepdim=True)  # [B, 1] bool
        open_emb = (hidden * open_mask.unsqueeze(-1)).sum(dim=1)  # [B, H]
        e_emb = torch.where(has_tokens, span_mean, open_emb)      # [B, H]
        return e_emb

    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        last_hidden = outputs.last_hidden_state  # [B, L, H]

        # Pool entity and date spans
        e1_emb = self._span_mean(last_hidden, input_ids, self.e1_open_id, self.e1_close_id)  # [B, H]
        e2_emb = self._span_mean(last_hidden, input_ids, self.e2_open_id, self.e2_close_id)  # [B, H]

        # Concatenate -> classify
        x = torch.cat([e1_emb, e2_emb], dim=-1)  # [B, 2H]
        x = self.dropout(x)
        logits = self.classifier(x)  # [B, num_labels]

        loss = None
        if labels is not None:
            if hasattr(self, "class_weights") and self.class_weights is not None:
                loss_fn = nn.CrossEntropyLoss(weight=self.class_weights)
            else:
                loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)

        return {"loss": loss, "logits": logits}

In [None]:
#Metrics: accuracy + F1s
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "f1_macro": f1_score(labels, preds, average="macro", zero_division=0),
        "f1_micro": f1_score(labels, preds, average="micro", zero_division=0),
        "f1_weighted": f1_score(labels, preds, average="weighted", zero_division=0),
    }

In [None]:
#Model
model = BertRC(model_name, tokenizer, num_labels=num_labels, class_weights=class_weights)

In [None]:
# Define training arguments
training_args = TrainingArguments(
    output_dir="./rc_results",
    evaluation_strategy="epoch" if len(eval_dataset) > 0 else "no",
    save_strategy="epoch" if len(eval_dataset) > 0 else "no",
    load_best_model_at_end=True if len(eval_dataset) > 0 else False,
    metric_for_best_model="f1_macro",
    greater_is_better=True,
    num_train_epochs=5,
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    warmup_ratio=0.06,
    weight_decay=0.01,
    logging_steps=50,
    fp16=torch.cuda.is_available(),
    report_to=[],  # turn off W&B/MLflow by default
    seed=42,
)

In [None]:
#Define trainer
trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=train_dataset if len(train_dataset) > 0 else None,
    eval_dataset=eval_dataset if len(eval_dataset) > 0 else None,
    compute_metrics=compute_metrics if len(eval_dataset) > 0 else None,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)] if len(eval_dataset) > 0 else None,
)

In [None]:
#Train model
if len(train_dataset) > 0:
    trainer.train()
    if len(eval_dataset) > 0:
        metrics = trainer.evaluate()
        print(metrics)