[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1t0_4BxEJ0XncyYvn_VyEQhxwNMvtSUNx?usp=sharing)

In [1]:
import sys
sys.path.append('../src')

import torch
import torch.nn as nn
from copy import deepcopy
from pathlib import Path
from gsat import GSAT, ExtractorMLP
from utils import get_data_loaders, get_model, set_seed, Criterion, init_metric_dict, load_checkpoint
from trainer import run_one_epoch, update_best_epoch_res, get_viz_idx, visualize_results
from datetime import datetime






In [2]:
dataset_name = 'ogbg_molbace'
# dataset_name = 'mutag'
model_name = 'GIN'

# dataset_name = 'ogbg_molhiv'
# model_name = 'PNA'

method_name = 'GSAT'
cuda_id = -1
seed = 0
set_seed(seed)


In [3]:
data_dir = Path('../data')
device = torch.device(f'cuda:{cuda_id}' if cuda_id >= 0 else 'cpu')

if model_name == 'GIN':
    model_config = {'model_name': 'GIN', 'hidden_size': 64, 'n_layers': 4, 'dropout_p': 0.3, 'use_edge_attr': True}
else:
    assert model_name == 'PNA'
    model_config = {'model_name': 'PNA', 'hidden_size': 80, 'n_layers': 4, 'dropout_p': 0.3, 'use_edge_attr': False, 
                    'atom_encoder': True, 'aggregators': ['mean', 'min', 'max', 'std'], 'scalers': False}

metric_dict = deepcopy(init_metric_dict)
model_dir = data_dir / dataset_name / 'logs' / (datetime.now().strftime("%m_%d_%Y-%H_%M_%S") + '-' + dataset_name + '-' + model_name + '-seed' + str(seed) + '-' + method_name)


In [4]:
batch_size = 128
loaders, test_set, x_dim, edge_attr_dim, num_class, aux_info = get_data_loaders(data_dir, dataset_name, batch_size=batch_size, random_state=seed,
                                                                                splits={'train': 0.8, 'valid': 0.1, 'test': 0.1}, 
                                                                                mutag_x=True if dataset_name == 'mutag' else False)
model_config['deg'] = aux_info['deg']
print ('train/val/test split:')
print (len(loaders['train'])*batch_size,len(loaders['valid'])*batch_size,len(loaders['test'])*batch_size)



[INFO] Using train splits!
[INFO] Calculating degree...
train/val/test split:
896 384 384




In [5]:
clf = get_model(x_dim, edge_attr_dim, num_class, aux_info['multi_label'], model_config, device)
extractor = ExtractorMLP(model_config['hidden_size'], learn_edge_att=False).to(device)
optimizer = torch.optim.Adam(list(extractor.parameters()) + list(clf.parameters()), lr=1e-3, weight_decay=3.0e-6)
criterion = Criterion(num_class, aux_info['multi_label'])
gsat = GSAT(clf, extractor, criterion, optimizer, learn_edge_att=False, final_r=0.7)

[INFO] Using multi_label: False


In [6]:
for epoch in range(80):
    train_res = run_one_epoch(gsat, loaders['train'], epoch, 'train', dataset_name, seed, model_config['use_edge_attr'], aux_info['multi_label'])
    valid_res = run_one_epoch(gsat, loaders['valid'], epoch, 'valid', dataset_name, seed, model_config['use_edge_attr'], aux_info['multi_label'])
    test_res = run_one_epoch(gsat, loaders['test'], epoch, 'test', dataset_name, seed, model_config['use_edge_attr'], aux_info['multi_label'])
    
    metric_dict = update_best_epoch_res(gsat, train_res, valid_res, test_res, metric_dict, dataset_name, epoch, model_dir)
    print(f'[Seed {seed}, Epoch: {epoch}]: Best Epoch: {metric_dict["metric/best_clf_epoch"]}, '
          f'Best Val Pred ACC/ROC: {metric_dict["metric/best_clf_valid"]:.3f}, Best Test Pred ACC/ROC: {metric_dict["metric/best_clf_test"]:.3f}, '
          f'Best Test X AUROC: {metric_dict["metric/best_x_roc_test"]:.3f}')
    print('='*50)
    print('='*50)

[Seed 0, Epoch: 0]: gsat_train finished, loss: 2.177, pred: 1.496, info: 0.681, clf_acc: 0.503, clf_roc: 0.474, att_roc: 0.000: 100%|██████████| 7/7 [00:01<00:00,  4.11it/s]
[Seed 0, Epoch: 0]: gsat_valid finished, loss: 1.211, pred: 0.763, info: 0.447, clf_acc: 0.631, clf_roc: 0.348, att_roc: 0.000: 100%|██████████| 3/3 [00:00<00:00,  9.63it/s]
[Seed 0, Epoch: 0]: gsat_test  finished, loss: 1.211, pred: 0.763, info: 0.447, clf_acc: 0.631, clf_roc: 0.348, att_roc: 0.000: 100%|██████████| 3/3 [00:00<00:00, 15.35it/s]


[Seed 0, Epoch: 0]: Best Epoch: 0, Best Val Pred ACC/ROC: 0.000, Best Test Pred ACC/ROC: 0.000, Best Test X AUROC: 0.000


