# Matrix Factorisation with Drug and Target similarity
Currently using cosine similarity for the target similarity matrix but this is being changed to a normalised smith waterman score

In [None]:
# import modules
import numpy as np
import matplotlib.pyplot as plt
import torch
from tdc.multi_pred import DTI
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import CountVectorizer
import logging
from rdkit import Chem

# load in the three datasets
data_Kd = DTI(name = 'BindingDB_Kd')
data_Kd.convert_to_log(form = 'binding')

# data_ic50 = DTI(name = 'BindingDB_IC50')
# data_ic50.convert_to_log(form = 'binding')

# data_Ki = DTI(name = 'BindingDB_Ki')
# data_Ki.convert_to_log(form = 'binding')

In [2]:
def data_split(data):
    # split data and get ID dicts
    split = data.get_split(seed = 42,  frac = [0.6, 0.05, 0.35])
    train = split['train']
    test = split['test']

    train = train[['Drug_ID', 'Drug', 'Target', 'Y']].dropna()
    train = train.reset_index(drop=True)

    ID_to_Drug = dict(enumerate(list(dict.fromkeys(train['Drug_ID']))))
    ID_to_Target = dict(enumerate(list(dict.fromkeys(train['Target']))))
    Drug_to_ID = dict((v,k) for k,v in ID_to_Drug.items())
    Target_to_ID = dict((v,k) for k,v in ID_to_Target.items())
    
    return train, test, Drug_to_ID, Target_to_ID

def data_loader(data, drug_dict, target_dict):
    # load data into correct format
    data["Target_ID2"] = data["Target"].apply(lambda x:target_dict.get(x))
    data["Drug_ID2"] = data["Drug_ID"].apply(lambda x:drug_dict.get(x))
    data = data.dropna()

    drug_ID = data["Drug_ID2"].to_numpy()
    target_ID = data["Target_ID2"].to_numpy()
    features = np.vstack((drug_ID, target_ID)).T
    label = data['Y'].to_numpy()
    return features, label

class RatingDataset(Dataset):
    def __init__(self, train, label):
        self.feature_= train
        self.label_= label
    def __len__(self):
    #return size of dataset
        return len(self.feature_)
    def __getitem__(self, idx):
        return  torch.tensor(self.feature_[idx], dtype=torch.long),torch.tensor(self.label_[idx], dtype=torch.float)
    
    
class MatrixFactorization(torch.nn.Module):
    
    def __init__(self, n_users, n_items, n_factors, drug_sim_mat, target_sim_mat):
        super().__init__()
        self.user_factors = torch.nn.Embedding(n_users, n_factors)
        self.item_factors = torch.nn.Embedding(n_items, n_factors)
        torch.nn.init.xavier_uniform_(self.user_factors.weight)
        torch.nn.init.xavier_uniform_(self.item_factors.weight)
        
        self.user_biases = torch.nn.Embedding(n_users, 1)
        self.item_biases = torch.nn.Embedding(n_items,1)
        self.user_biases.weight.data.fill_(0.)
        self.item_biases.weight.data.fill_(0.)
        
        # NEW WEIGHTS FOR THE SIMILARITY MATRIX
        self.drug_sim = drug_sim_mat
        self.user_sim = torch.nn.Embedding(n_users, 1)
        self.target_sim = target_sim_mat
        self.item_sim = torch.nn.Embedding(n_items, 1)
        torch.nn.init.xavier_uniform_(self.user_sim.weight)
        torch.nn.init.xavier_uniform_(self.item_sim.weight)

        
    def forward(self, user, item):
        user_len = len(user)
        item_len = len(item)
        
        AAT_list = [torch.dot(self.user_factors(user)[i,:], self.user_factors(user)[i,:]) for i in range(self.user_factors(user).shape[0])]
        AAT = torch.tensor(AAT_list)
        
        BBT_list = [torch.dot(self.item_factors(item)[i,:], self.item_factors(item)[i,:]) for i in range(self.item_factors(item).shape[0])]
        BBT = torch.tensor(BBT_list)
        
        pred = self.user_biases(user) + self.item_biases(item)
        pred += (self.user_factors(user) * self.item_factors(item)).sum(1, keepdim=True)
        
        # Sd = A*AT
        pred += ((self.drug_sim[user][:,0] * self.user_sim(user).double().reshape(user_len)) - AAT).reshape(user_len,1)
        # St = B*BT
        pred += ((self.target_sim[item][:,0] * self.item_sim(item).double().reshape(item_len)) - BBT).reshape(item_len,1)
        
        return pred.squeeze()

