In [1]:
import pickle
import timeit
import os
import random
import glob
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import KFold

from sklearn.metrics import roc_curve, auc

In [3]:

### Check if GPU is available ###
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

### Define k-folds ###
num_kfolds = 5
kfold      = KFold(n_splits=num_kfolds, shuffle=True, random_state=1)  # Updated

In [4]:

def load_tensor(file_name, dtype):
    return [dtype(d).to(device) for d in np.load(file_name + '.npy', allow_pickle=True)]


def load_pickle(file_name):
    with open(file_name, 'rb') as f:
        return pickle.load(f)


def shuffle_dataset(dataset, seed):
    np.random.seed(seed)
    np.random.shuffle(dataset)
    return dataset


def split_dataset(dataset, ratio):
    n = int(ratio * len(dataset))
    dataset_1, dataset_2 = dataset[:n], dataset[n:]
    return dataset_1, dataset_2


def file_ind(index):
    st_ind, in_ind = divmod(index,10)
    return 10*st_ind, in_ind

def load_pickle(file_name):
    with open(file_name, 'rb') as f:
        return pickle.load(f)

def calculate_performace(test_num, pred_y,  labels):
    tp =0
    fp = 0
    tn = 0
    fn = 0
    for index in range(test_num):
        if labels[index] ==1:
            if labels[index] == pred_y[index]:
                tp = tp +1
            else:
                fn = fn + 1
        else:
            if labels[index] == pred_y[index]:
                tn = tn +1
            else:
                fp = fp + 1        
                
                
    if (tp+fn) == 0:
        q9 = float(tn-fp)/(tn+fp + 1e-06)
    if (tn+fp) == 0:
        q9 = float(tp-fn)/(tp+fn + 1e-06)
    if  (tp+fn) != 0 and (tn+fp) !=0:
        q9 = 1- float(np.sqrt(2))*np.sqrt(float(fn*fn)/((tp+fn)*(tp+fn))+float(fp*fp)/((tn+fp)*(tn+fp)))
        
    Q9 = (float)(1+q9)/2
    accuracy = float(tp + tn)/test_num
    precision = float(tp)/(tp+ fp + 1e-06)
    sensitivity = float(tp)/ (tp + fn + 1e-06)
    recall = float(tp)/ (tp + fn + 1e-06)
    specificity = float(tn)/(tn + fp + 1e-06)
    ppv = float(tp)/(tp + fp + 1e-06)
    npv = float(tn)/(tn + fn + 1e-06)
    F1_score = float(2*tp)/(2*tp + fp + fn + 1e-06)
    MCC = float(tp*tn-fp*fn)/(np.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)))
    
    return tp,fp,tn,fn,accuracy, precision, sensitivity, recall, specificity, MCC, F1_score, Q9, ppv, npv


In [5]:

class PPI(nn.Module):
    def __init__(self):
        super(PPI, self).__init__()
        self.embed_fingerprint = nn.Embedding(n_fingerprint, dim)
        self.W_gnn             = nn.ModuleList([nn.Linear(dim, dim)
                                    for _ in range(layer_gnn)])
        self.W1_attention      = nn.Linear(dim, dim)
        self.W2_attention      = nn.Linear(dim, dim)
        self.w                 = nn.Parameter(torch.zeros(dim))
        
        self.W_out             = nn.Linear(2*dim, 2)
        
    def gnn(self, xs1, A1, xs2, A2):
        for i in range(layer_gnn):
            hs1 = torch.relu(self.W_gnn[i](xs1))            
            hs2 = torch.relu(self.W_gnn[i](xs2))
            
            xs1 = torch.matmul(A1, hs1)
            xs2 = torch.matmul(A2, hs2)
        
        return xs1, xs2
    
    def mutual_attention(self, h1, h2):
        x1 = self.W1_attention(h1)
        x2 = self.W2_attention(h2)
        
        m1 = x1.size()[0]
        m2 = x2.size()[0]
        
        c1 = x1.repeat(1,m2).view(m1, m2, dim)
        c2 = x2.repeat(m1,1).view(m1, m2, dim)

        d = torch.tanh(c1 + c2)
        alpha = torch.matmul(d,self.w).view(m1,m2)
        
        b1 = torch.mean(alpha,1)
        p1 = torch.softmax(b1,0)
        s1 = torch.matmul(torch.t(x1),p1).view(-1,1)
        
        b2 = torch.mean(alpha,0)
        p2 = torch.softmax(b2,0)
        s2 = torch.matmul(torch.t(x2),p2).view(-1,1)
        
        return torch.cat((s1,s2),0).view(1,-1), p1, p2
    
    def forward(self, inputs):

        fingerprints1, adjacency1, fingerprints2, adjacency2 = inputs
        
        """Protein vector with GNN."""
        x_fingerprints1        = self.embed_fingerprint(fingerprints1)
        x_fingerprints2        = self.embed_fingerprint(fingerprints2)
        
        x_protein1, x_protein2 = self.gnn(x_fingerprints1, adjacency1, x_fingerprints2, adjacency2)
        
        """Protein vector with mutual-attention."""
        y, p1, p2     = self.mutual_attention(x_protein1, x_protein2)
        z_interaction = self.W_out(y)

        return z_interaction, p1, p2, y
    
    def __call__(self, data, train=True):
        
        inputs, t_interaction = data[:-1], data[-1]
        z_interaction, p1, p2, y = self.forward(inputs)
        
        if train:
            loss = F.cross_entropy(z_interaction, t_interaction)
            return loss
        else:
            z = F.softmax(z_interaction, 1).to('cpu').data[0].numpy()
            t = int(t_interaction.to('cpu').data[0].numpy())
            return z, t, p1, p2, y