[Seed 0, Epoch: 1]: gsat_train finished, loss: 1.571, pred: 0.934, info: 0.636, clf_acc: 0.484, clf_roc: 0.490, att_roc: 0.000: 100%|██████████| 7/7 [00:01<00:00,  5.41it/s]
[Seed 0, Epoch: 1]: gsat_valid finished, loss: 1.078, pred: 0.684, info: 0.394, clf_acc: 0.534, clf_roc: 0.558, att_roc: 0.000: 100%|██████████| 3/3 [00:00<00:00, 14.30it/s]
[Seed 0, Epoch: 1]: gsat_test  finished, loss: 1.078, pred: 0.684, info: 0.394, clf_acc: 0.534, clf_roc: 0.558, att_roc: 0.000: 100%|██████████| 3/3 [00:00<00:00, 14.87it/s]


[Seed 0, Epoch: 1]: Best Epoch: 0, Best Val Pred ACC/ROC: 0.000, Best Test Pred ACC/ROC: 0.000, Best Test X AUROC: 0.000


[Seed 0, Epoch: 2]: gsat_train finished, loss: 1.356, pred: 0.760, info: 0.596, clf_acc: 0.560, clf_roc: 0.503, att_roc: 0.000: 100%|██████████| 7/7 [00:01<00:00,  5.34it/s]
[Seed 0, Epoch: 2]: gsat_valid finished, loss: 0.992, pred: 0.672, info: 0.320, clf_acc: 0.631, clf_roc: 0.484, att_roc: 0.000: 100%|██████████| 3/3 [00:00<00:00, 13.36it/s]
[Seed 0, Epoch: 2]: gsat_test  finished, loss: 0.992, pred: 0.672, info: 0.320, clf_acc: 0.631, clf_roc: 0.484, att_roc: 0.000: 100%|██████████| 3/3 [00:00<00:00, 11.62it/s]


[Seed 0, Epoch: 2]: Best Epoch: 0, Best Val Pred ACC/ROC: 0.000, Best Test Pred ACC/ROC: 0.000, Best Test X AUROC: 0.000


[Seed 0, Epoch: 3]: gsat_train........., loss: 1.298, pred: 0.731, info: 0.567, clf_acc: 0.547:  29%|██▊       | 2/7 [00:00<00:01,  2.82it/s]


KeyboardInterrupt: 

In [None]:
best_epoch = metric_dict['metric/best_clf_epoch']
load_checkpoint(gsat, model_dir, model_name=f'gsat_epoch_{best_epoch}', map_location=device)

## get pred node label

In [95]:
from torch_geometric.loader import DataLoader
from utils import process_data, get_preds, save_checkpoint
from torch_geometric.explain.metric import groundtruth_metrics,fidelity
from torchmetrics import AUROC
from torchmetrics.classification import BinaryF1Score
from tqdm import tqdm
from trainer import eval_one_batch


data_loader_single = DataLoader(loaders['test'].dataset, batch_size=1, shuffle=False)
if num_class>=3:
    auc_func = AUROC(task='multiclass',num_classes=num_class).to(gsat.device)
else:
    auc_func = AUROC(task='binary').to(gsat.device)
    


In [107]:
topK=5
gt = []
y_hat_preds = []
y_hat_removal_preds = []
for batch in tqdm(data_loader_single):
        batch = batch.to(gsat.device)
        att, loss_dict, clf_logits = eval_one_batch(gsat,batch,500)
        if num_class<=2:
            y = batch.y.view(-1,)
            y_hat = torch.sigmoid(clf_logits).detach().cpu().view(-1,)
            new_edge_index,new_edge_attr = remove_topk_edges(att,batch,topK)
            batch.edge_index = new_edge_index
            batch.edge_attr = new_edge_attr
            att, loss_dict, clf_logits = eval_one_batch(gsat,batch,500)
            y_hat_removal = torch.sigmoid(clf_logits).detach().cpu().view(-1,)
            print (y,y_hat,y_hat_removal)
            gt.append(y)
            y_hat_preds.append(y_hat)
            y_hat_removal_preds.append(y_hat_removal)
        else:
            y = batch.y.view(-1,)
            y_hat = torch.softmax(clf_logits,dim=-1).detach().cpu().view(1,-1)
            new_edge_index,new_edge_attr = remove_topk_edges(att,batch,topK)
            batch.edge_index = new_edge_index
            batch.edge_attr = new_edge_attr
            att, loss_dict, clf_logits = eval_one_batch(gsat,batch,500)
            y_hat_removal = torch.softmax(clf_logits,dim=-1).detach().cpu().view(1,-1)
            # print (y,y_hat,y_hat_removal)
            gt.append(y)
            y_hat_preds.append(y_hat)
            y_hat_removal_preds.append(y_hat_removal)
            

gt = torch.cat(gt,dim=0)
y_hat_preds = torch.cat(y_hat_preds,dim=0)
y_hat_removal_preds = torch.cat(y_hat_removal_preds,dim=0)
print('Groundtruth Metrics')
auc1 = auc_func(y_hat_preds,gt)
auc2 = auc_func(y_hat_removal_preds,gt)
print (auc1,auc2)

    
    

  5%|▍         | 17/363 [00:00<00:04, 78.95it/s]

