In [1]:
import torch
import numpy as np
import argparse
from tqdm import tqdm
import random
import logging
import time
import os
import sys
from torch_geometric.data import Data
import networkx as nx
from torch_geometric.utils import to_undirected
import torch.nn.functional as F
from brec.dataset import BRECDataset
from brec.evaluator import evaluate

In [2]:
def mon_op(A, B):
    return A + B + torch.mm(A, B)
def image(X, m):
    
    n = X.size(0)
    device, out_dtype = X.device, torch.float64

    cover = torch.zeros(n, n, n, device=device, dtype=out_dtype)
    if m == 0:
        return cover

    
    base_dec = (X != 0).to(out_dtype)
    wdec = X.clone()  

    for i in range(n):
        dec = base_dec.clone()

        
        cover[i].t()[i] = X.t()[i]   
        dec[i] = 0                   
        dec[:, i] = 0               

        c = 1
        while c < m:
            
            row_active = (cover[i].sum(dim=1) != 0)   
            col_active = (dec.sum(dim=0) != 0)        
            mask = row_active & col_active            
            M = mask.to(out_dtype).unsqueeze(0).expand(n, -1)  

            Md = M * dec                               
            # IMPORTANT: element-wise product with transpose (NOT matmul)
            om = Md - (Md * Md.t())                    

            # Update cover 
            cover[i] = mon_op(om * wdec, cover[i])    

            # Update dec 
            dec = dec - (om + Md.t())
            c += 1

    return torch.log1p(cover)

In [3]:
import math, numbers
import numpy as np
import torch

def _tuple_contains_inf(x):
    if isinstance(x, torch.Tensor):
        return torch.isinf(x).any().item()
    if isinstance(x, np.ndarray):
        return np.isinf(x).any()
    if isinstance(x, (list, tuple)):
        return any(_tuple_contains_inf(v) for v in x)
    if isinstance(x, numbers.Real):
        return math.isinf(float(x))
    return False

def _tuple_contains_nan(x):
    if isinstance(x, torch.Tensor):
        return torch.isnan(x).any().item()
    if isinstance(x, np.ndarray):
        return np.isnan(x).any()
    if isinstance(x, (list, tuple)):
        return any(_tuple_contains_nan(v) for v in x)
    if isinstance(x, numbers.Real):
        return math.isnan(float(x))
    return False

def _tuple_contains_nonfinite(x):
    # True if any NaN or ±inf anywhere
    if isinstance(x, torch.Tensor):
        return (~torch.isfinite(x)).any().item()
    if isinstance(x, np.ndarray):
        return (~np.isfinite(x)).any()
    if isinstance(x, (list, tuple)):
        return any(_tuple_contains_nonfinite(v) for v in x)
    if isinstance(x, numbers.Real):
        return not math.isfinite(float(x))
    return False



In [4]:
"""We utilized the Non-GNNs code from the BREC dataset repository (https://github.com/GraphPKU/BREC/tree/Release/Non-GNNs)
and integrated our code as a function named snn within this framework."""

# Set random seeds for reproducibility
np.random.seed(2022)
random.seed(2022)

# Placeholder functions for methods not used
def func_None():
    raise NotImplementedError(f"Cannot find func {args.method}")


import networkx as nx
import torch

def snn(gr):
    # Ensure the input is a NetworkX graph
    if not isinstance(gr, nx.Graph):
        raise TypeError(f"Expected a NetworkX graph, got {type(gr)}")
    
    # Get the number of nodes and edges
    num_nodes = gr.number_of_nodes()
    num_edges = gr.number_of_edges()
    
    # Create adjacency matrix from the graph
    Ad_mat = torch.tensor(nx.to_numpy_array(gr), dtype=torch.float64)
    # Apply custom processing
    Image = image(Ad_mat,25)  
    su = torch.sum(Image, 0)
    output_snn_beta = mon_op(mon_op(su.t(),su)**(1/4),mon_op(su.t(),su)**(1/4))
    s = torch.var(output_snn_beta)
    det=torch.linalg.det(output_snn_beta)
    det_su=torch.linalg.det(su)
    result = (det, s)

    if _tuple_contains_nan(result):
        print("[WARNING] snn returned NaN in its output tuple:", result)
    if _tuple_contains_inf(result):
        print("[WARNING] snn returned ±inf in its output tuple:", result)
    # or a single combined check:
    if _tuple_contains_nonfinite(result):
        print("[WARNING] snn returned non-finite (NaN/±inf):", result)

    return (det,s)


# Dictionary mapping methods to their respective functions
func_dict = {
    "fwl": func_None,  
    "wl": func_None,   
    "snn": snn,  
}

def wl_method(method, G, k=None, mode=None):
    return func_dict.get(method, func_None)(G)

# Define dataset partitions
part_dict = {
    "Basic": (0, 60),
    "Regular": (60, 160),
    "Extension": (160, 260),
    "CFI": (260, 360),
    "4-Vertex_Condition": (360, 380),
    "Distance_Regular": (380, 400),
    "Reliability": (400, 800),
}

