In [None]:
import numpy as np
%load_ext autoreload
%autoreload 2
from load_data import load_data
import torch
from modules import GNN
from train_model import train_model
from subgraph_relevance import subgraph_original, subgraph_mp_transcription, subgraph_mp_forward_hook, get_H_transform
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import sys
import pandas as pd
from io import StringIO
import pickle as pkl
from top_walks import *
from utils import *
from IPython.display import SVG, display
from data_structure import Graph

In [None]:
dataset_model_dirs = [['BA-2motif','gin-3-ba2motif.torch'],
                      ['BA-2motif','gin-5-ba2motif.torch'],
                      ['BA-2motif','gin-7-ba2motif.torch'],
                      ['MUTAG', 'gin-3-mutag.torch'],
                      ['Mutagenicity', 'gin-3-mutagenicity.torch'],
                      ['REDDIT-BINARY', 'gin-5-reddit.torch'],
                      ['Graph-SST2', 'gcn-3-sst2graph.torch'],
                      ['Mutagenicity', 'gin-4-Mutagenicity.torch']]

In [None]:
from captum.attr import IntegratedGradients

def model_edge_forward(edge_mask, nn, g):
    edges = g.get_adj().nonzero()
    a = torch.zeros_like(g.get_adj())
    for mask, edge in zip(edge_mask, edges):
        if edge[0] > edge[1]: continue
        a[edge[0]][edge[1]] = mask
        a[edge[1]][edge[0]] = mask
    pred = nn.forward(a, H0=g.node_features)
    return pred.reshape(1,2)

def model_node_forward(node_mask, nn, g):
    a = g.get_adj()
    a = a * node_mask.reshape(1,-1)
    a = a * node_mask.reshape(-1,1)
    assert (a - a.T).sum()== 0
    pred = nn.forward(a, H0=g.node_features)
    return pred.reshape(1,2)

# Mutagenicity

In [None]:
dataset, model_dir = dataset_model_dirs[7]

graphs, pos_idx, neg_idx = load_data(dataset)
nn = torch.load('models/'+model_dir)

In [None]:
g = graphs[4233]
H, transforms = get_H_transform(g.get_adj(),nn, H0=g.node_features, gammas=None, mode='gamma')
pred = nn.forward(g.get_adj(), H0=g.node_features)
print(pred)
pred = pred.argmax()

init_rel = np.zeros_like(H)
init_rel[:, pred] = H[:, pred]
num_walks = 10000000
top_max_walks, top_min_walks = topk_walks(g, nn, num_walks=num_walks, lrp_mode="gamma", mode="node")
top_walks = [walk_rel for walk_rel in top_max_walks+top_min_walks]
plot_top_k_walks(g, top_walks, transforms, init_rel, mode="node", factor=0.3, width=5, figname=f"imgs/molecule_topk_walks_{num_walks}.svg", dataset="Mutagenicity", linewidth= 3 if num_walks > 100 else 13)
# display(SVG("imgs/topk_walks.svg"))

In [None]:
target = nn.forward(g.get_adj(), H0=g.node_features).argmax()

ig = IntegratedGradients(model_edge_forward)
input_mask = torch.ones(len(g.get_adj().nonzero())).requires_grad_(True)
ig_mask = ig.attribute(input_mask, target=target, additional_forward_args=(nn, g),
                        internal_batch_size=len(input_mask))

edge_mask = ig_mask.cpu().detach().numpy()

edges = g.get_adj().nonzero()
sorted_edges = []
for i in edge_mask.argsort()[::-1]:
    sorted_edges.append((edges[i].tolist(), edge_mask[i]))

In [None]:
plot_top_k_walks(g, sorted_edges, transforms, init_rel, mode="node", factor=4, figname=f"imgs/molecule_topk_walks_{num_walks}.svg", dataset="Mutagenicity", width=5, compute_rel=False)

In [None]:
target = nn.forward(g.get_adj(), H0=g.node_features).argmax()

ig = IntegratedGradients(model_node_forward)
input_mask = torch.ones(g.nbnodes).requires_grad_(True)
ig_mask = ig.attribute(input_mask, target=target, additional_forward_args=(nn, g),
                        internal_batch_size=len(input_mask))

node_mask = ig_mask.cpu().detach().numpy()

In [None]:
node_label_dict = \
        {0:'C',1:'O',2:'Cl',3:'H',4:'N',5:'F',6:'Br',7:'S',8:'P',9:'I',10:'Na',11:'K',12:'Li',13:'Ca'} 