tensor([1]) tensor([0.7012]) tensor([0.5193])
tensor([1]) tensor([0.9463]) tensor([0.8453])
tensor([0]) tensor([0.4792]) tensor([0.2523])
tensor([0]) tensor([0.1554]) tensor([0.2044])
tensor([1]) tensor([0.1905]) tensor([0.1584])
tensor([0]) tensor([0.2276]) tensor([0.1314])
tensor([1]) tensor([0.7034]) tensor([0.3056])
tensor([1]) tensor([0.9606]) tensor([0.8395])
tensor([0]) tensor([0.1266]) tensor([0.1781])
tensor([0]) tensor([0.0520]) tensor([0.0693])
tensor([0]) tensor([0.9561]) tensor([0.8598])
tensor([1]) tensor([0.7075]) tensor([0.3423])
tensor([0]) tensor([0.1693]) tensor([0.2359])
tensor([1]) tensor([0.2290]) tensor([0.0899])
tensor([1]) tensor([0.8704]) tensor([0.7149])
tensor([1]) tensor([0.7045]) tensor([0.5576])
tensor([1]) tensor([0.9377]) tensor([0.6949])


  7%|▋         | 25/363 [00:00<00:05, 60.02it/s]

tensor([0]) tensor([0.8888]) tensor([0.6602])
tensor([0]) tensor([0.5075]) tensor([0.6993])
tensor([0]) tensor([0.1807]) tensor([0.5372])
tensor([0]) tensor([0.2342]) tensor([0.1381])
tensor([0]) tensor([0.0709]) tensor([0.1053])
tensor([1]) tensor([0.8102]) tensor([0.6903])
tensor([1]) tensor([0.9295]) tensor([0.7868])
tensor([0]) tensor([0.0296]) tensor([0.0142])
tensor([0]) tensor([0.1679]) tensor([0.5062])
tensor([0]) tensor([0.3740]) tensor([0.1054])


 11%|█         | 39/363 [00:00<00:05, 60.06it/s]

tensor([0]) tensor([0.9416]) tensor([0.7278])
tensor([0]) tensor([0.3432]) tensor([0.8141])
tensor([0]) tensor([0.3901]) tensor([0.2845])
tensor([0]) tensor([0.4255]) tensor([0.2118])
tensor([1]) tensor([0.9675]) tensor([0.8407])
tensor([1]) tensor([0.2181]) tensor([0.1415])
tensor([0]) tensor([0.1640]) tensor([0.0893])
tensor([1]) tensor([0.6586]) tensor([0.4199])
tensor([0]) tensor([0.5953]) tensor([0.5978])
tensor([0]) tensor([0.0501]) tensor([0.1137])
tensor([1]) tensor([0.9198]) tensor([0.6681])
tensor([1]) tensor([0.3146]) tensor([0.1425])


 15%|█▌        | 55/363 [00:00<00:04, 68.27it/s]

tensor([0]) tensor([0.0639]) tensor([0.0488])
tensor([1]) tensor([0.9224]) tensor([0.7968])
tensor([0]) tensor([0.1359]) tensor([0.2207])
tensor([0]) tensor([0.0982]) tensor([0.2513])
tensor([0]) tensor([0.2535]) tensor([0.0989])
tensor([0]) tensor([0.1475]) tensor([0.1993])
tensor([1]) tensor([0.7881]) tensor([0.6532])
tensor([0]) tensor([0.0853]) tensor([0.4381])
tensor([0]) tensor([0.1843]) tensor([0.1457])
tensor([0]) tensor([0.2267]) tensor([0.2286])
tensor([0]) tensor([0.1094]) tensor([0.0733])
tensor([1]) tensor([0.4826]) tensor([0.6462])
tensor([0]) tensor([0.2181]) tensor([0.0806])
tensor([0]) tensor([0.7356]) tensor([0.6804])
tensor([0]) tensor([0.3831]) tensor([0.4043])
tensor([0]) tensor([0.0228]) tensor([0.0525])


 19%|█▉        | 69/363 [00:01<00:04, 66.94it/s]

tensor([1]) tensor([0.8240]) tensor([0.5057])
tensor([0]) tensor([0.1844]) tensor([0.6918])
tensor([0]) tensor([0.1004]) tensor([0.5656])
tensor([1]) tensor([0.4873]) tensor([0.4266])
tensor([1]) tensor([0.6603]) tensor([0.6408])
tensor([0]) tensor([0.1206]) tensor([0.1081])
tensor([1]) tensor([0.8228]) tensor([0.6847])
tensor([0]) tensor([0.1657]) tensor([0.1482])
tensor([0]) tensor([0.2394]) tensor([0.1442])
tensor([1]) tensor([0.8550]) tensor([0.5466])
tensor([0]) tensor([0.8230]) tensor([0.5192])
tensor([1]) tensor([0.1783]) tensor([0.7089])
tensor([1]) tensor([0.4789]) tensor([0.0897])
tensor([1]) tensor([0.8474]) tensor([0.5646])


 23%|██▎       | 84/363 [00:01<00:04, 69.36it/s]

