In [None]:
import os
import pandas as pd
from pathlib import Path

def split_reports(input_csv: str, output_dir: str):
    df = pd.read_csv(input_csv)
    df["report_name"] = df["File Path"].apply(lambda x: Path(x).stem)
    base = Path(output_dir)
    base.mkdir(parents=True, exist_ok=True)
    for report_name, group in df.groupby("report_name"):
        report_dir = base / report_name
        report_dir.mkdir(parents=True, exist_ok=True)
        group.drop(columns=["report_name"]).to_csv(report_dir / "report.csv", index=False)
if __name__ == "__main__":
    split_reports(
        input_csv="ml/brain_sentences_combined.csv", # don't use full, because this is the filtered version
        output_dir="ml/CRF_individual"
    )
    print("Finished splitting into individual report CSVs.")


In [None]:
#!pip install pandas scikit-learn matplotlib
#!pip install transformers==4.37.0
#!pip install torch==2.1.0+cu118 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
import sys
from sklearn.metrics import f1_score, accuracy_score
import os
import random
import pandas as pd
#!pip install torch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
#from torchcrf import CRF
#!pip install pytorch-crf
#from TorchCRF import CRF
from torchcrf import CRF
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from collections import Counter
import copy
from pathlib import Path
import json
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score

import itertools
import numpy as np

# torchcrf doesn't natively support marginal probabilities so we implement our own non-brute force method to get them using the markov assumption
def compute_log_alpha(crf, emissions, mask):
    seq_length, batch_size, num_tags = emissions.shape

    score = crf.start_transitions + emissions[0] # (batch_size, num_tags)
    history = [score]
    for i in range(1, seq_length):
        broadcast_score = score.unsqueeze(2) # (batch_size, num_tags, 1)
        broadcast_emission = emissions[i].unsqueeze(1) # (batch_size, 1, num_tags)
        # Transition + emission scores
        next_score = broadcast_score + crf.transitions + broadcast_emission # (batch_size, num_tags, num_tags)
        next_score = torch.logsumexp(next_score, dim=1) # (batch_size, num_tags)
        score = torch.where(mask[i].unsqueeze(1), next_score, score)
        history.append(score)
    return torch.stack(history) # (seq_length, batch_size, num_tags)


def compute_log_beta(crf, emissions, mask):
    seq_len, batch_size, num_tags = emissions.shape

    score = crf.end_transitions.unsqueeze(0).expand(batch_size, num_tags) # (batch_size, num_tags)
    history = [score]

    for i in reversed(range(seq_len - 1)):
        bs = score.unsqueeze(1) # (batch_size, 1, num_tags)
        be = emissions[i+1].unsqueeze(1) # (batch_size, 1, num_tags)
        # (batch_size, num_tags, num_tags)
        all_paths = bs + be + crf.transitions
        next_score = torch.logsumexp(all_paths, dim=2) # (batch_size, num_tags)
        mask_i1 = mask[i+1].unsqueeze(1) # (batch_size, 1)
        score = torch.where(mask_i1, next_score, score)
        history.insert(0, score)

    return torch.stack(history) # (seq_len, batch_size, num_tags)

def crf_marginals(crf, emissions, mask):
    seq_len, batch_size, num_tags = emissions.shape
    log_alpha = compute_log_alpha(crf, emissions, mask)
    log_beta = compute_log_beta(crf, emissions, mask)
    
    lengths = mask.sum(dim=0)
    last_indices = (lengths - 1).clamp(min=0)
    batch_indices = torch.arange(batch_size)
    last_log_alpha = log_alpha[last_indices, batch_indices, :]
    log_Z = torch.logsumexp(last_log_alpha + crf.end_transitions, dim=1, keepdim=True)
    
    log_marg = log_alpha + log_beta - log_Z.unsqueeze(0)
    
    mask_t = mask.unsqueeze(2)
    log_marg = torch.where(mask_t, log_marg, torch.tensor(float('-inf'), device=log_marg.device))
    marginals = log_marg.exp()
    
    sums = marginals.sum(dim=2, keepdim=True)
    sums = torch.where(sums > 0, sums, torch.ones_like(sums))
    marginals = marginals / sums
    
    return marginals

USE_BIG_BERT = True
if USE_BIG_BERT:
    MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"
    #MODEL_NAME = "bert-base-uncased"
    BERT_SIZE = 768
else:
    MODEL_NAME = "prajjwal1/bert-tiny"  # lower‐mem model
    BERT_SIZE = 128

if hasattr(torch, "get_default_device"):
    DEVICE = torch.get_default_device()
else:
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "CPU")

SEED         = 42
NUM_LABELS   = 2
MAX_SENT_LEN = 128
DATA_ROOT    = "ml/CRF_individual"
CACHE_BASE   = "ml/CRF_cache"
EPOCHS       = 305
LR = 5e-4
DROPOUTP = 0.6
L2_REG = 0.05
USE_EXTRA_NONLINEARITIES = True

PROJ_DIM     = 768 # choose 768 (or BERT embedding) to not learn a projection layer
USE_ATTENTION = True
ATTN_DIM = 256 # 1 head
CONCAT_NOW_WITH_ATTENTION = True  # recommended
USE_VALUE_UP_DOWN_TO_SAVE_ATT_PARAMETERS = False  # if doing attention, do a low rank transformation to reduce the number of parameters
VALUE_BOTTLENECK_DIM = 64 # only if boolean is True does this matter
USE_MLP_INTERACTION = True
MLP_OUTPUT_DIM = 128 # doesn't matter if MLP is false

