In [None]:
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = ''
from torch import nn
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from syn_dataset import SynGraphDataset
from spmotif_dataset import *
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv, global_mean_pool, global_max_pool, global_add_pool
from utils import *
from sklearn.model_selection import train_test_split
import shutil
import glob
from torch.optim.lr_scheduler import ReduceLROnPlateau
import pandas as pd
import argparse
import pickle
import json
import io
from model import GIN
from train_baseline import test_epoch


In [None]:
dataset_name = 'Ba2MotifsNoisy'
seed = 5
def get_best_baseline_path(dataset_name):
    l = glob.glob(f'results/{dataset_name}/*/results.json')
    fl = [json.load(open(f)) for f in l]
    df = pd.DataFrame(fl)
    if df.shape[0] == 0: return None
    df['fname'] = l
    df = df.sort_values(by=['val_acc_mean', 'val_acc_std', 'test_acc_std'], ascending=[True,False,False])
    df = df[df.fname.str.contains('nogumbel=True')]
    fname = df.iloc[-1]['fname']
    fname = fname.replace('/results.json', '')
    return fname


results_path = os.path.join(get_best_baseline_path(dataset_name), str(seed))

In [None]:
import pickle
data = pickle.load(open(os.path.join(results_path, 'data.pkl'), 'rb'))

In [None]:
args = json.load(open(os.path.join(results_path, 'args.json'), 'r'))
args

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('cpu')
print(device)

In [None]:
dataset = get_dataset(dataset_name)
num_classes = dataset.num_classes
num_features = dataset.num_features
num_layers = args['num_layers']
hidden_dim = args['hidden_dim']

In [None]:
model = GIN(num_classes=num_classes, num_features=num_features, num_layers=num_layers, hidden_dim=hidden_dim, nogumbel=True, dropout=0.1)

In [None]:
model.load_state_dict(torch.load(os.path.join(results_path, 'best.pt'), map_location=device))
model = model.to(device)

In [None]:
train_indices = data['train_indices']
val_indices = data['val_indices']
test_indices = data['test_indices']

train_dataset = dataset[train_indices]
val_dataset = dataset[val_indices]
test_dataset = dataset[test_indices]
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
NUM_NODE_FEATURES = dataset.num_node_features
NUM_EDGE_FEATURES = 1#dataset.num_edge_features
TU_DATASETS = ["MUTAG", "Mutagenicity", "NCI1"]

In [None]:
"""Helper functions"""
from tqdm import tqdm
import torch_geometric as pyg
import torch
import networkx as nx
from sympy.logic.boolalg import Or, And, Not
from sympy.parsing.sympy_parser import parse_expr

from pygcanl import canonical

def preprocess_dataset(dataset):
    dataset_ = []
    for graph in dataset:
        try:
            edge_attr=graph.edge_attr.argmax(dim=1)
        except:
            edge_attr = torch.ones(graph.edge_index.shape[1])
        data = pyg.data.Data(
            x=graph.x.argmax(dim=1),
            id=torch.arange(graph.num_nodes),
            edge_index=graph.edge_index,
            y=graph.y,
            edge_attr=edge_attr,
        )
        dataset_.append(data)
    return dataset_

def graph_from_dfs_code(dfs_code):
    G = nx.DiGraph()
    dfs_code = dfs_code.split(" ")
    expec_root = True
    id = 0
    par = {}
    curr = -1
    par[-1] = -2
    i = 0
    while(i < len(dfs_code)):
        ch = dfs_code[i]
        if(expec_root):
            G.add_node(id, attr=int(ch))
            curr = id
            expec_root = False
            par[id] = id
            id += 1
            i += 1
        else:
            if(ch == '$'):
                curr = par[curr]
                i += 1
                continue
            ch_nxt = dfs_code[i+1]
            par[id] = curr
            G.add_node(id, attr=int(ch_nxt))
            G.add_edge(curr, id, attr=int(ch))
            curr = id
            id += 1
            i += 2
    return G.reverse()

def int_to_onehot(attr: int, num_features: int):
    one_hot = [0 for __ in range(num_features)]
    one_hot[attr] = 1.
    return one_hot