In [6]:
class PPIMPredict(nn.Module):
    def __init__(self):
        super(PPIMPredict, self).__init__()
        self.embed_fingerprint = nn.Embedding(nmod_fingerprint, dim)
        self.W_gnn             = nn.ModuleList([nn.Linear(dim, dim)
                                    for _ in range(layer_gnn)])
        self.W1_attention      = nn.Linear(dim, dim)
        self.W2_attention      = nn.Linear(2*dim, dim)
        self.w                 = nn.Parameter(torch.zeros(dim)) #attention between prots
        self.w2                 = nn.Parameter(torch.zeros(dim)) #attention ppi + mod        
        self.W_out             = nn.Linear(2*dim, 2)
        
    def gnn(self, xs1, A1):
        for i in range(layer_gnn):
            hs1 = torch.relu(self.W_gnn[i](xs1))            
            
            xs1 = torch.matmul(A1, hs1)
        
        return xs1
    
    def mutual_attention(self, h1, h2):
        x1 = self.W1_attention(h1)
        x2 = self.W2_attention(h2)
        
        m1 = x1.size()[0]
        m2 = x2.size()[0]
        
        c1 = x1.repeat(1,m2).view(m1, m2, dim)
        c2 = x2.repeat(m1,1).view(m1, m2, dim)

        d = torch.tanh(c1 + c2)
        alpha = torch.matmul(d,self.w).view(m1,m2)
        
        b1 = torch.mean(alpha,1)
        p1 = torch.softmax(b1,0)
        s1 = torch.matmul(torch.t(x1),p1).view(-1,1)
        
        b2 = torch.mean(alpha,0)
        p2 = torch.softmax(b2,0)
        s2 = torch.matmul(torch.t(x2),p2).view(-1,1)
        
        return torch.cat((s1,s2),0).view(1,-1), p1, p2
    
    def forward(self, inputs):

        fingerprints1, adjacency1, prot_embed = inputs
        
        """Protein vector with GNN."""
        x_fingerprints1        = self.embed_fingerprint(fingerprints1)
        # print(x_fingerprints1.shape)
        # print(adjacency1.shape)
        x_protein1 = self.gnn(x_fingerprints1, adjacency1)
        
        """Protein vector with mutual-attention."""
        y, p1, p2     = self.mutual_attention(x_protein1, prot_embed)
        z_interaction = self.W_out(y)

        return z_interaction, p1, p2
    
    def __call__(self, data, train=True):
        
        inputs, t_interaction = data[:-1], data[-1]
        z_interaction, p1, p2 = self.forward(inputs)
        
        if train:
            loss = F.cross_entropy(z_interaction, t_interaction)
            return loss
        else:
            z = F.softmax(z_interaction, 1).to('cpu').data[0].numpy()
            t = int(t_interaction.to('cpu').data[0].numpy())
            return z, t, p1, p2