#random.seed(SEED)
#torch.manual_seed(SEED)
#torch.cuda.manual_seed_all(SEED)

import joblib
import re
from scipy.sparse import csr_matrix
USE_BOW_FOR_NOW = False  # Set to True for BoW, False for BERT for "now" sentence
BOW_VECTOR_PATH = "ml/no_interactions_for_crf_bow_local_vectorizer.pkl"
BOW_EMBEDDING_SIZE = 768  # pad/truncate BoW to match BERT_SIZE

bow_vectorizer = joblib.load(BOW_VECTOR_PATH)

def clean_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    return text

# option to use some bow embeddings here but not used in final version
def get_bow_embedding(text):
    text_clean = clean_text(text)
    bow_vec = bow_vectorizer.transform([text_clean]).toarray()[0]
    arr = np.zeros(BOW_EMBEDDING_SIZE, dtype=np.float32)
    length = min(len(bow_vec), BOW_EMBEDDING_SIZE)
    arr[:length] = bow_vec[:length]
    return torch.tensor(arr)


# limiting to a fixed amount of context for consistent passage embeddings
USE_DYNAMIC_CONTEXT = False   # True = dynamic context, False = full context
def get_dynamic_context(sents, idx, min_words=4, direction="before"):
    context_sents = []
    count = 0

    if direction == "before":
        # walk backwards from idx-1 down to 0
        for j in range(idx-1, -1, -1):
            context_sents.insert(0, sents[j])  # keep original order
            count = sum(len(s.split()) for s in context_sents)
            if count >= min_words:
                break

    elif direction == "after":
        # walk forwards from idx+1 up to len(sents)-1
        for j in range(idx+1, len(sents)):
            context_sents.append(sents[j])
            count = sum(len(s.split()) for s in context_sents)
            if count >= min_words:
                break

    else:
        raise ValueError("direction must be 'before' or 'after'")

    return " ".join(context_sents)

def model_name_to_folder(model_name):
    return model_name.replace("/", "_")

AORB            = False
POOLING_STRATEGY = "cls" # "cls", "mean", or "max" embeddings

MODEL_FOLDER = model_name_to_folder(MODEL_NAME)
CACHE_DIR = os.path.join(CACHE_BASE, POOLING_STRATEGY, MODEL_FOLDER)
os.makedirs(CACHE_DIR, exist_ok=True)

def get_pooled_embedding(model, tokenizer, text, strategy):
    enc = tokenizer(
        text,
        padding="max_length",
        truncation=True,
        max_length=MAX_SENT_LEN,
        return_tensors="pt"
    )
    enc = {k: v.to(DEVICE) for k, v in enc.items()}
    with torch.no_grad():
        out = model(**enc)
        hidden = out.last_hidden_state
        mask   = enc["attention_mask"].unsqueeze(-1)
    if strategy == "cls":
        return hidden[:, 0, :].squeeze(0).cpu()
    elif strategy == "mean":
        m_hidden = hidden * mask
        summed   = m_hidden.sum(dim=1)
        counts   = mask.sum(dim=1)
        return (summed / counts).squeeze(0).cpu()
    elif strategy == "max":
        mask_exp = mask.expand_as(hidden)
        hidden[mask_exp == 0] = -1e9
        return hidden.max(dim=1).values.squeeze(0).cpu()
    else:
        raise ValueError(f"Unknown pooling strategy: {strategy}")

def preprocess_embeddings(report_dirs, tokenizer, bert):
    for d in report_dirs:
        name       = os.path.basename(d)
        cache_path = os.path.join(CACHE_DIR, f"{name}.pt")
        if os.path.exists(cache_path):
            continue

        df        = pd.read_csv(os.path.join(d, "report.csv"))
        sents     = df["Sentence"].tolist()
        labels    = torch.tensor(df["Brain Related"].tolist(), dtype=torch.long)
        befores, nows, afters = [], [], []

        N = len(sents)
        for i in range(N):
            if USE_DYNAMIC_CONTEXT:
                tb = get_dynamic_context(sents, i, min_words=4, direction="before")
                ta = get_dynamic_context(sents, i, min_words=4, direction="after")
            else:
                tb = " ".join(sents[:i])     if i > 0   else ""
                ta = " ".join(sents[i+1:])   if i < N-1 else ""
            tn = sents[i]

            befores.append(get_pooled_embedding(bert, tokenizer, tb, POOLING_STRATEGY))
            if USE_BOW_FOR_NOW:
                nows.append(get_bow_embedding(tn))
            else:
                nows.append(get_pooled_embedding(bert, tokenizer, tn, POOLING_STRATEGY))
            afters .append(get_pooled_embedding(bert, tokenizer, ta, POOLING_STRATEGY))

        torch.save({
            "before": torch.stack(befores),
            "now":    torch.stack(nows),
            "after":  torch.stack(afters),
            "labels": labels
        }, cache_path)
        print(f"Cached {name} → {cache_path}")

class CachedReportDataset(Dataset):
    def __init__(self, files):
        self.files = files

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        data = torch.load(self.files[idx])
        return (
            data["before"].to(DEVICE),
            data["now"   ].to(DEVICE),
            data["after" ].to(DEVICE),
            data["labels"].to(DEVICE)
        )

def collate_fn(batch):
    # assuming batch_size=1
    return batch[0]

