forked from Yuejiang-li/info-diff
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ip_model.py
83 lines (72 loc) · 3.24 KB
/
ip_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
import torch
from loader import DataLoader
from model import IPModel
import os
import pickle
from utils.metrics import metrics, tune_thres, get_preds, tune_thres_new
parser = argparse.ArgumentParser()
parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available())
parser.add_argument('--model', type=str, default='IP')
parser.add_argument('--data_dir', type=str, default='./data/final')
parser.add_argument('--emb_file', type=str, default='./data/content_dict.pkl')
parser.add_argument('--model_save_dir', type=str, default='./data/saved_model/ip/')
parser.add_argument('--idx_dict', type=str, default='./data/final/idx_dict.pkl')
parser.add_argument('--followee_count_file', type=str, default='./data/followee_count.pkl')
parser.add_argument('--window_size', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--seed', type=int, default=99)
args = parser.parse_args()
opt = vars(args)
torch.manual_seed(args.seed)
# np.random.seed(args.seed)
# random.seed(args.seed)
with open(opt['idx_dict'], 'rb') as fin:
weibo2embid = pickle.load(fin)
train_batch = DataLoader(os.path.join(opt['data_dir'], 'train.csv'),
opt['batch_size'],
opt,
weibo2embid=weibo2embid,
evaluation=False)
dev_batch = DataLoader(os.path.join(opt['data_dir'], 'dev.csv'),
opt['batch_size'],
opt,
weibo2embid=weibo2embid,
evaluation=True)
test_batch = DataLoader(os.path.join(opt['data_dir'], 'test.csv'),
opt['batch_size'],
opt,
weibo2embid=weibo2embid,
evaluation=True)
model = IPModel(train_batch.retw_prob, opt)
all_probs = []
for i, b in enumerate(dev_batch):
_, probs = model.predict(b, thres=0.0)
all_probs += probs
print('max prob: ', max(all_probs))
_, _, _, _, best_thres = tune_thres_new(dev_batch.gold(), all_probs) # , start=0.0, end=0.002, fold=1001)
print('Best thres (dev): %.8f' % best_thres)
all_probs = []
for i, b in enumerate(test_batch):
_, probs = model.predict(b, thres=0.0)
all_probs += probs
if not os.path.exists(opt['model_save_dir']):
os.mkdir(opt['model_save_dir'])
preds = get_preds(all_probs, best_thres)
accuracy, precision, recall, f1 = metrics(test_batch.gold(), preds)
auc, _, _, _, _ = tune_thres_new(dev_batch.gold(), all_probs, opt)
print('Accuracy: %.4f, Precision: %.4f, Recall: %.4f, F1: %.4f' % (accuracy, precision, recall, f1))
# thres_to_test = [0.0, 0.00001, 0.0005, 0.001]
# for thres in thres_to_test:
# preds = get_preds(all_probs, thres)
# accuracy, precision, recall, f1 = metrics(test_batch.gold(), preds)
#
# print('Accuracy: %.4f, Precision: %.4f, Recall: %.4f, F1: %.4f' % (accuracy, precision, recall, f1))
# print('Tunning on test...')
# print('max prob: ', max(all_probs))
# _, _, _, _, best_thres = tune_thres(test_batch.gold(), all_probs, start=0.0, end=0.002, fold=1001)
# print('Best thres (dev): %.8f' % best_thres)
#
# preds = get_preds(all_probs, best_thres)
# accuracy, precision, recall, f1 = metrics(test_batch.gold(), preds)
# print('Accuracy: %.4f, Precision: %.4f, Recall: %.4f, F1: %.4f' % (accuracy, precision, recall, f1))