tensor([0]) tensor([0.5221]) tensor([0.2082])
tensor([0]) tensor([0.0152]) tensor([0.0166])
tensor([0]) tensor([0.1922]) tensor([0.2527])
tensor([1]) tensor([0.9083]) tensor([0.8504])
tensor([0]) tensor([0.0119]) tensor([0.0155])
tensor([1]) tensor([0.6171]) tensor([0.2021])
tensor([1]) tensor([0.8622]) tensor([0.5865])
tensor([0]) tensor([0.2797]) tensor([0.2555])
tensor([0]) tensor([0.0378]) tensor([0.1110])
tensor([0]) tensor([0.0937]) tensor([0.3189])
tensor([0]) tensor([0.1866]) tensor([0.0845])
tensor([1]) tensor([0.6201]) tensor([0.4296])
tensor([0]) tensor([0.9015]) tensor([0.6488])
tensor([1]) tensor([0.9094]) tensor([0.6732])
tensor([0]) tensor([0.1630]) tensor([0.1956])


 25%|██▌       | 92/363 [00:01<00:03, 70.55it/s]

tensor([0]) tensor([0.0832]) tensor([0.1342])
tensor([1]) tensor([0.8773]) tensor([0.8045])
tensor([1]) tensor([0.9450]) tensor([0.8146])
tensor([0]) tensor([0.0879]) tensor([0.2024])
tensor([0]) tensor([0.8300]) tensor([0.5941])
tensor([1]) tensor([0.8454]) tensor([0.7220])
tensor([0]) tensor([0.9414]) tensor([0.7018])
tensor([1]) tensor([0.5662]) tensor([0.5362])
tensor([0]) tensor([0.2898]) tensor([0.1347])
tensor([1]) tensor([0.8372]) tensor([0.4899])
tensor([0]) tensor([0.3023]) tensor([0.3133])
tensor([0]) tensor([0.0877]) tensor([0.5489])
tensor([0]) tensor([0.3876]) tensor([0.4390])
tensor([1]) tensor([0.3737]) tensor([0.3698])


 29%|██▉       | 107/363 [00:01<00:03, 64.56it/s]

tensor([0]) tensor([0.0391]) tensor([0.0334])
tensor([0]) tensor([0.0505]) tensor([0.0302])
tensor([1]) tensor([0.2584]) tensor([0.3983])
tensor([1]) tensor([0.7491]) tensor([0.3512])
tensor([0]) tensor([0.1090]) tensor([0.1920])
tensor([0]) tensor([0.3133]) tensor([0.1535])
tensor([0]) tensor([0.4362]) tensor([0.3424])
tensor([0]) tensor([0.1667]) tensor([0.1034])
tensor([0]) tensor([0.1387]) tensor([0.1067])
tensor([0]) tensor([0.2150]) tensor([0.1277])
tensor([0]) tensor([0.2161]) tensor([0.6021])
tensor([1]) tensor([0.5140]) tensor([0.3113])
tensor([0]) tensor([0.0644]) tensor([0.1409])


 33%|███▎      | 121/363 [00:01<00:04, 57.02it/s]

tensor([0]) tensor([0.4610]) tensor([0.1406])
tensor([1]) tensor([0.7963]) tensor([0.5238])
tensor([1]) tensor([0.0842]) tensor([0.0444])
tensor([0]) tensor([0.0690]) tensor([0.2478])
tensor([0]) tensor([0.1440]) tensor([0.1109])
tensor([0]) tensor([0.1158]) tensor([0.4260])
tensor([1]) tensor([0.2880]) tensor([0.2452])
tensor([1]) tensor([0.8885]) tensor([0.6615])
tensor([0]) tensor([0.5720]) tensor([0.2337])
tensor([1]) tensor([0.5463]) tensor([0.3572])


 37%|███▋      | 135/363 [00:02<00:03, 60.56it/s]

tensor([1]) tensor([0.2979]) tensor([0.0732])
tensor([0]) tensor([0.5401]) tensor([0.1832])
tensor([0]) tensor([0.0939]) tensor([0.1936])
tensor([1]) tensor([0.7283]) tensor([0.4798])
tensor([0]) tensor([0.3185]) tensor([0.1891])
tensor([0]) tensor([0.2580]) tensor([0.0753])
tensor([0]) tensor([0.0999]) tensor([0.3145])
tensor([1]) tensor([0.0616]) tensor([0.0221])
tensor([0]) tensor([0.5397]) tensor([0.3007])
tensor([0]) tensor([0.7772]) tensor([0.4165])
tensor([0]) tensor([0.2560]) tensor([0.2337])
tensor([0]) tensor([0.6000]) tensor([0.3354])
tensor([1]) tensor([0.9091]) tensor([0.6345])
tensor([1]) tensor([0.7984]) tensor([0.4500])


 41%|████      | 149/363 [00:02<00:03, 62.25it/s]

tensor([1]) tensor([0.4313]) tensor([0.2268])
tensor([1]) tensor([0.9402]) tensor([0.7264])
tensor([0]) tensor([0.1291]) tensor([0.6361])
tensor([0]) tensor([0.1981]) tensor([0.5450])
tensor([0]) tensor([0.0760]) tensor([0.1857])
tensor([1]) tensor([0.8189]) tensor([0.5712])
tensor([1]) tensor([0.9353]) tensor([0.7633])
tensor([0]) tensor([0.3616]) tensor([0.1365])
tensor([1]) tensor([0.9112]) tensor([0.6464])
tensor([0]) tensor([0.4050]) tensor([0.2796])
tensor([0]) tensor([0.0620]) tensor([0.0975])
tensor([1]) tensor([0.2411]) tensor([0.4385])
tensor([0]) tensor([0.0174]) tensor([0.0389])
tensor([0]) tensor([0.0855]) tensor([0.1140])


 45%|████▍     | 163/363 [00:02<00:03, 64.09it/s]