# Handle argument parsing for terminal and interactive environments
if len(sys.argv) > 1 and "--file" in sys.argv:
    parser = argparse.ArgumentParser(description="Test non-GNN methods on BREC.")
    parser.add_argument("--file", type=str, default="brec_nonGNN.npy")
    parser.add_argument("--method", type=str, default="snn")
    parser.add_argument("--graph_type", type=str, default="none")
    args = parser.parse_args()
else:
    # Manual argument setup for interactive environments
    class Args:
        file = r"C:\brec_nonGNN.npy"  # Path to the dataset file
        method = "snn"      
        graph_type = "none"      

    args = Args()

G_TYPE = args.graph_type.strip()
if G_TYPE == "none":
    method_name = args.method
else:
    if G_TYPE in part_dict:
        method_name = f"{args.method}_{G_TYPE}"
    else:
        raise NotImplementedError(f"{G_TYPE} does not exist!")

path = os.path.join("result", method_name)
os.makedirs(path, exist_ok=True)
os.makedirs(os.path.join(path, "part_result"), exist_ok=True)

LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
DATE_FORMAT = "%m/%d/%Y %H:%M:%S %p"
logging.basicConfig(
    filename=os.path.join(path, "logging.log"),
    level=logging.INFO,
    format=LOG_FORMAT,
    datefmt=DATE_FORMAT,
)
logging.info(args)

def count_distinguish_num(graph_tuple_list):
    logging.info(f"{method_name} test starting ---")
    print(f"{method_name} test starting ---")

    cnt = 0
    correct_list = []
    time_cost = 0
    DATA_NUM = (
        400 if G_TYPE == "none" else int(part_dict[G_TYPE][1] - part_dict[G_TYPE][0])
    )

    for part_name, part_range in part_dict.items():
        if not (G_TYPE == "none" or G_TYPE == part_name):
            continue

        logging.info(f"{part_name} part starting ---")

        cnt_part = 0
        correct_list_part = []
        start = time.process_time()

        for id in tqdm(range(part_range[0], part_range[1])):
            graph_tuple = graph_tuple_list[id]
            if not wl_method(
                args.method, graph_tuple[0]
            ) == wl_method(args.method, graph_tuple[1]):
                cnt += 1
                cnt_part += 1
                correct_list.append(id)
                correct_list_part.append(id)
            else:
                logging.info(f"Wrong in {id}")

        end = time.process_time()
        time_cost_part = round(end - start, 2)
        time_cost += time_cost_part

        logging.info(
            f"{part_name} part costs time {time_cost_part}; Correct in {cnt_part} / {part_range[1] - part_range[0]}"
        )
        print(
            f"{part_name} part costs time {time_cost_part}; Correct in {cnt_part} / {part_range[1] - part_range[0]}"
        )
        np.save(os.path.join(path, "part_result", part_name), correct_list_part)

    time_cost = round(time_cost, 2)
    Acc = round(cnt / DATA_NUM, 2)

    logging.info(f"Costs time {time_cost}; Correct in {cnt} / {DATA_NUM}, Acc = {Acc}")
    print(f"Costs time {time_cost}; Correct in {cnt} / {DATA_NUM}, Acc = {Acc}")

    np.save(os.path.join(path, "result"), correct_list)

    return

def main():
    graph_tuple_list = np.load(args.file, allow_pickle=True)
    
    
    print("First item in graph_tuple_list:", graph_tuple_list[0])  
    print("Type of first graph in tuple:", type(graph_tuple_list[0][0]))  
    count_distinguish_num(graph_tuple_list)

    

if __name__ == "__main__":
    main()


First item in graph_tuple_list: [<networkx.classes.graph.Graph object at 0x000002B93C9206D0>
 <networkx.classes.graph.Graph object at 0x000002B93C9207F0>]
Type of first graph in tuple: <class 'networkx.classes.graph.Graph'>
snn test starting ---



00%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:02<00:00, 21.27it/s]

Basic part costs time 2.84; Correct in 60 / 60



00%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:12<00:00,  8.30it/s]

Regular part costs time 60.16; Correct in 100 / 100



00%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.92it/s]

Extension part costs time 8.78; Correct in 100 / 100



00%|████████████████████████████████████████████████████████████████████████████████| 100/100 [01:43<00:00,  1.03s/it]

CFI part costs time 615.3; Correct in 100 / 100



00%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:10<00:00,  1.89it/s]

4-Vertex_Condition part costs time 63.47; Correct in 20 / 20



00%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.21it/s]

Distance_Regular part costs time 37.36; Correct in 20 / 20


100%|████████████████████████████████████████████████████████████████████████████████| 400/400 [02:15<00:00,  2.95it/s]

Reliability part costs time 775.53; Correct in 400 / 400
Costs time 1563.44; Correct in 800 / 400, Acc = 2.0