In [7]:
from tqdm import tqdm
class Trainer(object):
    def __init__(self, model,prot_model):
        self.model = model
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.prot_model = prot_model

    def train(self, dataset):
        
        loss_total = 0
        for data in tqdm(dataset):
            try:
                mod,p1,p2,interaction,family = data
                A1 = np.load(
                        prot_data[prot_data['protein'] == p1].iloc[0]['adj_path'],
                        allow_pickle=True
                    )
                A2 = np.load(
                    prot_data[prot_data['protein'] == p2].iloc[0]['adj_path'],
                    allow_pickle=True
                )

                P1 = np.load(
                    prot_data[prot_data['protein'] == p1].iloc[0]['fp_path'],
                    allow_pickle=True
                )
                P2 = np.load(
                    prot_data[prot_data['protein'] == p2].iloc[0]['fp_path'],
                    allow_pickle=True
                )
                mod_fp = np.load(
                    mod_data[mod_data['inchikey'] == mod].iloc[0]['fp_path'],
                    allow_pickle=True
                )
                mod_adj = np.load(
                    mod_data[mod_data['inchikey'] == mod].iloc[0]['adj_path'],
                    allow_pickle=True
                )
            except Exception as e:
                continue
            protein1 = torch.LongTensor(P1.astype(np.float16))
            protein2 = torch.LongTensor(P2.astype(np.float16))
            adjacency1 = torch.FloatTensor(A1.astype(np.float16))
            adjacency2 = torch.FloatTensor(A2.astype(np.float16))
            mod_fp = torch.LongTensor(mod_fp.astype(np.float16))
            mod_adj = torch.FloatTensor(mod_adj.astype(np.float16))
            interaction = torch.LongTensor([interaction.astype(int)])
            # print("hi")

            comb = (protein1.to(device), adjacency1.to(device), protein2.to(device), adjacency2.to(device), interaction.to(device))
            _,_,_,_,prot_embed = self.prot_model(comb,train=False)

            # print(mod_fp.shape)
            # print(mod_adj.shape)
            
            comb = (mod_fp.to(device),mod_adj.to(device),prot_embed.to(device),interaction.to(device))
            loss = self.model(comb)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            loss_total += loss.to('cpu').data.numpy()
        return loss_total


In [10]:
import torch
import numpy as np
from sklearn.metrics import roc_auc_score, roc_curve, auc
import wandb

