In [None]:
import sys

sys.path.append("/home/carlos/Desktop/projects/diff-gnn")
import argparse

import models
import utils
import train
import explain
import json

import random

import numpy as np
import seaborn as sns
import torch
import networkx as nx
import seaborn as sns
import matplotlib.pyplot as plt

from torch_geometric.utils import to_networkx
from torch_geometric.data import Batch
import torch_geometric.nn as pyg_nn

from torch_geometric.explain import Explainer, GNNExplainer

from skimage.filters import threshold_otsu, threshold_li

In [None]:
import importlib

importlib.reload(utils)
importlib.reload(train)
importlib.reload(models)

In [None]:
# build model
json_dict = json.load(
    open(
        "/home/carlos/Desktop/projects/diff-gnn/checkpoints/tag_0.ckpt.args.json",
        "r",
    )
)
args = argparse.Namespace(**json_dict)
args.model_path = "/home/carlos/Desktop/projects/diff-gnn/checkpoints/tag_0.ckpt"
args.test = True
args.dropout = 0.0

model = train.build_model(args)
device = utils.get_device()
# model = model.eval()

In [None]:
npz = np.load(
    "/home/carlos/Desktop/projects/diff-gnn/datasets/HeLa_10000.t0-t12.chr1.obs_exp_qt.npz"
)
graphs_a_0, idx = utils.constGraphList(npz["t0_q30-chr1_p"], 31, 51, use_prcnt=1)
graphs_b_12 = utils.constGraphList(
    npz["t12_q30-chr1_p"], 31, 51, maxNodesQ=41, idx_list=idx
)

In [None]:
graphs_a_12 = utils.constGraphList(
    npz["t12_q30-chr1_p"], 31, 51, use_prcnt=0.2, idx_list=idx
)
graphs_b_0 = utils.constGraphList(
    npz["t0_q30-chr1_p"], 31, 51, maxNodesQ=41, idx_list=idx
)

In [None]:
def run_model_workflow(model, a, b):
    a.to(utils.get_device())
    b.to(utils.get_device())
    a.x = a.x.type(torch.float32)
    b.x = b.x.type(torch.float32)
    emb_a = model.emb_model(a.x, a.edge_index, a.edge_attr, a.batch)
    emb_b = model.emb_model(b.x, b.edge_index, b.edge_attr, b.batch)
    pred = model(emb_a, emb_b)
    pred = model.predict(pred)
    pred = model.clf_model(pred.unsqueeze(1))
    return pred.argmax(dim=-1), emb_a

In [None]:
zipped = list(zip(graphs_a_0, graphs_b_12, graphs_a_12, graphs_b_0))
random.shuffle(zipped)
graphs_a_0, graphs_b_12, graphs_a_12, graphs_b_0 = zip(*zipped)

with torch.no_grad():
    for a_0, b_12, a_12, b_0 in zip(graphs_a_0, graphs_b_12, graphs_a_12, graphs_b_0):
        if b_12 != None and a_12 != None and b_0 != None:
            pred_0_12, emb_a_0 = run_model_workflow(model, a_0, b_12)
            pred_12_0, emb_a_12 = run_model_workflow(model, a_12, b_0)
            if pred_0_12 == 1 and pred_12_0 == 1:
                print("Found an example!")
                break

In [None]:
target = torch.tensor([1]).to(utils.get_device())

emb_a = emb_a_0
a = a_0
b = b_12

# emb_a = emb_a_12
# a = a_12
# b = b_0

In [None]:
b.to(utils.get_device())

alg = models.explainer(epochs=1000, lr=0.01)
explainer = Explainer(
    model=models.model2explainer(model, b, target_emb=emb_a),
    algorithm=alg,
    explanation_type="phenomenon",
    node_mask_type="attributes",
    edge_mask_type="object",
    model_config=dict(
        mode="multiclass_classification",
        task_level="edge",
        return_type="raw",
    ),
)

explanation = explainer(b.x, b.edge_index, target=target)

In [None]:
explanation.visualize_feature_importance()

In [None]:
edge_mask = explanation.get("edge_mask").detach().cpu().numpy()
sns.histplot(edge_mask.flatten())

In [None]:
losses = alg.losses
plt.plot(losses)

In [None]:
importlib.reload(explain)

In [None]:
margin = 0.25
anchor = 25
ego_radius = 2
maxT = 51
perc = 75  

In [None]:
b = b.to("cpu")
b_sub = explain.apply_threshold(b, edge_mask, perc=perc)
b_sub = utils.relabel_nodes(b_sub, maxT)
b_sub = nx.ego_graph(b_sub, anchor, radius=ego_radius).copy()

a_sub = to_networkx(a, to_undirected=True, node_attrs=["x"], edge_attrs=["edge_attr"])
a_sub = utils.relabel_nodes(a_sub, 51)
a_sub = a_sub.subgraph(b_sub.nodes).copy()

fig, (ax1, ax2) = plt.subplots(2, figsize=(24, 16))

explain.visualize_edges(b_sub, ax=ax1)
explain.visualize_edges(a_sub, ax=ax2)

In [None]:
sub_nodes = explain.node_diff(a_sub, b_sub)
sub_edges = explain.edge_comp(a_sub, b_sub)
sub_nodes

In [None]:
vals, res = explain.viz_attrs(sub_edges[1], margin=margin, anchor=anchor)
for k, v in res.items():
    print(k, v)
sns.histplot(vals)

In [None]:
vals, res = explain.viz_attrs(sub_edges[2], margin=margin, anchor=anchor)
for k, v in res.items():
    print(k, v)
sns.histplot(vals)

In [None]:
# all to all
G_a = to_networkx(a, to_undirected=False, node_attrs=["x"], edge_attrs=["edge_attr"])
G_a = utils.relabel_nodes(G_a, maxT)

G_b = to_networkx(b, to_undirected=False, node_attrs=["x"], edge_attrs=["edge_attr"])
G_b = utils.relabel_nodes(G_b, maxT)

all_nodes = explain.node_diff(G_a, G_b)
all_edges = explain.edge_comp(G_a, G_b)
all_nodes

In [None]:
vals, res = explain.viz_attrs(all_edges[1], margin=margin, anchor=anchor)
for k, v in res.items():
    print(k, v)
sns.histplot(vals)

In [None]:
vals, res = explain.viz_attrs(all_edges[2], margin=margin, anchor=anchor)
for k, v in res.items():
    print(k, v)

sns.histplot(vals)

In [None]:
import data

json_dict = json.load(
    open(
        "/home/carlos/Desktop/projects/diff-gnn/checkpoints/tag_0.ckpt.args.json",
        "r",
    )
)
args = argparse.Namespace(**json_dict)

In [60]:
import importlib
importlib.reload(data)

<module 'data' from '/home/carlos/Desktop/projects/diff-gnn/data.py'>

In [61]:
data_source = data.DataSource(args)

Loading dataset...


100%|██████████| 1000/1000 [00:02<00:00, 494.65it/s]


In [62]:
batch = data_source.gen_batch(4096*4, train=True, par=True)

In [63]:
batch = data_source.gen_batch(4096*4, train=True, par=False)