In [1]:
import pandas as pd
import numpy as np
import torch
import json
import dill as pickle
from snomed_graph import *
from sentence_transformers import InputExample, SentenceTransformer, losses, models, evaluation
from itertools import combinations
from torch.utils.data import DataLoader
from gensim.models.keyedvectors import KeyedVectors

In [2]:
SG = SnomedGraph.from_serialized('full_concept_graph.gml')

SNOMED graph has 361179 vertices and 1179749 edges


In [3]:
procedures = SG.get_descendants(71388002)
procedures.add(SG.get_concept_details(71388002))

body_structures = SG.get_descendants(123037004)
body_structures.add(SG.get_concept_details(123037004))

clinical_finding = SG.get_descendants(404684003)
clinical_finding.add(SG.get_concept_details(404684003))

all_concepts = procedures.union(body_structures)
all_concepts.update(clinical_finding)
print(len(all_concepts))


219172


In [4]:
#What is label=1?
device = "cuda:0" if torch.cuda.is_available() else "cpu"

kb_embedding_model_id = ("sentence-transformers/all-MiniLM-L6-v2")
kb_model = SentenceTransformer(kb_embedding_model_id,device=device)
# kb_model = kb_model.to(device)

sentences1 = []
sentences2 = []
labels = []
kb_sft_examples = []
for concept in tqdm(all_concepts):
    #Add synonym pairs
    for syn1, syn2 in combinations(SG.get_concept_details(concept.sctid).synonyms, 2):
        kb_sft_examples.append(InputExample(texts=[syn1, syn2], label=1))
        sentences1.append(syn1)
        sentences2.append(syn2)
        labels.append(1)
    
    # #add parent pairs
    # for p in SG.get_parents(concept.sctid):
    #     kb_sft_examples.append(InputExample(texts=[p.fsn.split('(')[0], SG.get_concept_details(concept.sctid).fsn.split('(')[0]], label=1))

    # #Add ancestor pairings
    # for a in SG.get_ancestors(concept.sctid):
    #     kb_sft_examples.append(InputExample(texts=[a.fsn.split('(')[0], SG.get_concept_details(concept.sctid).fsn.split('(')[0]], label=1))

kb_sft_dataloader = DataLoader(kb_sft_examples, shuffle=True, batch_size=32)
kb_sft_loss = losses.ContrastiveLoss(kb_model)

kb_model.fit(
    train_objectives=[(kb_sft_dataloader, kb_sft_loss)],
    epochs=2,
)

# kb_model.save("kb_model_3label")

  0%|          | 0/219172 [00:00<?, ?it/s]

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8586 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8586 [00:00<?, ?it/s]

In [4]:
kb_model = SentenceTransformer('kb_model_3label')

In [5]:
class Linker:
    def __init__(self, encoder, context_window_width=0):
        self.encoder = encoder
        self.entity_index = KeyedVectors(self.encoder[1].word_embedding_dimension)
        self.context_index = dict()
        self.history = dict()
        self.context_window_width = context_window_width

    def add_context(self, row):
        window_start = max(0, row.start - self.context_window_width)
        window_end = min(row.end + self.context_window_width, len(row.text))
        return row.text[window_start:window_end]

    def add_entity(self, row):
        return row.text[row.start : row.end]

    def fit(self, df=None, snomed_concepts=None):
        # Create a map from the entities to the concepts and contexts in which they appear
        if df is not None:
            for row in df.itertuples():
                entity = self.add_entity(row)
                context = self.add_context(row)
                map_ = self.history.get(entity, dict())
                contexts = map_.get(row.concept_id, list())
                contexts.append(context)
                map_[row.concept_id] = contexts
                self.history[entity] = map_

        # Add SNOMED CT codes for lookup
        if snomed_concepts is not None:
            for c in snomed_concepts:
                # for syn in c.synonyms:
                for syn in SG.get_concept_details(c.sctid).synonyms:
                    map_ = self.history.get(syn, dict())
                    contexts = map_.get(c.sctid, list())
                    contexts.append(syn)
                    map_[c.sctid] = contexts
                    self.history[syn] = map_

        # Create indexes to help disambiguate entities by their contexts
        for entity, map_ in tqdm(self.history.items()):
            keys = [
                (concept_id, occurance)
                for concept_id, contexts in map_.items()
                for occurance, context in enumerate(contexts)
            ]
            contexts = [context for contexts in map_.values() for context in contexts]
            vectors = self.encoder.encode(contexts)
            index = KeyedVectors(self.encoder[1].word_embedding_dimension)
            index.add_vectors(keys, vectors)
            self.context_index[entity] = index

        # Now create the top-level entity index
        keys = list(self.history.keys())
        vectors = self.encoder.encode(keys)
        self.entity_index.add_vectors(keys, vectors)

    def link(self, row):
        entity = self.add_entity(row)
        context = self.add_context(row)
        vec = self.encoder.encode(entity)
        #Map to known entity
        nearest_entity = self.entity_index.most_similar(vec, topn=1)[0][0]
        index = self.context_index.get(nearest_entity, None)

        #When would it ever not return index? If nearest_entity not found in train set?
        if index:
            vec = self.encoder.encode(context)
            #Within givin known entity, if multiple SCTIDs associated, then get SCTID with closest context
            key, score = index.most_similar(vec, topn=1)[0]
            sctid, _ = key
            return sctid
        else:
            return None

