In [1]:
#/home/wngys/lab/DeepFold/Code
import numpy as np
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as T
from model import *

import os

In [2]:
# --------------------------------------------------------------------------------------------- #
# 设置显卡
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 5, 6, 7"
device_ids = [0, 1, 2, 3]

In [3]:
# --------------------------------------------------------------------------------------------- #
# 测试：加载训练过的模型接着训练
chkp = torch.load("/home/wngys/lab/DeepFold/new_model/new_model/model_2.pt")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DFold_model = DeepFold(in_channel = 1)
DFold_model = nn.DataParallel(DFold_model, device_ids).to(device)
DFold_model.load_state_dict(chkp["model_param"])

<All keys matched successfully>

In [4]:
# --------------------------------------------------------------------------------------------- #
# 自定义Dataset：加载蛋白质id、距离矩阵、标签
class MatrixLabelDataset(Dataset):
    def __init__(self, protein_id, pair_dir, matrix_dir, transform=None):
        self.protein_id = protein_id
        self.matrix_dir = matrix_dir
        self.transform = transform
        pair_path = pair_dir + protein_id + ".txt"
        self.id_label_list = self.get_id_label_list(pair_path)

    def __len__(self):
        return len(self.id_label_list)

    def __getitem__(self, idx):
        id = self.id_label_list[idx][0]
        label = self.id_label_list[idx][1]
        matrix_path = self.matrix_dir + id + ".npy"
        matrix = torch.from_numpy(np.expand_dims(np.load(matrix_path, allow_pickle=True), 0)).to(torch.float)
        if self.transform:
            matrix = self.transform(matrix)
        return id, matrix, label
    
    def get_id_label_list(self, pair_path):
        id_label_list = []
        with open(pair_path, "r") as f_r:
            while True:
                lines = f_r.readline()
                if not lines:
                    break
                id= lines.split('\t')[0]
                label = lines.split('\t')[1].split("\n")[0]
                id_label_list.append((id, label))
        return id_label_list

In [5]:
# --------------------------------------------------------------------------------------------- #
# 计算loss
def MaxMarginLoss(vectors):
    query_vector = vectors[:1]
    pos_vectors = vectors[1:7]
    neg_vectors = vectors[7:]

    query_vector_6 = query_vector.repeat(6, 1)
    query_vector_57 = query_vector.repeat(57, 1)

    pos_cos_simi = F.cosine_similarity(query_vector_6, pos_vectors, dim=1).view(6, 1)
    neg_cos_simi = F.cosine_similarity(query_vector_57, neg_vectors, dim=1).view(1, 57)

    m = 0.1
    diff = neg_cos_simi - pos_cos_simi + m
    loss = torch.sum(diff[diff>=0])
    # print(loss)
    
    return loss

In [6]:
# --------------------------------------------------------------------------------------------- #
def by_simi(t):
    return t[2]

