In [None]:
from torch_geometric.data import DataLoader
import torch.nn.functional as F
import torch, data_loader
from model import protein_feature_extractor, predictor, go_feature_extractor, att
from torch.autograd import Variable
from sklearn import metrics
import argparse, utils, os, evaluation_matrix, pickle, warnings, metric
import numpy as np
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)
DATA_PATH = '/home/nas2/biod/zq/ZQ/Protein Function Prediction/graph_go/dataset/'
MODEL_PATH = "/home/nas2/biod/zq/ZQ/Protein Function Prediction/BIBM_journal/model_checkpoint/A_model_copy5/"
from tqdm import tqdm

In [2]:
batch_size = 64
epoches = 50
learning_rate = 0.0001

In [3]:
X_train = data_loader.Protein_Gnn_data(root = DATA_PATH + 'data/seq_features/', chain_list = DATA_PATH + "data/data_splits/train.pdb.txt")
X_test = data_loader.Protein_Gnn_data(root = DATA_PATH + 'data/seq_features/', chain_list = DATA_PATH + "data/data_splits/test.pdb.txt")
X_valid = data_loader.Protein_Gnn_data(root = DATA_PATH + 'data/seq_features/', chain_list = DATA_PATH + "data/data_splits/valid.pdb.txt")
train_loader = DataLoader(X_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(X_test, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(X_valid, batch_size=batch_size, shuffle=True)
go_graph = utils.pickle_load(DATA_PATH+"PDB-cdhit/go_graph.pickle").to(device)



In [4]:
ont = ["all","bp","mf", "cc"]
Protein_feature_extraction = protein_feature_extractor.Protein_feature_extraction()
Label_feature_extraction = go_feature_extractor.Label_feature_extraction(ont[0])
Protein_Label_att = att.Protein_Label_att()
GnnPF_Model = predictor.information_exchange_2()
model = predictor.TotalModel(Protein_feature_extraction, Label_feature_extraction, Protein_Label_att, GnnPF_Model)
model = model.to(device)

In [5]:
optim = torch.optim.Adam(params = model.parameters(),lr = learning_rate, weight_decay=0.0001)
loss_function = torch.nn.BCELoss()

In [6]:
current_epoch = 1
min_val_loss = np.Inf
seq = [15,25,25,45,55,95]

In [None]:
model_dir = MODEL_PATH + '{}/'.format(seq[5])+"model_checkpoint/"

ckp_current = model_dir + "current_checkpoint.pt"
ckp_best = model_dir + "best_model.pt"

if os.path.exists(ckp_current):
    print("Loading model checkpoint @ {}".format(ckp_current))
    model, optim, current_epoch, min_val_loss = utils.load_ckp(ckp_current, model, optim, device = device)
else:
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
print("Training model on epoch {}".format(current_epoch))

In [8]:
def train(train_loader, model, loss_function): 
    train_loss = []
    save_data = []
    for data in tqdm(train_loader):
        model.train()
        esm_rep, seq, contact, pssm, seq_embed = data.x.to(device), data.seq.T.unsqueeze(0).to(device), data.edge_index.to(device), data.pssm.T.unsqueeze(0).to(device), data.seq_embed.to(device)
        esm_rep = esm_rep.float()
        label = data.label.float().to(device)
        batch_idx = data.batch.to(device)
        model_pred = model(seq = seq, edge_index = contact, esm_token = esm_rep, esm_representation=seq_embed, batch = batch_idx, label_relationship = go_graph).to(torch.float32)
        save_data.append((model_pred, model_pred.shape))
        
        class_count = torch.sum(label.to(torch.float32), dim=0)
        class_weights = class_count.sum() / class_count
        class_weights = torch.where(torch.isinf(class_weights), torch.zeros_like(class_weights), class_weights)
        loss = loss_function(model_pred, label)
        loss = (loss * class_weights).mean()
        
        train_loss.append(loss.clone().detach().cpu().numpy())
        optim.zero_grad()  
        loss.backward()   
        optim.step()    
    return train_loss

In [9]:
def valid(es, current_epoch, valid_loader, model, min_val_loss, loss_function):
    model.eval()
    y_pred_all = []
    y_true_all = []
    with torch.no_grad():
        for data in tqdm(valid_loader):
            esm_rep, seq, contact, pssm, seq_embed = data.x.to(device), data.seq.T.unsqueeze(0).to(device), data.edge_index.to(device), data.pssm.T.unsqueeze(0).to(device), data.seq_embed.to(device)
            esm_rep = esm_rep.float()
            label = data.label.float()
            batch_idx = data.batch.to(device)
            model_pred = model(seq = seq, edge_index = contact, esm_token = esm_rep, esm_representation=seq_embed, batch = batch_idx, label_relationship = go_graph).to(torch.float32).detach()
            y_pred_all.append(model_pred)
            y_true_all.append(label)
        y_pred_all = torch.cat(y_pred_all, dim=0).cpu()
        y_true_all = torch.cat(y_true_all, dim=0)
        eval_loss = loss_function(y_pred_all.reshape(-1), y_true_all.reshape(-1)).mean()

    result = metric.count_evaluation_TALE(y_pred_all, y_true_all)

    checkpoint = {
            'epoch': current_epoch,
            'valid_loss_min': eval_loss,
            'state_dict': model.state_dict(),
            'optimizer': optim.state_dict(),
        }
    utils.save_ckp(checkpoint, False, ckp_current, model_dir + "best_model.pt")
    if eval_loss <= min_val_loss:
        utils.save_ckp(checkpoint, True, ckp_current, model_dir + "best_model.pt")
        min_val_loss = eval_loss
        es = 0
    else:
        es += 1
    return es, min_val_loss, eval_loss, result

In [None]:
es = 0
for epoch in range(current_epoch, epoches):
    Train_loss = train(train_loader, model, loss_function)
    es, min_val_loss, eval_loss, result = valid(es, epoch, valid_loader, model, min_val_loss, loss_function)
    print("Epoch:%s TALE_mf_Fmax:%.3f TALE_mf_AUPRC:%.3f TALE_bp_Fmax:%.3f TALE_bp_AUPRC:%.3f TALE_cc_Fmax:%.3f TALE_cc_AUPRC:%.3f" %(epoch, result[0],result[1],result[2], result[3], result[4], result[5]))
    if es > 4:
        print("Counter {} of 5".format(es))
        break