class NeuralCRF(nn.Module):
    def __init__(self):
        super().__init__()
        self.use_attention = USE_ATTENTION
        self.use_mlp_interaction = USE_MLP_INTERACTION
        self.concat_now_with_attention = CONCAT_NOW_WITH_ATTENTION
        self.dropout = nn.Dropout(DROPOUTP)
        
        if PROJ_DIM == BERT_SIZE:
            self.proj = nn.Identity()
        else:
            self.proj = nn.Linear(BERT_SIZE, PROJ_DIM)
            
        if self.use_attention:
            self.attn_dim = ATTN_DIM
            # for each context vector: project to attn_dim for query/key/value
            self.attn_q = nn.Linear(PROJ_DIM, ATTN_DIM)
            self.attn_k = nn.Linear(PROJ_DIM, ATTN_DIM)

            if not USE_VALUE_UP_DOWN_TO_SAVE_ATT_PARAMETERS:
                self.attn_v = nn.Linear(PROJ_DIM, ATTN_DIM)
            else:
                self.attn_v_down = nn.Linear(PROJ_DIM, VALUE_BOTTLENECK_DIM)
                self.attn_v_up   = nn.Linear(VALUE_BOTTLENECK_DIM, ATTN_DIM)

        else:
            self.attn_dim = None
            
        if self.use_attention and self.concat_now_with_attention:
            input_dim = self.attn_dim + PROJ_DIM
        elif self.use_attention:
            input_dim = self.attn_dim
        else:
            input_dim = 3 * PROJ_DIM

        if self.use_mlp_interaction:
            self.interaction = nn.Sequential(
                nn.Linear(input_dim, MLP_OUTPUT_DIM),
                nn.ELU(alpha=1.0), # nonlinearity
                self.dropout
            )
            self.fc = nn.Linear(MLP_OUTPUT_DIM, NUM_LABELS)
        else:
            self.interaction = nn.Identity()
            self.fc = nn.Linear(input_dim, NUM_LABELS)
      
        self.crf  = CRF(NUM_LABELS)

    def forward(self, x_b, x_n, x_a, labels=None):
        if USE_EXTRA_NONLINEARITIES:
            extra_act_fn = nn.ELU(alpha=1.0)
        else:
            extra_act_fn = nn.Identity()
            
        pb = self.dropout(extra_act_fn(self.proj(x_b)))
        pn = self.dropout(extra_act_fn(self.proj(x_n)))
        pa = self.dropout(extra_act_fn(self.proj(x_a)))


        if self.use_attention: 
            # stack contexts: (S, 3, D)
            contexts = torch.stack([pb, pn, pa], dim=1)
            # project all to K, V: (S, 3, ATTN_DIM)
            keys   = self.attn_k(contexts)
            if not USE_VALUE_UP_DOWN_TO_SAVE_ATT_PARAMETERS:
                values = self.attn_v(contexts) # (S, 3, ATTN_DIM)
            else:
                v_down = extra_act_fn(self.attn_v_down(contexts)) # (S, 3, BOTTLENECK)
                values = self.attn_v_up(v_down)    

            # project current sentence only to query: (S, ATTN_DIM)
            query = self.attn_q(pn).unsqueeze(1)  # shape: (S, 1, ATTN_DIM)

            # compute attention scores: (S, 1, 3)
            attn_scores = torch.matmul(query, keys.transpose(1, 2)) / (ATTN_DIM ** 0.5)
            attn_weights = torch.softmax(attn_scores, dim=-1) # shape: (S, 1, 3)

            # weighted sum over values: (S, 1, ATTN_DIM)
            attn_vec = torch.matmul(attn_weights, values).squeeze(1) # (S, ATTN_DIM)

            attn_vec = self.dropout(extra_act_fn(attn_vec))
            
            if self.concat_now_with_attention:
                h = torch.cat([attn_vec, pn], dim=1) # (S, ATTN_DIM + PROJ_DIM)
            else:
                h = attn_vec # (S, ATTN_DIM)
        else:
            h = torch.cat([pb, pn, pa], dim=1)  # (S, 3*proj_dim)

        h = self.interaction(h) # nonlinearity built in if doing anything
        emis = self.fc(h).unsqueeze(1)         # (S, 1, num_labels)
        mask = torch.ones(emis.shape[:2], dtype=torch.bool, device=emis.device)

        if labels is not None: # loss is 1 scalar number per report
            lbl = labels.unsqueeze(1)
            return -self.crf(emis, lbl, mask=mask)
        else: # inference
            return self.crf.decode(emis, mask=mask)[0]

    def get_emissions(self, b, n, a):
        # use same code as in forward, but return the emissions
        if USE_EXTRA_NONLINEARITIES:
            extra_act_fn = nn.ELU(alpha=1.0)
        else:
            extra_act_fn = nn.Identity()
        pb = self.dropout(extra_act_fn(self.proj(b)))
        pn = self.dropout(extra_act_fn(self.proj(n)))
        pa = self.dropout(extra_act_fn(self.proj(a)))
        if self.use_attention:
            contexts = torch.stack([pb, pn, pa], dim=1)
            keys   = self.attn_k(contexts)
            if not USE_VALUE_UP_DOWN_TO_SAVE_ATT_PARAMETERS:
                values = self.attn_v(contexts)
            else:
                v_down = extra_act_fn(self.attn_v_down(contexts))
                values = self.attn_v_up(v_down)
            query = self.attn_q(pn).unsqueeze(1)
            attn_scores = torch.matmul(query, keys.transpose(1, 2)) / (ATTN_DIM ** 0.5)
            attn_weights = torch.softmax(attn_scores, dim=-1)
            attn_vec = torch.matmul(attn_weights, values).squeeze(1)
            attn_vec = self.dropout(extra_act_fn(attn_vec))
            if self.concat_now_with_attention:
                h = torch.cat([attn_vec, pn], dim=1)
            else:
                h = attn_vec
        else:
            h = torch.cat([pb, pn, pa], dim=1)
        h = self.interaction(h)
        emissions = self.fc(h).unsqueeze(1)
        return emissions  # (seq_len, 1, num_labels)


