In [6]:
from itertools import combinations

import dill as pickle
import evaluate
import numpy as np
import pandas as pd
import scipy.sparse as sp
import torch
from datasets import Dataset
from gensim.models.keyedvectors import KeyedVectors
from ipymarkup import show_span_line_markup
from more_itertools import chunked
from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model
from sentence_transformers import InputExample, SentenceTransformer, losses, models
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    pipeline,
)

from snomed_graph import *
from constants import id2label, label2id


In [13]:
random_seed = 10  # For reproducibility
max_seq_len = 512  # Maximum sequence length for (BERT-based) encoders
cer_model_path = "best_KISTI-AI-scideberta-cs-noLoRA"
kb_embedding_model_id = ("sentence-transformers/all-MiniLM-L6-v2") # base model for concept encoder
use_LoRA = False  # Whether to use a LoRA to fine-tune the CER model
torch.manual_seed(random_seed)
assert torch.cuda.is_available()

In [10]:
if use_LoRA:
    config = PeftConfig.from_pretrained(cer_model_path)

    cer_model = AutoModelForTokenClassification.from_pretrained(
        pretrained_model_name_or_path=config.base_model_name_or_path,
        num_labels=3,
        id2label=id2label,
        label2id=label2id,
    )
    cer_model = PeftModel.from_pretrained(cer_model, cer_model_path)
else:
    cer_model = AutoModelForTokenClassification.from_pretrained(
        pretrained_model_name_or_path=cer_model_path,
        num_labels=3,
        id2label=id2label,
        label2id=label2id,
    )

cer_tokenizer = AutoTokenizer.from_pretrained(cer_model_path)

In [11]:
# If using the adaptor, ignore the warning:
# "The model 'PeftModelForTokenClassification' is not supported for token-classification."
# The PEFT model is wrapped just fine and will work within the pipeline.
# N.B. moving model to CPU makes inference slower, but enables us to feed the pipeline
# directly with strings.
cer_pipeline = pipeline(
    task="token-classification",
    model=cer_model,
    tokenizer=cer_tokenizer,
    aggregation_strategy="first",
    device="cpu",
)

In [14]:
notes_df = pd.read_csv("data/training_notes.csv").set_index("note_id")
annotations_df = pd.read_csv("data/train_annotations.csv").set_index("note_id")

training_notes_df, test_notes_df = train_test_split(
    notes_df, test_size=32, random_state=random_seed
)
test_annotations_df = annotations_df.loc[test_notes_df.index]

In [16]:
test_annotations_df

Unnamed: 0_level_0,start,end,concept_id
note_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
12204158-DS-10,182,193,91936005
12204158-DS-10,196,201,1003755004
12204158-DS-10,239,258,267036007
12204158-DS-10,299,312,91602002
12204158-DS-10,318,328,264957007
...,...,...,...
12986424-DS-6,4697,4705,368208006
12986424-DS-6,4707,4714,439470001
12986424-DS-6,4747,4757,247347003
12986424-DS-6,4823,4828,72670004


In [None]:
note_id = "12986424-DS-6"
text = test_notes_df.loc[note_id].text[512:]

# +1 to offset the [CLS] token which will have been added by the tokenizer
predicted_annotations = [
    (span["start"] + 1, span["end"], "PRED") for span in cer_pipeline(text)
]

gt_annotations = [
    (row.start-512, row.end-512, "GT") if row.start > 512 else None for row in test_annotations_df.loc[note_id].itertuples()
]

gt_annotations = list(set(gt_annotations))

gt_annotations.remove(None)

show_span_line_markup(text, predicted_annotations + gt_annotations)

In [24]:
# Visualise the predicted clinical entities against the actual annotated entities.
# N.B. only the first 512 tokens of the note will contain predicted spans.
# Not run due to sensitivity of MIMIC-IV notes


note_id = "12986424-DS-6"
text = test_notes_df.loc[note_id].text

# +1 to offset the [CLS] token which will have been added by the tokenizer
predicted_annotations = [
    (span["start"] + 1, span["end"], "PRED") for span in cer_pipeline(text)
]

gt_annotations = [
    (row.start, row.end, "GT") for row in test_annotations_df.loc[note_id].itertuples()
]

show_span_line_markup(text, predicted_annotations + gt_annotations)