atoms = [node_label_dict[i] for i in g.node_tags]
molecule = Chem.RWMol()
for atom in atoms:
    molecule.AddAtom(Chem.Atom(atom))
A = g.get_adj().nonzero()

for x, y in A:
    if x < y:
        molecule.AddBond(int(x), int(y), Chem.rdchem.BondType.SINGLE)

AllChem.Compute2DCoords(molecule)
# compute 2D positions
pos = []
n_nodes = molecule.GetNumAtoms()
for i in range(n_nodes):
    conformer_pos = molecule.GetConformer().GetAtomPosition(i)
    pos.append([conformer_pos.x, conformer_pos.y])
    
pos = np.array(pos)
node_labels = atoms
factor = 1.
width=6
fig_width = width
pos_size = pos.max(axis=0) - pos.min(axis=0)
fig_height = (width / pos_size[0]) * pos_size[1]
fig = plt.figure(figsize=(fig_width, fig_height))
ax = plt.subplot(1, 1, 1)

G = nx.from_numpy_matrix(g.get_adj().numpy().astype(int)-np.eye(g.get_adj().shape[0]))
# plot atoms

node_colors = []
alphas = []
for node, score in enumerate(node_mask):
    node_colors.append('red' if score >= 0 else 'blue')
    alphas.append(abs(score * factor))

collection = nx.draw_networkx_nodes(G, pos, node_color=node_colors, alpha=alphas, node_size=500)
collection.set_zorder(2.)
# plot bonds
nx.draw(
    G,
    pos=pos,
    with_labels=False,
    node_color="k" if node_labels is None else "w",
    width=2,
    style="dotted",
    node_size=0
)
nx.draw_networkx_nodes(G, pos, node_color=node_colors, alpha=alphas, node_size=500)

if node_labels is not None:
    pos_labels = pos
    nx.draw_networkx_labels(G, pos_labels, {i: name.split('_')[0] for i, name in enumerate(node_labels)}, font_size=10)
plt.axis('off')


plt.savefig("imgs/molecule_topk_walks_{num_walks}.svg", dpi=600, format='svg',bbox_inches='tight', transparent=True)

# Graph-SST2

In [None]:
dataset, model_dir = dataset_model_dirs[6]

graphs, pos_idx, neg_idx = load_data(dataset)
nn = torch.load('models/'+model_dir)

In [None]:
graph_idx = 277
g = graphs[graph_idx]
print(graph_idx, 'pos' if g.label == 0 else 'neg')
if g.label == 1: 
    H, transforms = get_H_transform(g.get_adj(),nn, H0=g.node_features, gammas=None, mode='gamma')
    pred = nn.forward(g.get_adj(), H0=g.node_features)
    pred = pred.argmax()
    if pred != g.label: pass

    init_rel = np.zeros_like(H)
    init_rel[:, pred] = H[:, pred]

    num_walks = 10
    top_max_walks, top_min_walks = topk_walks(g, nn, num_walks=num_walks*2, lrp_mode="gamma", mode="node")
    top_walks = [walk_rel for walk_rel in top_max_walks+top_min_walks][:num_walks]

    plot_top_k_walks(g, top_walks, transforms, init_rel, idx=graph_idx, mode="node", factor=0.3, figname=f"imgs/graphsst2_topk_walks_{num_walks}.svg", dataset="Graph-SST2", width=5)
    plt.show()

graph_idx += 1

In [None]:
import itertools

walk_rels = {}
for walk in tqdm(itertools.product(np.arange(g.nbnodes),np.arange(g.nbnodes),np.arange(g.nbnodes),np.arange(g.nbnodes))):
    rel = walk_rel(transforms, init_rel, walk, mode="node")
    walk_rels[tuple(walk)] = rel
sorted_walk_rels = sorted(walk_rels.items(), key=lambda item: item[1], reverse=True)

plot_top_k_walks(g, sorted_walk_rels, transforms, init_rel, idx=graph_idx-1, mode="node", factor=0.3, figname=f"imgs/graphsst2_topk_walks_{num_walks}.svg", dataset="Graph-SST2", width=5, linewidth=3)
plt.show()

In [None]:
plot_top_k_walks(g, sorted_walk_rels, transforms, init_rel, idx=graph_idx-1, mode="node", factor=0.3, figname=f"imgs/graphsst2_topk_walks_{num_walks}.svg", dataset="Graph-SST2", width=5, linewidth=3)
plt.show()

In [None]:
target = nn.forward(g.get_adj(), H0=g.node_features).argmax()