class Tester(object):
    def __init__(self, model, prot_model, wandb_project_name=None):
        self.model = model
        self.prot_model = prot_model

        # Initialize wandb if a project name is provided
        if wandb_project_name:
            wandb.init(project=wandb_project_name)
            self.use_wandb = True
        else:
            self.use_wandb = False

    def test(self, dataset, epoch=None):
        sampling = dataset

        z_list, t_list = [], []
        for data in sampling:
            try:
                mod, p1, p2, interaction, _ = data
                A1 = np.load(
                    prot_data[prot_data['protein'] == p1].iloc[0]['adj_path'],
                    allow_pickle=True
                )
                A2 = np.load(
                    prot_data[prot_data['protein'] == p2].iloc[0]['adj_path'],
                    allow_pickle=True
                )
                P1 = np.load(
                    prot_data[prot_data['protein'] == p1].iloc[0]['fp_path'],
                    allow_pickle=True
                )
                P2 = np.load(
                    prot_data[prot_data['protein'] == p2].iloc[0]['fp_path'],
                    allow_pickle=True
                )
                mod_fp = np.load(
                    mod_data[mod_data['inchikey'] == mod].iloc[0]['fp_path'],
                    allow_pickle=True
                )
                mod_adj = np.load(
                    mod_data[mod_data['inchikey'] == mod].iloc[0]['adj_path'],
                    allow_pickle=True
                )
            except Exception as e:
                print(f"failed for {mod}, {p1}, and {p2}: {e}")
                continue

            protein1 = torch.LongTensor(P1.astype(np.float16))
            protein2 = torch.LongTensor(P2.astype(np.float16))
            adjacency1 = torch.FloatTensor(A1.astype(np.float16))
            adjacency2 = torch.FloatTensor(A2.astype(np.float16))
            mod_fp = torch.LongTensor(mod_fp.astype(np.float16))
            mod_adj = torch.FloatTensor(mod_adj.astype(np.float16))
            interaction = torch.LongTensor([interaction.astype(int)])

            comb = (protein1.to(device), adjacency1.to(device), protein2.to(device), adjacency2.to(device), interaction.to(device))
            _, _, _, _, prot_embed = self.prot_model(comb, train=False)

            comb = (mod_fp.to(device), mod_adj.to(device), prot_embed.to(device), interaction.to(device))
            z, _, _, _ = self.model(comb, train=False)
            z_list.append(z)
            t_list.append(interaction)

        score_list, label_list = [], []
        for z in z_list:
            score_list.append(z[1].item())
            label_list.append(torch.argmax(z).item())

        labels = np.array(label_list)
        y_true = np.array([t.item() for t in t_list])
        y_pred = np.array(score_list)

        (
            tp,
            fp,
            tn,
            fn,
            accuracy,
            precision,
            sensitivity,
            recall,
            specificity,
            MCC,
            F1_score,
            Q9,
            ppv,
            npv,
        ) = calculate_performance(len(sampling), labels, y_true)
        roc_auc_val = roc_auc_score(y_true, y_pred)
        fpr, tpr, thresholds = roc_curve(y_true, y_pred)
        auc_val = auc(fpr, tpr)

        # Log results to wandb
        if self.use_wandb:
            wandb.log({
                "Epoch": epoch,
                "Accuracy": accuracy,
                "Precision": precision,
                "Recall": recall,
                "Sensitivity": sensitivity,
                "Specificity": specificity,
                "MCC": MCC,
                "F1 Score": F1_score,
                "ROC AUC": roc_auc_val,
                "AUC": auc_val,
                "TP": tp,
                "FP": fp,
                "TN": tn,
                "FN": fn,
                "PPV": ppv,
                "NPV": npv,
            })

        return (
            accuracy,
            precision,
            recall,
            sensitivity,
            specificity,
            MCC,
            F1_score,
            roc_auc_val,
            auc_val,
            Q9,
            ppv,
            npv,
            tp,
            fp,
            tn,
            fn,
        )

    def result(
        self,
        epoch,
        time,
        loss,
        accuracy,
        precision,
        recall,
        sensitivity,
        specificity,
        MCC,
        F1_score,
        roc_auc_val,
        auc_val,
        Q9,
        ppv,
        npv,
        tp,
        fp,
        tn,
        fn,
        file_name,
    ):
        with open(file_name, "a") as f:
            result = map(
                str,
                [
                    epoch,
                    time,
                    loss,
                    accuracy,
                    precision,
                    recall,
                    sensitivity,
                    specificity,
                    MCC,
                    F1_score,
                    roc_auc_val,
                    auc_val,
                    Q9,
                    ppv,
                    npv,
                    tp,
                    fp,
                    tn,
                    fn,
                ],
            )
            f.write("\t".join(result) + "\n")

    def save_model(self, model, file_name):
        torch.save(model.state_dict(), file_name)

# Note: Make sure to define calculate_performance() and other missing components if not already done.


In [8]:
# get data
prot_data = pd.read_csv("prot_data.csv")
mod_data = pd.read_csv("mod_data.csv")
train_data = pd.read_csv("interaction_data.csv")
examples = np.array(train_data.values.tolist())
# setup folders
prot_fp_folder = "protein_fingerprints"
mod_fp_folder = "mod_fingerprints"
prot_fp_dict = np.load("protein_fingerprints/prot_fingerprint_dict.pickle",allow_pickle=True)
mod_fp_dict = np.load("mod_fingerprints/mod_fingerprint_dict.pickle",allow_pickle=True)


n_fingerprint = len(prot_fp_dict) + 100
nmod_fingerprint = len(mod_fp_dict) + 100


In [None]:

### Hyperparameters ###

radius         = 1
dim        = 20
layer_gnn      = 2
lr             = 1e-3
lr_decay       = 0.5
decay_interval = 1
iteration      = 20


In [13]:
import wandb
import timeit
import torch

# Initialize wandb run
wandb.init(project="promisegat4")  # Replace with your actual entity name

fold_count = 1