def train(p=1.0):
    with open("ml/from_classical_standardized_train_report_names.json") as f:
        train_report_names = json.load(f)
    with open("ml/from_classical_standardized_test_report_names.json") as f:
        test_report_names = json.load(f)

    all_report_dirs = sorted(
        os.path.join(DATA_ROOT, d)
        for d in os.listdir(DATA_ROOT)
        if os.path.isdir(os.path.join(DATA_ROOT, d))
    )



    train_dirs_all = [d for d in all_report_dirs if os.path.basename(d) in train_report_names]
    test_dirs      = [d for d in all_report_dirs if os.path.basename(d) in test_report_names]

    # can evaluate how good the crf does with less data
    np.random.seed(SEED)
    sample_size = max(1, int(p * len(train_dirs_all)))
    train_dirs_all = list(np.random.choice(train_dirs_all, size=sample_size, replace=False))


    print("Number of train_dirs:", len(train_dirs_all))
    print("Example train_dirs:", train_dirs_all[:5])
    missing = [d for d in train_dirs_all if not os.path.exists(d)]
    print("Missing train report directories:", missing)

    with open("ml/from_classical_standardized_train_report_names.json") as f:
        train_report_names = json.load(f)
    print(train_report_names[:5])

    print([f for f in os.listdir("ml/CRF_individual")][:5])

    np.random.seed(SEED)
    perm = np.random.permutation(len(train_dirs_all))
    #val_size = max(1, int((0.05 + (0.25 - p / 4)) * len(train_dirs_all)))
    val_size = max(1, int(((1 - p) * 5 + p * 9)))
    val_dirs  = [train_dirs_all[i] for i in perm[:val_size]]
    train_dirs = [train_dirs_all[i] for i in perm[val_size:]]

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    bert      = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE).eval()
    preprocess_embeddings(train_dirs + val_dirs + test_dirs, tokenizer, bert)

    def get_filelist_from_dirs(dirs):
        return [os.path.join(CACHE_DIR, f"{os.path.basename(d)}.pt") for d in dirs]

    train_files = get_filelist_from_dirs(train_dirs)
    val_files   = get_filelist_from_dirs(val_dirs)
    test_files  = get_filelist_from_dirs(test_dirs)

    print(f"Num train files: {len(train_files)}")
    print(f"Num val files:   {len(val_files)}")
    print(f"Num test files:  {len(test_files)}")

    ds_train = CachedReportDataset(train_files)
    ds_val   = CachedReportDataset(val_files)
    ds_test  = CachedReportDataset(test_files)

    dl_train = DataLoader(ds_train, batch_size=1, shuffle=True,  collate_fn=collate_fn)
    dl_val   = DataLoader(ds_val, batch_size=1, shuffle=False, collate_fn=collate_fn)
    dl_test  = DataLoader(ds_test, batch_size=1, shuffle=False, collate_fn=collate_fn)

    hid   = bert.config.hidden_size
    model = NeuralCRF().to(DEVICE)
    opt   = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay = L2_REG)

    best_val_loss = float('inf')
    best_state = None
    best_epoch = -1
    patience = 12 # if validation loss is more than the best validation loss for x epochs or more, stop.
    bad_epochs = 0
    ckpt_path = "ml/best_model_weights.pt"
    for epoch in range(1, EPOCHS+1):
        model.train()
        loss_acc = 0.0
        for b, n, a, y in dl_train:
            opt.zero_grad()
            l = model(b, n, a, y)
            l.backward()
            opt.step()
            loss_acc += l.item()
        avg_train_loss = loss_acc/len(dl_train)
        model.eval()
        val_acc = 0.0
        with torch.no_grad():
            for b, n, a, y in dl_val:
                val_acc += model(b, n, a, y).item()
        avg_val_loss = val_acc/len(dl_val)
        print(f"Epoch {epoch}/{EPOCHS} train_loss={avg_train_loss:.4f}  val_loss={avg_val_loss:.4f}")
        # early stopping + save best
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_state = copy.deepcopy(model.state_dict())
            torch.save(best_state, ckpt_path)  # backup to disk
            best_epoch = epoch
            bad_epochs = 0
        else:
            bad_epochs += 1
            if bad_epochs >= patience:
                print(f"Early stopping at epoch {epoch}. Best validation loss: {best_val_loss:.4f} (epoch {best_epoch})")
                break
    # restore best model weights
    if best_state is not None:
        model.load_state_dict(best_state)
    else:
        model.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))

    # evaluations
    model.eval()

    train_preds, train_labs = [], []
    with torch.no_grad():
        for b, n, a, y in dl_train:
            p = model(b, n, a)
            train_preds.extend(p)
            train_labs.extend(y.tolist())
    print("Train label distribution:", Counter(train_labs))
    print(f"\nTrain Accuracy: {accuracy_score(train_labs, train_preds):.4f}")
    cm_train = confusion_matrix(train_labs, train_preds)
    disp_train = ConfusionMatrixDisplay(confusion_matrix=cm_train)
    disp_train.plot(cmap="Blues")
    plt.title("Train Confusion Matrix")
    plt.show()
    
    preds, labs = [], []
    with torch.no_grad():
        for b, n, a, y in dl_test:
            p = model(b, n, a)
            preds.extend(p)
            labs .extend(y.tolist())
    print("Test label distribution:", Counter(labs))
    print(f"\nTest Accuracy: {accuracy_score(labs, preds):.4f}")
    print(classification_report(labs, preds, digits=4))
    cm_test = confusion_matrix(labs, preds)
    disp_test = ConfusionMatrixDisplay(confusion_matrix=cm_test)
    disp_test.plot(cmap="Blues")
    plt.title("Test Confusion Matrix")
    plt.show()
    if cm_test.shape == (2, 2):
        TN = cm_test[0, 0]
        FP = cm_test[0, 1]
        FN = cm_test[1, 0]
        TP = cm_test[1, 1]

        recall_1 = TP / (TP + FN) if (TP + FN) > 0 else 0.0
        recall_0 = TN / (TN + FP) if (TN + FP) > 0 else 0.0

        total = TN + FP + FN + TP
        actual_prop_0 = (TN + FP) / total
        actual_prop_1 = (TP + FN) / total
        print(f"Test set distribution: class 0 = {actual_prop_0:.2%}, class 1 = {actual_prop_1:.2%}")

        for target_prop_0, target_prop_1 in [(0.10, 0.90), (0.15, 0.85), (0.50, 0.50), (actual_prop_0, actual_prop_1)]:
            weighted_acc = target_prop_1 * recall_1 + target_prop_0 * recall_0
            print(f"Weighted accuracy ({target_prop_1:.0%} class 1, {target_prop_0:.0%} class 0): {weighted_acc:.4f}")
    else:
        print("\nWeighted accuracy calculation skipped: non-binary confusion matrix shape:", cm_test.shape)

    # evaluate on special report
    ABDOMEN_REPORT = "abdomenetc.csv"
    if os.path.exists(ABDOMEN_REPORT):
        df_abd = pd.read_csv(ABDOMEN_REPORT)
        sents_abd = df_abd["Sentence"].tolist()

        befores, nows, afters = [], [], []
        N = len(sents_abd)
        for i in range(N):
            tb = " ".join(sents_abd[:i])     if i > 0   else ""
            tn = sents_abd[i]
            ta = " ".join(sents_abd[i+1:])   if i < N-1 else ""
            befores.append(get_pooled_embedding(bert, tokenizer, tb, POOLING_STRATEGY))
            nows   .append(get_pooled_embedding(bert, tokenizer, tn, POOLING_STRATEGY))
            afters .append(get_pooled_embedding(bert, tokenizer, ta, POOLING_STRATEGY))

        befores = torch.stack(befores).to(DEVICE)
        nows    = torch.stack(nows).to(DEVICE)
        afters  = torch.stack(afters).to(DEVICE)

        with torch.no_grad():
            y_pred_abd = model(befores, nows, afters)

        if "Brain Related" in df_abd.columns:
            y_true_abd = df_abd["Brain Related"].values.astype(int)
            mask = y_true_abd != -1  # skip unknown labels
            acc_abd = accuracy_score(y_true_abd[mask], np.array(y_pred_abd)[mask])
            print(f"\nAccuracy on abdomenetc.csv: {acc_abd:.4f}")
        else:
            acc_abd = None
            print("\nNo 'Brain Related' column in abdomenetc.csv for accuracy scoring.")
    else:
        acc_abd = None
        print("\nadbomenetc.csv not found for evaluation.")

    if True:    
        THRESH_STEP = 0.001  # very fine grid
        OUT_THRESHOLD_CSV = "ml/neuralcrf_threshold_metrics.csv"
        OUT_THRESHOLD_NPZ = "ml/neuralcrf_test_marginals.npz"
    
        all_probs, all_labels = [], []
    
        model.eval()
        with torch.no_grad():
            for b, n, a, y in dl_test:
                emissions = model.get_emissions(b, n, a)
    
                mask = torch.ones(emissions.shape[:2], dtype=torch.bool, device=emissions.device)
    
                marg = crf_marginals(model.crf, emissions, mask)
    
                m = marg.squeeze(1)
                y_prob = m[:, 1].detach().cpu().numpy()
    
                y_true = y.detach().cpu().numpy()
    
                all_probs.append(y_prob)
                all_labels.append(y_true)
    
        probs_flat  = np.concatenate(all_probs).astype(np.float64)
        labels_flat = np.concatenate(all_labels).astype(np.int64)
    
        valid = np.isin(labels_flat, [0, 1])
        probs_flat  = probs_flat[valid]
        labels_flat = labels_flat[valid]

        roc_auc = roc_auc_score(labels_flat, probs_flat)
        print(f"ROC AUC (Neural CRF marginals on test): {roc_auc:.4f}")
    
        thresholds = np.arange(0.0, 1.0 + THRESH_STEP, THRESH_STEP)
        rows = []
        for t in thresholds:
            preds = (probs_flat >= t).astype(int)
            acc = accuracy_score(labels_flat, preds)
            f1_cls1 = f1_score(labels_flat, preds, pos_label=1, zero_division=0)
            f1_macro = f1_score(labels_flat, preds, average="macro", zero_division=0)
            rows.append((t, f1_macro, f1_cls1, acc))
    
        th_df = pd.DataFrame(rows, columns=["threshold", "MacroF1", "F1_class1", "accuracy"])
        os.makedirs(os.path.dirname(OUT_THRESHOLD_CSV), exist_ok=True)
        th_df.to_csv(OUT_THRESHOLD_CSV, index=False)
        np.savez_compressed(OUT_THRESHOLD_NPZ, probs=probs_flat, labels=labels_flat)
    
        print(f"Saved CRF marginal threshold sweep to {OUT_THRESHOLD_CSV}")
        print(f"Saved raw probs/labels to {OUT_THRESHOLD_NPZ}")

    
    if True:
        from sklearn.metrics import roc_curve, auc
    
        #SYNTHETIC_DIR = "MIMIC/real_individual_reports"
        SYNTHETIC_DIR = "MIMIC/synthetic_bundled_reports_deterministic"
        synthetic_csvs = [os.path.join(SYNTHETIC_DIR, f) for f in os.listdir(SYNTHETIC_DIR) if f.endswith('.csv')]
        
        all_probs = []
        all_labels = []
        all_preds = []
    
        for csv_path in synthetic_csvs:
            df = pd.read_csv(csv_path)
            sents = df["Sentence"].tolist()
            # for each sentence, build befores/nows/afters as before
            befores, nows, afters = [], [], []
            N = len(sents)
            for i in range(N):
                tb = " ".join(sents[:i])     if i > 0   else ""
                tn = sents[i]
                ta = " ".join(sents[i+1:])   if i < N-1 else ""
                befores.append(get_pooled_embedding(bert, tokenizer, tb, POOLING_STRATEGY))
                nows   .append(get_pooled_embedding(bert, tokenizer, tn, POOLING_STRATEGY))
                afters .append(get_pooled_embedding(bert, tokenizer, ta, POOLING_STRATEGY))
            befores = torch.stack(befores).to(DEVICE)
            nows    = torch.stack(nows).to(DEVICE)
            afters  = torch.stack(afters).to(DEVICE)
    
            with torch.no_grad():
                emissions = model.get_emissions(befores, nows, afters)
                mask = torch.ones((N, 1), dtype=torch.bool, device=DEVICE)
                marginals = crf_marginals(model.crf, emissions, mask)
                m = marginals.squeeze(1)
                y_prob = m[:, 1].cpu().numpy()
                # also get hard predictions from the marginals
                y_pred = y_prob >= 0.5
    
            if "Brain Related" in df.columns:
                y_true = df["Brain Related"].values.astype(int)
                mask_ = y_true != -1
                if mask_.sum() > 0:
                    acc = accuracy_score(y_true[mask_], y_pred[mask_])
                    print(f"File: {os.path.basename(csv_path):45s} | Accuracy: {acc:.4f}")
                    all_labels.append(y_true[mask_])
                    all_preds.append(y_pred[mask_])
                    all_probs.append(y_prob[mask_])
            else:
                print(f"File: {os.path.basename(csv_path):45s} | No 'Brain Related' labels found.")
    
        if all_labels:
            all_labels = np.concatenate(all_labels)
            all_probs = np.concatenate(all_probs)
            all_preds = np.concatenate(all_preds)

            np.save("neuralcrf_mimic_probabilities.npy", all_probs)
    
            accuracy = accuracy_score(all_labels, all_preds)
            print(f"Overall accuracy: {accuracy:.4f}")
            if np.unique(all_labels).size > 1:
                fpr, tpr, _ = roc_curve(all_labels, all_probs)
                roc_auc = auc(fpr, tpr)
                plt.figure(figsize=(7,7))
                plt.plot(fpr, tpr, label=f'Neural CRF (AUC = {roc_auc:.4f})', linewidth=2)
                plt.plot([0,1], [0,1], 'k--', label='Random')
                plt.xlabel('False Positive Rate')
                plt.ylabel('True Positive Rate')
                plt.title('ROC Curve: Synthetic Reports')
                plt.legend(loc='lower right')
                plt.grid(True)
                plt.tight_layout()
                plt.show()
            else:
                print("Not enough class diversity to plot ROC.")
        else:
            print("No ground truth labels found in any synthetic report for ROC.")
    
    return accuracy_score(labs, preds), acc_abd, *[classification_report(labs, preds, output_dict=True, zero_division=0)[c][m] for c, m in [("0", "precision"), ("0", "recall"), ("1", "precision"), ("1", "recall")]]