def nx_to_pyg(
        ctree: nx.digraph,
        num_node_features: int,
        num_edge_features: int
) -> pyg.data.Data:
    for node in ctree.nodes:
        ctree.nodes[node]["attr"] = int_to_onehot(
            ctree.nodes[node]["attr"], num_node_features
        )
    # for edges in ctree.edges:
    #     ctree.edges[edges]["attr"] = int_to_onehot(
    #         ctree.edges[edges]["attr"], num_edge_features
    #     )
    if len(ctree.edges) == 0:
        graph_pyg = pyg.utils.from_networkx(
            ctree, group_node_attrs=["attr"], group_edge_attrs=None
        )
    else:
        graph_pyg = pyg.utils.from_networkx(
            ctree, group_node_attrs=["attr"], group_edge_attrs=["attr"]
        )

    return graph_pyg

def simplify_expression(str_exp):
    expression = parse_expr(str_exp)
    simplified_expr = expression.simplify()
    return simplified_expr

def getVariables(expr):
    return expr.atoms()

def dfs(ctree, ctree_id, node_mapping=None):
    """
    ! Incorrect
    The ctree_id code is generated by writing down the node id as the canonical label of ctree
    is generated. Hence, the node order between the two is preserved. Therefore, we can map
    ctree's node attriburtes to ctree_id's node attributes.
    """
    G = nx.Graph()
    for i in range(len(ctree_id.nodes)):
        if node_mapping is not None:
            attr = node_mapping[ctree.nodes[i]['attr']]
        else:
            attr = ctree.nodes[i]['attr']
        G.add_node(ctree_id.nodes[i]['attr'], attr=attr)
    for e in ctree_id.edges:
        src, dest = e
        src = ctree_id.nodes[src]['attr']
        dest = ctree_id.nodes[dest]['attr']
        G.add_edge(src, dest, attr=ctree_id.edges[e]['attr'])
    return G


In [None]:
# * ----- Indicator vectors of training graphs
# start_time = process_time()

processed_dataset = preprocess_dataset(
    [dataset[i] for i in train_indices])
list_of_dfs_id_codes = canonical(processed_dataset, 3)

dict_dfs_id_codes = {}
all_ctree_codes = []
graph_cnt_lst = []

str1 = 'dfs_code \ id_code'

for l in list_of_dfs_id_codes:
    temp = []
    d = {}
    for s in l:
        key = s.split('/')[0]
        val = s.split('/')[1]
        temp.append(key)
        dict_dfs_id_codes[key] = val
        if key in d:
            d[key] += 1
        else:
            d[key] = 1
    graph_cnt_lst.append(d)
    all_ctree_codes.append(temp)

In [None]:
# Calculated in train, used directly in val, test.
unique_ctree_codes = list(dict_dfs_id_codes.keys())

# with open(f"{FOLDER}/dict_dfs_id_codes.pkl", "wb") as file:
#     dump(dict_dfs_id_codes, file)
# with open(f"{FOLDER}/unique_ctree_codes.pkl", "wb") as file:
#     dump(unique_ctree_codes, file)

cnt_ind_vec = []
for g_dict in graph_cnt_lst:
    temp = []
    for ct in unique_ctree_codes:
        if ct in g_dict:
            temp.append(g_dict[ct])
        else:
            temp.append(0)
    cnt_ind_vec.append(temp)
cnt_ind_vec = np.array(cnt_ind_vec)

In [None]:
# * ----- Indicator vectors of validation graphs
processed_dataset_val = preprocess_dataset(
    [dataset[i] for i in val_indices])
list_of_dfs_id_codes_val = canonical(processed_dataset_val, 3)

dict_dfs_id_codes_val = {}
all_ctree_codes_val = []
graph_cnt_lst_val = []

for l in list_of_dfs_id_codes_val:
    temp = []
    d = {}
    for s in l:
        key = s.split('/')[0]
        val = s.split('/')[1]
        temp.append(key)
        dict_dfs_id_codes_val[key] = val
        if key in d:
            d[key] += 1
        else:
            d[key] = 1
    graph_cnt_lst_val.append(d)
    all_ctree_codes_val.append(temp)