ig = IntegratedGradients(model_edge_forward)
input_mask = torch.ones(len(g.get_adj().nonzero())).requires_grad_(True)
ig_mask = ig.attribute(input_mask, target=target, additional_forward_args=(nn, g),
                        internal_batch_size=len(input_mask))

edge_mask = ig_mask.cpu().detach().numpy()
edge_mask = edge_mask / max(abs(edge_mask))
edges = g.get_adj().nonzero()
sorted_edges = []
for i in edge_mask.argsort()[::-1]:
    sorted_edges.append((edges[i].tolist(), edge_mask[i]))

plot_top_k_walks(g, sorted_edges, transforms, init_rel,idx=graph_idx-1,  mode="node", factor=0.6, figname=f"imgs/graphsst2_topk_walks_{num_walks}.svg", dataset="Graph-SST2", width=5, compute_rel=False)

In [None]:
target = nn.forward(g.get_adj(), H0=g.node_features).argmax()

ig = IntegratedGradients(model_node_forward)
input_mask = torch.ones(g.nbnodes).requires_grad_(True)
ig_mask = ig.attribute(input_mask, target=target, additional_forward_args=(nn, g),
                        internal_batch_size=len(input_mask))

node_mask = ig_mask.cpu().detach().numpy()
node_mask = node_mask / max(node_mask)

width = 5
idx = graph_idx-1
figname = f"imgs/graphsst2_topk_walks_{num_walks}.svg"

factor = 1.
edge_index = np.genfromtxt("datasets/Graph-SST2/raw/Graph_SST2_edge_index.txt")
node_indicator = np.genfromtxt("datasets/Graph-SST2/raw/Graph_SST2_node_indicator.txt")

nodes = np.where(node_indicator-1 == idx)[0].tolist()
edges = []
found = False
for edge in edge_index:
    if edge[0] in nodes:
        found = True
        edges.append(edge.astype(int).tolist())
    else:
        if found == True:
            break

edges = (np.array(edges) - min(nodes)).tolist()

with open("datasets/Graph-SST2/raw/Graph_SST2_sentence_tokens.json", "r") as f:
    token_json = json.load(f)

tokens = token_json[str(idx)]
for i in range(len(tokens)):
    if tokens.count(tokens[i]) != 1:
        tokens[i] = tokens[i]+"_"+str(i)
        
G = nx.DiGraph()
for n in range(len(g.node_features)):
    G.add_node(tokens[n])

for edge in edges:
    G.add_edge(tokens[edge[0]], tokens[edge[1]])
    
pos = np.array(list(nx.drawing.nx_pydot.graphviz_layout(G, prog='dot').values()))  
node_labels = tokens

fig_width = width
pos_size = pos.max(axis=0) - pos.min(axis=0)
fig_height = (width / pos_size[0]) * pos_size[1]
fig = plt.figure(figsize=(fig_width, fig_height))
ax = plt.subplot(1, 1, 1)

G = nx.from_numpy_matrix(g.get_adj().numpy().astype(int)-np.eye(g.get_adj().shape[0]))
# plot atoms
node_colors = []
alphas = []
for node, score in enumerate(node_mask):
    node_colors.append('red' if score >= 0 else 'blue')
    alphas.append(abs(score * factor))

collection = nx.draw_networkx_nodes(G, pos, node_color=node_colors, alpha=alphas, node_size=500)
collection.set_zorder(2.)
nx.draw_networkx_edges(G, pos)

if node_labels is not None:
    pos_labels = pos
    nx.draw_networkx_labels(G, pos_labels, {i: name.split('_')[0] for i, name in enumerate(node_labels)}, font_size=10)

plt.axis('off')

if figname is not None:
    plt.savefig(figname, dpi=600, format='svg',bbox_inches='tight',  transparent=True)

# Infection

In [None]:
from torch_geometric.utils import to_dense_adj
model = pkl.load(open("models/gcn-4-infection-sir_100_many_init_patient.torch",'rb'))
dataset = pkl.load(open('datasets/infection-sir/dataset_100_simulateded.pt','rb'))
train_set = []
for d in dataset[:80]:
    train_set.append(d)
test_set = []
for d in dataset[80:]:
    test_set.append(d)

In [None]:
data = test_set[4]

In [None]:
x_list = list([(key, dict(infect= val)) for key, val in zip(list(range(data.x.shape[0])), data.x[:,0].int().tolist())])
G = nx.DiGraph()
G.add_nodes_from(x_list)
G.add_edges_from(data.edge_index.T.tolist())

