In [1]:
import numpy as np
import pandas as pd
import json

import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer, TrainingArguments
from datasets import Dataset
from seqeval.metrics import f1_score, precision_score, recall_score, classification_report
from seqeval.scheme import IOB1, IOB2
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
# from sklearn.model_selection import KFold

from snomed_graph import *
from helpers import *

In [2]:
#Read in SNOMED graph to get PROC, STRUCT, or FIND sub-hierarchy
SG = SnomedGraph.from_serialized('full_concept_graph.gml')

SNOMED graph has 361179 vertices and 1179749 edges


In [3]:
#Read in Notes for training and test set
all_notes = pd.read_csv('mimic-iv_notes_training_set.csv',index_col='note_id')
all_annotations = pd.read_csv('train_annotations.csv',index_col='note_id')
print("# of Notes:",len(all_notes))
print("# of Annotations:",len(all_annotations))

# of Notes: 204
# of Annotations: 51574


In [4]:
# Set the seed for train/eval/test split
rng = np.random.default_rng(seed=42)
shuffled_indices = rng.permutation(len(all_notes))

# Split notes
train_notes = all_notes.iloc[shuffled_indices[:184],:] #~90%
eval_notes = all_notes.iloc[shuffled_indices[184:194],:] #~5%
test_notes = all_notes.iloc[shuffled_indices[194:],:] #~5%

#Add annotations to each dataset
train_notes_with_annotations = pd.merge(left=train_notes,right=all_annotations,how='left',left_index=True,right_index=True)
train_notes_with_annotations['annotation'] = train_notes_with_annotations.apply(lambda x: train_notes.loc[x.name,'text'][x['start']:x['end']],axis=1)
eval_notes_with_annotations = pd.merge(left=eval_notes,right=all_annotations,how='left',left_index=True,right_index=True)
eval_notes_with_annotations['annotation'] = eval_notes_with_annotations.apply(lambda x: eval_notes.loc[x.name,'text'][x['start']:x['end']],axis=1)
test_notes_with_annotations = pd.merge(left=test_notes,right=all_annotations,how='left',left_index=True,right_index=True)
test_notes_with_annotations['annotation'] = test_notes_with_annotations.apply(lambda x: test_notes.loc[x.name,'text'][x['start']:x['end']],axis=1)

print('Train notes:',len(train_notes),': # of Annotations:',train_notes_with_annotations.shape)
print('Eval notes:',len(eval_notes),': # of Annotations:',eval_notes_with_annotations.shape)
print('Test notes:',len(test_notes),': # of Annotations:',test_notes_with_annotations.shape)

Train notes: 184 : # of Annotations: (46955, 5)
Eval notes: 10 : # of Annotations: (2709, 5)
Test notes: 10 : # of Annotations: (1910, 5)


In [5]:
# Get SNOMED concept names for each annotation
train_notes_with_annotations['concept_fsn'] = train_notes_with_annotations['concept_id'].map(lambda x: SG.get_concept_details(x).fsn)
train_notes_with_annotations = train_notes_with_annotations.reset_index()

eval_notes_with_annotations['concept_fsn'] = eval_notes_with_annotations['concept_id'].map(lambda x: SG.get_concept_details(x).fsn)
eval_notes_with_annotations = eval_notes_with_annotations.reset_index()

test_notes_with_annotations['concept_fsn'] = test_notes_with_annotations['concept_id'].map(lambda x: SG.get_concept_details(x).fsn)
test_notes_with_annotations = test_notes_with_annotations.reset_index()

main_concepts = ['body structure','procedure','finding']

#Get main sub-hierarchy label for each annotaion
for index,row in train_notes_with_annotations.iterrows():

    ancestors = SG.get_ancestors(row['concept_id'])
    ancestors.add(SG.get_full_concept(row['concept_id']))

    
    for a in ancestors:
        for c in main_concepts:
            if re.search(r'\(([\w\s]+)\)',a.fsn) and c == re.search(r'\(([\w\s]+)\)',a.fsn).groups()[0]:
                train_notes_with_annotations.loc[index,'snomed_base'] = re.search(r'\(([\w\s]+)\)',a.fsn).groups()[0]

