In [7]:
from itertools import combinations

import dill as pickle
import evaluate
import numpy as np
import pandas as pd
import scipy.sparse as sp
import torch
from datasets import Dataset
from gensim.models.keyedvectors import KeyedVectors
from ipymarkup import show_span_line_markup
from more_itertools import chunked
from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model
from sentence_transformers import InputExample, SentenceTransformer, losses, models
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from transformers import (
    AutoTokenizer,
    DataCollatorForTokenClassification,
    DebertaV2ForTokenClassification,
    Trainer,
    TrainingArguments,
    AutoModel,
    pipeline,
)
from utils import load_notes, load_annotations
from snomed_graph import *

In [2]:
SG = SnomedGraph.from_serialized("./data/full_concept_graph.gml")

SNOMED graph has 361179 vertices and 1179749 edges


In [3]:
concepts_in_scope = (
    SG.get_descendants(71388002)
    | SG.get_descendants(123037004)
    | SG.get_descendants(404684003)
)
print(f"{len(concepts_in_scope)} concepts have been selected.")

219169 concepts have been selected.


In [5]:
len(concepts_in_scope)

219169

In [10]:
annotations_df = load_annotations()
notes_df = load_notes()

In [20]:
cid_to_synonyms = dict()
for note_id, annot in tqdm(annotations_df.iterrows()):
    note = notes_df.loc[note_id].text
    entity = note[annot["start"]:annot["end"]]
    cid = annot["concept_id"]
    _map = cid_to_synonyms.get(cid, set())
    _map.add(entity)
    cid_to_synonyms[cid] = _map


0it [00:00, ?it/s]

In [21]:
len(cid_to_synonyms.keys())

5336

In [26]:
for concept in tqdm(concepts_in_scope):
    tmp = concept.synonyms
    concept.synonyms = list(cid_to_synonyms.get(concept.sctid, set())) + tmp

  0%|          | 0/219169 [00:00<?, ?it/s]

In [27]:
# kb_embedding_model_id = ("sentence-transformers/all-mpnet-base-v2") # base model for concept encoder
kb_embedding_model_id = ("sentence-transformers/all-MiniLM-L6-v2")
kb_model = SentenceTransformer(kb_embedding_model_id)



kb_sft_examples = [
    InputExample(texts=[syn1, syn2], label=1)
    for concept in tqdm(concepts_in_scope)
    for syn1, syn2 in combinations(concept.synonyms, 2)
]



kb_sft_dataloader = DataLoader(kb_sft_examples, shuffle=True, batch_size=48)

kb_sft_loss = losses.ContrastiveLoss(kb_model)

kb_model.fit(
    train_objectives=[(kb_sft_dataloader, kb_sft_loss)],
    epochs=2,
    warmup_steps=100,
    checkpoint_path="temp/ke_encoder",
)

kb_model.save("kb_model")

  return self.fget.__get__(instance, owner)()


  0%|          | 0/219169 [00:00<?, ?it/s]

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Iteration:   0%|          | 0/7015 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [18]:
SG.get_descendants(71388002, steps_removed=2)