In [7]:
# --------------------------------------------------------------------------------------------- #
# 模型在训练过程中，测试在验证集上的准确率acc
def ModelOnValidSet():
    DFold_model.eval()
    K = 10
    valid_pair_dir = "/home/wngys/lab/DeepFold/pair/pair_bool_90/"
    valid_matrix_dir = "/home/wngys/lab/DeepFold/distance_matrix_r/distance_matrix_mine_r/"
    
    cntShot = 0
    for idx, protein_id in enumerate(validIDlist):
        # print(protein_id)
        query_matrix_path = valid_matrix_dir + protein_id + ".npy"
        query_matrix = torch.unsqueeze(torch.from_numpy(np.expand_dims(np.load(query_matrix_path, allow_pickle=True), 0)).to(torch.float), 0)
        query_matrix = transform(query_matrix)
        query_matrix = query_matrix.to(device)
        query_vector = DFold_model(query_matrix)
        
        valid_dataset = MatrixLabelDataset(protein_id, valid_pair_dir, valid_matrix_dir, transform)
        valid_dataloader = DataLoader(valid_dataset, BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
        
        id_label_simi_list = []

        for ids, matrices, labels in valid_dataloader:
            bs = len(ids) # batch size
            matrices = matrices.to(device)
            vectors = DFold_model(matrices)
            query_vector_bs = query_vector.repeat(bs, 1)
            cos_simi = F.cosine_similarity(query_vector_bs, vectors, dim=1)

            for i in range(bs):
                id_label_simi_list.append((ids[i], labels[i], cos_simi[i].tolist()))

        id_label_simi_list = sorted(id_label_simi_list, key=by_simi, reverse=True)

        shot = False
        for t in id_label_simi_list[:K]:
            if t[1] == '1':
                shot = True
                break
        if shot:
            cntShot += 1

    acc = cntShot / len(validIDlist)
    print("Valid acc:", acc, "| shot:", cntShot, "| total:", len(validIDlist))
    return (acc, cntShot, len(validIDlist))

In [8]:
# --------------------------------------------------------------------------------------------- #
# 模型在训练过程中，测试在训练集上的准确率acc
def ModelOnTrainSet():
    DFold_model.eval()
    K = 10
    train_pair_dir = "/home/wngys/lab/DeepFold/pair/pair_bool_90/"
    train_matrix_dir = "/home/wngys/lab/DeepFold/distance_matrix_r/distance_matrix_mine_r/"
    
    cntShot = 0
    for idx, protein_id in enumerate(trainIDlist[:100]):
        # print(protein_id)
        query_matrix_path = train_matrix_dir + protein_id + ".npy"
        query_matrix = torch.unsqueeze(torch.from_numpy(np.expand_dims(np.load(query_matrix_path, allow_pickle=True), 0)).to(torch.float), 0)
        query_matrix = transform(query_matrix)
        query_matrix = query_matrix.to(device)
        query_vector = DFold_model(query_matrix)
        
        train_dataset = MatrixLabelDataset(protein_id, train_pair_dir, train_matrix_dir, transform)
        train_dataloader = DataLoader(train_dataset, BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
        
        id_label_simi_list = []

        for ids, matrices, labels in train_dataloader:
            bs = len(ids) # batch size
            matrices = matrices.to(device)
            vectors = DFold_model(matrices)
            query_vector_bs = query_vector.repeat(bs, 1)
            cos_simi = F.cosine_similarity(query_vector_bs, vectors, dim=1)

            for i in range(bs):
                id_label_simi_list.append((ids[i], labels[i], cos_simi[i].tolist()))

        id_label_simi_list = sorted(id_label_simi_list, key=by_simi, reverse=True)

        shot = False
        for t in id_label_simi_list[:K]:
            if t[1] == '1':
                shot = True
                break
        if shot:
            cntShot += 1

    acc = cntShot / len(trainIDlist[:100])
    print("Train acc:", acc, "| shot:", cntShot, "| total:", len(trainIDlist[:100]))
    return (acc, cntShot, len(trainIDlist[:100]))

In [13]:
ModelOnTrainSet()

Train acc: 0.78 | shot: 78 | total: 100


(0.78, 78, 100)

In [16]:
ModelOnValidSet()

Valid acc: 0.73 | shot: 73 | total: 100


(0.73, 73, 100)

In [8]:
# --------------------------------------------------------------------------------------------- #
# 训练一个未被训练过的模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DFold_model = DeepFold(in_channel = 1)
DFold_model = nn.DataParallel(DFold_model, device_ids).to(device)

In [12]:
START_EPOCH = 0
EPOCH = 10
BATCH_SIZE = 64

In [15]:
trainIDlist = np.load("/home/wngys/lab/DeepFold/pair/train.npy", allow_pickle=True)
# random.shuffle(trainIDlist)
trainIDlist = trainIDlist[:400]
# validIDlist = np.load("/home/wngys/lab/DeepFold/pair/valid.npy", allow_pickle=True)
# random.shuffle(validIDlist)
# validIDlist = validIDlist[:100]
validIDlist = chkp['valid_id_list']

In [10]:
train_pair_dir = "/home/wngys/lab/DeepFold/pair/new_train_pair_bool_90/"
train_matrix_dir = "/home/wngys/lab/DeepFold/distance_matrix_r/distance_matrix_mine_r/"
transform = T.Compose([
    T.Resize((256, 256)),
    T.Normalize(mean=[0.0660], std=[0.0467])
])
optimizer = torch.optim.SGD(DFold_model.parameters(), lr = 1e-2, momentum=0.9)

In [12]:
train_acc_list = []
valid_acc_list = []

for epoch in range(START_EPOCH, EPOCH):
    for idx, protein_id in enumerate(trainIDlist):
        # print(protein_id)
        DFold_model.train()
        train_dataset = MatrixLabelDataset(protein_id, train_pair_dir, train_matrix_dir, transform)
        train_dataloader = DataLoader(train_dataset, BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
        for ids, matrices, labels in train_dataloader:
            matrices = matrices.to(device)
            vectors = DFold_model(matrices)
            loss = MaxMarginLoss(vectors)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # print(loss)
        
        if idx % 1 == 0:
            print("Epoch:", epoch, "| idx:", idx, "| id:", protein_id, "| last batch loss:", loss.tolist())
        
        if idx % 200 == 0:
            train_t = ModelOnTrainSet()
            valid_t = ModelOnValidSet()
            train_acc_list.append(train_t)
            valid_acc_list.append(valid_t)

    chkp = {
        "epoch": epoch,
        "model_param": DFold_model.state_dict(),
        "optim_param": optimizer.state_dict(),
        "train_acc": train_acc_list,
        "valid_acc": valid_acc_list,
        "valid_id_list": validIDlist
    }
    torch.save(chkp, "/home/wngys/lab/DeepFold/new_model/new_model/" + f"model_{epoch}.pt")

Epoch: 0 | idx: 0 | id: d5azpa_ | last batch loss: 39.852821350097656
Train acc: 0.61 | shot: 61 | total: 100
Valid acc: 0.6 | shot: 60 | total: 100
Epoch: 0 | idx: 1 | id: d1xrsa_ | last batch loss: 31.024084091186523
Epoch: 0 | idx: 2 | id: d1u8sa2 | last batch loss: 16.157451629638672
Epoch: 0 | idx: 3 | id: d1f81a_ | last batch loss: 15.918352127075195
Epoch: 0 | idx: 4 | id: d3lhla_ | last batch loss: 11.380094528198242
Epoch: 0 | idx: 5 | id: d1vcaa2 | last batch loss: 15.257631301879883
Epoch: 0 | idx: 6 | id: d2cpha1 | last batch loss: 11.957173347473145
Epoch: 0 | idx: 7 | id: d1vaja1 | last batch loss: 43.931297302246094
Epoch: 0 | idx: 8 | id: d2q9qb2 | last batch loss: 43.82976531982422
Epoch: 0 | idx: 9 | id: d1uzka1 | last batch loss: 18.757997512817383
Epoch: 0 | idx: 10 | id: d1ngka_ | last batch loss: 16.36570930480957
Epoch: 0 | idx: 11 | id: d2dj0a1 | last batch loss: 16.466266632080078
Epoch: 0 | idx: 12 | id: d1nkpa1 | last batch loss: 13.507769584655762
Epoch: 0 |

KeyboardInterrupt: 

In [17]:
print(chkp['train_acc'])

[(0.61, 61, 100), (0.68, 68, 100), (0.71, 71, 100), (0.75, 75, 100), (0.71, 71, 100), (0.77, 77, 100)]


In [18]:
print(chkp['valid_acc'])

[(0.63, 63, 100), (0.68, 68, 100), (0.76, 76, 100), (0.71, 71, 100), (0.75, 75, 100), (0.74, 74, 100)]


In [19]:
print(chkp.keys())

dict_keys(['epoch', 'model_param', 'optim_param', 'train_acc', 'valid_acc', 'valid_id_list'])
