In [1]:
import os
import pandas as pd
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
torch.manual_seed(42)
np.random.seed(0)

In [3]:
%cd /home/trishad2/PyTrial/

/home/trishad2/PyTrial


In [4]:
if torch.cuda.is_available(): 
    dev = "cuda:0" 
else:
    dev = "cpu" 
device = torch.device(dev)

In [5]:
data_path = 'lung/data/train_test_valid/'

In [6]:
data_path1 = 'lung/data/'

In [7]:
model_path = 'lung/model/train_test_valid/with_biobert/'

In [8]:
data= pd.read_csv(data_path1+'full_datav3.csv',  index_col = 0)

ae_cols = [i for i in data.columns if i.startswith('AE_')]
med_cols = [i for i in data.columns if i.startswith('CM_')]
treatment_cols = [i for i in data.columns if i.startswith('Treatment_')]
feature_cols =    treatment_cols + med_cols + ae_cols
vocab_size = [len(treatment_cols), len(med_cols), len(ae_cols)]
vocab_size

[3, 13, 83]

In [9]:
treatment_id_dict={}
for i in range(len(treatment_cols)):
    treatment_id_dict[i]=treatment_cols[i].split('_')[1]
    
med_id_dict={}
for i in range(len(med_cols)):
    med_id_dict[i]=med_cols[i].split('_')[1]
    
ae_id_dict={}
for i in range(len(ae_cols)):
    ae_id_dict[i]=ae_cols[i].split('_')[1]

In [10]:
ae_emb_dict = pd.read_pickle(data_path1+'ae_emb_dict.pickle')

med_emb_dict = pd.read_pickle(data_path1+'med_emb_dict.pickle')

treatment_emb_dict = pd.read_pickle(data_path1+'treatment_emb_dict.pickle')

In [11]:
visits = []
treatment_cols =  [col for col in data if col.startswith('Treatment_')]
medication_cols = [col for col in data if col.startswith('CM_')]
ae_cols = [col for col in data if col.startswith('AE_')]
for i in data.People.unique():
  sample=[]
  temp = data[data['People']==i]
  for index, row in temp.iterrows():
    visit=[]
    visit.append(np.nonzero(row[treatment_cols].to_list())[0].tolist())
    visit.append(np.nonzero(row[medication_cols].to_list())[0].tolist())
    visit.append(np.nonzero(row[ae_cols].to_list())[0].tolist())
    sample.append(visit)
  visits.append(sample)

In [12]:
visits_biobert=[]
#i for a patient
for i in range(len(visits)):
    
    visits_per_patient=[]
    #j for a visit of that patient
    for j in range(len(visits[i])):
        visit=[]
        #k for an event type of that visit
        for k in range(len(visits[i][j])):
            
            #l for an event of that event type
            for l in visits[i][j][k]:
                if k==0:
                    #print(treatment_id_dict[l])
                    visit.append(treatment_emb_dict[treatment_id_dict[l]])
                if k==1:
                    #print(med_code_dict[med_id_dict[l]])
                    visit.append(med_emb_dict[l])
                if k ==2:
                    #print(ae_code_dict[ae_id_dict[l]])
                    visit.append(ae_emb_dict[l])
        visit = torch.cat(visit)
        visit = visit.mean(dim = 0).detach().numpy()
        #print(visit.shape)
        visits_per_patient.append(visit)
    visits_biobert.append(visits_per_patient)

In [13]:
len(visits_biobert), len(visits_biobert[1]), len(visits_biobert[0][0])

(527, 9, 768)

In [14]:
vocab_size = 768

In [15]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    
    def __init__(self, seqs):
        # your code here
        self.x = seqs
        #self.y = labels
        #self.baseline = baseline
    
    def __len__(self):

        return len(self.x)
    
    def __getitem__(self, index):

        return self.x[index]#, self.y[index], self.baseline[index]
        

train_dataset = CustomDataset(visits_biobert)

In [16]:
def collate_fn_med(data):

    sequences = data

    #y = torch.tensor(labels, dtype=torch.float)
    #baseline = torch.tensor(np.vstack(baselines), dtype=torch.float)
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    max_num_visits = max(num_visits)
    lengths = [len(x) for x in sequences]

    x = torch.zeros((num_patients, max_num_visits, vocab_size), dtype=torch.float)    
    for i_patient, patient in enumerate(sequences):
        for j_visit, visit in enumerate(patient):
                x[i_patient][j_visit] = torch.from_numpy(np.array(visit))
                
    #print(torch.sum(x, dim=-1)!=0)
    masks = torch.sum(x, dim=-1) != 0
    
    return x, masks, lengths

