In [1]:
import numpy as np
%load_ext autoreload
%autoreload 2
from load_data import load_data
import torch
from modules import GNN
from train_model import train_model
from subgraph_relevance import subgraph_original, subgraph_mp_transcription, subgraph_mp_forward_hook, get_H_transform
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import sys
import pandas as pd
from io import StringIO
import pickle as pkl
from top_walks import *
from utils import *
from IPython.display import SVG, display

from torch_geometric.utils import to_dense_adj

# Generate Dataset
-> Lukas Faber, Amin K. Moghaddam, and Roger Wattenhofer. 2021. When Comparing to Ground Truth is Wrong: On Evaluating GNN Explanation Methods. In Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining (KDD '21). Association for Computing Machinery, New York, NY, USA, 332–341. https://doi.org/10.1145/3447548.3467283


In [None]:
import random
from collections import defaultdict

# import mlflow
import networkx as nx
from torch_geometric.utils import from_networkx

def create_infection_dataset(num_layers=4,size=10):
    print(num_layers+2, "classes")

    dataset = []

    for i in range(size):
        max_dist = num_layers  # anything larger than max_dist has a far away label
        g = nx.erdos_renyi_graph(1000, 0.004, directed=True)
        N = len(g.nodes())
        infected_nodes = random.sample(g.nodes(), 50)
        g.add_node('X')  # dummy node for easier computation, will be removed in the end
        for u in infected_nodes:
            g.add_edge('X', u)
        shortest_path_length = nx.single_source_shortest_path_length(g, 'X')
        unique_solution_nodes = []
        unique_solution_explanations = []
        labels = []
        features = np.zeros((N, 2))
        for i in range(N):
            if i == 'X':
                continue
            length = shortest_path_length.get(i, 100) - 1  # 100 is inf distance
            labels.append(min(max_dist + 1, length))
            col = 0 if i in infected_nodes else 1
            features[i, col] = 1
            if 0 < length <= max_dist:
                path_iterator = iter(nx.all_shortest_paths(g, 'X', i))
                unique_shortest_path = next(path_iterator)
                if next(path_iterator, 0) != 0:
                    continue
                unique_shortest_path.pop(0)  # pop 'X' node
                if len(unique_shortest_path) == 0:
                    continue
                unique_solution_explanations.append(unique_shortest_path)
                unique_solution_nodes.append(i)
        g.remove_node('X')
        data = from_networkx(g)
        data.x = torch.tensor(features, dtype=torch.float)
        data.y = torch.tensor(labels)
        data.unique_solution_nodes = unique_solution_nodes
        data.unique_solution_explanations = unique_solution_explanations
        data.num_classes = 1 + max_dist + 1
        # print('created one')
        dataset.append(data)
    return dataset

In [None]:
dataset = create_infection_dataset(num_layers=4,size=100)

# Transform to SIR dataset

- INFECTION_RATE: infection rate
- CURE_RATE: cure rate
- IMMUNE_RATE: immune rate, default 1

Labels:
- 0: Not infected
- 1: Infected

In [None]:
train_set = []
for d in dataset[:80]:
    train_set.append(d)
test_set = []
for d in dataset[80:]:
    test_set.append(d)

In [None]:
INIT_INFECT_RATE = 0.02
INFECTION_RATE = 0.6
CURE_RATE = 0.
IMMUNE_RATE = 1

SUSPECTFUL = 0
INFECTED = 1
IMMUNED = 2

zero_patient = False
np.random.seed(0)

In [None]:
steps = 4
for data in train_set + test_set:
    data.num_classes = 2
    del data.unique_solution_nodes
    del data.unique_solution_explanations

    A = to_dense_adj(data.edge_index)[0]
    
    if zero_patient:
        data.x = torch.zeros_like(data.x)
        data.x[:,1] = 1
        data.x[0,0] = 1
        data.x[0,1] = 0
    else:
        init_infect_rate_ = np.random.rand() * INIT_INFECT_RATE
        x = torch.tensor([int(np.random.rand() < init_infect_rate_) for _ in range(data.num_nodes)])
        data.x = torch.column_stack([x, 1*(x != 1)]).float()
    
    x_state = torch.zeros(data.num_nodes)
    x_state[torch.where(data.x[:,0]==1)] = INFECTED

    infection_chains = {}
    for node in range(data.num_nodes):
        if x_state[node] == INFECTED:
            infection_chains[node] = [node]

    for step in range(steps):
        I_nodes = torch.where(x_state==INFECTED)[0].tolist()
        for I_node in I_nodes:
            for node in A[I_node].nonzero().flatten().tolist():
                if node in I_nodes:
                    continue

                if np.random.rand() < INFECTION_RATE:
                    if x_state[node] == IMMUNED:
                        if np.random.rand() < 1 - IMMUNE_RATE:
                            x_state[node] = INFECTED
                            infection_chains[node] = infection_chains[I_node].copy() + [node]
                    else:
                        x_state[node] = INFECTED
                        infection_chains[node] = infection_chains[I_node].copy() + [node]
            
            if np.random.rand() < CURE_RATE:
                x_state[I_node] = IMMUNED
                del infection_chains[I_node]
        
    data.infection_chains = infection_chains.copy()
    data.y = (x_state==INFECTED) * 1

data = val_data
data.num_classes = 2
del data.unique_solution_nodes
del data.unique_solution_explanations

A = to_dense_adj(data.edge_index)[0]

if zero_patient:
    data.x = torch.zeros_like(data.x)
    data.x[:,1] = 1
    data.x[0,0] = 1
    data.x[0,1] = 0
else:
    init_infect_rate_ = np.random.rand() * INIT_INFECT_RATE
    x = torch.tensor([int(np.random.rand() < init_infect_rate_) for _ in range(data.num_nodes)])
    data.x = torch.column_stack([x, 1*(x != 1)]).float()

x_state = torch.zeros(data.num_nodes)
x_state[torch.where(data.x[:,0]==1)] = INFECTED

infection_chains = {}
for node in range(data.num_nodes):
    if x_state[node] == INFECTED:
        infection_chains[node] = [node]

for step in range(steps):
    I_nodes = torch.where(x_state==INFECTED)[0].tolist()
    for I_node in I_nodes:
        for node in A[I_node].nonzero().flatten().tolist():
            if node in I_nodes:
                continue

            if np.random.rand() < INFECTION_RATE:
                if x_state[node] == IMMUNED:
                    if np.random.rand() < 1 - IMMUNE_RATE:
                        x_state[node] = INFECTED
                        infection_chains[node] = infection_chains[I_node].copy() + [node]
                else:
                    x_state[node] = INFECTED
                    infection_chains[node] = infection_chains[I_node].copy() + [node]
        
        if np.random.rand() < CURE_RATE:
            x_state[I_node] = IMMUNED
            del infection_chains[I_node]
    
data.infection_chains = infection_chains.copy()
data.y = (x_state==INFECTED) * 1

# Train model

In [None]:
from data_structure import Graph
from train_model import modules
import torch.nn.functional as F

num_layer = 4
config = {
    'num_layer': num_layer,
    'mode': 'gcn',
    'epochs': 100,
    'lr': 0.0005,
    'model_dir': 'models/gcn-'+str(num_layer)+'-infection-sir.torch',
    'nbclasses': 2,
    'inter_feat_dim': 32,
    'print_out_nb': 100,
    'optimizer': 'adam'
}
num_layer= config['num_layer']
mode = config['mode']
epochs = config['epochs']
lr = config['lr'] 
model_dir = config['model_dir']
nbclasses = config['nbclasses']
inter_feat_dim = config['inter_feat_dim']
print_out_nb = config['print_out_nb']
optimizer_label = config['optimizer']

H0_dim = 2

torch.manual_seed(0)
print('train {}'.format(model_dir))

In [None]:
model = modules.Net1(num_node_features=H0_dim, num_classes=nbclasses, num_layers=num_layer, concat_features=False, conv_type='GraphConv')
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
model.train()
losses = []
accs = []

for ep in range(epochs):
    loss_ = []
    acc_ = []
    for data in train_set:
        optimizer.zero_grad()
        output = model(data.x, data.edge_index)
        acc = ((output.argmax(axis=1) == data.y) * 1.0).mean()
        loss = F.nll_loss(output, data.y)
        loss.backward()
        loss_.append(float(loss))
        acc_.append(float(acc))
        optimizer.step()
    losses.append(np.mean(loss_))
    accs.append(np.mean(acc_))
    

    if (ep+1) % print_out_nb == 0:
        loss_all = 0
        acc_all = 0
        for data in test_set:
            output = model(data.x, data.edge_index)
            loss = F.nll_loss(output, data.y)
            loss_all += loss.item()
            acc_all += ((output.argmax(axis=1) == data.y) * 1.0).mean()
        
        
        output = model(val_data.x, val_data.edge_index)
        loss_val = F.nll_loss(output, val_data.y).item()
        acc_val = ((output.argmax(axis=1) == val_data.y) * 1.0).mean()
        print("Ep:", ep+1, "Loss: ",loss_all / 20, "Acc: ", acc_all / 20, "Val loss: ", loss_val, "Val acc: ", acc_val)

In [None]:
plt.plot(losses)
plt.plot(accs)
plt.legend(['loss', 'acc'])

# Infection Chain Detection Evaluation

## Oracle

In [None]:
def mc_oracle_ks_edge(data, model, skip_edge=False):
    steps = 4
    sim_steps = 1000
    inf_hist = []
    A = to_dense_adj(data.edge_index)[0]

    infection_chains_list = []
    for _ in range(sim_steps):
        x_state = torch.zeros(data.num_nodes)
        x_state[torch.where(data.x[:,0]==1)] = INFECTED

        infection_chains = {}
        for node in range(data.num_nodes):
            if x_state[node] == INFECTED:
                infection_chains[node] = [node]

        for step in range(steps):
            I_nodes = torch.where(x_state==INFECTED)[0].tolist()
            for I_node in I_nodes:
                for node in A[I_node].nonzero().flatten().tolist():
                    if node in I_nodes:
                        continue

                    if np.random.rand() < INFECTION_RATE:
                        if x_state[node] == IMMUNED:
                            if np.random.rand() < 1 - IMMUNE_RATE:
                                x_state[node] = INFECTED
                                infection_chains[node] = infection_chains[I_node].copy() + [node]
                        else:
                            x_state[node] = INFECTED
                            infection_chains[node] = infection_chains[I_node].copy() + [node]
                
                if np.random.rand() < CURE_RATE:
                    x_state[I_node] = IMMUNED
                    del infection_chains[I_node]
        inf_hist.append(((x_state==INFECTED) * 1).tolist())
        infection_chains_list.append(infection_chains)

    inf_hist = np.array(inf_hist)

    infection_chains_prob_dict = {}
    edges_prob_dict = {}
    for infection_chains in infection_chains_list:
        for key, val in zip(infection_chains.keys(), infection_chains.values()):
            if key not in infection_chains_prob_dict:
                infection_chains_prob_dict[key] = {}
                edges_prob_dict[key] = {}

            val = tuple(val)
            if val not in infection_chains_prob_dict[key]:
                infection_chains_prob_dict[key][val] = 1 / sim_steps
            else:
                infection_chains_prob_dict[key][val] += 1 / sim_steps

            if not skip_edge:
                edges = [(val[i], val[i+1]) for i in range(len(val) - 1)]
                for edge in edges:
                    if edge[0] == edge[1]: continue
                    if edge not in edges_prob_dict[key]:
                        edges_prob_dict[key][edge] = 1
                    else:
                        edges_prob_dict[key][edge] += 1


    return infection_chains_prob_dict, edges_prob_dict

## Evaluation

In [None]:
acc_edge_igs = []
acc_edge_igs_prod = []
acc_top_ks_zero = []
acc_top_ks_ab = []
acc_top_ks_gamma = []
acc_top_ks_02 = []
oracle_ks_list = []
oracle_edge_nos_list = []

num_walks = 25
threshold = 0.05
np.random.seed(0)
data_idxs  = np.random.choice(100, 20, replace=False)

def compare_topk_groundtruth(real_top_k_max_walks_rels, inf_chain):
    top_k_walk = [item[0] for item in real_top_k_max_walks_rels]

    found = False
    for k in range(len(top_k_walk)):
        walk_drop_duplicate = [top_k_walk[k][0]]
        for node in top_k_walk[k][1:]:
            if walk_drop_duplicate[-1] != node:
                walk_drop_duplicate.append(node)

        if tuple(walk_drop_duplicate) == tuple(inf_chain): 
            found = True
            k = k+1
            break
    if found == False: 
        k = float('inf')
    return k

for data_idx in data_idxs:
    data = dataset[data_idx]
    pos_correct_pred_nodes = sorted(list(set(data.y.nonzero().flatten().tolist()).intersection(set(model(data.x, data.edge_index).argmax(axis=1).nonzero().flatten().tolist()))))
    
    infection_chains_prob_dict, _ = mc_oracle_ks_edge(data, model, skip_edge=True)

    for node_idx_to_explain in tqdm(np.random.choice(pos_correct_pred_nodes, min(200, len(pos_correct_pred_nodes)), replace=False)):
        if len(data.infection_chains[node_idx_to_explain]) == 1: continue
        inf_chain = data.infection_chains[node_idx_to_explain]

        
        edges_ig, edge_rels_ig = edge_ig(data, model, node_idx_to_explain, top_n=num_walks)        
        edges_ig = [tuple(edge) for edge in edges_ig]
        edges_ig_dict = dict(zip(edges_ig, edge_rels_ig))

        walks_candidates = set(edges_ig)
        for step in range(4):
            walks_candidates_toadd = set()
            for walk in walks_candidates:
                for edge in edges_ig:
                    if walk[-1] == edge[0]:
                        walks_candidates_toadd.add(tuple(list(walk)+[edge[1]]))
            walks_candidates = walks_candidates.union(walks_candidates_toadd)

        walks_candidates_rels_sum = []
        walks_candidates_rels_prod = []
        for walk in walks_candidates:
            rel_sum = 0
            rel_prod = 1
            for i in range(len(walk) - 1):
                rel_sum += edges_ig_dict[(walk[i], walk[i+1])]
                rel_prod *= edges_ig_dict[(walk[i], walk[i+1])]
            walks_candidates_rels_sum.append([walk, rel_sum])
            walks_candidates_rels_prod.append([walk, rel_prod])

        top_walks_edge_ig = sorted(walks_candidates_rels_sum, key = lambda x: x[1], reverse=True)
        k = compare_topk_groundtruth(top_walks_edge_ig, inf_chain)
        acc_edge_igs.append(k)

        top_walks_edge_ig = sorted(walks_candidates_rels_prod, key = lambda x: x[1], reverse=True)
        k = compare_topk_groundtruth(top_walks_edge_ig, inf_chain)
        acc_edge_igs_prod.append(k)
        
        rel_components, X_fc, A = get_rel_components(data=data, model=model,lrp_rule = 'gamma', gammas = None, normalize=False)
        real_top_k_max_walks_rels, real_top_k_min_walks_rels = approx_top_k_walk(num_walks, node_idx_to_explain, data, model, rel_components, X_fc, A, verbose=False)
        k = compare_topk_groundtruth(real_top_k_max_walks_rels, inf_chain)
        acc_top_ks_gamma.append(k)

        rel_components, X_fc, A = get_rel_components(data=data, model=model,lrp_rule = 'gamma', gammas = [0.2]*4, normalize=False)
        real_top_k_max_walks_rels, real_top_k_min_walks_rels = approx_top_k_walk(num_walks, node_idx_to_explain, data, model, rel_components, X_fc, A, verbose=False)
        k = compare_topk_groundtruth(real_top_k_max_walks_rels, inf_chain)
        acc_top_ks_02.append(k)

        rel_components, X_fc, A = get_rel_components(data=data, model=model,lrp_rule = 'pos-clip', gammas = None, normalize=False)
        real_top_k_max_walks_rels, real_top_k_min_walks_rels = approx_top_k_walk(num_walks, node_idx_to_explain, data, model, rel_components, X_fc, A, verbose=False)
        k = compare_topk_groundtruth(real_top_k_max_walks_rels, inf_chain)
        acc_top_ks_ab.append(k)

        k = compare_topk_groundtruth(infection_chains_prob_dict[node_idx_to_explain], inf_chain)
        oracle_ks_list.append(k)
    

In [None]:
fig = plt.figure(figsize=(5,2))
x_range = 31
plt.plot(np.arange(1,x_range), [np.mean(np.array(oracle_ks_list) <= i) for i in range(1, x_range)], '--', color='gray')
plt.plot(np.arange(1,x_range), [np.mean(np.array(acc_top_ks_gamma) <= i) for i in range(1, x_range)], 'b-.')
plt.plot(np.arange(1,x_range), [np.mean(np.array(acc_edge_igs) <= i) for i in range(1, x_range)], 'g-+')
plt.plot(np.arange(1,x_range), [np.mean(np.array(acc_edge_igs_prod) <= i) for i in range(1, x_range)], '-')


plt.grid()
plt.ylabel('Walk Recall')
plt.xlabel('K')
plt.legend([r'oracle', r'AMP-ave (ours)', 'Edge-IG sum', 'Edge-IG prod'])

plt.xlim(1,x_range-1)
plt.ylim(0.4,1.01)
plt.xticks([1,5,10,15,20,25,30])
plt.savefig("imgs/infection_res_walk.svg", dpi=600, format='svg',bbox_inches='tight')