for index,row in eval_notes_with_annotations.iterrows():
    ancestors = SG.get_ancestors(row['concept_id'])
    ancestors.add(SG.get_full_concept(row['concept_id']))
    
    for a in ancestors:
        for c in main_concepts:
            if re.search(r'\(([\w\s]+)\)',a.fsn) and c == re.search(r'\(([\w\s]+)\)',a.fsn).groups()[0]:
                eval_notes_with_annotations.loc[index,'snomed_base'] = re.search(r'\(([\w\s]+)\)',a.fsn).groups()[0]
                
for index,row in test_notes_with_annotations.iterrows():
    ancestors = SG.get_ancestors(row['concept_id'])
    ancestors.add(SG.get_full_concept(row['concept_id']))
    
    for a in ancestors:
        for c in main_concepts:
            if re.search(r'\(([\w\s]+)\)',a.fsn) and c == re.search(r'\(([\w\s]+)\)',a.fsn).groups()[0]:
                test_notes_with_annotations.loc[index,'snomed_base'] = re.search(r'\(([\w\s]+)\)',a.fsn).groups()[0]
                
train_notes_with_annotations = train_notes_with_annotations.set_index('note_id')
eval_notes_with_annotations = eval_notes_with_annotations.set_index('note_id')
test_notes_with_annotations = test_notes_with_annotations.set_index('note_id')

train_notes_with_annotations.loc[(train_notes_with_annotations.index == '10513485-DS-7') & (train_notes_with_annotations['snomed_base'].isna()),:]

Unnamed: 0_level_0,text,start,end,concept_id,annotation,concept_fsn,snomed_base
note_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1


In [6]:
# model_names = ['bert-base-cased',"dmis-lab/biobert-large-cased-v1.1","microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
#                "cambridgeltl/SapBERT-from-PubMedBERT-fulltext"]
model_names = ["microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"]
scores = pd.DataFrame(columns=['model','Train_P(O)_Actual','P(O)_Actual','P(B)_Actual','P(O)','P(B)','f1','f1_strict','char_f1','char_f1_2','accuracy_2'])
score_dict = {i:0 for i in model_names}

scores = scores.set_index('model')
MAX_LEN = 512

evaluation_range = len(train_notes)
for i in model_names:
    print('--------------------------')
    print(i)
    tokenizer = AutoTokenizer.from_pretrained(i, model_max_length=MAX_LEN)

    train_tokens, train_token_array, train_map_token_to_char, train_orig_char_array = tokenize_and_label_7label(train_notes,train_notes_with_annotations,tokenizer,use_overflow=True)
    eval_tokens, eval_token_array, eval_map_token_to_char, eval_orig_char_array = tokenize_and_label_7label(eval_notes,eval_notes_with_annotations,tokenizer,use_overflow=True)
    test_tokens, test_token_array, test_map_token_to_char, test_orig_char_array = tokenize_and_label_7label(test_notes,test_notes_with_annotations,tokenizer,use_overflow=True)
    
    label_list = ['O','B-PROC','I-PROC','B-FIND','I-FIND','B-STRUCT','I-STRUCT']
    label_to_num = {label: i for i, label in enumerate(label_list)}
    num_to_label = {i: label for i, label in enumerate(label_list)}
    NUM_LABELS = len(label_list)

    # learning_rates = [1e-4,5e-5,5e-6]
    learning_rates = [5e-5]
    for lr in learning_rates:
        print('---------',str(lr),'----------')
        training_args = TrainingArguments(
            output_dir="./results/7label",
            evaluation_strategy="epoch",
            logging_strategy="epoch",
            num_train_epochs=10,
            learning_rate=lr,
            save_strategy='epoch',
            load_best_model_at_end=True,
            metric_for_best_model='loss'
        )

        device = "cuda:0" if torch.cuda.is_available() else "cpu"

        model = AutoModelForTokenClassification.from_pretrained(i, num_labels=NUM_LABELS)
        model = model.to(device)

        cols = list(train_tokens.columns)
        cols.remove('offset_mapping')