tensor([0]) tensor([0.1581]) tensor([0.2187])
tensor([1]) tensor([0.8608]) tensor([0.5250])
tensor([0]) tensor([0.0527]) tensor([0.0339])
tensor([0]) tensor([0.6417]) tensor([0.3388])
tensor([0]) tensor([0.0384]) tensor([0.3249])
tensor([0]) tensor([0.3845]) tensor([0.3008])
tensor([1]) tensor([0.0457]) tensor([0.2915])
tensor([0]) tensor([0.1737]) tensor([0.1435])
tensor([0]) tensor([0.0499]) tensor([0.0914])
tensor([0]) tensor([0.0524]) tensor([0.0811])
tensor([0]) tensor([0.5710]) tensor([0.3149])
tensor([0]) tensor([0.2149]) tensor([0.0908])
tensor([0]) tensor([0.0955]) tensor([0.5588])
tensor([1]) tensor([0.3557]) tensor([0.2976])


 49%|████▉     | 178/363 [00:02<00:02, 67.72it/s]

tensor([1]) tensor([0.7597]) tensor([0.5620])
tensor([0]) tensor([0.7031]) tensor([0.2983])
tensor([0]) tensor([0.1185]) tensor([0.2015])
tensor([0]) tensor([0.9009]) tensor([0.6961])
tensor([0]) tensor([0.2771]) tensor([0.2374])
tensor([0]) tensor([0.2282]) tensor([0.0574])
tensor([0]) tensor([0.2752]) tensor([0.7634])
tensor([0]) tensor([0.1327]) tensor([0.2494])
tensor([0]) tensor([0.2890]) tensor([0.1250])
tensor([1]) tensor([0.9357]) tensor([0.7439])
tensor([0]) tensor([0.8021]) tensor([0.4613])
tensor([1]) tensor([0.4856]) tensor([0.2498])
tensor([0]) tensor([0.0679]) tensor([0.1220])
tensor([1]) tensor([0.6809]) tensor([0.5662])
tensor([0]) tensor([0.0902]) tensor([0.1171])


 51%|█████     | 186/363 [00:02<00:02, 69.12it/s]

tensor([0]) tensor([0.9132]) tensor([0.6107])
tensor([0]) tensor([0.3926]) tensor([0.6387])
tensor([0]) tensor([0.1279]) tensor([0.0955])
tensor([0]) tensor([0.8440]) tensor([0.5523])
tensor([0]) tensor([0.1612]) tensor([0.1194])
tensor([0]) tensor([0.1067]) tensor([0.3689])
tensor([0]) tensor([0.2751]) tensor([0.2165])
tensor([0]) tensor([0.2635]) tensor([0.3066])
tensor([1]) tensor([0.6594]) tensor([0.2429])
tensor([0]) tensor([0.8744]) tensor([0.5160])
tensor([0]) tensor([0.1362]) tensor([0.6530])
tensor([0]) tensor([0.6677]) tensor([0.4473])
tensor([0]) tensor([0.4391]) tensor([0.2294])
tensor([1]) tensor([0.9204]) tensor([0.6745])


 55%|█████▌    | 200/363 [00:03<00:02, 66.38it/s]

tensor([1]) tensor([0.9117]) tensor([0.6649])
tensor([0]) tensor([0.3298]) tensor([0.4184])
tensor([0]) tensor([0.2706]) tensor([0.1440])
tensor([0]) tensor([0.1458]) tensor([0.1424])
tensor([0]) tensor([0.0175]) tensor([0.0233])
tensor([0]) tensor([0.1719]) tensor([0.2091])
tensor([0]) tensor([0.0557]) tensor([0.0875])
tensor([1]) tensor([0.8058]) tensor([0.6069])
tensor([0]) tensor([0.2673]) tensor([0.1276])
tensor([0]) tensor([0.4436]) tensor([0.1674])
tensor([1]) tensor([0.7841]) tensor([0.5852])
tensor([0]) tensor([0.3520]) tensor([0.3824])
tensor([0]) tensor([0.1076]) tensor([0.2586])
tensor([1]) tensor([0.7747]) tensor([0.6461])


 59%|█████▉    | 215/363 [00:03<00:02, 70.71it/s]

tensor([1]) tensor([0.1179]) tensor([0.6094])
tensor([0]) tensor([0.3784]) tensor([0.2701])
tensor([0]) tensor([0.0729]) tensor([0.2696])
tensor([0]) tensor([0.0785]) tensor([0.0667])
tensor([1]) tensor([0.5999]) tensor([0.3909])
tensor([0]) tensor([0.0878]) tensor([0.0828])
tensor([1]) tensor([0.7137]) tensor([0.5519])
tensor([0]) tensor([0.0413]) tensor([0.0706])
tensor([1]) tensor([0.8645]) tensor([0.5413])
tensor([1]) tensor([0.8201]) tensor([0.7096])
tensor([1]) tensor([0.7353]) tensor([0.3298])
tensor([1]) tensor([0.4683]) tensor([0.2282])
tensor([1]) tensor([0.0545]) tensor([0.0603])
tensor([0]) tensor([0.4451]) tensor([0.2494])


 63%|██████▎   | 230/363 [00:03<00:02, 58.95it/s]

