In [None]:
# This notebook trains a student prognostic model on MIMIC imaging reports and evaluates it on DFCI test imaging reports.

In [1]:
import pandas as pd
import numpy as np
from transformers import AutoTokenizer
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
mimic = pd.read_csv('/data/clin_notes_outcomes/mimic-iv-note-deidentified-free-text-clinical-notes-2.2/mimic-iv-note-deidentified-free-text-clinical-notes-2.2/note/radiology.csv')

In [4]:
mimic.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2321355 entries, 0 to 2321354
Data columns (total 8 columns):
 #   Column      Dtype  
---  ------      -----  
 0   note_id     object 
 1   subject_id  int64  
 2   hadm_id     float64
 3   note_type   object 
 4   note_seq    int64  
 5   charttime   object 
 6   storetime   object 
 7   text        object 
dtypes: float64(1), int64(2), object(5)
memory usage: 141.7+ MB


In [5]:
mimic_cancer = mimic[mimic.text.str.lower().str.contains('cancer|restaging|malignan')] 
mimic_cancer = mimic_cancer[mimic_cancer.text.str.lower().str.contains('ct |mr |pet |nm |mammo')]

In [6]:
mimic_cancer.info()

<class 'pandas.core.frame.DataFrame'>
Index: 217642 entries, 9 to 2321304
Data columns (total 8 columns):
 #   Column      Non-Null Count   Dtype  
---  ------      --------------   -----  
 0   note_id     217642 non-null  object 
 1   subject_id  217642 non-null  int64  
 2   hadm_id     55804 non-null   float64
 3   note_type   217642 non-null  object 
 4   note_seq    217642 non-null  int64  
 5   charttime   217642 non-null  object 
 6   storetime   217642 non-null  object 
 7   text        217642 non-null  object 
dtypes: float64(1), int64(2), object(5)
memory usage: 14.9+ MB


In [7]:
mimic_cancer['text'] = mimic_cancer['text'].str.lower().str.replace("\n", " ")

In [None]:
phi_data = pd.read_csv('/data/clin_notes_outcomes/profile_3-2023/derived_data/labeled_imaging_prissmm.csv')

# define mixed response to be progression
phi_data['text'] = phi_data.text.str.lower().str.replace("\n", " ")
phi_data['progression'] = np.where(phi_data.class_status==3, 1, phi_data.progression)
phi_data = phi_data[pd.to_datetime(phi_data.date) > pd.to_datetime(phi_data.genomics_date)]

phi_data = phi_data[['dfci_mrn','date','hybrid_death_dt','hybrid_death_ind','text','split','any_cancer','progression','response','brain_met','bone_met','adrenal_met','liver_met','lung_met','node_met','peritoneal_met']]
phi_data['text'] = phi_data.text.str.lower()
phi_data['date'] = pd.to_datetime(phi_data.date)
phi_data['hybrid_death_dt'] = pd.to_datetime(phi_data.hybrid_death_dt)
phi_data['death_date'] = np.where(phi_data.hybrid_death_dt > pd.to_datetime('2022-12-31'), pd.to_datetime('2022-12-31'), phi_data.hybrid_death_dt)

#phi_data['died'] = 0.
phi_data['died'] = np.where(np.logical_and(phi_data.hybrid_death_ind == "Y", phi_data.hybrid_death_dt <= pd.to_datetime('2022-12-31')),  1., 0.)
phi_data['death_date'] = np.where(phi_data.death_date.isnull(), pd.to_datetime('2022-12-31'), phi_data.death_date)

phi_data.head()
phi_data['time_to_death'] = ((pd.to_datetime(phi_data['death_date']) - pd.to_datetime(phi_data['date'])).dt.days)/30
phi_data[['date','death_date','time_to_death']]
#phi_data.class_status.value_counts()

In [3]:
validation = phi_data[phi_data.split=='validation']

#validation.head()
#validation['length'] = validation.text.str.count(' ')

In [4]:
test = phi_data[phi_data.split=='test']

#test.head()

In [5]:
from torch.utils import data

class UnLabeledDataset(data.Dataset):
    def __init__(self, pandas_dataset):
        self.data = pandas_dataset.copy()
        self.indices = self.data.index.unique()
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', truncation_side='left')        
        
        
    def __len__(self):
        # how many notes in the dataset
        return len(self.indices)
    
    def __getitem__(self, index):
        # get data for notes corresponding to indices passed
        this_index = self.indices[index]
        pand = self.data.loc[this_index, :]
    
        encoded = self.tokenizer(pand['text'], padding='max_length', truncation=True)

        x_text_tensor = torch.tensor(encoded.input_ids, dtype=torch.long)
        x_attention_mask = torch.tensor(encoded.attention_mask, dtype=torch.long)


        return x_text_tensor, x_attention_mask
        