In [6]:
all_notes = pd.read_csv('mimic-iv_notes_training_set.csv',index_col='note_id')
all_annotations = pd.read_csv('train_annotations.csv',index_col='note_id')

rng = np.random.default_rng(seed=42)
shuffled_indices = rng.permutation(len(all_notes))

train_notes = all_notes.iloc[shuffled_indices[:184],:] #~90%
train_notes_with_annotations = pd.merge(left=train_notes,right=all_annotations,how='left',left_index=True,right_index=True)

print('Train notes:',len(train_notes),': # of Annotations:',train_notes_with_annotations.shape)

Train notes: 184 : # of Annotations: (46955, 4)


In [7]:
linker_training_df = train_notes_with_annotations

In [8]:
linker = Linker(kb_model, 12)
linker.fit(linker_training_df, all_concepts)

  0%|          | 0/385329 [00:00<?, ?it/s]

In [9]:
with open("3label_linker.pickle", "wb") as f:
    pickle.dump(linker, f)

In [None]:
# with open("3label_linker.pickle", "rb") as f:
#     linker = pickle.load(f)

In [10]:
with open('3label_pred.json') as f:
    data = f.read()

annotations_3label = json.loads(json.loads(data))

df_3label = pd.DataFrame(columns=['note_id','start','end','concept_id'])
for note in annotations_3label:
    for annotation in annotations_3label[note]:
        df_3label.loc[len(df_3label),:] = [note,list(annotation.values())[0][0],list(annotation.values())[0][1],-1]

df_3label = pd.merge(left=df_3label,right=all_notes[['text']],how='left',left_on='note_id',right_index=True)

for _,row in tqdm(df_3label.iterrows()):
    df_3label.loc[_,'concept_id'] = linker.link(row)

df_3label

0it [00:00, ?it/s]

Unnamed: 0,note_id,start,end,concept_id,text
0,14652764-DS-17,178,186,281900007,\nName: ___ Unit No: ___\n...
1,14652764-DS-17,199,221,419511003,\nName: ___ Unit No: ___\n...
2,14652764-DS-17,259,277,64766004,\nName: ___ Unit No: ___\n...
3,14652764-DS-17,322,340,47092002,\nName: ___ Unit No: ___\n...
4,14652764-DS-17,406,425,26925005,\nName: ___ Unit No: ___\n...
...,...,...,...,...,...
2175,15906604-DS-2,5421,5428,71616004,\nName: ___ Unit No: _...
2176,15906604-DS-2,5433,5442,4303006,\nName: ___ Unit No: _...
2177,15906604-DS-2,5480,5484,22253000,\nName: ___ Unit No: _...
2178,15906604-DS-2,5511,5523,103622007,\nName: ___ Unit No: _...


In [11]:
df_3label.set_index('note_id')[['start','end','concept_id']].to_csv('3label_res_embedding.csv')

In [12]:
df_3label['note_id'].unique()

array(['14652764-DS-17', '16441224-DS-19', '18914188-DS-20',
       '10797747-DS-20', '16464652-DS-17', '19476699-DS-25',
       '11436844-DS-4', '13397956-DS-5', '19442119-DS-15',
       '15906604-DS-2'], dtype=object)