if __name__ == "__main__":
    x = train()


In [None]:
import os
import torch
import pandas as pd
import numpy as np
import json
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import random
from sklearn.metrics import roc_curve, auc


for SEED in range(10):
    print(f"\n==== Running seed {SEED} ====")
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    g = torch.Generator()
    g.manual_seed(SEED)
    
    
    MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    DATA_PATH = "ml/CRF_individual"
    MAX_LEN = 128
    BATCH_SIZE = 8
    LR = 5e-5
    NUM_EPOCHS = 8
    PATIENCE = 3
    NUM_CLASSES = 2
    
    with open("ml/standardized_train_report_names.json") as f:
        train_reports = set(json.load(f))
    with open("ml/standardized_test_report_names.json") as f:
        test_reports = set(json.load(f))
    
    all_folders = sorted(os.listdir(DATA_PATH))
    train_folders = [f for f in all_folders if f in train_reports]
    test_folders  = [f for f in all_folders if f in test_reports]
    
    np.random.seed(SEED)
    perm = np.random.permutation(len(train_folders))
    val_size = max(1, int(0.05 * len(train_folders)))
    val_folders = [train_folders[i] for i in perm[:val_size]]
    train_folders = [train_folders[i] for i in perm[val_size:]]
    
    splits = [
        ('train', train_folders),
        ('val',   val_folders),
        ('test',  test_folders)
    ]
    
    class SentenceDataset(Dataset):
        def __init__(self, folders, data_path, tokenizer, max_len):
            self.samples = []
            for folder in folders:
                path = os.path.join(data_path, folder, "report.csv")
                if not os.path.exists(path): continue
                df = pd.read_csv(path)
                for _, row in df.iterrows():
                    self.samples.append((row["Sentence"], int(row["Brain Related"])))
            self.tokenizer = tokenizer
            self.max_len = max_len
    
        def __len__(self):
            return len(self.samples)
        
        def __getitem__(self, idx):
            sentence, label = self.samples[idx]
            tokens = self.tokenizer(
                sentence, padding='max_length', truncation=True,
                max_length=self.max_len, return_tensors='pt'
            )
            return {
                "input_ids": tokens["input_ids"].squeeze(0),
                "attention_mask": tokens["attention_mask"].squeeze(0),
                "label": torch.tensor(label, dtype=torch.long)
            }
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    
    datasets = {}
    print("Loading data")
    for split, folders in splits:
        datasets[split] = SentenceDataset(folders, DATA_PATH, tokenizer, MAX_LEN)
    
    loaders = {
        split: DataLoader(ds, batch_size=BATCH_SIZE, shuffle=(split=="train"))
        for split, ds in datasets.items()
    }
    
    class BERTClassifier(nn.Module):
        def __init__(self, model_name, num_classes):
            super().__init__()
            self.bert = AutoModel.from_pretrained(model_name)
            self.linear = nn.Linear(self.bert.config.hidden_size, num_classes)
    
        def forward(self, input_ids, attention_mask):
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            # [CLS] token is at index 0
            cls_emb = outputs.last_hidden_state[:, 0, :]
            logits = self.linear(cls_emb)
            return logits
    
    model = BERTClassifier(MODEL_NAME, NUM_CLASSES).to(DEVICE)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LR)
    
    best_val_loss = float('inf')
    bad_epochs = 0
    best_state = None
    print("Starting training")
    for epoch in range(1, NUM_EPOCHS+1):
        model.train()
        for batch in loaders['train']:
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["label"].to(DEVICE)
            optimizer.zero_grad()
            logits = model(input_ids, attention_mask)
            loss = loss_fn(logits, labels)
            loss.backward()
            optimizer.step()
    
        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for batch in loaders['val']:
                input_ids = batch["input_ids"].to(DEVICE)
                attention_mask = batch["attention_mask"].to(DEVICE)
                labels = batch["label"].to(DEVICE)
                logits = model(input_ids, attention_mask)
                loss = loss_fn(logits, labels)
                total_val_loss += loss.item() * input_ids.size(0)
        avg_val_loss = total_val_loss / len(loaders['val'].dataset)
        print(f"Epoch {epoch}, val loss: {avg_val_loss:.4f}")
    
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_state = model.state_dict()
            bad_epochs = 0
        else:
            bad_epochs += 1
            if bad_epochs >= PATIENCE:
                print(f"Early stopping at epoch {epoch}")
                break
    
    if best_state is not None:
        model.load_state_dict(best_state)
    
    torch.save(model.state_dict(), "unfrozen_bert_classifier_best.pt")
    
    
    for split in ["train", "val", "test"]:
        model.eval()
        all_preds, all_labels = [], []
        all_probs2 = []
        with torch.no_grad():
            for batch in loaders[split]:
                input_ids = batch["input_ids"].to(DEVICE)
                attention_mask = batch["attention_mask"].to(DEVICE)
                labels = batch["label"].cpu().numpy()
                logits = model(input_ids, attention_mask)
                preds = torch.argmax(logits, dim=1).cpu().numpy()
                probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
                all_preds.extend(preds)
                all_labels.extend(labels)
                all_probs2.extend(probs)
        print(f"\n[{split.upper()}] BERT classifier")
        print("Accuracy:", accuracy_score(all_labels, all_preds))
        print(classification_report(all_labels, all_preds, digits=4))
        if split == "test":
            if len(np.unique(all_labels)) > 1:
                test_auc = roc_auc_score(all_labels, all_probs2)
                print(f"Test ROC-AUC: {test_auc:.4f}")

                if SEED == 0:
                    test_labels_once = np.array([lbl for _, lbl in datasets['test'].samples], dtype=np.int32)
                    np.save("ml/bert_test_labels.npy", test_labels_once)
                    print("Saved test labels to ml/bert_test_labels.npy")
                    
                np.savez(f"ml/bert_test_scores_labels_seed{SEED}.npz",
                         scores=np.asarray(all_probs2, dtype=np.float32),
                         labels=np.asarray(all_labels, dtype=np.int32))
    
                fpr, tpr, _ = roc_curve(all_labels, all_probs2)
                plt.figure(figsize=(6, 6))
                plt.plot(fpr, tpr, lw=2, label=f"ROC (AUC={test_auc:.4f})")
                plt.plot([0, 1], [0, 1], "k--", lw=1)
                plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
                plt.title("BERT Test ROC")
                plt.legend(loc="lower right"); plt.grid(True, alpha=0.3); plt.tight_layout()
                plt.savefig(f"ml/bert_test_roc_seed{SEED}.png", dpi=300, bbox_inches="tight")
                plt.close()
            else:
                print("Test AUC undefined (only one class present).")
        
    from sklearn.metrics import accuracy_score, classification_report, roc_curve, auc, ConfusionMatrixDisplay
    import matplotlib.pyplot as plt
    import os
    
    SYNTHETIC_DIR = "MIMIC/synthetic_bundled_reports_deterministic"
    synthetic_csvs = [os.path.join(SYNTHETIC_DIR, f) for f in os.listdir(SYNTHETIC_DIR) if f.endswith('.csv')]
    
    all_probs = []
    all_labels = []
    all_preds = []
    
    for csv_path in synthetic_csvs:
        df = pd.read_csv(csv_path)
        sents = df["Sentence"].tolist()
    
        batch = tokenizer(
            sents,
            padding=True,
            truncation=True,
            max_length=128,
            return_tensors='pt'
        )
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
    
        with torch.no_grad():
            logits = model(input_ids, attention_mask)
            probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
            preds = (probs >= 0.5).astype(int)
    
        if "Brain Related" in df.columns:
            y_true = df["Brain Related"].values.astype(int)
            mask = y_true != -1
            if mask.sum() > 0:
                acc = accuracy_score(y_true[mask], preds[mask])
                print(f"File: {os.path.basename(csv_path):45s} | Accuracy: {acc:.4f}")
                all_labels.append(y_true[mask])
                all_preds.append(preds[mask])
                all_probs.append(probs[mask])
        else:
            print(f"File: {os.path.basename(csv_path):45s} | No 'Brain Related' labels found.")
    
    if all_labels:
        all_labels = np.concatenate(all_labels)
        all_preds = np.concatenate(all_preds)
        all_probs = np.concatenate(all_probs)
    
        acc = accuracy_score(all_labels, all_preds)
        print(f"\nOverall accuracy: {acc:.4f}")
        print(classification_report(all_labels, all_preds, digits=4))
    
        if np.unique(all_labels).size > 1:
            fpr, tpr, _ = roc_curve(all_labels, all_probs)
            roc_auc = auc(fpr, tpr)
            plt.figure(figsize=(7,7))
            plt.plot(fpr, tpr, label=f'BERT (AUC = {roc_auc:.4f})', linewidth=2)
            plt.plot([0,1], [0,1], 'k--', label='Random')
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.title('ROC Curve: Synthetic MIMIC Reports (BERT)')
            plt.legend(loc='lower right')
            plt.grid(True)
            plt.tight_layout()
            plt.show()
    else:
        print("No ground truth labels found in any synthetic report for ROC.")
    
    np.save("sota_mimic_probabilities.npy", all_probs)
    np.save("mimic_labels.npy", all_labels)