In [14]:
from transformers import AutoModel



from torch.nn import functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn import LSTM, Linear, Embedding, Conv1d, MaxPool1d, GRU, LSTMCell, Dropout, Module, Sequential, ReLU

   
class PrognosisModel(nn.Module):

    def __init__(self):
        super(PrognosisModel, self).__init__()
        
        self.bert = AutoModel.from_pretrained('bert-base-uncased')
        
        self.risk_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))


        
    def forward(self, x_text_tensor, x_attention_mask):
        # x should be tuple of input IDs, then attention mask
        
        main = self.bert(x_text_tensor, x_attention_mask)
        main = main.last_hidden_state[:,0,:].squeeze(1)

                                          
        risk_out = F.sigmoid(self.risk_head(main))

        return  risk_out


In [16]:
themodel = PrognosisModel()
themodel.load_state_dict(torch.load('./imaging_prognosis_model_teacher.pt'))
themodel.to('cuda')

themodel.eval()

no_shuffle_valid_dataset = data.DataLoader(UnLabeledDataset(mimic_cancer), batch_size=16, shuffle=False, num_workers=0)


output_prediction_list = []
for batch in no_shuffle_valid_dataset:
    x_text_ids = batch[0].to('cuda')
    x_attention_mask = batch[1].to('cuda')
    with torch.no_grad():
        predictions = themodel(x_text_ids, x_attention_mask).detach().cpu().numpy()
    output_prediction_list.append(predictions)


output_predictions = np.concatenate(output_prediction_list).squeeze(1)


output_mimic = mimic_cancer.copy()
output_mimic['risk_scores'] = output_predictions

In [17]:
output_mimic.to_csv('./data/mimic_imaging_risk_scores.csv')

In [6]:
from torch.utils import data

class PseudoLabeledDataset(data.Dataset):
    def __init__(self, pandas_dataset):
        self.data = pandas_dataset.copy()
        self.indices = self.data.index.unique()
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', truncation_side='left')        
        
        
    def __len__(self):
        # how many notes in the dataset
        return len(self.indices)
    
    def __getitem__(self, index):
        # get data for notes corresponding to indices passed
        this_index = self.indices[index]
        pand = self.data.loc[this_index, :]
        #label = torch.tensor(pand.progression, dtype=torch.float32)
    
        encoded = self.tokenizer(pand['text'], padding='max_length', truncation=True)

        x_text_tensor = torch.tensor(encoded.input_ids, dtype=torch.long)
        x_attention_mask = torch.tensor(encoded.attention_mask, dtype=torch.long)
        
        #y_class_status = torch.tensor(pand.class_status, dtype=torch.long)

        risk_score = pand['risk_scores']
        return x_text_tensor, x_attention_mask, risk_score
        
        

In [7]:
from torch.utils import data

class LabeledDataset(data.Dataset):
    def __init__(self, pandas_dataset):
        self.data = pandas_dataset.copy()
        self.indices = self.data.index.unique()
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', truncation_side='left')        
        
        
    def __len__(self):
        # how many notes in the dataset
        return len(self.indices)
    
    def __getitem__(self, index):
        # get data for notes corresponding to indices passed
        this_index = self.indices[index]
        pand = self.data.loc[this_index, :]
    
        encoded = self.tokenizer(pand['text'], padding='max_length', truncation=True)

        x_text_tensor = torch.tensor(encoded.input_ids, dtype=torch.long)
        x_attention_mask = torch.tensor(encoded.attention_mask, dtype=torch.long)

        y_time_to_death = torch.tensor(pand.time_to_death, dtype=torch.float32)
        y_died = torch.tensor(pand.died, dtype=torch.float32)
        
     
        return x_text_tensor, x_attention_mask, y_time_to_death, y_died
        

In [8]:
from transformers import AutoModel



from torch.nn import functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn import LSTM, Linear, Embedding, Conv1d, MaxPool1d, GRU, LSTMCell, Dropout, Module, Sequential, ReLU

   
class PrognosisModel(nn.Module):

    def __init__(self):
        super(PrognosisModel, self).__init__()
        
        self.bert = AutoModel.from_pretrained('bert-base-uncased')
        
        self.risk_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))


        
    def forward(self, x_text_tensor, x_attention_mask):
        # x should be tuple of input IDs, then attention mask
        
        main = self.bert(x_text_tensor, x_attention_mask)
        main = main.last_hidden_state[:,0,:].squeeze(1)

                                          
        risk_out = F.sigmoid(self.risk_head(main))

        return  risk_out


