In [11]:
import pandas as pd
import numpy as np
from transformers import AutoTokenizer
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [12]:
# this notebook is for running inference on external data with a previously trained student model
external_data = pd.read_csv('./synthetic_example_medonc_data.csv')
device = 'cuda'

In [13]:

from transformers import AutoModel



from torch.nn import functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn import LSTM, Linear, Embedding, Conv1d, MaxPool1d, GRU, LSTMCell, Dropout, Module, Sequential, ReLU

   
class LabeledModel(nn.Module):

    def __init__(self):
        super(LabeledModel, self).__init__()
        
        self.longformer = AutoModel.from_pretrained('yikuan8/Clinical-Longformer')
        
        self.any_cancer_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))
        self.response_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))
        self.progression_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))


        
    def forward(self, x_text_tensor, x_attention_mask):
        # x should be tuple of input IDs, then attention mask
        global_attention_mask = torch.zeros_like(x_text_tensor).to(device)
        # global attention on cls token
        global_attention_mask[:, 0] = 1
        main = self.longformer(x_text_tensor, x_attention_mask, global_attention_mask)
        main = main.last_hidden_state[:,0,:].squeeze(1)

                                          
        any_cancer_out = self.any_cancer_head(main)
        response_out = self.response_head(main)
        progression_out = self.progression_head(main)



        
        return any_cancer_out, response_out, progression_out
        




In [14]:
from torch.utils import data

class UnLabeledDataset(data.Dataset):
    def __init__(self, pandas_dataset):
        self.data = pandas_dataset.copy()
        self.indices = self.data.index.unique()
        self.tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-Longformer", truncation_side='left')        
        
        
    def __len__(self):
        # how many notes in the dataset
        return len(self.indices)
    
    def __getitem__(self, index):
        # get data for notes corresponding to indices passed
        this_index = self.indices[index]
        pand = self.data.loc[this_index, :]
    
        encoded = self.tokenizer(pand['text'], padding='max_length', truncation=True)

        x_text_tensor = torch.tensor(encoded.input_ids, dtype=torch.long)
        x_attention_mask = torch.tensor(encoded.attention_mask, dtype=torch.long)
        

        return x_text_tensor, x_attention_mask
        
        

In [15]:
# write out inference dataset
themodel = LabeledModel()
themodel.load_state_dict(torch.load('dfci_mimic_note_longformer.pt'))
themodel.to(device)

themodel.eval()

no_shuffle_valid_dataset = data.DataLoader(UnLabeledDataset(external_data), batch_size=8, shuffle=False, num_workers=0)

output_prediction_lists = [[] for x in range(3)]
for batch in no_shuffle_valid_dataset:
    x_text_ids = batch[0].to(device)
    x_attention_mask = batch[1].to(device)
    with torch.no_grad():
        predictions = themodel(x_text_ids, x_attention_mask)
    for j in range(3):
        output_prediction_lists[j].append(predictions[j].detach().cpu().numpy())

output_prediction_lists = [np.concatenate(x) for x in output_prediction_lists]


output_external = external_data.copy()
for x in range(3):
    output_external['outcome_' + str(x) + '_logit'] = output_prediction_lists[x]



Some weights of LongformerModel were not initialized from the model checkpoint at yikuan8/Clinical-Longformer and are newly initialized: ['longformer.pooler.dense.bias', 'longformer.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
output_external = output_external.rename(columns={'outcome_0_logit':'any_cancer_logit',
                                                  'outcome_1_logit':'response_logit',
                                                  'outcome_2_logit':'progression_logit'})


In [17]:
output_external.to_csv('./synthetic_example_medonc_inference_result.csv')

In [18]:
output_external.head()

Unnamed: 0,claude2_sonnet_prompt,text,any_cancer_logit,response_logit,progression_logit
0,A patient with a history of lung cancer is see...,PROGRESS NOTE\n\nPatient: John Doe\nDate of Vi...,-5.174407,-6.148384,-5.649448
1,A patient with metastatic breast cancer with m...,Patient Name: Jane Doe\r\nMRN: 123456\r\nDate ...,5.674859,-5.19949,2.578383
2,A patient with a history of resected colorecta...,PROGRESS NOTE\r\n\r\nPatient: John Doe\r\nDOB:...,5.230693,-6.157811,2.323241
3,A patient with metastatic bladder cancer to ly...,SUBJECTIVE:\r\npatient is a 68-year-old female...,4.673669,1.951149,-3.411746
4,A patient with clear cell renal cell carcinoma...,Patient: John Doe\r\nDOB: 01/01/1970\r\nMRN: 1...,6.016739,-5.540566,1.989165