cnt_ind_vec_val = []
for g_dict in graph_cnt_lst_val:
    temp = []
    for ct in unique_ctree_codes:
        if ct in g_dict:
            temp.append(g_dict[ct])
        else:
            temp.append(0)
    cnt_ind_vec_val.append(temp)
cnt_ind_vec_val = np.array(cnt_ind_vec_val)


In [None]:
# * ----- Indicator vectors of test graphs
processed_dataset_test = preprocess_dataset(
    [dataset[i] for i in test_indices])
list_of_dfs_id_codes_test = canonical(processed_dataset_test, 3)

dict_dfs_id_codes_test = {}
all_ctree_codes_test = []
graph_cnt_lst_test = []

for l in list_of_dfs_id_codes_test:
    temp = []
    d = {}
    for s in l:
        key = s.split('/')[0]
        val = s.split('/')[1]
        temp.append(key)
        dict_dfs_id_codes_test[key] = val
        if key in d:
            d[key] += 1
        else:
            d[key] = 1
    graph_cnt_lst_test.append(d)
    all_ctree_codes_test.append(temp)

cnt_ind_vec_test = []
for g_dict in graph_cnt_lst_test:
    temp = []
    for ct in unique_ctree_codes:
        if ct in g_dict:
            temp.append(g_dict[ct])
        else:
            temp.append(0)
    cnt_ind_vec_test.append(temp)
cnt_ind_vec_test = np.array(cnt_ind_vec_test)

In [None]:
NUM_NODE_FEATURES

In [None]:
# * ----- Ctree Embeddings
model.eval()

ctree_embeddings = []
with torch.no_grad():
    for ct in tqdm(unique_ctree_codes, desc="Ctree embeddings", colour="green"):
        ctree_nx = graph_from_dfs_code(ct)
        ctree_pyg = nx_to_pyg(
            ctree_nx,
            NUM_NODE_FEATURES,
            NUM_EDGE_FEATURES,
        )
        xs, ys = model.forward_e(
            torch.ones_like(ctree_pyg.x)*0.1,
            ctree_pyg.edge_index,
            batch=None
        )

        ctree_embeddings.append(torch.hstack(ys[:-1])[0].tolist())

ctree_embeddings = np.array(ctree_embeddings)
# end_time = process_time()
# print(f"[TIME] gen_ctree: {end_time - start_time} s.ms")


In [None]:
import multiprocess as mp
from argparse import ArgumentParser
from pickle import load, dump

import numpy as np
import shap
import torch


In [None]:
def get_graph_embedding(row):
    embd = np.zeros((ctree_embeddings.shape[1]))
    freq = 0
    max_embd = embd
    for i, cnt in enumerate(row):
        freq += cnt
        embd += cnt * ctree_embeddings[i]
        max_embd = np.maximum(max_embd, cnt * ctree_embeddings[i])
    final_embd = [embd/max(1,freq), max_embd, embd]
    
    return np.hstack(final_embd)


def f(ind_vectors):
    embedings = np.apply_along_axis(
        get_graph_embedding, axis=1, arr=ind_vectors)
    embedings = torch.Tensor(embedings)
    with torch.no_grad():
        model.eval()
        x = model.fc1(embedings)
        x = torch.nn.functional.relu(x)
        out = torch.sigmoid(model.fc2(x))
        # Probability of class 1
        out = out[:, 1]
    return out.numpy()


def calculate_shap(shap_arr, chunk_num):
    z = np.zeros((1, shap_arr[0].shape[0]))
    print('initializing shap')
    explainer = shap.KernelExplainer(f, z)
    print('running shap')
    shap_values = explainer.shap_values(X=shap_arr, gc_collect=True, silent=False,nsamples=1000)
    print('finished shap')
    print(f"Chunk {chunk_num} done!")
    return shap_values


procs = 5
chunk_size = len(cnt_ind_vec) // procs
remainder  = len(cnt_ind_vec) %  procs

# Doesn't contain zero and len(cnt_ind_vec)
chunks = [chunk_size for __ in range(procs)]

# Distribute the remainder
for j in range(remainder):
    chunks[j] += 1

# Take the cumulative sum to get the indices.
indices = [0] + np.cumsum(chunks).tolist()