In [None]:
import os
import torch
import pandas as pd
import numpy as np
import json
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import LoraConfig, get_peft_model, TaskType
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATA_PATH = "ml/CRF_individual"
MAX_LEN = 128
BATCH_SIZE = 8
LR = 5e-5
NUM_EPOCHS = 8
PATIENCE = 3
NUM_CLASSES = 2

with open("ml/standardized_train_report_names.json") as f:
    train_reports = set(json.load(f))
with open("ml/standardized_test_report_names.json") as f:
    test_reports = set(json.load(f))

all_folders = sorted(os.listdir(DATA_PATH))
train_folders = [f for f in all_folders if f in train_reports]
test_folders  = [f for f in all_folders if f in test_reports]

np.random.seed(42)
perm = np.random.permutation(len(train_folders))
val_size = max(1, int(0.05 * len(train_folders)))
val_folders = [train_folders[i] for i in perm[:val_size]]
train_folders = [train_folders[i] for i in perm[val_size:]]

splits = [
    ('train', train_folders),
    ('val',   val_folders),
    ('test',  test_folders)
]

class SentenceDataset(Dataset):
    def __init__(self, folders, data_path, tokenizer, max_len):
        self.samples = []
        for folder in folders:
            path = os.path.join(data_path, folder, "report.csv")
            if not os.path.exists(path): continue
            df = pd.read_csv(path)
            for _, row in df.iterrows():
                self.samples.append((row["Sentence"], int(row["Brain Related"])))
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sentence, label = self.samples[idx]
        tokens = self.tokenizer(
            sentence, padding='max_length', truncation=True,
            max_length=self.max_len, return_tensors='pt'
        )
        return {
            "input_ids": tokens["input_ids"].squeeze(0),
            "attention_mask": tokens["attention_mask"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long)
        }

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

