In [None]:
import torch
import numpy as np
import argparse
from tqdm import tqdm
import random
import logging
import time
import os
import sys
from torch_geometric.data import Data
import networkx as nx
from torch_geometric.utils import to_undirected
import torch.nn.functional as F
from brec.dataset import BRECDataset
from brec.evaluator import evaluate

In [None]:
def mon_op(A, B):
    return A + B + torch.mm(A, B)

def image(X,n,m):
    cover=torch.zeros(n,n,n).to(torch.float64)
    for i in range(n):
        
        dec=torch.clone(X)
        cover[i].t()[i]=X.t()[i]
        dec[i]=0
        dec.t()[i]=0
        M=torch.zeros(n,n).to(torch.float64)
        N=torch.ones(n,n).to(torch.float64)
        for k in range(n):
            if cover[i][k].sum()!=0:
                M.t()[k]=1
                N.t()[k]=0
        c=0
            #M.sum()!=0  c<m 
        while M.sum()!=0:
            cover[i]=mon_op((M*dec)-(((M*dec)*((M*dec).t()))),cover[i])
            dec=dec-(M*dec)
            M=torch.zeros(n,n).to(torch.float64)
            N=torch.ones(n,n).to(torch.float64)
            for k_ in range(n):
                if cover[i][k_].sum()!=0 and dec.t()[k_].sum()!=0:
                    M.t()[k_]=1
                    N.t()[k_]=0
            c+=1
            
    return torch.log1p(cover)

In [None]:
"""We utilized the Non-GNNs code from the BREC dataset repository (https://github.com/GraphPKU/BREC/tree/Release/Non-GNNs)
and integrated our code as a function named snn within this framework."""

# Set random seeds for reproducibility
np.random.seed(2022)
random.seed(2022)

# Placeholder functions for methods not used
def func_None():
    raise NotImplementedError(f"Cannot find func {args.method}")

# Your custom model function
import networkx as nx
import torch

def snn(gr):
    # Ensure the input is a NetworkX graph
    if not isinstance(gr, nx.Graph):
        raise TypeError(f"Expected a NetworkX graph, got {type(gr)}")
    
    # Get the number of nodes and edges
    num_nodes = gr.number_of_nodes()
    num_edges = gr.number_of_edges()
    
    # Create adjacency matrix from the graph
    Ad_mat = torch.tensor(nx.to_numpy_array(gr), dtype=torch.float64)
    # Apply custom processing
    Image = image(Ad_mat, num_nodes,20)  # Ensure `image` function works as expected
    su = torch.sum(Image, 0)
    output_snn_beta = mon_op(mon_op(su.t(),su)**(1/2),mon_op(su.t(),su)**(1/2))
    s = torch.var(output_snn_beta)
    det=torch.linalg.det(output_snn_beta)
    det_su=torch.linalg.det(su)
    return (det,s)


# Dictionary mapping methods to their respective functions
func_dict = {
    "fwl": func_None,  # Replace with actual implementation if needed
    "wl": func_None,   # Replace with actual implementation if needed
    "snn": snn,  # Your model is added here
}

def wl_method(method, G, k=None, mode=None):
    return func_dict.get(method, func_None)(G)

# Define dataset partitions
part_dict = {
    "Basic": (0, 60),
    "Regular": (60, 160),
    "Extension": (160, 260),
    "CFI": (260, 360),
    "4-Vertex_Condition": (360, 380),
    "Distance_Regular": (380, 400),
    "Reliability": (400, 800),
}

# Handle argument parsing for terminal and interactive environments
if len(sys.argv) > 1 and "--file" in sys.argv:
    parser = argparse.ArgumentParser(description="Test non-GNN methods on BREC.")
    parser.add_argument("--file", type=str, default="brec_nonGNN.npy")
    parser.add_argument("--method", type=str, default="snn")
    parser.add_argument("--graph_type", type=str, default="none")
    args = parser.parse_args()
else:
    # Manual argument setup for interactive environments
    class Args:
        file = r"C:\Users\shira\Downloads\brec_nonGNN.npy"  # Path to the dataset file
        method = "snn"      # Your method
        graph_type = "none"      # Default graph type

    args = Args()

G_TYPE = args.graph_type.strip()
if G_TYPE == "none":
    method_name = args.method
else:
    if G_TYPE in part_dict:
        method_name = f"{args.method}_{G_TYPE}"
    else:
        raise NotImplementedError(f"{G_TYPE} does not exist!")

path = os.path.join("result", method_name)
os.makedirs(path, exist_ok=True)
os.makedirs(os.path.join(path, "part_result"), exist_ok=True)

LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
DATE_FORMAT = "%m/%d/%Y %H:%M:%S %p"
logging.basicConfig(
    filename=os.path.join(path, "logging.log"),
    level=logging.INFO,
    format=LOG_FORMAT,
    datefmt=DATE_FORMAT,
)
logging.info(args)

def count_distinguish_num(graph_tuple_list):
    logging.info(f"{method_name} test starting ---")
    print(f"{method_name} test starting ---")

    cnt = 0
    correct_list = []
    time_cost = 0
    DATA_NUM = (
        400 if G_TYPE == "none" else int(part_dict[G_TYPE][1] - part_dict[G_TYPE][0])
    )

    for part_name, part_range in part_dict.items():
        if not (G_TYPE == "none" or G_TYPE == part_name):
            continue

        logging.info(f"{part_name} part starting ---")

        cnt_part = 0
        correct_list_part = []
        start = time.process_time()

        for id in tqdm(range(part_range[0], part_range[1])):
            graph_tuple = graph_tuple_list[id]
            if not wl_method(
                args.method, graph_tuple[0]
            ) == wl_method(args.method, graph_tuple[1]):
                cnt += 1
                cnt_part += 1
                correct_list.append(id)
                correct_list_part.append(id)
            else:
                logging.info(f"Wrong in {id}")

        end = time.process_time()
        time_cost_part = round(end - start, 2)
        time_cost += time_cost_part

        logging.info(
            f"{part_name} part costs time {time_cost_part}; Correct in {cnt_part} / {part_range[1] - part_range[0]}"
        )
        print(
            f"{part_name} part costs time {time_cost_part}; Correct in {cnt_part} / {part_range[1] - part_range[0]}"
        )
        np.save(os.path.join(path, "part_result", part_name), correct_list_part)

    time_cost = round(time_cost, 2)
    Acc = round(cnt / DATA_NUM, 2)

    logging.info(f"Costs time {time_cost}; Correct in {cnt} / {DATA_NUM}, Acc = {Acc}")
    print(f"Costs time {time_cost}; Correct in {cnt} / {DATA_NUM}, Acc = {Acc}")

    np.save(os.path.join(path, "result"), correct_list)

    return

def main():
    graph_tuple_list = np.load(args.file, allow_pickle=True)
    
    
    print("First item in graph_tuple_list:", graph_tuple_list[0])  # Print the first item for inspection
    print("Type of first graph in tuple:", type(graph_tuple_list[0][0]))  # Check the type of the graph
    count_distinguish_num(graph_tuple_list)

    

if __name__ == "__main__":
    main()