# Divide the indicator vectors into chunks.
chunked_ind_vectors = []
for i in range(len(indices) - 1):
    idx_start = indices[i]
    idx_end = indices[i + 1]
    chunked_ind_vectors.append(cnt_ind_vec[idx_start: idx_end])

print("Chunk size:", chunks)
print("#Chunks:", len(chunked_ind_vectors))
print()

with mp.Pool(procs) as p:
    results = p.starmap(calculate_shap, zip(chunked_ind_vectors, range(len(chunked_ind_vectors))))

shap_values = np.concatenate(results, axis=0)



In [None]:
shap_imp = np.abs(shap_values).mean(axis=0)
indices = np.argsort(shap_imp)

In [None]:
def predict(loader):
    model.eval()
    predictions = []
    probabilities = []
    for data in loader:
        out = model(
            x=data.x,
            edge_index=data.edge_index,
            batch=data.batch
        )
        pred = out.argmax(dim=1)
        prob = out
        predictions += pred.tolist()
        probabilities += prob.tolist()
    return predictions, probabilities

In [None]:
train_pred, train_prob = predict(train_loader)
val_pred, val_prob = predict(val_loader)
test_pred, test_prob = predict(test_loader)

In [None]:
k = 200
c = 1.0

In [None]:
pysr_weights = [max(prob) for prob in train_prob]

x_train = cnt_ind_vec[:, indices[- k:]]
x_val = cnt_ind_vec_val[:, indices[- k:]]
x_test = cnt_ind_vec_test[:, indices[- k:]]

In [None]:
x_train_bin = []
x_val_bin = []
x_test_bin = []
for set_ in ["train", "val", "test"]:
    for lst in eval(f"x_{set_}"):
        temp = []
        for i in lst:
            if i > 0:
                temp.append(1)
            else:
                temp.append(0)
        eval(f"x_{set_}_bin.append(temp)")


In [None]:
from pysr import PySRRegressor # Import this first!


# * ----- Symbolic Regression
# start_time = process_time()

pysrmodel = PySRRegressor(
    unary_operators = ["Not(x) = (x <= zero(x)) * one(x)"],
    binary_operators = [
        "And(x, y) = ((x > zero(x)) & (y > zero(y))) * one(x)",
        "Or(x, y)  = ((x > zero(x)) | (y > zero(y))) * one(x)",
        "Xor(x, y) = (((x > 0) & (y <= 0)) | ((x <= 0) & (y > 0))) * 1f0",
    ],
    extra_sympy_mappings = {
        "Not": lambda x: sympy.Piecewise((1.0, (x <= 0)), (0.0, True)),
        "And": lambda x, y: sympy.Piecewise((1.0, (x > 0) & (y > 0)), (0.0, True)),
        "Or":  lambda x, y: sympy.Piecewise((1.0, (x > 0) | (y > 0)), (0.0, True)),
        "Xor": lambda x, y: sympy.Piecewise((1.0, (x > 0) ^ (y > 0)), (0.0, True)),
    },

    elementwise_loss = "loss(prediction, target) = sum(prediction != target)",
    model_selection="accuracy",

    complexity_of_variables=c,
    complexity_of_operators={'Not': c, 'And': c, 'Or': c, 'Xor': c},

    select_k_features = min(k, 10),
    weights = pysr_weights,

    batch_size = 32,

    # Paperwork
    temp_equation_file = True,
    delete_tempfiles = True,

    # Determinism
    procs=0,
    deterministic=True,
    multithreading=False,
    random_state=0,
    warm_start=False,
)

In [None]:
import sympy
def cal_pysr_acc(X, Y, index=None):
    Y = np.array(Y)
    Y_pred = pysrmodel.predict(X, index=index)
    assert Y.shape == Y_pred.shape , "Shape mismatch!"
    return (Y_pred == Y).sum() / len(Y)


pysrmodel.fit(x_train_bin, train_pred)
print(pysrmodel)

selected_ctrees = pysrmodel.selection_mask_
selected_ctrees = np.where(pysrmodel.selection_mask_)[0]  # Convert boolean mask to integer indices

