In [None]:
# Named Entity Recognition for German MRI Reports (Synthetic Example)

from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import DataCollatorForTokenClassification
from datasets import Dataset, load_metric
import numpy as np
import pandas as pd
import random
import torch

# ----------------------------
# 1. Generate Synthetic Reports
# ----------------------------
synthetic_reports = [
    "Patient erhielt intraarteriell 10 mg rtPA unter Vollnarkose mittels Stent Retriever.",
    "Durchgeführt wurde eine Aspiration ohne Komplikationen, rtPA nicht gegeben.",
    "Urokinase intraarteriell in Dosis von 5 mg verwendet, gefolgt von Ballonangioplastie.",
    "Kein Stent, keine Anästhesie, keine EVT-Komplikationen.",
    "Thrombektomie mit Distal Retriever durchgeführt."
]

labels = [
    ["O", "O", "B-Intraarterial_rtPA", "B-rtPA_dose", "I-rtPA_dose", "I-rtPA_dose", "O", "B-Anesthesia", "O", "B-Stent_Retriever", "O"],
    ["B-Mechanical_Treatment", "O", "O", "O", "O", "B-Intraarterial_rtPA", "O", "O"],
    ["B-Intraarterial_Urokinase", "O", "O", "B-Urokinase_dose", "I-Urokinase_dose", "O", "O", "B-Balloon_Angioplasty", "O"],
    ["O", "B-Anesthesia", "O", "O", "B-EVT_Complications", "O"],
    ["B-Mechanical_Treatment", "O", "B-Distal_Retriever", "O"]
]

# ----------------------------
# 2. Tokenizer & Label Mapping
# ----------------------------
model_checkpoint = "bert-base-german-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
label_list = list({tag for sent in labels for tag in sent if tag != 'O'})
label_list = ['O'] + sorted(label_list)
label2id = {l: i for i, l in enumerate(label_list)}
id2label = {i: l for l, i in label2id.items()}

# Tokenize and align labels
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
    labels_aligned = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label2id[label[word_idx]])
            else:
                label_ids.append(label2id[label[word_idx]] if label[word_idx].startswith("I-") else label2id[label[word_idx]])
            previous_word_idx = word_idx
        labels_aligned.append(label_ids)
    tokenized_inputs["labels"] = labels_aligned
    return tokenized_inputs

# ----------------------------
# 3. Prepare Dataset
# ----------------------------
df = pd.DataFrame({"tokens": [r.split() for r in synthetic_reports], "ner_tags": labels})
dataset = Dataset.from_pandas(df)
tokenized_dataset = dataset.map(tokenize_and_align_labels, batched=True)

# ----------------------------
# 4. Load Model & Train
# ----------------------------
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list), id2label=id2label, label2id=label2id)
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

args = TrainingArguments(
    "ner-mri-german",
    evaluation_strategy="no",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    num_train_epochs=5,
    weight_decay=0.01,
    logging_dir="./logs",
    report_to="none"
)

trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator
)

trainer.train()

# ----------------------------
# 5. Predict on New Text
# ----------------------------
text = "Der Patient erhielt intraarteriell 10 mg rtPA und eine Aspiration wurde durchgeführt."
tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(text)))
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs).logits
predictions = torch.argmax(outputs, dim=2)

print("\nPredicted Entities:")
for token, pred_id in zip(tokens, predictions[0][1:-1]):  # skip [CLS] and [SEP]
    label = id2label[pred_id.item()]
    if label != "O":
        print(f"{token} -> {label}")