In [10]:
def train_model(train_loader, test_loader, model, num_epochs=100):
#     dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    # need to change dtype to float to reduce resources required
    # maybe just one csim_drug.to(dev)
    dev = torch.device("cpu")
    loss_func = torch.nn.MSELoss()
    
    model.to(dev)
    
    train_losses = []
    test_losses = []
    for epoch in range(0,num_epochs):
        count = 0
        cum_loss = 0.
        for i, (train_batch, label_batch) in enumerate(train_loader):
            count = 1 + i
            # Predict and calculate loss for user factor and bias
            #### WEIGHTS FOR USER SIMILARITY MATRIX ADDED INTO OPTIMISER HERE
            optimizer = torch.optim.SGD([model.user_biases.weight,model.user_factors.weight,
                                         model.user_sim.weight], lr=0.05, weight_decay=1e-3)
            prediction = model(train_batch[:,0].to(dev), train_batch[:,1].to(dev))
            loss = loss_func(prediction, label_batch.to(dev)).float()    
            # Backpropagate
            loss.backward()

            # Update the parameters
            optimizer.step()
            optimizer.zero_grad()

            # predict and calculate loss for item factor and bias
            #### WEIGHTS FOR ITEM SIMILARITY MATRIX ADDED INTO OPTIMISER HERE
            optimizer = torch.optim.SGD([model.item_biases.weight,model.item_factors.weight,
                                    model.item_sim.weight], lr=0.05, weight_decay=1e-3)              
            prediction = model(train_batch[:,0].to(dev), train_batch[:,1].to(dev))
            loss = loss_func(prediction, label_batch.to(dev))
            loss_item = loss.item()
            cum_loss += loss_item


            # Backpropagate
            loss.backward()

            # Update the parameters
            optimizer.step()
            optimizer.zero_grad()
        train_loss = cum_loss/count
        train_losses.append(train_loss)

        cum_loss =0.
        count = 0
        for i, (test_batch, label_batch) in enumerate(test_loader):
            count = 1 + i
            with torch.no_grad():
                prediction = model(test_batch[:,0].to(dev), test_batch[:,1].to(dev))
                loss = loss_func(prediction, label_batch.to(dev))
                cum_loss += loss.item()

        test_loss = cum_loss/count
        test_losses.append(test_loss)
        if epoch % 1 == 0:
            print('epoch: ', epoch ,' avg training loss: ', train_loss, ' avg test loss: ',test_loss)
    return train_losses, test_losses

In [4]:
def full_model(data, csim_drug, img_name, n_factors=100, bs=100, num_epochs=100):
    train, test, drug_dict, target_dict = data_split(data)
    x_train, y_train = data_loader(train, drug_dict, target_dict)
    x_test, y_test = data_loader(test, drug_dict, target_dict)

    train_dataloader = DataLoader(RatingDataset(x_train, y_train), batch_size=bs, shuffle=True)
    test_dataloader = DataLoader(RatingDataset(x_test, y_test), batch_size=bs)
    
    csim_target_np = cos_matrix(train, 'Target', False)
    csim_target = torch.from_numpy(csim_target_np)

    
    model = MatrixFactorization(len(drug_dict), len(target_dict), n_factors, csim_drug, csim_target)

    train_losses, test_losses = train_model(train_dataloader, test_dataloader, model, num_epochs)

    epochs = range(1, num_epochs+1)
    plt.plot(epochs, train_losses, label='train')
    plt.plot(epochs, test_losses, label='test')
    plt.xlabel('epoch')
    plt.ylabel('mse loss')
    plt.legend()
    plt.title(img_name)
#     plt.savefig(img_name)
    plt.show()

In [5]:
Kd_drug_sim_np = np.loadtxt('../sim_matrix/drug_sim.txt', delimiter=',')
Kd_drug_sim = torch.from_numpy(Kd_drug_sim_np)

In [6]:
CHARPROTSET = {
    "A": 1,
    "C": 2,
    "B": 3,
    "E": 4,
    "D": 5,
    "G": 6,
    "F": 7,
    "I": 8,
    "H": 9,
    "K": 10,
    "M": 11,
    "L": 12,
    "O": 13,
    "N": 14,
    "Q": 15,
    "P": 16,
    "S": 17,
    "R": 18,
    "U": 19,
    "T": 20,
    "W": 21,
    "V": 22,
    "Y": 23,
    "X": 24,
    "Z": 25,
}

CHARPROTLEN = 25

CHARISOSMISET = {
    "#": 29,
    "%": 30,
    ")": 31,
    "(": 1,
    "+": 32,
    "-": 33,
    "/": 34,
    ".": 2,
    "1": 35,
    "0": 3,
    "3": 36,
    "2": 4,
    "5": 37,
    "4": 5,
    "7": 38,
    "6": 6,
    "9": 39,
    "8": 7,
    "=": 40,
    "A": 41,
    "@": 8,
    "C": 42,
    "B": 9,
    "E": 43,
    "D": 10,
    "G": 44,
    "F": 11,
    "I": 45,
    "H": 12,
    "K": 46,
    "M": 47,
    "L": 13,
    "O": 48,
    "N": 14,
    "P": 15,
    "S": 49,
    "R": 16,
    "U": 50,
    "T": 17,
    "W": 51,
    "V": 18,
    "Y": 52,
    "[": 53,
    "Z": 19,
    "]": 54,
    "\\": 20,
    "a": 55,
    "c": 56,
    "b": 21,
    "e": 57,
    "d": 22,
    "g": 58,
    "f": 23,
    "i": 59,
    "h": 24,
    "m": 60,
    "l": 25,
    "o": 61,
    "n": 26,
    "s": 62,
    "r": 27,
    "u": 63,
    "t": 28,
    "y": 64,
}