tensor([0]) tensor([0.0611]) tensor([0.0598])
tensor([0]) tensor([0.3970]) tensor([0.2292])
tensor([0]) tensor([0.4141]) tensor([0.2140])
tensor([0]) tensor([0.2019]) tensor([0.0883])
tensor([0]) tensor([0.3960]) tensor([0.4058])
tensor([1]) tensor([0.8353]) tensor([0.7216])
tensor([0]) tensor([0.0428]) tensor([0.0605])
tensor([1]) tensor([0.5771]) tensor([0.2113])
tensor([0]) tensor([0.1024]) tensor([0.2674])
tensor([0]) tensor([0.1297]) tensor([0.6415])
tensor([0]) tensor([0.2805]) tensor([0.1784])


 68%|██████▊   | 246/363 [00:03<00:01, 66.84it/s]

tensor([0]) tensor([0.4951]) tensor([0.3498])
tensor([1]) tensor([0.8062]) tensor([0.4654])
tensor([1]) tensor([0.9406]) tensor([0.7658])
tensor([1]) tensor([0.8466]) tensor([0.7184])
tensor([0]) tensor([0.1098]) tensor([0.2157])
tensor([1]) tensor([0.5143]) tensor([0.2440])
tensor([0]) tensor([0.0638]) tensor([0.1387])
tensor([0]) tensor([0.1075]) tensor([0.1406])
tensor([0]) tensor([0.1401]) tensor([0.4766])
tensor([1]) tensor([0.3453]) tensor([0.0881])
tensor([0]) tensor([0.3790]) tensor([0.4333])
tensor([1]) tensor([0.2077]) tensor([0.2697])
tensor([1]) tensor([0.8774]) tensor([0.7577])
tensor([0]) tensor([0.3650]) tensor([0.7466])
tensor([1]) tensor([0.6644]) tensor([0.3464])
tensor([0]) tensor([0.2167]) tensor([0.1119])


 70%|██████▉   | 253/363 [00:03<00:01, 66.25it/s]

tensor([0]) tensor([0.1841]) tensor([0.2632])
tensor([0]) tensor([0.2542]) tensor([0.1017])
tensor([0]) tensor([0.3797]) tensor([0.1866])
tensor([1]) tensor([0.9380]) tensor([0.7980])
tensor([1]) tensor([0.5871]) tensor([0.2120])
tensor([1]) tensor([0.8539]) tensor([0.5352])
tensor([1]) tensor([0.7789]) tensor([0.3858])
tensor([0]) tensor([0.3911]) tensor([0.2345])
tensor([0]) tensor([0.2162]) tensor([0.1618])
tensor([0]) tensor([0.0126]) tensor([0.0138])
tensor([0]) tensor([0.3967]) tensor([0.6793])
tensor([1]) tensor([0.3227]) tensor([0.4913])


 74%|███████▎  | 267/363 [00:04<00:01, 64.10it/s]

tensor([0]) tensor([0.1237]) tensor([0.1724])
tensor([0]) tensor([0.2286]) tensor([0.6008])
tensor([0]) tensor([0.4625]) tensor([0.1285])
tensor([0]) tensor([0.3946]) tensor([0.1250])
tensor([0]) tensor([0.3889]) tensor([0.1922])
tensor([1]) tensor([0.3475]) tensor([0.0936])
tensor([0]) tensor([0.3981]) tensor([0.4022])
tensor([0]) tensor([0.7901]) tensor([0.6274])
tensor([1]) tensor([0.4378]) tensor([0.3720])
tensor([1]) tensor([0.2654]) tensor([0.3146])
tensor([0]) tensor([0.2734]) tensor([0.1007])
tensor([1]) tensor([0.3652]) tensor([0.2996])
tensor([0]) tensor([0.3953]) tensor([0.1861])
tensor([0]) tensor([0.8874]) tensor([0.5984])
tensor([0]) tensor([0.1172]) tensor([0.1536])


 78%|███████▊  | 282/363 [00:04<00:01, 68.40it/s]

tensor([0]) tensor([0.1862]) tensor([0.1512])
tensor([1]) tensor([0.7624]) tensor([0.6259])
tensor([0]) tensor([0.1741]) tensor([0.0912])
tensor([0]) tensor([0.2928]) tensor([0.0979])
tensor([0]) tensor([0.1529]) tensor([0.0565])
tensor([0]) tensor([0.4518]) tensor([0.2233])
tensor([1]) tensor([0.9045]) tensor([0.6904])
tensor([0]) tensor([0.3966]) tensor([0.1234])
tensor([0]) tensor([0.3348]) tensor([0.1594])
tensor([1]) tensor([0.4037]) tensor([0.2545])
tensor([0]) tensor([0.1203]) tensor([0.0744])
tensor([1]) tensor([0.6752]) tensor([0.2632])
tensor([1]) tensor([0.4394]) tensor([0.1264])
tensor([0]) tensor([0.3726]) tensor([0.5582])


 82%|████████▏ | 296/363 [00:04<00:00, 68.95it/s]

