In [1]:
import pandas as pd
import numpy as np
import torch
import json
import dill as pickle
from snomed_graph import *
from sentence_transformers import InputExample, SentenceTransformer, losses, models, evaluation
from itertools import combinations
from torch.utils.data import DataLoader
from gensim.models.keyedvectors import KeyedVectors

In [2]:
SG = SnomedGraph.from_serialized('full_concept_graph.gml')

SNOMED graph has 361179 vertices and 1179749 edges


In [3]:
procedures = SG.get_descendants(71388002)
procedures.add(SG.get_concept_details(71388002))

body_structures = SG.get_descendants(123037004)
body_structures.add(SG.get_concept_details(123037004))

clinical_finding = SG.get_descendants(404684003)
clinical_finding.add(SG.get_concept_details(404684003))

all_concepts = procedures.union(body_structures)
all_concepts.update(clinical_finding)
print(len(all_concepts))
# all_concepts
concept_dict = {}
concept_dict['procedures'] = procedures
concept_dict['body_structures'] = body_structures
concept_dict['clinical_finding'] = clinical_finding


219172


In [4]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
kb_embedding_model_id = ("sentence-transformers/all-MiniLM-L6-v2")

for i in list(concept_dict.keys()):
    kb_model = SentenceTransformer(kb_embedding_model_id,device=device)

    kb_sft_examples = []
    sentences1 = []
    sentences2 = []
    labels = []
    for concept in tqdm(concept_dict[i]):
        #Add synonym pairings
        for syn1, syn2 in combinations(SG.get_concept_details(concept.sctid).synonyms, 2):
            kb_sft_examples.append(InputExample(texts=[syn1, syn2], label=1))
            sentences1.append(syn1)
            sentences2.append(syn2)
            labels.append(1)

        # #Add parent pairings
        # for p in SG.get_parents(concept.sctid):
        #     kb_sft_examples.append(InputExample(texts=[p.fsn.split('(')[0], SG.get_concept_details(concept.sctid).fsn.split('(')[0]], label=1))
        
        # #Add ancestor pairings
        # for a in SG.get_ancestors(concept.sctid):
        #     kb_sft_examples.append(InputExample(texts=[a.fsn.split('(')[0], SG.get_concept_details(concept.sctid).fsn.split('(')[0]], label=1))

    train_dataloader = DataLoader(kb_sft_examples, shuffle=True, batch_size=32)
    train_loss = losses.ContrastiveLoss(kb_model)
    kb_model.fit(
        [(train_dataloader, train_loss)],
        epochs=2,
    )

    kb_model.save("kb_model_"+i)

  0%|          | 0/59091 [00:00<?, ?it/s]

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Iteration:   0%|          | 0/2130 [00:00<?, ?it/s]

Iteration:   0%|          | 0/2130 [00:00<?, ?it/s]

  0%|          | 0/41109 [00:00<?, ?it/s]

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1457 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1457 [00:00<?, ?it/s]

  0%|          | 0/118972 [00:00<?, ?it/s]

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Iteration:   0%|          | 0/5001 [00:00<?, ?it/s]

Iteration:   0%|          | 0/5001 [00:00<?, ?it/s]

In [5]:
procedures_kb_model = SentenceTransformer('kb_model_procedures')
body_structures_kb_model = SentenceTransformer('kb_model_body_structures')
clinical_finding_kb_model = SentenceTransformer('kb_model_clinical_finding')