#         cols.remove('')
        train_data = Dataset.from_dict(train_tokens[cols])
        eval_data = Dataset.from_dict(eval_tokens[cols])
        test_data = Dataset.from_dict(test_tokens[cols])

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_data,
            eval_dataset=eval_data,
            tokenizer=tokenizer
        )

        print('Training...')
        trainer.train()

        print('Evaluating...')
        y_pred = []
        y_pred_overflow = []
        all_predicitions = {note_id:[] for note_id in list(test_notes.index)}
        for j in range(len(test_notes)):
            one_input = tokenizer(list(test_notes['text'].values)[j], padding='max_length',
                                    truncation=True, return_tensors="pt").to(device)
            res = model(**one_input).logits.argmax(-1)[0]
            y_pred.append(res)
            all_predicitions[list(test_notes.index)[j]].extend(res)
            for o in one_input[0].overflowing:
                overflow_input = {}
                overflow_input['input_ids'] = torch.as_tensor([o.ids]).to(device)
                if 'token_type_ids' in list(one_input.keys()):
                    overflow_input['token_type_ids'] = torch.as_tensor([o.type_ids]).to(device)
                overflow_input['attention_mask'] = torch.as_tensor([o.attention_mask]).to(device)
                res = model(**overflow_input).logits.argmax(-1)[0]
                y_pred_overflow.append(res)
                all_predicitions[list(test_notes.index)[j]].extend(res)

        combined_model_res = y_pred + y_pred_overflow


        train_annotations = {note_id:[] for note_id in list(train_notes.index)}

        for batch_input in train_tokens.iterrows():
            one_note_annotations = []
            start_token = -1
            end_token = -1

            for j,t in enumerate(batch_input[1]['labels']):
                if t > 0:
                    if start_token == -1:
                        start_token = j

                    end_token = j

                    if j == len(batch_input[1]['labels'])-1 or batch_input[1]['labels'][j+1] in [0,1,3,5]:
                        one_note_annotations.append({train_notes.loc[batch_input[0],'text'][batch_input[1]['offset_mapping'][start_token][0]:batch_input[1]['offset_mapping'][end_token][1]]:
                                                    [batch_input[1]['offset_mapping'][start_token][0],batch_input[1]['offset_mapping'][end_token][1]]})
                        start_token = -1
                        end_token = -1

            train_annotations[batch_input[0]].extend(one_note_annotations)

        true_annotations = {note_id:[] for note_id in list(test_notes.index)}
        pred_annotations = {note_id:[] for note_id in list(test_notes.index)}
        true_char_array = {note_id:np.zeros(len(test_notes.loc[note_id,'text']),dtype=np.int32) for note_id in list(test_notes.index)}
        true_char_array_2 = {note_id:np.zeros(len(test_notes.loc[note_id,'text']),dtype=np.int32) for note_id in list(test_notes.index)}
        pred_char_array = {note_id:np.zeros(len(test_notes.loc[note_id,'text']),dtype=np.int32) for note_id in list(test_notes.index)}
        pred_char_array_2 = {note_id:np.zeros(len(test_notes.loc[note_id,'text']),dtype=np.int32) for note_id in list(test_notes.index)}
        num = 0
        char_token_count = {k:0 for k in label_list}

        for batch_input in test_tokens.iterrows():
            ## Get TRUE Annotations
            one_note_annotations = []
            start_token = -1
            end_token = -1

            for j,t in enumerate(batch_input[1]['labels']):
                if t > 0:
                    if start_token == -1:
                        start_token = j

                    end_token = j

                    if j == len(batch_input[1]['labels'])-1 or batch_input[1]['labels'][j+1] in [0,1,3,5]:
                        one_note_annotations.append({test_notes.loc[batch_input[0],'text'][batch_input[1]['offset_mapping'][start_token][0]:batch_input[1]['offset_mapping'][end_token][1]]:
                                                    [batch_input[1]['offset_mapping'][start_token][0],batch_input[1]['offset_mapping'][end_token][1]]})

                        true_char_array[batch_input[0]][batch_input[1]['offset_mapping'][start_token][0]:batch_input[1]['offset_mapping'][end_token][1]] = t
                        true_char_array_2[batch_input[0]][batch_input[1]['offset_mapping'][start_token][0]:batch_input[1]['offset_mapping'][end_token][1]] = t
                        if t in [2,4,6]:
                            true_char_array_2[batch_input[0]][batch_input[1]['offset_mapping'][start_token][0]] = t-1
                            true_char_array[batch_input[0]][batch_input[1]['offset_mapping'][start_token][0]:batch_input[1]['offset_mapping'][start_token][1]] = t-1

                        start_token = -1
                        end_token = -1

            true_annotations[batch_input[0]].extend(one_note_annotations)

            ##Get PREDICTED Annotations
            one_note_annotations = []
            start_token = -1
            end_token = -1

            for j,t in enumerate(combined_model_res[num]):
                if t > 0:
                    if start_token == -1:
                        start_token = j

                    end_token = j

                    if j == len(combined_model_res[num])-1 or combined_model_res[num][j+1] in [0,1,3,5]:
                        one_note_annotations.append({test_notes.loc[batch_input[0],'text'][batch_input[1]['offset_mapping'][start_token][0]:batch_input[1]['offset_mapping'][end_token][1]]:
                                                    [t.cpu(),[batch_input[1]['offset_mapping'][start_token][0],batch_input[1]['offset_mapping'][end_token][1]]]})

                        pred_char_array[batch_input[0]][batch_input[1]['offset_mapping'][start_token][0]:batch_input[1]['offset_mapping'][end_token][1]] = t.cpu()
                        pred_char_array_2[batch_input[0]][batch_input[1]['offset_mapping'][start_token][0]:batch_input[1]['offset_mapping'][end_token][1]] = t.cpu()
                        if t in [2,4,6]:
                            pred_char_array_2[batch_input[0]][batch_input[1]['offset_mapping'][start_token][0]] = t.cpu()-1
                            pred_char_array[batch_input[0]][batch_input[1]['offset_mapping'][start_token][0]:batch_input[1]['offset_mapping'][start_token][1]] = t.cpu()-1

                        start_token = -1
                        end_token = -1

            pred_annotations[batch_input[0]].extend(one_note_annotations)

            num += 1


        #Evaluation
        pred_char_f1 = []
        pred_char_f1_2 = []
        for x in pred_char_array.keys():
            pred_char_f1.append([num_to_label[y] for y in pred_char_array[x]])
        for x in pred_char_array_2.keys():
            pred_char_f1_2.append([num_to_label[y] for y in pred_char_array_2[x]])
        true_char_f1 = []
        true_char_f1_2 = []
        for x in true_char_array.keys():
            true_char_f1.append([num_to_label[y] for y in true_char_array[x]])
        for x in true_char_array_2.keys():
            true_char_f1_2.append([num_to_label[y] for y in true_char_array_2[x]])

        pred_label_array = []
        for x in y_pred:
            temp = []
            for y in x.cpu():
                temp.append(num_to_label[y.item()])
            pred_label_array.append(temp)
        for x in y_pred_overflow:
            temp = []
            for y in x.cpu():
                temp.append(num_to_label[y.item()])
            pred_label_array.append(temp)

        true_label_array = []
        for x in test_tokens['labels']:
            temp = []
            for y in x:
                temp.append(num_to_label[y])
            true_label_array.append(temp)