tensor([0]) tensor([0.4825]) tensor([0.2356])
tensor([1]) tensor([0.5391]) tensor([0.2021])
tensor([0]) tensor([0.7494]) tensor([0.6423])
tensor([0]) tensor([0.2458]) tensor([0.1619])
tensor([1]) tensor([0.0164]) tensor([0.0490])
tensor([0]) tensor([0.1073]) tensor([0.1799])
tensor([0]) tensor([0.4648]) tensor([0.1187])
tensor([0]) tensor([0.1059]) tensor([0.1669])
tensor([1]) tensor([0.9248]) tensor([0.7529])
tensor([0]) tensor([0.4247]) tensor([0.2132])
tensor([0]) tensor([0.0073]) tensor([0.0213])
tensor([0]) tensor([0.8726]) tensor([0.5877])
tensor([1]) tensor([0.7273]) tensor([0.4004])
tensor([1]) tensor([0.6358]) tensor([0.3850])
tensor([0]) tensor([0.2687]) tensor([0.3833])


 86%|████████▌ | 311/363 [00:04<00:00, 67.40it/s]

tensor([1]) tensor([0.7356]) tensor([0.3639])
tensor([0]) tensor([0.7316]) tensor([0.5611])
tensor([1]) tensor([0.2844]) tensor([0.5821])
tensor([1]) tensor([0.0749]) tensor([0.0528])
tensor([1]) tensor([0.6510]) tensor([0.4975])
tensor([0]) tensor([0.0373]) tensor([0.0838])
tensor([0]) tensor([0.0414]) tensor([0.1211])
tensor([0]) tensor([0.1313]) tensor([0.0730])
tensor([0]) tensor([0.4018]) tensor([0.5798])
tensor([1]) tensor([0.3622]) tensor([0.2236])
tensor([0]) tensor([0.3716]) tensor([0.2293])
tensor([0]) tensor([0.0471]) tensor([0.0486])
tensor([0]) tensor([0.1990]) tensor([0.3012])
tensor([0]) tensor([0.1494]) tensor([0.2861])


 90%|████████▉ | 325/363 [00:04<00:00, 66.73it/s]

tensor([0]) tensor([0.0885]) tensor([0.1752])
tensor([0]) tensor([0.0854]) tensor([0.5321])
tensor([1]) tensor([0.8147]) tensor([0.4796])
tensor([0]) tensor([0.1122]) tensor([0.0896])
tensor([1]) tensor([0.1729]) tensor([0.7738])
tensor([0]) tensor([0.1650]) tensor([0.4782])
tensor([0]) tensor([0.4111]) tensor([0.6065])
tensor([1]) tensor([0.5915]) tensor([0.4122])
tensor([1]) tensor([0.7504]) tensor([0.3588])
tensor([1]) tensor([0.9305]) tensor([0.6973])
tensor([0]) tensor([0.4263]) tensor([0.2288])
tensor([1]) tensor([0.4290]) tensor([0.2217])


 93%|█████████▎| 339/363 [00:05<00:00, 57.85it/s]

tensor([1]) tensor([0.8615]) tensor([0.7504])
tensor([0]) tensor([0.7350]) tensor([0.3403])
tensor([0]) tensor([0.7806]) tensor([0.6642])
tensor([0]) tensor([0.0600]) tensor([0.0858])
tensor([0]) tensor([0.0693]) tensor([0.1231])
tensor([1]) tensor([0.8690]) tensor([0.6519])
tensor([1]) tensor([0.7759]) tensor([0.6621])
tensor([0]) tensor([0.1189]) tensor([0.2144])
tensor([1]) tensor([0.8105]) tensor([0.6786])
tensor([1]) tensor([0.9559]) tensor([0.8446])
tensor([0]) tensor([0.0197]) tensor([0.0269])


 98%|█████████▊| 354/363 [00:05<00:00, 63.73it/s]

tensor([0]) tensor([0.6385]) tensor([0.3796])
tensor([0]) tensor([0.5040]) tensor([0.7486])
tensor([1]) tensor([0.7902]) tensor([0.5857])
tensor([1]) tensor([0.5914]) tensor([0.2002])
tensor([0]) tensor([0.1803]) tensor([0.7025])
tensor([1]) tensor([0.9372]) tensor([0.8492])
tensor([0]) tensor([0.3990]) tensor([0.4299])
tensor([1]) tensor([0.6153]) tensor([0.4069])
tensor([0]) tensor([0.3362]) tensor([0.2008])
tensor([0]) tensor([0.1109]) tensor([0.1643])
tensor([1]) tensor([0.9299]) tensor([0.7059])
tensor([1]) tensor([0.8570]) tensor([0.6518])
tensor([0]) tensor([0.0518]) tensor([0.1720])
tensor([0]) tensor([0.0265]) tensor([0.0288])
tensor([0]) tensor([0.3143]) tensor([0.2468])


100%|██████████| 363/363 [00:05<00:00, 65.04it/s]