df_equations = pysrmodel.equations.drop(["sympy_format", "lambda_format"], axis=1)
# Add a column for accuracy.
df_equations["acc"] = 1 - df_equations["loss"]
# Re-arrange columns to have "acc" as the second column.
cols = df_equations.columns.tolist()
cols.insert(1, cols.pop(-1))
df_equations = df_equations[cols]
# Round values.
for col in ["acc", "loss", "score"]:
    df_equations[col] = df_equations[col].round(4)


In [None]:
pysrmodel.equations_.equation

In [None]:
# Find the equation that performs the best on the validation set
best_val_acc = 0
print("\nValidation accuracies:")
for j in range(pysrmodel.equations_.shape[0]):
    if pysrmodel.equations_.equation.iloc[j] == '1.0':continue
    # PySR sometimes fails to evaluate certain formulae
    # it usually happens when C is set to a small value.
    # We've been unable to identify when and why it happens
    try:
        __ = pysrmodel.predict(x_train_bin, index=j)
        __ = pysrmodel.predict(x_test_bin, index=j)
        pysr_val_pred = pysrmodel.predict(x_val_bin, index=j)
    except ValueError:
        print(f"{j}: failed")
        continue
    val_acc = (pysr_val_pred == val_pred).sum() / len(val_pred)
    print(f"{j}: {val_acc}")
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_index = j
print("Best equation index:", best_index)

In [None]:
# * ----- Metrics
pysr_train_pred = pysrmodel.predict(x_train_bin, index=best_index)
pysr_test_pred = pysrmodel.predict(x_test_bin, index=best_index)
pysr_train_pred = torch.LongTensor(pysr_train_pred)
pysr_test_pred = torch.LongTensor(pysr_test_pred)

In [None]:
best_index = 2

In [None]:
# Best based on val set
equation = pysrmodel.get_best(index=best_index).equation
print()
print("=" * 50)
print("Equation:", equation)
print("C =", c)

train_acc = round(cal_pysr_acc(x_train_bin, train_pred, index=best_index), 3)
test_acc  = round(cal_pysr_acc(x_test_bin, test_pred, index=best_index), 3)

equation = simplify_expression(equation)
print("Simplified equation:", equation)
print("Train accuracy:", train_acc)
print("Test accuracy:", test_acc)

In [None]:
# * ----- Save stuff to disk
# Save equations
# df_equations.to_csv(f"{FOLDER}/equations_sample{args.sample}.csv", index=True)
# del df_equations

# Save predictions
# torch.save(torch.LongTensor(train_pred), f"{FOLDER}/gnn_train_pred.pt")
# torch.save(torch.LongTensor(test_pred), f"{FOLDER}/gnn_test_pred.pt")
# # torch.save(pysr_train_pred, f"{FOLDER}/pysr_train_pred_sample{args.sample}.pt")
# torch.save(pysr_test_pred, f"{FOLDER}/pysr_test_pred_sample{args.sample}.pt")
# del train_pred, test_pred, pysr_train_pred, pysr_test_pred

# # Save pysrmodel
# with open(f"{FOLDER}/pysrmodel_sample{args.sample}.pkl", "wb") as file:
#     dump(pysrmodel, file)
# del pysrmodel

# end_time = process_time()
# print(f"[TIME] gen_formulae: {end_time - start_time} s.ms")

In [None]:



import matplotlib.pyplot as plt


# * ----- Visualize the computation trees present in the forumulae.
variables_eq = getVariables(equation)

node_mapping = None
if dataset_name == "MUTAG":
    node_mapping = {0: "C", 1: "N", 2: "O", 3: "F", 4: "I", 5: "Cl", 6: "Br"}
elif dataset_name == "Mutagenicity":
    node_mapping = {0: 'C', 1: 'O', 2: 'Cl', 3: 'H', 4: 'N', 5: 'F', 6: 'Br',
                    7: 'S', 8: 'P', 9: 'I', 10: 'Na', 11: 'K', 12: 'Li', 13: 'Ca'}

FIGSIZE = (8, 6)
NODESIZE = ...
EDGE_WIDTH = 1.75
NODE_COLOR = "#FD5D02"
colors = ['green', 'black', 'blue', 'red'] # aromatic, single, double, triple

