In [1]:
import datetime
import torch

from torch_geometric.loader import DataLoader
from gnnexplainer import GNNExplainer
from pgexplainer import PGExplainer

from load_data import load_data
from gnn_dig import load_gnn_model
from train_and_evaluate_dig import train, test
from explanation_analysis import generate_explanation_dig, save_edge_masks_dig
from utils import set_seed, save_model, load_model, manipulate_dataset, parse_args

import warnings
warnings.filterwarnings('ignore')

In [2]:
args = parse_args()
# args.fast_mode = True
args.gpu_id = 0
args.dataset = "BA_2motifs"
args.manipulate_ratio = 1.
args.gnn_model = "GIN_3l_BN"
args.explainer_model = "PGExplainer"

if args.fast_mode and args.explainer_model == "PGExplainer":
    args.ex_epoch = 10

if args.dataset.lower() == "Mutagenicity".lower():
    args.save_explanation = True
    
device = torch.device(
    'cuda:%s' % args.gpu_id if torch.cuda.is_available() else 'cpu'
)
set_seed(args.seed)

print(args)

Namespace(dataset='BA_2motifs', ex_epoch=100, ex_lr=0.001, explainer_model='PGExplainer', fast_mode=False, gnn_epoch=200, gnn_hid_dim=128, gnn_lr=0.001, gnn_model='GIN_3l_BN', gpu_id=0, manipulate_ratio=1.0, save_explanation=False, seed=0)


In [3]:
# load dataset
dataset, train_dataset, val_dataset, test_dataset = load_data(args=args)

# manipulate dataset by reversing their label
if args.manipulate_ratio == 0:
    pass
else:
    train_dataset = manipulate_dataset(dataset=train_dataset, ratio=args.manipulate_ratio)
    val_dataset = manipulate_dataset(dataset=val_dataset, ratio=args.manipulate_ratio)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)


Dataset: BA2MotifDataset(1000):
Number of graphs: 1000
Number of node features: 10
Number of edge features: 0
Number of classes: 2
N. train: 700, N. valid: 100, N. test: 200


In [4]:
model = load_gnn_model(args=args, dataset=dataset)
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=args.gnn_lr, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

GIN_3l_BN(
  (conv1): GINConv(nn=Sequential(
    (0): Linear(in_features=10, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
  ))
  (convs): ModuleList(
    (0-1): 2 x GINConv(nn=Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
    ))
  )
  (relu1): Sequential(
    (0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): ReLU()
  )
  (relus): ModuleList(
    (0-1): 2 x Sequential(
      (0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): ReLU()
    )
  )
  (readout): GlobalMeanPool()
  (ffn): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=128, out_features=2, bias=True)
  )
  (dropout): Dropout(p=0

In [5]:
best_val_acc = test_acc = 0
now = datetime.datetime.now()
best_model_filename = f"best_model_{now.hour}_{now.minute}_{now.second}.pth"

for epoch in range(1, args.gnn_epoch+1):
    loss = train(
        model=model, train_loader=train_loader, optimizer=optimizer, device=device
    )
    train_acc = test(model=model, loader=train_loader, device=device)
    val_acc = test(model=model, loader=val_loader, device=device)
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        save_model(model=model, filename=best_model_filename)  # Save the current model
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f} Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}')