tensor([1]) tensor([0.2733]) tensor([0.2048])
tensor([0]) tensor([0.4483]) tensor([0.6708])
tensor([1]) tensor([0.9590]) tensor([0.8020])
tensor([0]) tensor([0.2229]) tensor([0.3432])
tensor([1]) tensor([0.5704]) tensor([0.3645])
tensor([0]) tensor([0.3253]) tensor([0.2243])
tensor([0]) tensor([0.7943]) tensor([0.6517])
tensor([1]) tensor([0.3595]) tensor([0.1622])
Groundtruth Metrics
tensor(0.8294) tensor(0.7476)





In [106]:
score = calc_optimal_thres(gt,y_hat_preds)
score

(0.4800000000000001, tensor(0.7286))

In [105]:
import torch
from typing import Tuple
import numpy as np

def calc_optimal_thres(target, preds):
    score_dict ={}
    thres = np.arange(0.1,1,0.02)
    for t in thres:
        func = BinaryF1Score(threshold=t)
        score = func(preds,target)
        score_dict[t] = score
    best_thres = max(score_dict,key=score_dict.get)
    return best_thres,score_dict[best_thres]
    
    
    
    

def fidelity(target: torch.Tensor, logits: torch.Tensor, logits_removed_subgraph: torch.Tensor):
    """
    Calculates the fidelity scores based on the given target labels, prediction logits, 
    and prediction logits after the removal of the explanatory subgraph.

    Args:
        target (torch.Tensor): A tensor of shape (N,) containing the target labels.
        logits (torch.Tensor): A tensor of shape (N, C) containing the prediction logits.
        logits_removed_subgraph (torch.Tensor): A tensor of shape (N, C) containing the 
                                                 prediction logits after removal of the subgraph.

    Returns:
        Tuple[float, float]: The fidelity- and fidelity+ scores.
    """
    # Get the predicted labels from logits
    predicted_labels = torch.argmax(logits, dim=1)
    predicted_labels_removed = torch.argmax(logits_removed_subgraph, dim=1)

    # Calculate the fidelity+ score
    correct_predictions = (predicted_labels == target).float()
    correct_predictions_removed = (predicted_labels_removed == target).float()
    
    fidelity_plus = torch.mean(torch.abs(correct_predictions - correct_predictions_removed)).item()

    return fidelity_plus


fidelity_plus = fidelity(gt, y_hat_preds, y_hat_removal_preds)
print(f"Fidelity+: {fidelity_plus}")


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [None]:
import torch

def remove_topk_edges(attn_score, batch, topK):
    """
    Given an attention score tensor, edge indices, and a value topK, this function removes the
    top K edges with the highest attention scores and returns the new edge_index.

    Args:
        attn_score (torch.Tensor): A tensor of shape (E,) containing the attention scores for each edge.
        edge_index (torch.Tensor): A tensor of shape (2, E) containing the edge indices.
        topK (int): The number of top attention scores to consider for removal.

    Returns:
        torch.Tensor: A tensor containing the new edge indices after the removal.
    """
    # Get the indices of the top K largest values in attn_score
    edge_index = batch.edge_index
    if 'edge_attr' in batch:
        edge_attr = batch.edge_attr
    else:
        edge_attr = None
    topk_indices = torch.topk(attn_score, topK).indices

    # Create a mask to keep all edges except the top K
    mask = torch.ones(attn_score.size(0), dtype=torch.bool)
    mask[topk_indices] = False

    # Apply the mask to edge_index to get the new edge_index
    new_edge_index = edge_index[:, mask]
    if edge_attr is not None:
        new_edge_attr = edge_attr[mask]
    else:
        new_edge_attr = None

    return new_edge_index,new_edge_attr



def topk_nodes_from_attn(attn_score: torch.Tensor, edge_index: torch.Tensor, topK: int) -> torch.Tensor:
    """
    Given an attention score tensor, edge indices, and a value topK, this function returns the 
    top K nodes with the highest attention scores.

    Args:
        attn_score (torch.Tensor): A tensor of shape (E,) containing the attention scores for each edge.
        edge_index (torch.Tensor): A tensor of shape (2, E) containing the edge indices.
        topK (int): The number of top attention scores to consider.

    Returns:
        torch.Tensor: A tensor containing the nodes corresponding to the top K attention scores.
    """
    # Get the indices of the top K largest values in attn_score
    topk_indices = torch.topk(attn_score, topK).indices

    # Get the corresponding edges from edge_index using the top K indices
    topk_edges = edge_index[:, topk_indices]

    # Get the unique nodes from the top K edges
    topk_nodes = torch.unique(topk_edges)

    return topk_nodes


    

In [None]:
pred_edges = []
gt_edges = []

for batch in data_loader_single:
    data = process_data(batch,True)
    batch_att, _, _ = eval_one_batch(gsat, data.to(gsat.device), epoch=500)
    pred_edges.append(batch_att.view(-1,))
    gt_edges.append(data.edge_label.view(-1,))
pred_edges = torch.cat(pred_edges)
gt_edges = torch.cat(gt_edges)


In [None]:
num_viz_samples = 10
assert aux_info['multi_label'] is False

all_viz_set = get_viz_idx(test_set, dataset_name, num_viz_samples)
visualize_results(gsat, all_viz_set, test_set, num_viz_samples, dataset_name, model_config['use_edge_attr'])

In [None]:
all_viz_set