#         scores.loc[i+'('+str(lr)+')','char_f1'] = f1_score(true_char_f1,pred_char_f1)
#         scores.loc[i+'('+str(lr)+')','char_f1_2'] = f1_score(true_char_f1_2,pred_char_f1_2)
#         scores.loc[i+'('+str(lr)+')','accuracy'] = accuracy_score(true_char_f1,pred_char_f1)
#         scores.loc[i+'('+str(lr)+')','accuracy_2'] = accuracy_score(true_char_f1_2,pred_char_f1_2)

        print(classification_report(true_char_f1,pred_char_f1,mode='strict',scheme=IOB2))
        print(classification_report(true_char_f1_2,pred_char_f1_2,mode='strict',scheme=IOB2))
        score_dict[i] = classification_report(true_char_f1,pred_char_f1,mode='strict',scheme=IOB2)
    
    
# scores

--------------------------
microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext
--------- 5e-05 ----------


Some weights of BertForTokenClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Training...


Epoch,Training Loss,Validation Loss
1,0.3681,0.201536
2,0.1828,0.186419
3,0.1454,0.177469
4,0.1196,0.17721
5,0.099,0.181831
6,0.0834,0.197943
7,0.0706,0.206821
8,0.0621,0.214734
9,0.0547,0.218029
10,0.0506,0.223237


