In [None]:
# this notebook trains a student model on MIMIC data to predict labels assigned
# by the DFCI teacher that was overfit to small training set

In [1]:
import pandas as pd
import numpy as np
from transformers import AutoTokenizer
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
mimic0 = pd.read_csv('../data/first_half_discharges_overfit_small_train.csv')
mimic1 = pd.read_csv('../data/second_half_discharges_overfit_small_train.csv')
mimic = pd.concat([mimic0,mimic1], axis=0)



In [3]:
mimic.info()

<class 'pandas.core.frame.DataFrame'>
Index: 141377 entries, 0 to 70688
Data columns (total 12 columns):
 #   Column           Non-Null Count   Dtype  
---  ------           --------------   -----  
 0   Unnamed: 0       141377 non-null  int64  
 1   note_id          141377 non-null  object 
 2   subject_id       141377 non-null  int64  
 3   hadm_id          141377 non-null  int64  
 4   note_type        141377 non-null  object 
 5   note_seq         141377 non-null  int64  
 6   charttime        141377 non-null  object 
 7   storetime        141375 non-null  object 
 8   text             141377 non-null  object 
 9   outcome_0_logit  141377 non-null  float64
 10  outcome_1_logit  141377 non-null  float64
 11  outcome_2_logit  141377 non-null  float64
dtypes: float64(3), int64(4), object(5)
memory usage: 14.0+ MB


In [4]:
mimic = mimic.rename(columns={'outcome_0_logit':'any_cancer_logit',
                              'outcome_1_logit':'response_logit',
                              'outcome_2_logit':'progression_logit'})

In [5]:
phi_data = pd.read_csv('/data/clin_notes_outcomes/profile_3-2023/derived_data/labeled_medonc_prissmm_mixedisprog.csv')


In [None]:
validation = phi_data[phi_data.split=='validation']

validation.head()


In [7]:

from transformers import AutoModel



from torch.nn import functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn import LSTM, Linear, Embedding, Conv1d, MaxPool1d, GRU, LSTMCell, Dropout, Module, Sequential, ReLU

   
class LabeledModel(nn.Module):

    def __init__(self):
        super(LabeledModel, self).__init__()
        
        self.longformer = AutoModel.from_pretrained('yikuan8/Clinical-Longformer')
        
        self.any_cancer_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))
        self.response_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))
        self.progression_head = Sequential(Linear(768, 128), ReLU(), Linear(128,1))


        
    def forward(self, x_text_tensor, x_attention_mask):
        # x should be tuple of input IDs, then attention mask
        global_attention_mask = torch.zeros_like(x_text_tensor).to('cuda:1')
        # global attention on cls token
        global_attention_mask[:, 0] = 1
        main = self.longformer(x_text_tensor, x_attention_mask, global_attention_mask)
        main = main.last_hidden_state[:,0,:].squeeze(1)

                                          
        any_cancer_out = self.any_cancer_head(main)
        response_out = self.response_head(main)
        progression_out = self.progression_head(main)



        
        return any_cancer_out, response_out, progression_out
        




In [8]:
from torch.utils import data

class PseudoLabeledDataset(data.Dataset):
    def __init__(self, pandas_dataset):
        self.data = pandas_dataset.copy()
        self.indices = self.data.index.unique()
        self.tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-Longformer", truncation_side='right')        
        
        
    def __len__(self):
        # how many notes in the dataset
        return len(self.indices)
    
    def __getitem__(self, index):
        # get data for notes corresponding to indices passed
        this_index = self.indices[index]
        pand = self.data.loc[this_index, :]
        #label = torch.tensor(pand.progression, dtype=torch.float32)
    
        encoded = self.tokenizer(pand['text'], padding='max_length', truncation=True)

        x_text_tensor = torch.tensor(encoded.input_ids, dtype=torch.long)
        x_attention_mask = torch.tensor(encoded.attention_mask, dtype=torch.long)
        
        #y_class_status = torch.tensor(pand.class_status, dtype=torch.long)

        outcome_vars = [pand.any_cancer_logit, pand.response_logit, pand.progression_logit]
        return x_text_tensor, x_attention_mask, *tuple(outcome_vars)
        
        

