In [2]:
import numpy as np
import pandas as pd
import torch
from torch.utils import data
from transformers import AutoTokenizer
from torch.utils.data import DataLoader

In [3]:
# pull in your dataset here. It should have a column labeled 'text' containing the full radiology report text (not just impression).
# if you have narrative reports separate from the impressions, would concatenate the impressions at the end of the narratives.
inference_input = pd.read_csv('./synthetic_example_imaging_data.csv')
inference_input['text'] = inference_input['text'].str.lower().str.replace("\n", " ")
inference_input.drop(inference_input.filter(regex='Unnamed|outcome').columns, axis=1, inplace=True)


In [4]:
class UnLabeledDataset(data.Dataset):
    def __init__(self, pandas_dataset):
        self.data = pandas_dataset.copy()
        self.indices = self.data.index.unique()
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', truncation_side='left')        
        
        
    def __len__(self):
        # how many notes in the dataset
        return len(self.indices)
    
    def __getitem__(self, index):
        # get data for notes corresponding to indices passed
        this_index = self.indices[index]
        pand = self.data.loc[this_index, :]
    
        encoded = self.tokenizer(pand['text'], padding='max_length', truncation=True)

        x_text_tensor = torch.tensor(encoded.input_ids, dtype=torch.long)
        x_attention_mask = torch.tensor(encoded.attention_mask, dtype=torch.long)
        return x_text_tensor, x_attention_mask
        

In [5]:
from transformers import AutoModel
from torch.nn import functional as F
import torch.nn as nn
from torch.nn import Linear, Sequential, ReLU

   
class LabeledModel(nn.Module):

    def __init__(self):
        super(LabeledModel, self).__init__()
        
        self.bert = AutoModel.from_pretrained('bert-base-uncased')
        
        self.any_cancer_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))
        self.response_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))
        self.progression_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))
        self.brain_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))
        self.bone_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))
        self.adrenal_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))
        self.liver_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))
        self.lung_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))
        self.node_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))
        self.peritoneal_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))

        
    def forward(self, x_text_tensor, x_attention_mask):
        main = self.bert(x_text_tensor, x_attention_mask)
        main = main.last_hidden_state[:,0,:].squeeze(1)

                                          
        any_cancer_out = self.any_cancer_head(main)
        response_out = self.response_head(main)
        progression_out = self.progression_head(main)
        brain_out = self.brain_head(main)
        bone_out = self.bone_head(main)
        adrenal_out = self.adrenal_head(main)
        liver_out = self.liver_head(main)
        lung_out = self.lung_head(main)
        node_out = self.node_head(main)
        peritoneum_out = self.peritoneal_head(main)

        return any_cancer_out, response_out, progression_out, brain_out, bone_out, adrenal_out, liver_out, lung_out, node_out, peritoneum_out
        




In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [7]:
# write out the inference dataset
themodel = LabeledModel()
themodel.load_state_dict(torch.load('./dfci_mimic_imaging_bert.pt'))
themodel.to(device)

themodel.eval()

dataset = data.DataLoader(UnLabeledDataset(inference_input), batch_size=16, shuffle=False, num_workers=0)

output_true_lists = [[] for x in range(10)]
output_prediction_lists = [[] for x in range(10)]
for batch in dataset:

    x_text_ids = batch[0].to(device)
    x_attention_mask = batch[1].to(device)
    with torch.no_grad():
        predictions = themodel(x_text_ids, x_attention_mask)
    for j in range(10):
        output_prediction_lists[j].append(predictions[j].detach().cpu().numpy())

output_prediction_lists = [np.concatenate(x) for x in output_prediction_lists]


output_dataset = inference_input.copy()
for x in range(10):
    output_dataset['outcome_' + str(x) + '_logit'] = output_prediction_lists[x]


In [8]:
output_dataset=output_dataset.rename(columns={'outcome_0_logit' : 'any_cancer_logit',
                                              'outcome_1_logit' : 'response_logit',
                                              'outcome_2_logit' : 'progression_or_mixed_logit',
                                              'outcome_3_logit' : 'brain_met_logit',
                                              'outcome_4_logit' : 'bone_met_logit',
                                              'outcome_5_logit' : 'adrenal_met_logit',
                                              'outcome_6_logit' : 'liver_met_logit',
                                              'outcome_7_logit' : 'lung_met_logit',
                                              'outcome_8_logit' : 'node_met_logit',
                                              'outcome_9_logit' : 'peritoneum_met_logit'})

In [9]:
output_dataset.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5 entries, 0 to 4
Data columns (total 12 columns):
 #   Column                      Non-Null Count  Dtype  
---  ------                      --------------  -----  
 0   gpt4_turbo_prompt           5 non-null      object 
 1   text                        5 non-null      object 
 2   any_cancer_logit            5 non-null      float32
 3   response_logit              5 non-null      float32
 4   progression_or_mixed_logit  5 non-null      float32
 5   brain_met_logit             5 non-null      float32
 6   bone_met_logit              5 non-null      float32
 7   adrenal_met_logit           5 non-null      float32
 8   liver_met_logit             5 non-null      float32
 9   lung_met_logit              5 non-null      float32
 10  node_met_logit              5 non-null      float32
 11  peritoneum_met_logit        5 non-null      float32
dtypes: float32(10), object(2)
memory usage: 408.0+ bytes


In [10]:
output_dataset.to_csv('./synthetic_example_imaging_inference_result.csv')