In [None]:
from captum.attr import IntegratedGradients
from infection_utils import get_rel_components, approx_most_rel_walk, plot_n_hop_infect_graph, walk_to_walk_edges, walk_to_walk_edges_set, approx_top_k_walk
def edge_ig(data, model, node_idx, top_n=10):
    # print("Target: ",data.y[node_idx])
    def model_forward(edge_mask, model, node_idx, x, edge_index):
        out = model(x, edge_index, edge_mask)
        return out[[node_idx]]

    x, edge_index, edge_weight = data.x, data.edge_index, None
    target = data.y[node_idx]

    ig = IntegratedGradients(model_forward)
    input_mask = torch.ones(edge_index.shape[1]).requires_grad_(True)
    ig_mask = ig.attribute(input_mask, target=target, additional_forward_args=(model, node_idx, x, edge_index),
                            internal_batch_size=edge_index.shape[1])
    
    edge_mask = ig_mask.cpu().detach().numpy()
    order = (-edge_mask).argsort()
    return edge_index[:, (-edge_mask).argsort()[:top_n]].T.tolist(), edge_mask[(-edge_mask).argsort()[:top_n]]

def node_ig(data, model, node_idx, top_n=10):
    # print("Target: ",data.y[node_idx])
    def model_forward(node_mask, model, node_idx, x, edge_index):
        out = model(x * node_mask.reshape(-1,1), edge_index)
        return out[[node_idx]]

    x, edge_index, edge_weight = data.x, data.edge_index, None
    target = data.y[node_idx]

    ig = IntegratedGradients(model_forward)
    input_mask = torch.ones(x.shape[0]).requires_grad_(True)
    ig_mask = ig.attribute(input_mask, target=target, additional_forward_args=(model, node_idx, x, edge_index),
                            internal_batch_size=len(input_mask))
    
    node_mask = ig_mask.cpu().detach().numpy()
    return (-node_mask).argsort()[:top_n].tolist(), node_mask[(-node_mask).argsort()[:top_n]]

In [None]:
node_idx_to_explain = 1
edges_ig, edge_rels_ig = edge_ig(data, model, node_idx_to_explain, top_n=10)

nodes_ig, nodes_rel_ig = node_ig(data, model, node_idx_to_explain, top_n=10)
nodes_rel_ig = nodes_rel_ig / max(nodes_rel_ig)

num_walks = 25
rel_components, X_fc, A = get_rel_components(data=data, model=model,lrp_rule = 'gamma', gammas = torch.linspace(4, 1, 4), normalize=False)

from infection_utils import approx_top_k_walk
# num_walks = 200
real_top_k_max_walks_rels, _ = approx_top_k_walk(3, node_idx_to_explain, data, model, rel_components, X_fc, A, verbose=False)


In [None]:
edge_rels_ig = np.array(edge_rels_ig)
edge_rels_ig = edge_rels_ig/max(edge_rels_ig)

In [None]:
important_nodes = nodes_ig.copy()
for walk in real_top_k_max_walks_rels:
    important_nodes += list(walk[0])
for edge in edges_ig:
    important_nodes += edge
important_nodes = list(set(important_nodes))

plot

In [None]:
hop = 4
pr_nodes = [node_idx_to_explain]

for i in range(hop):
    pr_nodes_cp = pr_nodes.copy()
    for n in pr_nodes_cp:
        if 'layer' not in G.nodes[n].keys(): G.nodes[n]['layer'] = i

        for pr in G.predecessors(n):
            pr_nodes.append(pr)
        pr_nodes = list(set(pr_nodes))

for n in pr_nodes:
    if 'layer' not in G.nodes[n].keys(): G.nodes[n]['layer'] = hop

subgraph = G.subgraph(pr_nodes)

topk_to_plot = 3
pos = nx.spring_layout(subgraph, seed=pos_seed, k=7/np.sqrt(G.order()))

# manually set positions
pos[700] = np.array([0.3, -0.3])
pos[767] = np.array([0.3, 0.])
pos[631] = np.array([0.3, 0.3])
pos[1] = np.array([0., 0.6])
pos[621] = np.array([-0.3, -0.3])
pos[295] = np.array([-0.3, 0.])
pos[922] = np.array([-0.2, 0.3])
pos[206] = np.array([-0., 0.3])

pos[256] = np.array([-0.6, -0.3])
pos[754] = np.array([-0.6, 0.])
pos[314] = np.array([-0.6, 0.3])