{78635008 | Repositioning of aberrant renal vessels (procedure),
 710144001 | Promotion of diversional therapy (procedure),
 409088002 | Tube care: endotracheal (procedure),
 30208000 | Implantation of electronic stimulator to ureter (procedure),
 171520005 | Insertion of antisyphon device into ventricular shunt (procedure),
 107733003 | Introduction procedure (procedure),
 410538000 | Scheduling (procedure),
 89002000 | Therapeutic ultrasound (regime/therapy),
 448938001 | Preparation of smear (procedure),
 243114000 | Support (regime/therapy),
 7082004 | Reduction of fracture of hand with internal fixation (procedure),
 171434007 | Health promotion (regime/therapy),
 359551000 | Angioplasty of vein (procedure),
 431530008 | Percutaneous brachytherapy of hepatic artery using fluoroscopic guidance with contrast (procedure),
 765055004 | Implantation of intracardiac electrode (procedure),
 31359006 | Special blood bank procedure, explain by report (procedure),
 410751007 | Procedure on 

In [41]:
SG.get_ancestors(71388002,1)

set()

In [40]:
deets = SG.get_concept_details(71388002)
deets.synonyms

['Procedure']

In [25]:
len(SG.get_children(71388002))

70

In [27]:
SG.get_descendants(71388002, steps_removed=1)

{10226009 | Disposal of radioactive source (procedure),
 102986000 | Straight line walking test (procedure),
 108252007 | Laboratory procedure (procedure),
 1263452006 | Anesthesia and/or sedation procedure (procedure),
 127777001 | Provider-specific procedure (procedure),
 128927009 | Procedure by method (procedure),
 133858001 | Preoperative procedure (procedure),
 14734007 | Administrative procedure (procedure),
 164773008 | Sleep studies (procedure),
 176837007 | Intrauterine contraceptive device procedure (procedure),
 185316007 | Indirect encounter (procedure),
 223490009 | Appliance procedures (procedure),
 225190004 | Stoma appliance procedure (procedure),
 225205005 | Procedures relating to eating and drinking (procedure),
 225214000 | Procedures relating to control, restraint, seclusion and segregation (procedure),
 225288009 | Environmental care procedure (procedure),
 225299006 | Equipment-related procedure (procedure),
 225424005 | Protection procedure (procedure),
 225430

In [44]:
concepts_list = list(concepts_in_scope)

In [54]:
SG.get_children(concepts_list[0].sctid)

[245006005 | Entire fourth dorsal interosseous of hand (body structure)]

In [53]:
concepts_list[0].fsn

'Structure of fourth dorsal interosseous muscle of hand (body structure)'

In [49]:
SG.get_ancestors(concepts_list[0].sctid)

{110540008 | Muscle structure of hand (body structure),
 113343008 | Body organ structure (body structure),
 118497009 | Structure of intrinsic muscle of hand (body structure),
 120573002 | Extremity part (body structure),
 120574008 | Upper extremity part (body structure),
 120577001 | Hand part (body structure),
 123037004 | Body structure (body structure),
 127948008 | Structure of region of upper extremity (body structure),
 127954009 | Skeletal muscle structure (body structure),
 128262006 | Upper body structure (body structure),
 26107004 | Structure of musculoskeletal system (body structure),
 265803009 | Structure of muscle and/or tendon within hand (body structure),
 280440001 | Regional skeletal muscle structure (body structure),
 281242000 | Musculoskeletal structure of limb (body structure),
 281243005 | Musculoskeletal structure of upper limb (body structure),
 302158005 | Musculoskeletal structure of hand (body structure),
 303756009 | Structure of muscle acting on metaca

In [31]:
# count the average number of children for each concept
child_counts = [len(SG.get_children(concept.sctid)) for concept in concepts_in_scope]
child_mean = np.mean(child_counts)

# count the average number of descendants for each concept
descendant_counts = [len(SG.get_descendants(concept.sctid)) for concept in concepts_in_scope]
descd_mean = np.mean(descendant_counts)

In [55]:
from utils import load_notes, load_annotations

In [56]:
notes_df = load_notes()
annotations_df = load_annotations()

In [61]:
for cid in annotations_df["concept_id"]:
    print(SG.get_concept_details(cid))

91936005 | Allergy to penicillin (finding)
95563007 | Gallstone pancreatitis (disorder)
45595009 | Laparoscopic cholecystectomy (procedure)
95563007 | Gallstone pancreatitis (disorder)
1835003 | Necrosis of pancreas (disorder)
310244003 | Nasojejunal feeding (regime/therapy)
19387007 | Ectopic pancreas (disorder)
57653000 | Multiple organ failure (disorder)
268910001 | Patient's condition improved (finding)
13467000 | Pseudocyst (morphologic abnormality)
56783008 | Incision AND drainage (procedure)
122865005 | Gastrointestinal tract structure (body structure)
737492002 | Outpatient care management (procedure)
38102005 | Cholecystectomy (procedure)
223482009 | Discussion (procedure)
38102005 | Cholecystectomy (procedure)
84089009 | Hiatal hernia (disorder)
32849002 | Esophageal structure (body structure)
30811009 | Ulcer of esophagus (disorder)
48694002 | Anxiety (finding)
161891005 | Backache (finding)
94391008 | Metastatic malignant neoplasm to lung (disorder)
438949009 | Alive (findi

In [169]:
# get all note snippets for which a certain concept id appears in

cid = 281789004
cw = 20
print(SG.get_concept_details(cid))
texts = dict()
count = 0
for index, ann in annotations_df.iterrows():
    if cid == ann["concept_id"]:
        count += 1
        #fetch note that corresponds to this id
        note = notes_df.loc[index]
        if len(note) != 1:
            print('SOMETHINGWRNG')
        note = note.iloc[0]
        start = max(ann["start"]-cw, 0)
        end = min(ann["end"]+cw, len(note)-1)
        note_text = note[start:end]
        note_item = note[ann["start"]:ann["end"]]
        print(note_text)
        if note_item in texts:
            texts[note_item] += 1
        else:
            texts[note_item] = 1
        print("*"*10)

281789004 | Antibiotic therapy (procedure)
my. She had been on ciprofloxacin 250 mg PO BID 
prio
**********
ransitioned back to ciprofloxacin 250 mg PO 
BID. She
**********
 previous course of ciprofloxacin 
after discharge.


**********
piration) but these antibiotics were 
discontinued 
**********
dary to 
infection, antibiotics (Bactrim or meropen
**********
ropenem), or other 
medication effect. OSH records
**********
You were started on antibiotics and will continue t
**********
d with diuresis and antibiotic 
course.  Her leuko
**********
  She improved with antibiotics and 
defervesced.  
**********
with management 
of antibiotics, diuresis, and BP/r
**********
 after diuresis and antibiotic treatment.  
Pulmon
**********
with two course of 
antibiotics which improved her 
**********
ted you with lasix, antibiotics, and inhalers.  You
**********
d determination of 
antibiotic duration.  Patient'
**********
on of metabolism by antibiotics. Temporarily 
disco
**********
nd determi

In [170]:
print("Synonyms: ",SG.get_concept_details(cid).synonyms)
count, texts

Synonyms:  ['Antibiotic therapy']


(215,
 {'ciprofloxacin': 4,
  'antibiotics': 87,
  'medication': 4,
  'antibiotic': 19,
  'Abx': 2,
  'antibiotic therapy': 2,
  'azithromycin': 10,
  'azithromcyin': 1,
  'Vancomycin': 10,
  'Levofloxacin': 2,
  'Vanc': 1,
  'Ciprofloxacin': 1,
  'Augmentin': 3,
  'levofloxacin': 15,
  'treated with antibiotics': 5,
  'vanc': 5,
  'CIPROFLOXACIN': 1,
  'LEVOFLOXACIN': 1,
  'on Levo/Flagyl': 1,
  'Azithromycin': 1,
  'vancomycin': 19,
  'antibiotic \ntherapy': 1,
  'vanco': 2,
  'Antibiotic': 1,
  'antiobiotics': 1,
  'levofloxavin': 1,
  'augmentin': 1,
  'Antibiotics': 3,
  'treated \nwith antibiotics': 1,
  'Bactrim': 1,
  'Ampicillin': 1,
  'ANTIBIOTICS': 1,
  'ABX': 1,
  'medicatio': 1,
  's \nfor 4 d': 1,
  'abx': 2,
  'Keflex': 2})

In [62]:
# Get what the main categories are for this labelling exercise
concepts_in_scope = (
    71388002,
    123037004,
    404684003,
)

In [67]:
SG.get_concept_details(1162928000)

1162928000 | Acute myeloid leukemia (morphologic abnormality)

In [179]:
SG.get_ancestors(1163439000)

{108369006 | Neoplasm (morphologic abnormality),
 1162768007 | Leukemia (morphologic abnormality),
 118956008 | Body structure, altered from its original anatomical structure (morphologic abnormality),
 123037004 | Body structure (body structure),
 1240414004 | Malignant neoplasm (morphologic abnormality),
 30217000 | Proliferation (morphologic abnormality),
 400177003 | Neoplasm and/or hamartoma (morphologic abnormality),
 414388001 | Hematopoietic neoplasm (morphologic abnormality),
 414644002 | Malignant hematopoietic neoplasm (morphologic abnormality),
 4147007 | Mass (morphologic abnormality),
 414792005 | Myeloid neoplasm (morphologic abnormality),
 414794006 | Myeloid proliferation (morphologic abnormality),
 415181008 | Proliferation of hematopoietic cell type (morphologic abnormality),
 416939005 | Proliferative mass (morphologic abnormality),
 49755003 | Morphologically abnormal structure (morphologic abnormality),
 52988006 | Lesion (morphologic abnormality),
 57697001 | Gro

In [68]:
SG.get_ancestors(108369006)

{118956008 | Body structure, altered from its original anatomical structure (morphologic abnormality),
 123037004 | Body structure (body structure),
 30217000 | Proliferation (morphologic abnormality),
 400177003 | Neoplasm and/or hamartoma (morphologic abnormality),
 4147007 | Mass (morphologic abnormality),
 416939005 | Proliferative mass (morphologic abnormality),
 49755003 | Morphologically abnormal structure (morphologic abnormality),
 52988006 | Lesion (morphologic abnormality),
 57697001 | Growth alteration (morphologic abnormality)}

In [69]:
SG.get_ancestors(118956008)

{123037004 | Body structure (body structure)}

In [64]:
for con in concepts_in_scope:
    print(SG.get_concept_details(con))

71388002 | Procedure (procedure)
123037004 | Body structure (body structure)
404684003 | Clinical finding (finding)


In [184]:
class Linker:
    def __init__(self, encoder, context_window_width=0):
        self.encoder = encoder
        self.entity_index = KeyedVectors(self.encoder[1].word_embedding_dimension)
        self.context_index = dict()
        self.history = dict()
        self.context_window_width = context_window_width

    def add_context(self, row):
        window_start = max(0, row.start - self.context_window_width)
        window_end = min(row.end + self.context_window_width, len(row.text))
        return row.text[window_start:window_end]

    def add_entity(self, row):
        return row.text[row.start : row.end]

    def fit(self, df=None, snomed_concepts=None):
        # Create a map from the entities to the concepts and contexts in which they appear
        if df is not None:
            for row in df.itertuples():
                entity = self.add_entity(row)
                context = self.add_context(row)
                map_ = self.history.get(entity, dict())
                contexts = map_.get(row.concept_id, list())
                contexts.append(context)
                map_[row.concept_id] = contexts
                self.history[entity] = map_

        # Add SNOMED CT codes for lookup
        if snomed_concepts is not None:
            for c in snomed_concepts:
                for syn in c.synonyms:
                    map_ = self.history.get(syn, dict())
                    contexts = map_.get(c.sctid, list())
                    contexts.append(syn)
                    map_[c.sctid] = contexts
                    self.history[syn] = map_

        # Create indexes to help disambiguate entities by their contexts
        for entity, map_ in tqdm(self.history.items()):
            keys = [
                (concept_id, occurance)
                for concept_id, contexts in map_.items()
                for occurance, context in enumerate(contexts)
            ]
            contexts = [context for contexts in map_.values() for context in contexts]
            vectors = self.encoder.encode(contexts)
            index = KeyedVectors(self.encoder[1].word_embedding_dimension)
            index.add_vectors(keys, vectors)
            self.context_index[entity] = index

        # Now create the top-level entity index
        keys = list(self.history.keys())
        vectors = self.encoder.encode(keys)
        self.entity_index.add_vectors(keys, vectors)

    def link(self, row):
        entity = self.add_entity(row)
        context = self.add_context(row)
        vec = self.encoder.encode(entity)
        nearest_entity = self.entity_index.most_similar(vec, topn=1)[0][0]
        index = self.context_index.get(nearest_entity, None)

        if index:
            vec = self.encoder.encode(context)
            key, score = index.most_similar(vec, topn=1)[0]
            sctid, _ = key
            return sctid
        else:
            return None

In [None]:
linker_training_df = training_notes_df.join(training_annotations_df)
linker_test_df = test_notes_df.join(test_annotations_df)