print(selected_ctrees)
# for v in variables_eq:
#     if str(v)[0] != "x":
#         continue
#     v = int(str(v)[1:])
for v in selected_ctrees:
    print(v)
    code_dfs = unique_ctree_codes[indices[- k:][v]]
    code_id = dict_dfs_id_codes[code_dfs]

    # * ----- Ctree using node attributes
    if dataset_name == 'BAMultiShapesDataset':
        ctree = graph_from_dfs_code(code_id)
    else:
        ctree = graph_from_dfs_code(code_dfs)
    ctree = ctree.reverse()

    edge_colors = None
    if dataset_name in ["MUTAG", "Mutagenicity"]:
        edge_colors = [colors[ctree.edges[edge]['attr']] for edge in ctree.edges()]

    labeldict = {}
    for i in range(len(ctree.nodes)):
        if dataset_name in ["MUTAG", "Mutagenicity"]:
            labeldict[i] = node_mapping[ctree.nodes[i]['attr']]
        elif dataset_name == "NCI1":
            labeldict[i] = ctree.nodes[i]['attr']

    plt.figure(figsize=FIGSIZE)
    plt.title(v)
    if dataset_name == "NCI1":
        nx.draw_planar(
            ctree,
            labels=labeldict,
            with_labels=True,
            node_color=NODE_COLOR,
            width=EDGE_WIDTH
        )
    elif dataset_name in TU_DATASETS:
        nx.draw_planar(
            ctree,
            labels=labeldict,
            with_labels=True,
            node_color=NODE_COLOR,
            width=EDGE_WIDTH,
            edge_color=edge_colors,
        )
    else:
        nx.draw_planar(ctree, node_color=NODE_COLOR, width=EDGE_WIDTH)

    # plt.savefig(f"{PLOT_FOLDER}/{v}_ctree.png")
    print(v)
    plt.show()


    # * ----- Ctree using node ids
    ctree_id = graph_from_dfs_code(code_id)
    ctree_id = ctree_id.reverse()

    labeldict = None
    if dataset_name in TU_DATASETS:
        labeldict = {}
        for i in range(len(ctree_id.nodes)):
            labeldict[i] = ctree_id.nodes[i]['attr']

    plt.figure(figsize=FIGSIZE)
    plt.title(v)
    if dataset_name == "NCI1":
        nx.draw_planar(
            ctree_id,
            labels=labeldict,
            with_labels=True,
            node_color=NODE_COLOR,
            width=EDGE_WIDTH,
        )
    elif dataset_name in TU_DATASETS:
        nx.draw_planar(
            ctree_id,
            labels=labeldict,
            with_labels=True,
            node_color=NODE_COLOR,
            width=EDGE_WIDTH,
            edge_color=edge_colors,
        )
    else:
        nx.draw_planar(ctree_id, node_color=NODE_COLOR, width=EDGE_WIDTH)

    # plt.savefig(f"{PLOT_FOLDER}/{v}_ctree_id.png")
    print(v)
    plt.show()


    # * ----- Ctree to subgraph
    G = dfs(ctree=ctree, ctree_id=ctree_id, node_mapping=node_mapping)
    edge_colors = [colors[G.edges[edge]['attr']] for edge in G.edges()]

    labeldict = None
    if dataset_name in TU_DATASETS:
        labeldict = {}
        for i in G.nodes:
            labeldict[i] = G.nodes[i]['attr']

    plt.figure(figsize=FIGSIZE)
    plt.title(v)
    if dataset_name == "NCI1":
        nx.draw_kamada_kawai(
            G,
            labels=labeldict,
            with_labels=True,
            node_color=NODE_COLOR,
            width=EDGE_WIDTH,
        )
    elif dataset_name in TU_DATASETS:
        nx.draw_kamada_kawai(
            G,
            labels=labeldict,
            with_labels=True,
            node_color=NODE_COLOR,
            width=EDGE_WIDTH,
            edge_color=edge_colors,
        )
    else:
        nx.draw_kamada_kawai(G, node_color=NODE_COLOR, width=EDGE_WIDTH)

    # plt.savefig(f"{PLOT_FOLDER}/{v}_structure.png")
    print(v)
    plt.show()