for train, test in kfold.split(examples):
    dataset_train = examples[train]  # mod, prot1, prot2, int, int_family
    dataset_test = examples[test]

    prot_model = PPI().to(device)
    prot_model.load_state_dict(torch.load("output/model/one/model_fold_1"))
    start = timeit.default_timer()

    model = PPIMPredict().to(device)
    trainer = Trainer(model, prot_model)
    file_model = "ppim/model/" + "model_fold_" + str(fold_count)
    file_result = "ppim/result/" + "results_fold_" + str(fold_count) + ".txt"

    for epoch in range(iteration):
        loss = trainer.train(dataset_train)
        print(f"finished with loss {loss}")

        # Log training loss and GPU usage to wandb
        gpu_usage = torch.cuda.memory_allocated(device=device) / 1024 ** 3  # in GB
        wandb.log({"epoch": epoch, "loss": loss, "gpu_usage_gb": gpu_usage, "fold": fold_count})

        tester = Tester(model, prot_model)
        (
            accuracy,
            precision,
            recall,
            sensitivity,
            specificity,
            MCC,
            F1_score,
            roc_auc_val,
            auc_val,
            Q9,
            ppv,
            npv,
            tp,
            fp,
            tn,
            fn,
        ) = tester.test(dataset_test, epoch=epoch)

        end = timeit.default_timer()
        time = end - start

        # Log results to wandb
        wandb.log({
            "epoch": epoch,
            "time": time,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "sensitivity": sensitivity,
            "specificity": specificity,
            "MCC": MCC,
            "F1_score": F1_score,
            "ROC_AUC": roc_auc_val,
            "AUC": auc_val,
            "Q9": Q9,
            "PPV": ppv,
            "NPV": npv,
            "TP": tp,
            "FP": fp,
            "TN": tn,
            "FN": fn,
            "fold": fold_count
        })

        tester.result(
            epoch,
            time,
            loss,
            accuracy,
            precision,
            recall,
            sensitivity,
            specificity,
            MCC,
            F1_score,
            roc_auc_val,
            auc_val,
            Q9,
            ppv,
            npv,
            tp,
            fp,
            tn,
            fn,
            file_result,
        )
        tester.save_model(model, file_model)

        print("Epoch: " + str(epoch))
        print("Accuracy: " + str(accuracy))
        print("Precision: " + str(precision))
        print("Recall: " + str(recall))
        print("Sensitivity: " + str(sensitivity))
        print("Specificity: " + str(specificity))
        print("MCC: " + str(MCC))
        print("F1-score: " + str(F1_score))
        print("ROC-AUC: " + str(roc_auc_val))
        print("AUC: " + str(auc_val))
        print("Q9: " + str(Q9))
        print("PPV: " + str(ppv))
        print("NPV: " + str(npv))
        print("TP: " + str(tp))
        print("FP: " + str(fp))
        print("TN: " + str(tn))
        print("FN: " + str(fn))
        print("\n")

        torch.manual_seed(1234)
    
    fold_count += 1

# Finish the wandb run
wandb.finish()


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011161762966852014, max=1.0…

CommError: Run initialization has timed out after 90.0 sec. 
Please refer to the documentation for additional information: https://docs.wandb.ai/guides/track/tracking-faq#initstarterror-error-communicating-with-wandb-process-

In [None]:
nmod_fingerprint

143

In [None]:
# try out struct2grap
prot_model = PPI().to(device)
prot_model.load_state_dict(torch.load("output/model/one/model_fold_1"))

  prot_model.load_state_dict(torch.load("output/model/one/model_fold_1"))


<All keys matched successfully>

In [None]:
fp1 = torch.LongTensor(np.load("protein_fingerprints/fingerprint/E9PNT5.npy",allow_pickle=True).astype(np.float16))
fp2 = torch.LongTensor(np.load("protein_fingerprints/fingerprint/P33681.npy",allow_pickle=True).astype(np.float16))
a1 = torch.FloatTensor(np.load("protein_fingerprints/adj/E9PNT5.npy",allow_pickle=True).astype(np.float16))
a2 = torch.FloatTensor(np.load("protein_fingerprints/adj/P33681.npy",allow_pickle=True).astype(np.float16))
val = torch.LongTensor(1)
comb = (fp1,a1,fp2,a2,val)

In [None]:
z, t, p1, p2, y= prot_model(comb,train=False)

In [None]:
y.shape

torch.Size([1, 40])