In [None]:
temp = PseudoLabeledDataset(mimic.reset_index(drop=True))

temp_loader = data.DataLoader(temp, batch_size=4, shuffle=True)
temp_iter = iter(temp_loader)
a = next(temp_iter)


In [10]:
print([x.shape for x in a])

[torch.Size([4, 4096]), torch.Size([4, 4096]), torch.Size([4]), torch.Size([4]), torch.Size([4])]


In [11]:
from torch.utils import data

class LabeledDataset(data.Dataset):
    def __init__(self, pandas_dataset):
        self.data = pandas_dataset.copy()
        self.indices = self.data.index.unique()
        self.tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-Longformer", truncation_side='left')        
        
        
    def __len__(self):
        # how many notes in the dataset
        return len(self.indices)
    
    def __getitem__(self, index):
        # get data for notes corresponding to indices passed
        this_index = self.indices[index]
        pand = self.data.loc[this_index, :]
        #label = torch.tensor(pand.progression, dtype=torch.float32)
    
        encoded = self.tokenizer(pand['text'], padding='max_length', truncation=True)

        x_text_tensor = torch.tensor(encoded.input_ids, dtype=torch.long)
        x_attention_mask = torch.tensor(encoded.attention_mask, dtype=torch.long)
        
        y_any_cancer = torch.tensor(pand.any_cancer, dtype=torch.float32)
        y_response = torch.tensor(pand.response, dtype=torch.float32)
        y_progression = torch.tensor(pand.progression, dtype=torch.float32)
        


        return x_text_tensor, x_attention_mask, y_any_cancer, y_response, y_progression
        
        

In [12]:
device = 'cuda:1'

In [13]:
# train loop
from transformers import get_scheduler
from torch.optim import AdamW, Adam
#, get_linear_schedule_with_warmup


def train_model(model, num_epochs, trainloader, validloader=None, device='cuda:1'):
    
    

    optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
    num_training_steps = num_epochs * len(trainloader)
    lr_scheduler = get_scheduler(
        name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

    model.to(device)
    
    for epoch in range(num_epochs):  # loop over the dataset multiple times
        running_train_losses = [0.0 for i in range(3)]
        mean_train_losses = [0.0 for i in range(3)]
        
        running_valid_losses = [0.0 for i in range(3)]
        mean_valid_losses = [0.0 for i in range(3)]

        num_train_batches = len(trainloader)
                
        model.train()
        
        for i, batch in enumerate(trainloader, 0):
            input_ids = batch[0].to(device)
            input_masks = batch[1].to(device)
            
            optimizer.zero_grad()
            
            outputs_true = [x.to(device) for x in batch[2:]]
            
            outputs_pred = model(input_ids, input_masks)
            
           
            losses = [F.binary_cross_entropy_with_logits(outputs_pred[x].squeeze(1), torch.sigmoid(outputs_true[x])) for x in range(3)]
            
            total_loss = 0.0
            for j in range(3):
                total_loss = total_loss + losses[j]

                
            total_loss.backward()
            optimizer.step()
            lr_scheduler.step()
            
            
            for j in range(3):
                running_train_losses[j] += losses[j].detach().cpu().numpy()
                mean_train_losses[j] = running_train_losses[j] / (i+1)
                
            print('Training Epoch: ' + str(epoch+1) + ', batch: ' + str(i + 1) + '/' + str(num_train_batches) + ' this_loss:' + str(total_loss.detach().cpu().numpy()) +', train losses: ' + str([str(x) + ': ' + str(mean_train_losses[x]) + ", " for x in range(3)]), end='\r', flush=True)
        
        print('\n')
        # eval on valid

        torch.save(model.state_dict(), 'dfci_mimic_note_longformer_overfit_small_train.pt')
        
        if validloader is not None:
            num_valid_batches = len(validloader)
            model.eval()
            
            for i, batch in enumerate(validloader, 0):
                input_ids = batch[0].to(device)
                input_masks = batch[1].to(device)


                outputs_true = [x.to(device) for x in batch[2:]]

                outputs_pred = model(input_ids, input_masks)

                losses = [F.binary_cross_entropy_with_logits(outputs_pred[x].squeeze(1), torch.sigmoid(outputs_true[x])) for x in range(3)]

                total_loss = 0.0
                for j in range(3):
                    total_loss = total_loss + losses[j]
                


                for j in range(3):
                    running_valid_losses[j] += losses[j].detach().cpu().numpy()

            
            for j in range(3):
                mean_valid_losses[j] = running_valid_losses[j] / (i+1)
            

            
            print('Validation Epoch: ' + str(epoch+1) + ', batch: ' + str(i + 1) + '/' + str(num_valid_batches) + ', valid losses: ' + str([str(x) + ': ' + str(mean_valid_losses[x]) + ", " for x in range(10)]), end='\r', flush=True)
            print('\n')

    

In [None]:

themodel = LabeledModel().to('cuda:1')
trainloader = data.DataLoader(PseudoLabeledDataset(mimic.reset_index(drop=True)), batch_size=2, num_workers=8, shuffle=True)
#validloader = data.DataLoader(LabeledDataset(validation.reset_index(drop=True)), batch_size=4, num_workers=8, shuffle=True)
train_model(themodel,8, trainloader, device='cuda:1')

Some weights of LongformerModel were not initialized from the model checkpoint at yikuan8/Clinical-Longformer and are newly initialized: ['longformer.pooler.dense.bias', 'longformer.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Training Epoch: 1, batch: 70689/70689 this_loss:0.9337675783971154, train losses: ['0: 0.3815686339567896, ', '1: 0.37347575843524855, ', '2: 0.13372587011603312, ']]]

Training Epoch: 2, batch: 70689/70689 this_loss:0.6156848652391111, train losses: ['0: 0.38605176772610045, ', '1: 0.40659981530350175, ', '2: 0.13526865986753986, ']

Training Epoch: 3, batch: 70689/70689 this_loss:0.5826248970501404, train losses: ['0: 0.386020821962349, ', '1: 0.4065184379992952, ', '2: 0.13523275849027333, ']]]]]