In [17]:
x, masks, lengths = collate_fn_med(train_dataset)

In [18]:
masks[0]

tensor([ True,  True,  True,  True,  True, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False])

In [19]:
class MLMEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(MLMEncoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embedding = nn.Linear(input_size, hidden_size)

        self.pos_encoder = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size)
        )

        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(hidden_size, nhead=1, dim_feedforward = hidden_size,),
            num_layers
        )

        self.mlm_head = nn.Linear(hidden_size, input_size)
        
    def forward(self, x, mask):

        x = self.embedding(x)
        x = self.pos_encoder(x)


        x = x.permute(1, 0, 2)
        output = self.transformer_encoder(x, src_key_padding_mask=mask)
        output = output.permute(1, 0, 2)
        
        logits = self.mlm_head(output)
        return output, logits 

In [20]:
pretrained_encoder = MLMEncoder(input_size = vocab_size, hidden_size = 256, num_layers = 1 )

In [21]:
# Pretraining hyperparameters
pretrain_epochs = 100
pretrain_optimizer = torch.optim.Adam(pretrained_encoder.parameters(), lr=0.0001)

In [22]:
batch_size = len(train_dataset)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn_med)

# Pretraining loop
for epoch in range(pretrain_epochs):
    total_loss = 0.0

    for x, batch_mask, lengths in train_dataloader:  # Assuming you have a DataLoader for pretraining data
        
        pretrain_optimizer.zero_grad()
        number_of_visits = x.shape[1]
        

        # Randomly select visits to mask out
        mask_indices = random.sample(range(x.size(1)), k=1)  # Select 2 random visits to mask out

        # Generate the mask tensor
        mask = torch.ones(x.size(0), x.size(1)).bool()
        mask[:, mask_indices] = False
        
        
        # Replace visits with all zeroes based on the mask
        masked_input_data = torch.where(mask.unsqueeze(-1), x, torch.full_like(x, float(-10.0)))#torch.zeros_like(x[0]))

        and_mask = torch.logical_and(mask, batch_mask)

        # Forward pass
        output, logits = pretrained_encoder(masked_input_data, and_mask)

        # Compute MLM loss
        mlm_loss = F.mse_loss(logits, x.float(), reduction='none')
        mlm_loss = mlm_loss.masked_select(batch_mask.unsqueeze(-1)).mean()

        # Backward pass
        mlm_loss.backward(retain_graph=True)
        pretrain_optimizer.step()

        total_loss += mlm_loss.item()

    # Print average MLM loss for the epoch
    print(f"Pretraining Epoch {epoch+1} MLM Loss: {total_loss / len(train_dataloader)}")

Pretraining Epoch 1 MLM Loss: 0.46994689106941223
Pretraining Epoch 2 MLM Loss: 0.3977401852607727
Pretraining Epoch 3 MLM Loss: 0.3380074203014374
Pretraining Epoch 4 MLM Loss: 0.3105923533439636
Pretraining Epoch 5 MLM Loss: 0.27812159061431885
Pretraining Epoch 6 MLM Loss: 0.25639307498931885
Pretraining Epoch 7 MLM Loss: 0.23466560244560242
Pretraining Epoch 8 MLM Loss: 0.21940608322620392
Pretraining Epoch 9 MLM Loss: 0.19098100066184998
Pretraining Epoch 10 MLM Loss: 0.17713876068592072
Pretraining Epoch 11 MLM Loss: 0.18379372358322144
Pretraining Epoch 12 MLM Loss: 0.15356020629405975
Pretraining Epoch 13 MLM Loss: 0.14176160097122192
Pretraining Epoch 14 MLM Loss: 0.14829622209072113
Pretraining Epoch 15 MLM Loss: 0.12318902462720871
Pretraining Epoch 16 MLM Loss: 0.1150459423661232
Pretraining Epoch 17 MLM Loss: 0.11476486176252365
Pretraining Epoch 18 MLM Loss: 0.09963613003492355
Pretraining Epoch 19 MLM Loss: 0.09363079816102982
Pretraining Epoch 20 MLM Loss: 0.11118687689

In [23]:
# Save the pretrained encoder
torch.save(pretrained_encoder.state_dict(), model_path+"pretrained_encoder_MLM_biobert_e100.pkl")