Evaluating...
              precision    recall  f1-score   support

        FIND       0.79      0.84      0.82      5958
        PROC       0.77      0.75      0.76      3352
      STRUCT       0.60      0.75      0.66       829

   micro avg       0.77      0.80      0.78     10139
   macro avg       0.72      0.78      0.75     10139
weighted avg       0.77      0.80      0.78     10139

              precision    recall  f1-score   support

        FIND       0.75      0.81      0.78      3310
        PROC       0.76      0.74      0.75      2227
      STRUCT       0.56      0.76      0.65       580

   micro avg       0.73      0.78      0.76      6117
   macro avg       0.69      0.77      0.73      6117
weighted avg       0.74      0.78      0.76      6117



Unnamed: 0_level_0,Train_P(O)_Actual,P(O)_Actual,P(B)_Actual,P(O),P(B),f1,f1_strict,char_f1,char_f1_2,accuracy_2
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1


In [7]:
for note in pred_annotations:
    for i,annotation in enumerate(pred_annotations[note]):
#         print(annotation,pred_annotations[note][i][list(pred_annotations[note][i].keys())[0]][0])
        pred_annotations[note][i][list(pred_annotations[note][i].keys())[0]][0] = pred_annotations[note][i][list(pred_annotations[note][i].keys())[0]][0].item()
pred_annotations

{'14652764-DS-17': [{'No': [3, [178, 180]]},
  {'Adverse Drug Reactions': [4, [199, 221]]},
  {'ulcerative colitis': [4, [259, 277]]},
  {'ileostomy takedown': [2, [322, 340]]},
  {'abdominal colectomy': [2, [406, 425]]},
  {'laparoscopic': [1, [427, 439]]},
  {'proctectomy': [2, [440, 451]]},
  {'diverting\nloop ileostomy': [2, [457, 481]]},
  {'infection': [3, [579, 588]]},
  {'bleeding': [3, [590, 598]]},
  {'leak': [5, [600, 604]]},
  {'procedures': [1, [620, 630]]},
  {'Ulcerative Colitis': [4, [712, 730]]},
  {'Lap colectomy': [2, [737, 750]]},
  {'ileostomy': [2, [758, 767]]},
  {'CAD': [3, [815, 818]]},
  {'HLD': [4, [820, 823]]},
  {'RA': [3, [837, 839]]},
  {'DM': [3, [840, 842]]},
  {'NAD': [3, [888, 891]]},
  {'CV': [1, [892, 894]]},
  {'RRR': [4, [896, 899]]},
  {'Resp': [1, [900, 904]]},
  {'GI': [1, [926, 928]]},
  {'inc': [3, [930, 933]]},
  {'ND': [3, [941, 943]]},
  {'NT': [3, [945, 947]]},
  {'soft': [3, [949, 953]]},
  {'ileostomy takedown': [2, [1066, 1084]]},
  {'

In [12]:
# pred_annotations
# data = json.dumps(pred_annotations)
# with open("7label_pred.json", "w") as file:
#     json.dump(data, file)

In [10]:
print(score_dict['microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext'])

              precision    recall  f1-score   support

        FIND       0.75      0.81      0.78      3310
        PROC       0.76      0.74      0.75      2227
      STRUCT       0.56      0.76      0.65       580

   micro avg       0.73      0.78      0.76      6117
   macro avg       0.69      0.77      0.73      6117
weighted avg       0.74      0.78      0.76      6117



In [11]:
# model.save_pretrained('7label_NER_Final')