datasets = {}
print("Loading data")
for split, folders in splits:
    datasets[split] = SentenceDataset(folders, DATA_PATH, tokenizer, MAX_LEN)

loaders = {
    split: DataLoader(ds, batch_size=BATCH_SIZE, shuffle=(split=="train"))
    for split, ds in datasets.items()
}

lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=16,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["query", "value"]
)

base_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=NUM_CLASSES
)
model = get_peft_model(base_model, lora_config)
model = model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)

best_val_loss = float('inf')
bad_epochs = 0
best_state = None
print("Starting training with LoRA")
for epoch in range(1, NUM_EPOCHS+1):
    model.train()
    for batch in loaders['train']:
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

    model.eval()
    total_val_loss = 0.0
    with torch.no_grad():
        for batch in loaders['val']:
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            total_val_loss += loss.item() * input_ids.size(0)
    avg_val_loss = total_val_loss / len(loaders['val'].dataset)
    print(f"Epoch {epoch}, val loss: {avg_val_loss:.4f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_state = model.state_dict()
        bad_epochs = 0
    else:
        bad_epochs += 1
        if bad_epochs >= PATIENCE:
            print(f"Early stopping at epoch {epoch}")
            break

if best_state is not None:
    model.load_state_dict(best_state)

for split in ["train", "val", "test"]:
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in loaders[split]:
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].cpu().numpy()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs.logits, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels)
    print(f"\n[{split.upper()}] BERT classifier (LoRA)")
    print("Accuracy:", accuracy_score(all_labels, all_preds))
    print(classification_report(all_labels, all_preds, digits=4))
    if split == "test":
        cm = confusion_matrix(all_labels, all_preds)
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["not brain", "brain"])
        fig, ax = plt.subplots(figsize=(5,5))
        disp.plot(ax=ax, cmap="Blues", values_format='d')
        plt.title("BERT Classifier Test Confusion Matrix (LoRA)")
        plt.grid(False)
        plt.show()
