In [1]:
#/home/wngys/lab/DeepFold/Code
import numpy as np
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as T
from model import *

import os

In [2]:
# --------------------------------------------------------------------------------------------- #
# 设置显卡
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"
device_ids = [0, 1, 2, 3]

In [3]:
# --------------------------------------------------------------------------------------------- #
# 加载模型的参数进行测试
# model_path = "/home/wngys/lab/DeepFold/new_model/new_model_2/model_17.pt"
# model_path = "/home/wngys/lab/DeepFold/model/model_lossF/best_model.pt"
# model_path = "/home/wngys/lab/DeepFold/new_model/new_model/best_model.pt"
model_path = "/home/wngys/lab/DeepFold/new_model/new_model_3/model_9.pt"

chkp = torch.load(model_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DFold_model = DeepFold(in_channel = 1)
DFold_model = DeepFold(in_channel = 3)
DFold_model = nn.DataParallel(DFold_model, device_ids).to(device)
# DFold_model.load_state_dict(chkp)
DFold_model.load_state_dict(chkp["model_param"])

In [4]:
# --------------------------------------------------------------------------------------------- #
# 自定义Dataset：加载蛋白质id、距离矩阵、标签
class allMatrixLabelDataset(Dataset):
    def __init__(self, matrix_dir, transform=None):
        IDListPath = "/home/wngys/lab/DeepFold/protein_infor/IDArray.npy"
        self.IDlist = np.load(IDListPath, allow_pickle=True).tolist()
        self.matrix_dir = matrix_dir
        self.transform = transform

    def __len__(self):
        return len(self.IDlist)

    def __getitem__(self, idx):
        id = self.IDlist[idx]
        matrix_path = self.matrix_dir + id + ".npy"
        matrix = torch.from_numpy(np.load(matrix_path, allow_pickle=True)).to(torch.float)
        # matrix = torch.from_numpy(np.expand_dims(np.load(matrix_path, allow_pickle=True), 0)).to(torch.float)
        if self.transform:
            matrix = self.transform(matrix)
        return id, matrix

In [5]:
# all_matrix_dir = "/home/wngys/lab/DeepFold/distance_matrix_r/distance_matrix_mine_r/"
all_matrix_dir = "/home/wngys/lab/DeepFold/distance_matrix_r/distance_matrix_mine_r_3/"
transform = T.Compose([
    T.Resize((256, 256)),
    # T.Normalize(mean=[0.0660], std=[0.0467])
    T.Normalize(mean=[0.0068, 0.0003, 2.3069e-05], std=[0.0140, 0.0015, 0.0002])
])

In [6]:
BATCH_SIZE = 64
all_dataset = allMatrixLabelDataset(all_matrix_dir, transform)
all_dataloader = DataLoader(all_dataset, BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

In [7]:
vector_dict = {}

DFold_model.eval()

for ids, matrices in all_dataloader:
    matrices = matrices.to(device)
    with torch.no_grad():
        vectors = DFold_model(matrices)
    for idx in range(len(ids)):
        vector_dict[ids[idx]] = vectors[idx:idx+1]

In [8]:
all_ids = list(vector_dict.keys())
print(len(all_ids))

14274


In [9]:
def by_simi(t):
    return t[1]

In [10]:
K = 10

In [11]:
def getTruePairSetByID(protein_id):
    pair_dir = "/home/wngys/lab/DeepFold/pair/pair_bool_90/"
    pair_path = pair_dir + protein_id + ".txt"
    true_pair_list = []
    with open(pair_path, "r") as f_r:
        while True:
            lines = f_r.readline()
            if not lines:
                break
            id= lines.split('\t')[0]
            label = lines.split('\t')[1].split("\n")[0]
            if label == '1':
                true_pair_list.append(id)
    true_pair_set = set(true_pair_list)
    return true_pair_set

In [12]:
testIDList = np.load("/home/wngys/lab/DeepFold/pair/test.npy", allow_pickle=True).tolist()
random.shuffle(testIDList)
testIDList = testIDList[:100]

In [13]:
validIDList = np.load("/home/wngys/lab/DeepFold/pair/valid.npy", allow_pickle=True).tolist()
random.shuffle(validIDList)
validIDList = validIDList[:100]

In [14]:
# query_list = testIDList
query_list = validIDList


In [15]:
cntShot = 0

for query_id in query_list:
    # print(query_id)
    
    query_vector = vector_dict[query_id]

    id_simi_list = []
    for id in all_ids:
        vector = vector_dict[id]
        cos_simi = F.cosine_similarity(query_vector, vector, dim=1).tolist()[0]
        id_simi_list.append((id, cos_simi))
    id_simi_list = sorted(id_simi_list, key=by_simi, reverse=True)

    top_K_id_list = []
    for idx in range(K+1):
        top_K_id_list.append(id_simi_list[idx][0])
    top_K_id_set = set(top_K_id_list)

    true_pair_id_set = getTruePairSetByID(query_id)
    shot_set = top_K_id_set & true_pair_id_set
    if len(shot_set)>0:
        cntShot += 1
        # print(query_id, "shot.")
    
    # print(top_K_id_set)
    # print(true_pair_id_set)
    # print(shot_set)

In [16]:
acc = cntShot / len(query_list)
print(acc, cntShot, len(query_list))

0.31 31 100