Training Epoch: 4, batch: 70689/70689 this_loss:0.5040451727156279, train losses: ['0: 0.38600184463521275, ', '1: 0.40643915479617976, ', '2: 0.13522169835464565, ']

Training Epoch: 5, batch: 70689/70689 this_loss:0.9459959376164059, train losses: ['0: 0.3859587396688456, ', '1: 0.40643653810791114, ', '2: 0.13520631531019814, ']]

Training Epoch: 6, batch: 70689/70689 this_loss:0.8748391733308205, train losses: ['0: 0.38596153224769325, ', '1: 0.40636751726761383, ', '2: 0.1351891766690

In [None]:
torch.save(themodel.state_dict(), 'dfci_mimic_note_longformer_overfit_small_train.pt')

In [None]:
# write out actual validation dataset (not relevant for overfit teacher labels)
themodel = LabeledModel()
themodel.load_state_dict(torch.load('dfci_mimic_note_longformer_overfit_small_train.pt'))
themodel.to('cuda')

themodel.eval()

no_shuffle_valid_dataset = data.DataLoader(LabeledDataset(validation), batch_size=8, shuffle=False, num_workers=0)

output_true_lists = [[] for x in range(3)]
output_prediction_lists = [[] for x in range(3)]
for batch in no_shuffle_valid_dataset:
    #thisframe = pd.DataFrame()
    x_text_ids = batch[0].to('cuda')
    x_attention_mask = batch[1].to('cuda')
    with torch.no_grad():
        predictions = themodel(x_text_ids, x_attention_mask)
    for j in range(3):
        output_true_lists[j].append(batch[2+j].detach().cpu().numpy())
        output_prediction_lists[j].append(predictions[j].detach().cpu().numpy())

output_true_lists = [np.concatenate(x) for x in output_true_lists]        
output_prediction_lists = [np.concatenate(x) for x in output_prediction_lists]


output_validation = validation.copy()
for x in range(3):
    output_validation['outcome_' + str(x) + '_logit'] = output_prediction_lists[x]