CHARISOSMILEN = 64

CHARATOMSET = [
    "C",
    "N",
    "O",
    "S",
    "F",
    "Si",
    "P",
    "Cl",
    "Br",
    "Mg",
    "Na",
    "Ca",
    "Fe",
    "As",
    "Al",
    "I",
    "B",
    "V",
    "K",
    "Tl",
    "Yb",
    "Sb",
    "Sn",
    "Ag",
    "Pd",
    "Co",
    "Se",
    "Ti",
    "Zn",
    "H",
    "Li",
    "Ge",
    "Cu",
    "Au",
    "Ni",
    "Cd",
    "In",
    "Mn",
    "Zr",
    "Cr",
    "Pt",
    "Hg",
    "Pb",
    "Unknown",
]

CHARATOMLEN = 44


def integer_label_smiles(smiles, max_length=85, isomeric=False):
    """
    Integer encoding for SMILES string sequence.
    Args:
        smiles (str): Simplified molecular-input line-entry system, which is a specification in the form of a line
        notation for describing the structure of chemical species using short ASCII strings.
        max_length (int): Maximum encoding length of input SMILES string. (default: 85)
        isomeric (bool): Whether the input SMILES string includes isomeric information (default: False).
    """
    if not isomeric:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            logging.warning(f"rdkit cannot find this SMILES {smiles}.")
            return None
        smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles), isomericSmiles=True)
    encoding = np.zeros(max_length)
    for idx, letter in enumerate(smiles[:max_length]):
        try:
            encoding[idx] = CHARISOSMISET[letter]
        except KeyError:
            logging.warning(
                f"character {letter} does not exists in default SMILE category encoding, skip and treat as " f"padding."
            )

    return encoding


def integer_label_protein(sequence, max_length=1200):
    """
    Integer encoding for protein string sequence.
    Args:
        sequence (str): Protein string sequence.
        max_length: Maximum encoding length of input protein string. (default: 1200)
    """
    encoding = np.zeros(max_length)
    for idx, letter in enumerate(sequence[:max_length]):
        try:
            encoding[idx] = CHARPROTSET[letter]
        except KeyError:
            logging.warning(
                f"character {letter} does not exists in sequence category encoding, skip and treat as " f"padding."
            )
    return encoding

def cos_matrix(train, col, Drug):
#     sentences = []
    sentences_comb = []
    for i in range(len(train[col])):
        if Drug == True:
            num_rep = integer_label_smiles(train[col][i]).astype(str)
            num_rep_trim = np.trim_zeros(num_rep)
        else:
            num_rep = integer_label_protein(train[col][i]).astype(str)
            num_rep_trim = np.trim_zeros(num_rep)

#         sentences.append(num_rep)

        str_rep = " ".join(num_rep_trim)
        sentences_comb.append(str_rep)
    vectorizer = CountVectorizer().fit_transform(sentences_comb)
    vectors = vectorizer.toarray()
    csim = cosine_similarity(vectors)
    return csim

In [None]:
full_model(data_Kd, Kd_drug_sim, 'Kd', n_factors=30, bs=200, num_epochs=200)

epoch:  0  avg training loss:  28.34370138083294  avg test loss:  18.83685381753104
epoch:  1  avg training loss:  18.08625157471675  avg test loss:  11.842932510375977
epoch:  2  avg training loss:  13.270707974767989  avg test loss:  8.463146441323417
epoch:  3  avg training loss:  10.685226458652764  avg test loss:  6.646307563781738
epoch:  4  avg training loss:  9.141301947794142  avg test loss:  5.581969445092337
epoch:  5  avg training loss:  8.134556414974723  avg test loss:  4.909772838865008
epoch:  6  avg training loss:  7.428863048553467  avg test loss:  4.454018531526838
epoch:  7  avg training loss:  6.895751418581434  avg test loss:  4.125686373029437
epoch:  8  avg training loss:  6.474533287582884  avg test loss:  3.874986226218087
epoch:  9  avg training loss:  6.128798679181725  avg test loss:  3.675952236992972
epoch:  10  avg training loss:  5.839500393837121  avg test loss:  3.5115288870675223
epoch:  11  avg training loss:  5.590548703624944  avg test loss:  3.37