/
accValid.py
53 lines (46 loc) · 2 KB
/
accValid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix
from torch_geometric.data import Data
from torch_geometric.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import pdb
import torch.nn.functional as F
def calcAccuracy(modelClassifier, acc_loader, device, logger, num_neg=1, evaluator=None):
modelClassifier.eval()
preds = []
labels = []
for batchBA,batchLTL,label,node_num,edge_num,start_node,gid,lSys,lSpec in acc_loader:
batchBA=batchBA.to(device)
batchLTL=batchLTL.to(device)
#for linkpred only
#pred = modelClassifier(batchBA,batchLTL)
pred = modelClassifier(batchBA,node_num,edge_num,start_node,gid,False)
preds.append(pred)
labels.append(label)
#for node voting
#preds = torch.cat(preds).squeeze()
#for MLP factor
preds = torch.cat(preds).squeeze().sigmoid()
labels = torch.cat(labels)
acc = ((preds > 0.5) == labels.to(device)).sum().item() / len(labels)
auc = roc_auc_score(labels, preds.cpu().detach().numpy())
predBinary = (preds > 0.5)
#extract tp, tn, fp, fn for precision-recall computation
tn, fp, fn, tp = confusion_matrix(labels, predBinary.cpu().detach().numpy()).ravel()
precision = tp / (tp + fp)
recall = tp / (tp + fn)
#fnPercent to solidify
fnPercent = fn / (tn + fn)
#print(f"ACC: {acc:.4f} AUC: {auc:.4f} PRECISION: {precision:.4f} RECALL: {recall:.4f}")
print(f"ACC: {acc:.4f}")
logger.info("Accuracy: "+str(acc)+ " AUC: "+str(auc)+ " Precision: "+str(precision) + " Recall: "+str(recall) + " False Negatives: "+str(fnPercent))
if num_neg > 1:
pdb.set_trace()
num_pos = labels.sum().item()
pred_pos, pred_neg = preds[:num_pos], preds[num_pos:]
result_dict = evaluator.eval({"y_pred_pos": pred_pos, "y_pred_neg": pred_neg.view(num_pos, -1)})
for key, value in result_dict.items():
print(key, value.mean().item())
return acc