In [9]:
# train loop
from transformers import get_scheduler
from torch.optim import AdamW, Adam
#, get_linear_schedule_with_warmup


def train_model(model, num_epochs, trainloader, validloader=None, device='cuda'):
    
    

    optimizer = AdamW(model.parameters(), lr=5e-5)
    num_training_steps = num_epochs * len(trainloader)
    lr_scheduler = get_scheduler(
        name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

    model.to(device)
    
    for epoch in range(num_epochs):  
        running_train_loss = 0.0
        mean_train_loss = 0.0
        
        running_valid_loss = 0.0
        mean_valid_loss = 0.0

        num_train_batches = len(trainloader)
                
        model.train()

        for i, batch in enumerate(trainloader, 0):
            input_ids, input_masks, risk_score = [x.to(device) for x in batch]
            
            optimizer.zero_grad()

            pred_h_t = model(input_ids, input_masks).squeeze(1)            


                
            prognosis_loss = F.binary_cross_entropy(pred_h_t, risk_score)
            prognosis_loss.backward()
            optimizer.step()
            lr_scheduler.step()

            running_train_loss += prognosis_loss.detach().cpu().numpy()
            mean_train_loss = running_train_loss / (i+1)


            print('Training Epoch: ' + str(epoch+1) + ', batch: ' + str(i + 1) + '/' + str(num_train_batches) + ' this_loss:' + str(prognosis_loss.detach().cpu().numpy()) 
                  +', train loss: ' + str(mean_train_loss), end='\r', flush=True)
        
        print('\n')
        # eval on valid
        
        # if validloader is not None:
        #     model.eval()

            
        #     for i, batch in enumerate(validloader, 0):
        #         input_ids, input_masks, risk_score = [x.to(device) for x in batch]
                    
        #         with torch.no_grad():
        #             pred_h_t = model(input_ids, input_masks).squeeze(1)


                    
        #         prognosis_loss = F.binary_cross_entropy(pred_h_t, risk_scorte)
    
        #         running_valid_loss += prognosis_loss.data.cpu().numpy()
                    
        #     mean_valid_loss = running_valid_loss / (i+1)

    
        #     print('Validation Epoch: ' + str(epoch+1) +', validation loss: ' + str(mean_valid_loss), end='\r', flush=True)            # 


    

In [34]:
# actual student model training, commented out after model trained.

#themodel = PrognosisModel().to('cuda')
#trainloader = data.DataLoader(PseudoLabeledDataset(output_mimic.reset_index(drop=True)), batch_size=16, num_workers=8, shuffle=True)
#validloader = data.DataLoader(LabeledDataset(validation.reset_index(drop=True)), batch_size=16, num_workers=8, shuffle=True)
#train_model(themodel,2, trainloader)

Training Epoch: 1, batch: 13603/13603 this_loss:0.14987479, train loss: 0.089875888120483539

Training Epoch: 2, batch: 13603/13603 this_loss:0.07460602, train loss: 0.088597042857833821



In [35]:
#torch.save(themodel.state_dict(), 'imaging_prognosis_model_student.pt')

In [10]:
# write out actual validation dataset
themodel = PrognosisModel()
themodel.load_state_dict(torch.load('imaging_prognosis_model_student.pt'))
themodel.to('cuda')


themodel.eval()

no_shuffle_valid_dataset = data.DataLoader(LabeledDataset(test), batch_size=16, shuffle=False, num_workers=0)

event_times = []
events = []
output_prediction_list = []
for batch in no_shuffle_valid_dataset:
    x_text_ids = batch[0].to('cuda')
    x_attention_mask = batch[1].to('cuda')
    event_times.append(batch[2].detach().cpu().numpy())
    events.append(batch[3].detach().cpu().numpy())
    with torch.no_grad():
        predictions = themodel(x_text_ids, x_attention_mask).detach().cpu().numpy()
    output_prediction_list.append(predictions)


output_predictions = np.concatenate(output_prediction_list).squeeze(1)
event_times = np.concatenate(event_times)
events = np.concatenate(events)



In [11]:
from sksurv.metrics import concordance_index_censored

In [12]:
concordance_index_censored(np.where(events==1, True, False), event_times, output_predictions)

(0.7578195176372534, 1450243, 463462, 0, 181)