real_important_nodes = list(set(np.array([[real_top_k_max_walks_rels[rank][0][i:i+2] for i in range(4)] for rank in range(topk_to_plot)]).flatten()))

for key in pos.keys():
    if key not in important_nodes:
        pos[key] = np.clip(pos[key],a_min=-1,a_max=1)

nx.draw_networkx_nodes(subgraph, pos,
                       node_color = 'lightgrey', 
                       alpha = [1 if data.x[idx][0] == 1 else 0.1 for idx in subgraph.nodes], node_size = 10)
nx.draw_networkx_nodes(set(important_nodes)-set([node_idx_to_explain]+real_important_nodes), pos,
                       node_color = 'lightgrey', 
                       alpha = 0.3, node_size = 100, edgecolors = 'k')

nx.draw_networkx_nodes(set(real_important_nodes)-set([node_idx_to_explain]+ini_infect_nodes), pos,
                       node_color = 'lightgrey', 
                       alpha = 1, node_size = 100, edgecolors = 'k', linewidths=2)
ini_infect_nodes = []
for idx in important_nodes:
    if data.x[idx][0] == 1: ini_infect_nodes.append(idx)
nx.draw_networkx_nodes(ini_infect_nodes, pos,
                       node_color = 'lightgrey', 
                       alpha = 1, node_size = 100, edgecolors = 'k', node_shape='s', linewidths=2)
nx.draw_networkx_nodes([node_idx_to_explain], pos, 
                       node_color = 'r' if data.x[node_idx_to_explain][0] == 1 else 'lightgrey',
                       node_size = 200, edgecolors = 'r' if data.x[node_idx_to_explain][0] == 1 else 'k', node_shape='*')
labels = {}
for idx in subgraph.nodes:
    labels[idx] = idx if idx in important_nodes else ''
nx.draw_networkx_labels(subgraph, pos, labels=labels, font_size=5, font_color='k')

nx.draw_networkx_edges(subgraph, pos, width=0.5, edgelist=subgraph.edges, edge_color='k', alpha=0.05, arrows=True)
nx.draw_networkx_edges(subgraph, pos, edgelist=np.array([real_top_k_max_walks_rels[0][0][i:i+2] for i in range(4)]), edge_color='r', 
                            alpha=0.9, arrows=True, arrowsize= 5, width=2)
nx.draw_networkx_edges(subgraph, pos, edgelist=np.array([real_top_k_max_walks_rels[1][0][i:i+2] for i in range(4)]), edge_color='r', 
                            alpha=0.6, arrows=True, style='dashed', arrowsize= 5, width=2)
nx.draw_networkx_edges(subgraph, pos, edgelist=np.array([real_top_k_max_walks_rels[2][0][i:i+2] for i in range(4)]), edge_color='r', 
                            alpha=0.6, arrows=True, style='dotted', arrowsize= 5, width=2)
plt.axis('off')
plt.savefig("imgs/infection_ours.svg", dpi=600, format='svg',bbox_inches='tight',transparent=True)
# plt.show()

In [None]:
hop = 4
pr_nodes = [node_idx_to_explain]

for i in range(hop):
    pr_nodes_cp = pr_nodes.copy()
    for n in pr_nodes_cp:
        if 'layer' not in G.nodes[n].keys(): G.nodes[n]['layer'] = i

        for pr in G.predecessors(n):
            pr_nodes.append(pr)
        pr_nodes = list(set(pr_nodes))

for n in pr_nodes:
    if 'layer' not in G.nodes[n].keys(): G.nodes[n]['layer'] = hop

subgraph = G.subgraph(pr_nodes)

topk_to_plot = 10

for key in pos.keys():
    if key not in important_nodes:
        pos[key] = np.clip(pos[key],a_min=-1,a_max=1)

nx.draw_networkx_nodes(subgraph, pos,
                       node_color = 'lightgrey', 
                       alpha = [1 if data.x[idx][0] == 1 else 0.1 for idx in subgraph.nodes], node_size = 10)
nx.draw_networkx_nodes(set(important_nodes)-set([node_idx_to_explain]+real_important_nodes), pos,
                       node_color = 'lightgrey', 
                       alpha = 0.3, node_size = 110, edgecolors = 'k')
ini_infect_nodes = []
for idx in important_nodes:
    if data.x[idx][0] == 1: ini_infect_nodes.append(idx)
