/
test.py
91 lines (77 loc) · 3.51 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Adapted from https://github.com/PaddlePaddle/PaddleHelix/blob/dev/apps/drug_target_interaction/sign/test.py
"""
Testing code for Curvature-based Adaptive Graph Neural Networks (CurvAGN).
"""
import os
import time
import math
import argparse
import random
import numpy as np
import paddle
import paddle.nn.functional as F
from pgl.utils.data import Dataloader
from cdataset import ComplexDataset, collate_fn
from cmodel import SIGN
from utils import rmse, mae, sd, pearson
from tqdm import tqdm
paddle.seed(123)
def setup_seed(seed):
# paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
@paddle.no_grad()
def evaluate(model, loader):
model.eval()
y_hat_list = []
y_list = []
for batch_data in loader:
a2a_g, b2a_g, b2b_gl, feats, types, counts, y = batch_data
_, y_hat = model(a2a_g, b2a_g, b2b_gl, types, counts)
y_hat_list += y_hat.tolist()
y_list += y.tolist()
y_hat = np.array(y_hat_list).reshape(-1,)
y = np.array(y_list).reshape(-1,)
return rmse(y, y_hat), mae(y, y_hat), sd(y, y_hat), pearson(y, y_hat)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='/data/wujq/')
parser.add_argument('--dataset', type=str, default='g2016')
parser.add_argument('--model_dir', type=str, default='/data/wujq/output/sign1')
parser.add_argument('--cuda', type=str, default='0')
parser.add_argument('--seed', type=int, default=123)
parser.add_argument("--save_model", action="store_true", default=True)
parser.add_argument("--lambda_", type=float, default=1.75)
parser.add_argument("--feat_drop", type=float, default=0.2)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--weight_decay", type=float, default=0.)
parser.add_argument("--lr_dec_rate", type=float, default=0.5)
parser.add_argument("--dec_step", type=int, default=8000)
parser.add_argument('--stop_epoch', type=int, default=100)
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument("--num_convs", type=int, default=2)
parser.add_argument("--hidden_dim", type=int, default=128)
parser.add_argument("--infeat_dim", type=int, default=36)
parser.add_argument("--dense_dims", type=str, default='128*4,128*2,128')
parser.add_argument('--num_heads', type=int, default=4)
parser.add_argument('--cut_dist', type=float, default=5.)
parser.add_argument('--num_angle', type=int, default=6)
parser.add_argument('--merge_b2b', type=str, default='cat')
parser.add_argument('--merge_b2a', type=str, default='mean')
parser.add_argument('--num_flt', type=int, default=50)
args = parser.parse_args()
args.activation = F.relu
args.dense_dims = [eval(dim) for dim in args.dense_dims.split(',')]
if int(args.cuda) == -1:
paddle.set_device('cpu')
else:
paddle.set_device('gpu:%s' % args.cuda)
tst_complex = ComplexDataset(args.data_dir, "%s_test" % args.dataset, args.cut_dist, args.num_angle,args.num_flt)
tst_loader = Dataloader(tst_complex, args.batch_size, shuffle=False, num_workers=1, collate_fn=collate_fn)
model = SIGN(args)
path = os.path.join(args.model_dir, 'saved_model')
load_state_dict = paddle.load(path)
model.set_state_dict(load_state_dict['model'])
rmse_tst, mae_tst, sd_tst, r_tst = evaluate(model, tst_loader)
print('Test - RMSE: %.6f, MAE: %.6f, SD: %.6f, R: %.6f.\n' % (rmse_tst, mae_tst, sd_tst, r_tst))