# Load the best model and evaluate on the test set
load_model(model=model, filename=best_model_filename)
train_acc = test(model=model, loader=train_loader, device=device)
val_acc = test(model=model, loader=val_loader, device=device)
test_acc = test(model=model, loader=test_loader, device=device)
print(f'Train Acc: {train_acc:.4f}, Valid Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

Epoch: 001, Loss: 0.6936 Train Acc: 0.5029, Val Acc: 0.4900
Epoch: 002, Loss: 0.6939 Train Acc: 0.5029, Val Acc: 0.4900
Epoch: 003, Loss: 0.6942 Train Acc: 0.5029, Val Acc: 0.4900
Epoch: 004, Loss: 0.6931 Train Acc: 0.5857, Val Acc: 0.4800
Epoch: 005, Loss: 0.6926 Train Acc: 0.5029, Val Acc: 0.4900
Epoch: 006, Loss: 0.6913 Train Acc: 0.5029, Val Acc: 0.4900
Epoch: 007, Loss: 0.6919 Train Acc: 0.5729, Val Acc: 0.4700
Epoch: 008, Loss: 0.6926 Train Acc: 0.6257, Val Acc: 0.6000
Epoch: 009, Loss: 0.6881 Train Acc: 0.6400, Val Acc: 0.5300
Epoch: 010, Loss: 0.6870 Train Acc: 0.6114, Val Acc: 0.6100
Epoch: 011, Loss: 0.6828 Train Acc: 0.5043, Val Acc: 0.4900
Epoch: 012, Loss: 0.6773 Train Acc: 0.6543, Val Acc: 0.5800
Epoch: 013, Loss: 0.6677 Train Acc: 0.6814, Val Acc: 0.6100
Epoch: 014, Loss: 0.6341 Train Acc: 0.4971, Val Acc: 0.5100
Epoch: 015, Loss: 0.6323 Train Acc: 0.6657, Val Acc: 0.6000
Epoch: 016, Loss: 0.5713 Train Acc: 0.8243, Val Acc: 0.7900
Epoch: 017, Loss: 0.4766 Train Acc: 0.88

In [6]:
model.eval()
if args.explainer_model == "GNNExplainer":
    explainer = GNNExplainer(model, epochs=args.ex_epoch, lr=args.ex_lr, explain_graph=True)
elif args.explainer_model == "PGExplainer":
    explainer = PGExplainer(
        model, in_channels=args.gnn_hid_dim*2, epochs=args.ex_epoch, lr=args.ex_lr, 
        explain_graph=True, device=device
    ).to(device)
    print("training PGExplainer")
    explainer.train_explanation_network(train_dataset)
else:
    raise ValueError("%s is not a available explainer model" % args.explainer_model)

x_collector, ex_list = generate_explanation_dig(
    args=args, model=model, explainer=explainer, dataset=test_dataset, device=device
)

if args.save_explanation:
    save_edge_masks_dig(args=args, ex_list=ex_list)
    
print(f'Fidelity: {x_collector.fidelity:.4f}\n'
      f'Fidelity_inv: {x_collector.fidelity_inv:.4f}\n'
      f'Sparsity: {x_collector.sparsity:.4f}')
print(f'Train Acc: {train_acc:.4f}, Valid Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

training PGExplainer


100%|██████████| 700/700 [00:03<00:00, 181.80it/s]
100%|██████████| 700/700 [00:07<00:00, 88.30it/s]


Epoch: 0 | Loss: 1283.1314597129822


100%|██████████| 700/700 [00:07<00:00, 88.76it/s]


Epoch: 1 | Loss: 1259.301480859518


100%|██████████| 700/700 [00:07<00:00, 88.86it/s]


Epoch: 2 | Loss: 1247.1258131563663


100%|██████████| 700/700 [00:07<00:00, 88.90it/s]


Epoch: 3 | Loss: 1223.521841943264


100%|██████████| 700/700 [00:07<00:00, 93.21it/s]


Epoch: 4 | Loss: 1206.6200764775276


100%|██████████| 700/700 [00:07<00:00, 95.51it/s]


Epoch: 5 | Loss: 1184.319948360324


100%|██████████| 700/700 [00:07<00:00, 88.57it/s]


Epoch: 6 | Loss: 1168.714027017355


100%|██████████| 700/700 [00:07<00:00, 88.76it/s]


Epoch: 7 | Loss: 1154.4424302726984


100%|██████████| 700/700 [00:07<00:00, 88.86it/s]


Epoch: 8 | Loss: 1139.9920891970396


100%|██████████| 700/700 [00:07<00:00, 89.10it/s]


Epoch: 9 | Loss: 1126.943503215909


100%|██████████| 700/700 [00:07<00:00, 89.05it/s]


Epoch: 10 | Loss: 1112.5348004698753


100%|██████████| 700/700 [00:07<00:00, 89.01it/s]


Epoch: 11 | Loss: 1103.136912599206


100%|██████████| 700/700 [00:07<00:00, 87.87it/s]


Epoch: 12 | Loss: 1092.232505902648


100%|██████████| 700/700 [00:08<00:00, 87.30it/s]


Epoch: 13 | Loss: 1083.4716954827309


100%|██████████| 700/700 [00:08<00:00, 87.42it/s]


Epoch: 14 | Loss: 1075.9405370950699


100%|██████████| 700/700 [00:07<00:00, 87.86it/s]


Epoch: 15 | Loss: 1064.563217997551


100%|██████████| 700/700 [00:07<00:00, 88.21it/s]


Epoch: 16 | Loss: 1060.7282246053219


100%|██████████| 700/700 [00:07<00:00, 88.20it/s]


Epoch: 17 | Loss: 1058.1052361875772


100%|██████████| 700/700 [00:07<00:00, 88.03it/s]


Epoch: 18 | Loss: 1049.612675562501


100%|██████████| 700/700 [00:07<00:00, 88.06it/s]


Epoch: 19 | Loss: 1042.0073565393686


100%|██████████| 700/700 [00:07<00:00, 87.96it/s]


Epoch: 20 | Loss: 1036.8068566918373


100%|██████████| 700/700 [00:07<00:00, 87.85it/s]


Epoch: 21 | Loss: 1032.5767567455769


100%|██████████| 700/700 [00:08<00:00, 83.53it/s]


Epoch: 22 | Loss: 1025.7818055301905


100%|██████████| 700/700 [00:09<00:00, 76.97it/s]


Epoch: 23 | Loss: 1025.0290607959032


100%|██████████| 700/700 [00:09<00:00, 76.96it/s]


Epoch: 24 | Loss: 1021.5198516398668


100%|██████████| 700/700 [00:09<00:00, 77.02it/s]


Epoch: 25 | Loss: 1019.9152511954308


100%|██████████| 700/700 [00:09<00:00, 76.91it/s]


Epoch: 26 | Loss: 1014.2341841608286


100%|██████████| 700/700 [00:09<00:00, 76.90it/s]


Epoch: 27 | Loss: 1010.0257352292538


100%|██████████| 700/700 [00:08<00:00, 82.21it/s]


Epoch: 28 | Loss: 1009.8767112195492


100%|██████████| 700/700 [00:07<00:00, 88.56it/s]


Epoch: 29 | Loss: 1004.942217707634


100%|██████████| 700/700 [00:07<00:00, 87.75it/s]


Epoch: 30 | Loss: 1002.132453083992


100%|██████████| 700/700 [00:07<00:00, 87.89it/s]


Epoch: 31 | Loss: 998.1884600818157


100%|██████████| 700/700 [00:07<00:00, 87.68it/s]


Epoch: 32 | Loss: 994.7130023390055


100%|██████████| 700/700 [00:07<00:00, 87.87it/s]


Epoch: 33 | Loss: 989.3069160431623


100%|██████████| 700/700 [00:07<00:00, 87.87it/s]


Epoch: 34 | Loss: 992.6620364487171


100%|██████████| 700/700 [00:07<00:00, 87.84it/s]


Epoch: 35 | Loss: 986.1289637237787


100%|██████████| 700/700 [00:07<00:00, 87.91it/s]


Epoch: 36 | Loss: 982.8520016521215


100%|██████████| 700/700 [00:08<00:00, 83.38it/s]


Epoch: 37 | Loss: 985.1311707645655


100%|██████████| 700/700 [00:07<00:00, 88.17it/s]


Epoch: 38 | Loss: 978.5982787460089


100%|██████████| 700/700 [00:07<00:00, 88.31it/s]


Epoch: 39 | Loss: 977.1664851903915


100%|██████████| 700/700 [00:07<00:00, 88.17it/s]


Epoch: 40 | Loss: 976.8341814428568


100%|██████████| 700/700 [00:07<00:00, 88.19it/s]


Epoch: 41 | Loss: 975.2293108403683


100%|██████████| 700/700 [00:07<00:00, 88.29it/s]


Epoch: 42 | Loss: 968.5541907995939


100%|██████████| 700/700 [00:07<00:00, 88.04it/s]


Epoch: 43 | Loss: 968.0839033573866


100%|██████████| 700/700 [00:07<00:00, 88.02it/s]


Epoch: 44 | Loss: 966.1660152077675


100%|██████████| 700/700 [00:07<00:00, 87.79it/s]


Epoch: 45 | Loss: 965.8425975441933


100%|██████████| 700/700 [00:07<00:00, 87.95it/s]


Epoch: 46 | Loss: 959.139415949583


100%|██████████| 700/700 [00:07<00:00, 88.00it/s]


Epoch: 47 | Loss: 959.0613195300102


100%|██████████| 700/700 [00:07<00:00, 87.97it/s]


Epoch: 48 | Loss: 957.0862386524677


100%|██████████| 700/700 [00:07<00:00, 87.97it/s]


Epoch: 49 | Loss: 952.1943862885237


100%|██████████| 700/700 [00:07<00:00, 87.96it/s]


Epoch: 50 | Loss: 950.0863481089473


100%|██████████| 700/700 [00:07<00:00, 87.98it/s]


Epoch: 51 | Loss: 950.8759074360132


100%|██████████| 700/700 [00:07<00:00, 88.03it/s]


Epoch: 52 | Loss: 949.1204069256783


100%|██████████| 700/700 [00:07<00:00, 87.97it/s]


Epoch: 53 | Loss: 947.904205262661


100%|██████████| 700/700 [00:07<00:00, 88.05it/s]


Epoch: 54 | Loss: 944.2169529870152


100%|██████████| 700/700 [00:07<00:00, 87.96it/s]


Epoch: 55 | Loss: 943.1463708505034


100%|██████████| 700/700 [00:07<00:00, 88.01it/s]


Epoch: 56 | Loss: 940.7397491708398


100%|██████████| 700/700 [00:07<00:00, 88.65it/s]


Epoch: 57 | Loss: 937.6944184303284


100%|██████████| 700/700 [00:08<00:00, 80.44it/s]


Epoch: 58 | Loss: 939.3224634826183


100%|██████████| 700/700 [00:08<00:00, 85.82it/s]


Epoch: 59 | Loss: 934.601231880486


100%|██████████| 700/700 [00:07<00:00, 88.54it/s]


Epoch: 60 | Loss: 932.4049851149321


100%|██████████| 700/700 [00:07<00:00, 88.53it/s]


Epoch: 61 | Loss: 934.8224903196096


100%|██████████| 700/700 [00:07<00:00, 88.46it/s]


Epoch: 62 | Loss: 930.6418070346117


100%|██████████| 700/700 [00:07<00:00, 88.47it/s]


Epoch: 63 | Loss: 931.3982972279191


100%|██████████| 700/700 [00:07<00:00, 88.96it/s]


Epoch: 64 | Loss: 931.5238173082471


100%|██████████| 700/700 [00:07<00:00, 88.53it/s]


Epoch: 65 | Loss: 928.8769996017218


100%|██████████| 700/700 [00:07<00:00, 88.51it/s]


Epoch: 66 | Loss: 925.2079020291567


100%|██████████| 700/700 [00:07<00:00, 88.48it/s]


Epoch: 67 | Loss: 923.8571067452431


100%|██████████| 700/700 [00:07<00:00, 88.61it/s]


Epoch: 68 | Loss: 922.437731936574


100%|██████████| 700/700 [00:07<00:00, 88.71it/s]


Epoch: 69 | Loss: 920.6525191217661


100%|██████████| 700/700 [00:07<00:00, 88.06it/s]


Epoch: 70 | Loss: 915.8125018700957


100%|██████████| 700/700 [00:07<00:00, 87.89it/s]


Epoch: 71 | Loss: 915.8664417862892


100%|██████████| 700/700 [00:07<00:00, 87.88it/s]


Epoch: 72 | Loss: 914.8006425648928


100%|██████████| 700/700 [00:07<00:00, 88.04it/s]


Epoch: 73 | Loss: 909.6516040861607


100%|██████████| 700/700 [00:07<00:00, 87.82it/s]


Epoch: 74 | Loss: 906.1255554780364


100%|██████████| 700/700 [00:07<00:00, 87.66it/s]


Epoch: 75 | Loss: 905.1694183275104


100%|██████████| 700/700 [00:07<00:00, 87.80it/s]


Epoch: 76 | Loss: 904.6280265897512


100%|██████████| 700/700 [00:08<00:00, 84.76it/s]


Epoch: 77 | Loss: 903.0168069750071


100%|██████████| 700/700 [00:07<00:00, 87.64it/s]


Epoch: 78 | Loss: 899.7734067589045


100%|██████████| 700/700 [00:07<00:00, 88.51it/s]


Epoch: 79 | Loss: 897.9345528781414


100%|██████████| 700/700 [00:07<00:00, 88.43it/s]


Epoch: 80 | Loss: 896.1013332083821


100%|██████████| 700/700 [00:07<00:00, 88.42it/s]


Epoch: 81 | Loss: 894.1942103430629


100%|██████████| 700/700 [00:07<00:00, 88.81it/s]


Epoch: 82 | Loss: 891.5017141997814


100%|██████████| 700/700 [00:07<00:00, 88.94it/s]


Epoch: 83 | Loss: 891.4996096715331


100%|██████████| 700/700 [00:08<00:00, 87.03it/s]


Epoch: 84 | Loss: 887.3829964548349


100%|██████████| 700/700 [00:09<00:00, 76.97it/s]


Epoch: 85 | Loss: 885.8313204273582


100%|██████████| 700/700 [00:08<00:00, 78.65it/s]


Epoch: 86 | Loss: 884.6600716263056


100%|██████████| 700/700 [00:07<00:00, 88.56it/s]


Epoch: 87 | Loss: 883.3944241628051


100%|██████████| 700/700 [00:07<00:00, 88.46it/s]


Epoch: 88 | Loss: 881.5626002922654


100%|██████████| 700/700 [00:07<00:00, 88.55it/s]


Epoch: 89 | Loss: 880.2914931252599


100%|██████████| 700/700 [00:07<00:00, 88.54it/s]


Epoch: 90 | Loss: 878.2455834895372


100%|██████████| 700/700 [00:07<00:00, 88.77it/s]


Epoch: 91 | Loss: 877.3840539976954


100%|██████████| 700/700 [00:07<00:00, 96.97it/s]


Epoch: 92 | Loss: 875.9399463459849


100%|██████████| 700/700 [00:07<00:00, 97.92it/s]


Epoch: 93 | Loss: 874.1190492063761


100%|██████████| 700/700 [00:07<00:00, 92.16it/s]


Epoch: 94 | Loss: 873.1512119919062


100%|██████████| 700/700 [00:07<00:00, 88.69it/s]


Epoch: 95 | Loss: 873.402904137969


100%|██████████| 700/700 [00:07<00:00, 88.69it/s]


Epoch: 96 | Loss: 870.9827398136258


100%|██████████| 700/700 [00:07<00:00, 88.60it/s]


Epoch: 97 | Loss: 870.6921853721142


100%|██████████| 700/700 [00:07<00:00, 88.46it/s]


Epoch: 98 | Loss: 870.151337146759


100%|██████████| 700/700 [00:07<00:00, 88.15it/s]


Epoch: 99 | Loss: 869.7161606401205


Generate explanation for each data: 100%|██████████| 200/200 [00:03<00:00, 55.35it/s]

Fidelity: 0.0029
Fidelity_inv: 0.4645
Sparsity: 0.7678
Train Acc: 0.9943, Valid Acc: 1.0000, Test Acc: 0.0150





In [7]:
# save results
args.fidelity = round(x_collector.fidelity, 4)
args.fidelity_inv = round(x_collector.fidelity_inv, 4)
args.sparsity = round(x_collector.sparsity, 4)
args.train_acc = round(train_acc, 4)
args.val_acc = round(val_acc, 4)
args.test_acc = round(test_acc, 4)
with open('result_record.txt', 'a') as file:
    file.write('\n')
    file.write(str(args))