In [6]:
class Linker:
    def __init__(self, encoder, context_window_width=0):
        self.encoder = encoder
        self.entity_index = KeyedVectors(self.encoder[1].word_embedding_dimension)
        self.context_index = dict()
        self.history = dict()
        self.context_window_width = context_window_width

    def add_context(self, row):
        window_start = max(0, row.start - self.context_window_width)
        window_end = min(row.end + self.context_window_width, len(row.text))
        return row.text[window_start:window_end]

    def add_entity(self, row):
        return row.text[row.start : row.end]

    def fit(self, df=None, snomed_concepts=None):
        # Create a map from the entities to the concepts and contexts in which they appear
        if df is not None:
            for row in df.itertuples():
                entity = self.add_entity(row)
                context = self.add_context(row)
                map_ = self.history.get(entity, dict())
                contexts = map_.get(row.concept_id, list())
                contexts.append(context)
                map_[row.concept_id] = contexts
                self.history[entity] = map_

        # Add SNOMED CT codes for lookup
        if snomed_concepts is not None:
            for c in snomed_concepts:
                # for syn in c.synonyms:
                for syn in SG.get_concept_details(c.sctid).synonyms:
                    map_ = self.history.get(syn, dict())
                    contexts = map_.get(c.sctid, list())
                    contexts.append(syn)
                    map_[c.sctid] = contexts
                    self.history[syn] = map_

        # Create indexes to help disambiguate entities by their contexts
        for entity, map_ in tqdm(self.history.items()):
            keys = [
                (concept_id, occurance)
                for concept_id, contexts in map_.items()
                for occurance, context in enumerate(contexts)
            ]
            contexts = [context for contexts in map_.values() for context in contexts]
            vectors = self.encoder.encode(contexts)
            index = KeyedVectors(self.encoder[1].word_embedding_dimension)
            index.add_vectors(keys, vectors)
            self.context_index[entity] = index

        # Now create the top-level entity index
        keys = list(self.history.keys())
        vectors = self.encoder.encode(keys)
        self.entity_index.add_vectors(keys, vectors)

    def link(self, row):
        entity = self.add_entity(row)
        context = self.add_context(row)
        vec = self.encoder.encode(entity)
        #Map to known entity
        nearest_entity = self.entity_index.most_similar(vec, topn=1)[0][0]
        index = self.context_index.get(nearest_entity, None)

        #When would it ever not return index? If nearest_entity not found in train set?
        if index:
            vec = self.encoder.encode(context)
            #Within givin known entity, if multiple SCTIDs associated, then get SCTID with closest context
            key, score = index.most_similar(vec, topn=1)[0]
            sctid, _ = key
            return sctid
        else:
            return None

In [7]:
all_notes = pd.read_csv('mimic-iv_notes_training_set.csv',index_col='note_id')
all_annotations = pd.read_csv('train_annotations.csv',index_col='note_id')

rng = np.random.default_rng(seed=42)
shuffled_indices = rng.permutation(len(all_notes))

train_notes = all_notes.iloc[shuffled_indices[:184],:] #~90%
train_notes_with_annotations = pd.merge(left=train_notes,right=all_annotations,how='left',left_index=True,right_index=True)

print('Train notes:',len(train_notes),': # of Annotations:',train_notes_with_annotations.shape)

Train notes: 184 : # of Annotations: (46955, 4)


In [8]:
train_notes_with_annotations['concept_fsn'] = train_notes_with_annotations['concept_id'].map(lambda x: SG.get_concept_details(x).fsn)
train_notes_with_annotations = train_notes_with_annotations.reset_index()

main_concepts = ['body structure','procedure','finding']

for index,row in train_notes_with_annotations.iterrows():
    ancestors = SG.get_ancestors(row['concept_id'])
    ancestors.add(SG.get_full_concept(row['concept_id']))
    
    for a in ancestors:
        for c in main_concepts:
            if re.search(r'\(([\w\s]+)\)',a.fsn) and c == re.search(r'\(([\w\s]+)\)',a.fsn).groups()[0]:
                train_notes_with_annotations.loc[index,'snomed_base'] = re.search(r'\(([\w\s]+)\)',a.fsn).groups()[0]
                
train_notes_with_annotations = train_notes_with_annotations.set_index('note_id')
train_notes_with_annotations.loc[(train_notes_with_annotations.index == '10513485-DS-7') & (train_notes_with_annotations['snomed_base'].isna()),:]

Unnamed: 0_level_0,text,start,end,concept_id,concept_fsn,snomed_base
note_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1


In [9]:
clinical_finding_linker_training_df = train_notes_with_annotations.loc[train_notes_with_annotations.snomed_base == 'finding',:]
body_structures_linker_training_df = train_notes_with_annotations.loc[train_notes_with_annotations.snomed_base == 'body structure',:]
procedures_linker_training_df = train_notes_with_annotations.loc[train_notes_with_annotations.snomed_base == 'procedure',:]

