In [None]:
# this notebook trains and evaluates a simple teacher prognostic model on DFCI imaging reports

In [1]:
import pandas as pd
import numpy as np
from transformers import AutoTokenizer
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
reports = pd.read_csv('/data/clin_notes_outcomes/profile_3-2023/derived_data/labeled_imaging_prissmm.csv')

In [3]:
# make sure progression is defined to include annotations of mixed response
reports['progression'] = np.where(reports.class_status==3,1,reports.progression)

In [4]:

reports.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 37274 entries, 0 to 37273
Data columns (total 35 columns):
 #   Column                    Non-Null Count  Dtype  
---  ------                    --------------  -----  
 0   Unnamed: 0.1              37274 non-null  int64  
 1   dfci_mrn                  37274 non-null  float64
 2   cancer_type               37274 non-null  object 
 3   image_scan_type           37274 non-null  float64
 4   date                      37274 non-null  object 
 5   head_imaged               37274 non-null  float64
 6   neck_imaged               37274 non-null  float64
 7   spine_imaged              37274 non-null  float64
 8   chest_imaged              37274 non-null  float64
 9   abdomen_imaged            37274 non-null  float64
 10  pelvis_imaged             37274 non-null  float64
 11  any_cancer                37274 non-null  int64  
 12  progression               37274 non-null  int64  
 13  response                  37274 non-null  int64  
 14  class_

In [5]:
reports = reports[pd.to_datetime(reports.date) > pd.to_datetime(reports.genomics_date)]

In [6]:
reports.image_scan_type.value_counts()

image_scan_type
1.0     14986
3.0      4864
7.0      1614
5.0      1553
11.0      187
Name: count, dtype: int64

In [7]:
check_distribution = reports[reports.split=='train']

In [8]:
check_distribution.cancer_type.value_counts()/check_distribution.shape[0]

cancer_type
nsclc_phase2_existing    0.287033
crc                      0.170681
breast                   0.154446
prostate                 0.115238
bladder                  0.097411
pancreas                 0.093697
rcc_barkouny             0.081494
Name: count, dtype: float64

In [9]:
check_distribution.class_status.value_counts()/check_distribution.shape[0]

class_status
0.0    0.451454
4.0    0.230794
2.0    0.163094
5.0    0.071360
1.0    0.062765
3.0    0.020320
Name: count, dtype: float64

In [None]:
to_train = reports[['dfci_mrn','date','hybrid_death_dt','hybrid_death_ind','text','split','any_cancer','progression','response','brain_met','bone_met','adrenal_met','liver_met','lung_met','node_met','peritoneal_met']]
to_train['text'] = to_train.text.str.lower()
to_train['date'] = pd.to_datetime(to_train.date)
to_train['hybrid_death_dt'] = pd.to_datetime(to_train.hybrid_death_dt)
to_train['death_date'] = np.where(to_train.hybrid_death_dt > pd.to_datetime('2022-12-31'), pd.to_datetime('2022-12-31'), to_train.hybrid_death_dt)

#to_train['died'] = 0.
to_train['died'] = np.where(np.logical_and(to_train.hybrid_death_ind == "Y", to_train.hybrid_death_dt <= pd.to_datetime('2022-12-31')),  1., 0.)
to_train['death_date'] = np.where(to_train.death_date.isnull(), pd.to_datetime('2022-12-31'), to_train.death_date)

to_train.head()
to_train['time_to_death'] = ((pd.to_datetime(to_train['death_date']) - pd.to_datetime(to_train['date'])).dt.days)/30
to_train[['date','death_date','time_to_death']]

In [11]:
training = to_train[to_train.split=='train']
#training['length'] = training.text.str.count(' ')

In [12]:
training.shape[0]

18848

In [13]:
training.dfci_mrn.nunique()

2156

In [14]:
validation = to_train[to_train.split=='validation']
#validation['length'] = validation.text.str.count(' ')

In [15]:
test = to_train[to_train.split=='test']


In [16]:
from torch.utils import data

class LabeledDataset(data.Dataset):
    def __init__(self, pandas_dataset):
        self.data = pandas_dataset.copy()
        self.indices = self.data.index.unique()
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', truncation_side='left')        
        
        
    def __len__(self):
        # how many notes in the dataset
        return len(self.indices)
    
    def __getitem__(self, index):
        # get data for notes corresponding to indices passed
        this_index = self.indices[index]
        pand = self.data.loc[this_index, :]
    
        encoded = self.tokenizer(pand['text'], padding='max_length', truncation=True)

        x_text_tensor = torch.tensor(encoded.input_ids, dtype=torch.long)
        x_attention_mask = torch.tensor(encoded.attention_mask, dtype=torch.long)

        y_time_to_death = torch.tensor(pand.time_to_death, dtype=torch.float32)
        y_died = torch.tensor(pand.died, dtype=torch.float32)
        
     
        return x_text_tensor, x_attention_mask, y_time_to_death, y_died
        

In [17]:
from transformers import AutoModel



from torch.nn import functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn import LSTM, Linear, Embedding, Conv1d, MaxPool1d, GRU, LSTMCell, Dropout, Module, Sequential, ReLU

   
class PrognosisModel(nn.Module):

    def __init__(self):
        super(PrognosisModel, self).__init__()
        
        self.bert = AutoModel.from_pretrained('bert-base-uncased')
        
        self.risk_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))


        
    def forward(self, x_text_tensor, x_attention_mask):
        # x should be tuple of input IDs, then attention mask
        
        main = self.bert(x_text_tensor, x_attention_mask)
        main = main.last_hidden_state[:,0,:].squeeze(1)

                                          
        risk_out = F.sigmoid(self.risk_head(main))

        return  risk_out


