In [3]:
from tdc.multi_pred import DrugRes
import numpy as np
import matplotlib.pyplot as plot
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm


In [65]:
class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
    def forward(self, x):
        return self.linear(x)
    

def train_loop(dataloader, model, loss_fn, optimizer):
    train_loss = 0
    for batch, (X, y) in enumerate(dataloader):
        pred = model(X.float())
        loss = loss_fn(pred, y.float())
        train_loss += loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return train_loss / len(dataloader)

def test_loop(dataloader, model, loss_fn):
    test_loss = 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X.float())
            test_loss += loss_fn(pred, y.float()).item()
    return test_loss / len(dataloader)
    

In [114]:
class Drug():
    def __init__(self, drug_id, drug_smiles):
        self.id = drug_id
        self.smiles = drug_smiles
#         self.train_dataloader = train_dataloader
#         self.test_dataloader = test_dataloader
    
class Model():
    def __init__(self):
        data = DrugRes(name = 'GDSC1')
        split = data.get_split()
        self.train = split['train']
        self.test = split['test']
        self.valid = split['valid']
        self.drug_IDs = self.train['Drug_ID'].unique()
        assert(set(self.test['Drug_ID'].unique().flatten()) == \
               set(self.drug_IDs.flatten())) # make sure train, test set involves same set of drugs
        
        self.drugs = {}
        self.submodels = {}
        self.batch_size = 256
            
        
    def train_all_submodels(self):
        epochs = 1e5
        lr = 1e-7
        loss_fn = torch.nn.MSELoss()
        print("Initializing and training submodel for each drug")
        for ID in tqdm(self.drug_IDs):
            train_X = torch.from_numpy(np.vstack(self.train[self.train['Drug_ID']==ID]['Cell Line'].to_numpy()))
            train_Y = torch.from_numpy(np.vstack(self.train[self.train['Drug_ID']==ID]['Y'].to_numpy()))
            test_X = torch.from_numpy(np.vstack(self.test[self.test['Drug_ID']==ID]['Cell Line'].to_numpy()))
            test_Y = torch.from_numpy(np.vstack(self.test[self.test['Drug_ID']==ID]['Y'].to_numpy()))
            self.drugs[ID] = Drug(ID, None)  
            self.submodels[ID] = LogisticRegression(train_X.shape[1], 1).to('cpu')
            optimizer = torch.optim.Adam(self.submodels[ID].parameters(), lr=lr)
            train_dataloader = DataLoader(TensorDataset(train_X, train_Y),\
                                               self.batch_size, shuffle = True)
            test_dataloader = DataLoader(TensorDataset(test_X, test_Y),\
                                              self.batch_size, shuffle = True)
            train_loss, test_loss = [100], [100]
            counter = 0
            for t in range(epochs):
                train_l = train_loop(train_dataloader, self.submodels[ID], loss_fn, optimizer)
                test_l = test_loop(test_dataloader, self.submodels[ID], loss_fn)
                if (t + 1) % 100 == 0:
                    print(f"Epoch {t+1}\n-------------------------------")
                    train_loss.append(train_l)
                    test_loss.append(test_l)
                    print(f"Avg train loss: {train_loss[-1]:.8f}")
                    print(f"Avg test loss: {test_loss[-1]:.8f}")
                    if test_loss[-1] >= test_loss[-2]: # simple early stopping
                        counter+=1
                    else:
                        counter=max(0, counter-1)
                    if counter>=3:
                        break
    def evaluate(criterion = 'MSE'):
        if criterion == 'MAE':
            loss_fn = torch.nn.L1Loss()
        elif criterion == 'MSE':
            loss_fn = torch.nn.MSELoss()
        
        total_loss = 0
        valid_X = self.valid['Cell Line'].to_numpy()
        valid_Y = self.valid['Y'].to_numpy()
        valid_IDs = self.valid['Drug_ID'].to_numpy()
        for X, y, ID in zip(valid_X, valid_Y, valid_IDs):
            pred = self.submodels[ID](torch.from_numpy(X).float)
            total_loss += loss_fn(pred, torch.from_numpy(y).float()).item()
        return total_loss / len(valid_X)

In [None]:
model = Model()
model.train_all_submodels()

In [140]:
loss_fn = torch.nn.MSELoss()
valid_X = model.valid['Cell Line'].to_numpy()
valid_Y = model.valid['Y'].to_numpy()
valid_IDs = model.valid['Drug_ID'].to_numpy()
counter = 0
total_loss = 0
for X, y, ID in zip(valid_X, valid_Y, valid_IDs):
    if ID in model.submodels:
        counter += 1
        pred = model.submodels[ID](torch.from_numpy(X).float())
        total_loss += loss_fn(pred, torch.from_numpy(np.asarray(y)).float()).item()

avg_loss = total_loss / counter
print(avg_loss)





1.8911321830559649