nx.draw_networkx_nodes(set(real_important_nodes)-set([node_idx_to_explain]+ini_infect_nodes), pos,
                       node_color = 'lightgrey', 
                       alpha = 1, node_size = 110, edgecolors = 'k', linewidths=2)

nx.draw_networkx_nodes(ini_infect_nodes, pos,
                       node_color = 'lightgrey', 
                       alpha = 1, node_size = 110, edgecolors = 'k', node_shape='s', linewidths=2)
nx.draw_networkx_nodes([node_idx_to_explain], pos, 
                       node_color = 'r' if data.x[node_idx_to_explain][0] == 1 else 'lightgrey',
                       node_size = 200, edgecolors = 'r' if data.x[node_idx_to_explain][0] == 1 else 'k', node_shape='*')

labels = {}
for idx in subgraph.nodes:
    labels[idx] = idx if idx in important_nodes else ''
nx.draw_networkx_labels(subgraph, pos, labels=labels, font_size=5, font_color='k')

nx.draw_networkx_edges(subgraph, pos, width=0.5, edgelist=subgraph.edges, edge_color='k', alpha=0.05, arrows=True)
for edge, rel in zip(edges_ig[:topk_to_plot], edge_rels_ig[:topk_to_plot]):
    nx.draw_networkx_edges(subgraph, pos, edgelist=np.array([edge]), edge_color='r', 
                            alpha=rel, arrows=True)
plt.axis('off')
plt.savefig("imgs/infection_edge.svg", dpi=600, format='svg',bbox_inches='tight',transparent=True)
# plt.show()

In [None]:
hop = 4
pr_nodes = [node_idx_to_explain]

for i in range(hop):
    pr_nodes_cp = pr_nodes.copy()
    for n in pr_nodes_cp:
        if 'layer' not in G.nodes[n].keys(): G.nodes[n]['layer'] = i

        for pr in G.predecessors(n):
            pr_nodes.append(pr)
        pr_nodes = list(set(pr_nodes))

for n in pr_nodes:
    if 'layer' not in G.nodes[n].keys(): G.nodes[n]['layer'] = hop

subgraph = G.subgraph(pr_nodes)

topk_to_plot = 3

# real_important_nodes = nodes_ig

for key in pos.keys():
    if key not in important_nodes:
        pos[key] = np.clip(pos[key],a_min=-1,a_max=1)

nx.draw_networkx_nodes(subgraph, pos,
                       node_color = 'lightgrey', 
                       alpha = [1 if data.x[idx][0] == 1 else 0.1 for idx in subgraph.nodes], node_size = 10)
ini_infect_nodes = []
for idx in important_nodes:
    if data.x[idx][0] == 1: ini_infect_nodes.append(idx)
nx.draw_networkx_nodes(set(real_important_nodes)-set([node_idx_to_explain]+ini_infect_nodes), pos,
                       node_color = 'lightgrey', 
                       alpha = 1, node_size = 110, edgecolors = 'k', linewidths=2)

nx.draw_networkx_nodes([node_idx_to_explain], pos, 
                       node_color = 'r' if data.x[node_idx_to_explain][0] == 1 else 'lightgrey',
                       node_size = 200, edgecolors = 'r' if data.x[node_idx_to_explain][0] == 1 else 'k', node_shape='*')

nx.draw_networkx_nodes(set(important_nodes)-set([node_idx_to_explain]+real_important_nodes), pos,
                       node_color = 'lightgrey', 
                       alpha = 0.3, node_size = 110, edgecolors = 'k')

nx.draw_networkx_nodes(set(nodes_ig)-set(ini_infect_nodes), pos,
                       node_color = 'red', 
                       alpha = [rel*0.7 + 0.3 for rel in nodes_rel_ig], node_size = 110, edgecolors = 'k')
nx.draw_networkx_nodes(set(nodes_ig).intersection(set(ini_infect_nodes)), pos,
                       node_color = 'red', 
                       alpha = [rel*0.7 + 0.3 for rel in nodes_rel_ig], node_size = 110, edgecolors = 'k', node_shape='s')

labels = {}
for idx in subgraph.nodes:
    labels[idx] = idx if idx in important_nodes+[node_idx_to_explain] else ''
nx.draw_networkx_labels(subgraph, pos, labels=labels, font_size=5, font_color='k')

nx.draw_networkx_edges(subgraph, pos, width=0.5, edgelist=subgraph.edges, edge_color='k', alpha=0.05, arrows=True)

plt.axis('off')
plt.savefig("imgs/infection_node.svg", dpi=600, format='svg',bbox_inches='tight',transparent=True)
# plt.show()