In [18]:
#activation = nn.Softplus()
def survival_loss(linear_h, end_times, events):
    #hazard = activation(linear_h)
    hazard = linear_h
    
    cum_hazard = hazard * end_times

    loss = -torch.sum(torch.log(hazard)*events - cum_hazard)

    return loss

In [19]:
# train loop
from transformers import get_scheduler
from torch.optim import AdamW, Adam
#, get_linear_schedule_with_warmup


def train_model(model, num_epochs, trainloader, validloader=None, device='cuda'):
    
    

    optimizer = AdamW(model.parameters(), lr=5e-5)
    num_training_steps = num_epochs * len(trainloader)
    lr_scheduler = get_scheduler(
        name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

    model.to(device)
    
    for epoch in range(num_epochs):  
        running_train_loss = 0.0
        mean_train_loss = 0.0
        
        running_valid_loss = 0.0
        mean_valid_loss = 0.0

        num_train_batches = len(trainloader)
                
        model.train()

        preds_h_list = []
        times_list = []
        events_list = []

        num_dead = 0

        model.train()
        for i, batch in enumerate(trainloader, 0):
            input_ids, input_masks, time_to_death, died = [x.to(device) for x in batch]
            
            optimizer.zero_grad()

            pred_h_t = model(input_ids, input_masks)
            pred_h_t = pred_h_t.squeeze(1)
            
            preds_h_list.append(pred_h_t)
            times_list.append(time_to_death)
            events_list.append(died)

            num_dead += died.sum()

            if num_dead >= 3:
                pred_h_t = torch.cat(preds_h_list)
                times = torch.cat(times_list)
                events = torch.cat(events_list)
                
                prognosis_loss = survival_loss(pred_h_t, times, events)
                prognosis_loss.backward()
                optimizer.step()
                lr_scheduler.step()

                running_train_loss += prognosis_loss.data.cpu().numpy()
                mean_train_loss = running_train_loss / (i+1)
                
           
                preds_h_list = []
                times_list = []
                events_list = []
        
                num_dead = 0

            print('Training Epoch: ' + str(epoch+1) + ', batch: ' + str(i + 1) + '/' + str(num_train_batches) + ' this_loss:' + str(prognosis_loss.detach().cpu().numpy()) 
                  +', train loss: ' + str(mean_train_loss), end='\r', flush=True)
        
        print('\n')
        # eval on valid
        
        if validloader is not None:
            model.eval()
            preds_h_list = []
            times_list = []
            events_list = []
    
            num_dead = 0
            
            for i, batch in enumerate(validloader, 0):
                input_ids, input_masks, time_to_death, died = [x.to(device) for x in batch]
                    
                with torch.no_grad():
                    pred_h_t = model(input_ids, input_masks)
                    pred_h_t = pred_h_t.squeeze(1)
                
                preds_h_list.append(pred_h_t)
                times_list.append(time_to_death)
                events_list.append(died)
    
                num_dead += died.sum()
    
                if num_dead >= 3:
                    pred_h_t = torch.cat(preds_h_list)
                    times = torch.cat(times_list)
                    events = torch.cat(events_list)
                    
                    prognosis_loss = survival_loss(pred_h_t, times, events)
    
                    running_valid_loss += prognosis_loss.data.cpu().numpy()
                    
               
                    preds_h_list = []
                    times_list = []
                    events_list = []
            
                    num_dead = 0

            mean_valid_loss = running_valid_loss / (i+1)

    
            print('Validation Epoch: ' + str(epoch+1) +', validation loss: ' + str(mean_valid_loss), end='\r', flush=True)
            


    

In [20]:
# # actual training code, commented out after model was trained

# training_small = training.sample(100)
# validation_small = validation.sample(100)

# themodel = PrognosisModel().to('cuda')
# trainloader = data.DataLoader(LabeledDataset(training.reset_index(drop=True)), batch_size=16, num_workers=8, shuffle=True)
# validloader = data.DataLoader(LabeledDataset(validation.reset_index(drop=True)), batch_size=16, num_workers=8, shuffle=True)
# train_model(themodel,3, trainloader, validloader)

In [21]:
#torch.save(themodel.state_dict(), './imaging_prognosis_model_teacher.pt')

In [23]:
# write out actual test dataset
themodel = PrognosisModel()
themodel.load_state_dict(torch.load('./imaging_prognosis_model_teacher.pt'))
themodel.to('cuda')

themodel.eval()

no_shuffle_valid_dataset = data.DataLoader(LabeledDataset(test), batch_size=16, shuffle=False, num_workers=0)

event_times = []
events = []
output_prediction_list = []
for batch in no_shuffle_valid_dataset:
    x_text_ids = batch[0].to('cuda')
    x_attention_mask = batch[1].to('cuda')
    event_times.append(batch[2].detach().cpu().numpy())
    events.append(batch[3].detach().cpu().numpy())
    with torch.no_grad():
        predictions = themodel(x_text_ids, x_attention_mask).detach().cpu().numpy()
    output_prediction_list.append(predictions)


output_predictions = np.concatenate(output_prediction_list).squeeze(1)
event_times = np.concatenate(event_times)
events = np.concatenate(events)



In [24]:
from sksurv.metrics import concordance_index_censored

In [25]:
concordance_index_censored(np.where(events==1, True, False), event_times, output_predictions)

(0.7586501054237722, 1451832, 461872, 1, 181)