In [10]:
procedures_linker = Linker(procedures_kb_model, 12)
procedures_linker.fit(procedures_linker_training_df, procedures)

body_structures_linker = Linker(body_structures_kb_model, 12)
body_structures_linker.fit(body_structures_linker_training_df, body_structures)

clinical_finding_linker = Linker(clinical_finding_kb_model, 12)
clinical_finding_linker.fit(clinical_finding_linker_training_df, clinical_finding)

  0%|          | 0/101550 [00:00<?, ?it/s]

  0%|          | 0/73171 [00:00<?, ?it/s]

  0%|          | 0/212378 [00:00<?, ?it/s]

In [11]:
with open("procedures_linker.pickle", "wb") as f:
    pickle.dump(procedures_linker, f)

with open("body_structures_linker.pickle", "wb") as f:
    pickle.dump(body_structures_linker, f)

with open("clinical_finding_linker.pickle", "wb") as f:
    pickle.dump(clinical_finding_linker, f)

In [4]:
# with open("procedures_linker.pickle", "rb") as f:
#     procedures_linker = pickle.load(f)

# with open("body_structures_linker.pickle", "rb") as f:
#     body_structures_linker = pickle.load(f)

# with open("clinical_finding_linker.pickle", "rb") as f:
#     clinical_finding_linker = pickle.load(f)

In [12]:
with open('7label_pred.json') as f:
    data = f.read()
# print(data)

all_notes = pd.read_csv('mimic-iv_notes_training_set.csv',index_col='note_id')
all_annotations = pd.read_csv('train_annotations.csv',index_col='note_id')

annotations_7label = json.loads(json.loads(data))

df_7label = pd.DataFrame(columns=['note_id','start','end','main_concept','concept_id'])
for note in annotations_7label:
    for annotation in annotations_7label[note]:
#         print(annotation,list(annotation.values())[0][0])
        df_7label.loc[len(df_7label),:] = [note,list(annotation.values())[0][1][0],list(annotation.values())[0][1][1],list(annotation.values())[0][0],-1]

df_7label = pd.merge(left=df_7label,right=all_notes[['text']],how='left',left_on='note_id',right_index=True)

for _,row in tqdm(df_7label.iterrows()):
#     print(row)
    if row.main_concept in [1,2]:
        df_7label.loc[_,'concept_id'] = procedures_linker.link(row)
    elif row.main_concept in [3,4]:
        df_7label.loc[_,'concept_id'] = clinical_finding_linker.link(row)
    elif row.main_concept in [5,6]:
        df_7label.loc[_,'concept_id'] = body_structures_linker.link(row)

df_7label

0it [00:00, ?it/s]

Unnamed: 0,note_id,start,end,main_concept,concept_id,text
0,14652764-DS-17,178,180,3,13920009,\nName: ___ Unit No: ___\n...
1,14652764-DS-17,199,221,4,419511003,\nName: ___ Unit No: ___\n...
2,14652764-DS-17,259,277,4,64766004,\nName: ___ Unit No: ___\n...
3,14652764-DS-17,322,340,2,47092002,\nName: ___ Unit No: ___\n...
4,14652764-DS-17,406,425,2,43075005,\nName: ___ Unit No: ___\n...
...,...,...,...,...,...,...
2018,15906604-DS-2,5401,5413,3,128139000,\nName: ___ Unit No: _...
2019,15906604-DS-2,5421,5428,5,71616004,\nName: ___ Unit No: _...
2020,15906604-DS-2,5433,5442,5,771314001,\nName: ___ Unit No: _...
2021,15906604-DS-2,5511,5523,3,128139000,\nName: ___ Unit No: _...


In [16]:
# df_3label.set_index('note_id')[['start','end','concept_id']].to_csv('3label_res_embedding_with_ancestors.csv')
df_7label.set_index('note_id')[['start','end','concept_id']].to_csv('7label_res_embedding.csv')

In [15]:
df_7label['note_id'].unique()

array(['14652764-DS-17', '16441224-DS-19', '18914188-DS-20',
       '10797747-DS-20', '16464652-DS-17', '19476699-DS-25',
       '11436844-DS-4', '13397956-DS-5', '19442119-DS-15',
       '15906604-DS-2'], dtype=object)