In [1]:
import torch
from torch_geometric.data import Data
import networkx as nx
from torch_geometric.utils import to_undirected

In [2]:
"""Code of this function is from the repository of the official code
of the papers Weisfeiler and Lehman Go Cellular: CW Networks (NeurIPS 2021)
and Weisfeiler and Lehman Go Topological: Message Passing Simplicial Networks (ICML 2021)"""
def load_sr_dataset(path):
    """Load the Strongly Regular Graph Dataset from the supplied path."""
    nx_graphs = nx.read_graph6(path)
    graphs = list()
    for nx_graph in nx_graphs:
        n = nx_graph.number_of_nodes()
        edge_index = to_undirected(torch.tensor(list(nx_graph.edges()), dtype=torch.long).transpose(1,0))
        graphs.append((edge_index, n))
    return graphs

In [3]:
"""Monoidal operation $\circ$"""
def mon_op(A,B):
    return A+B+torch.mm(A,B)

In [4]:
"""comuting image for all nodes"""
def image(X,n,m):
    cover=torch.zeros(n,n,n).to(torch.float64)
    for i in range(n):
        
        dec=(X!=0).float()
        wdec=torch.clone(X)
        cover[i].t()[i]=X.t()[i]
        dec[i]=0
        dec.t()[i]=0
        #wdec[i]=0
        #wdec.t()[i]=0
        M=torch.zeros(n,n).to(torch.float64)
       # N=torch.ones(n,n).to(torch.float64)
        for k in range(n):
            if cover[i][k].sum()!=0:
                M.t()[k]=1
                #N.t()[k]=0
        c=0
            #M.sum()!=0
        while c<m:
            om=(M*dec)-(((M*dec)*((M*dec).t())))
            cover[i]=mon_op(om*wdec,cover[i])
            dec=dec-(M*dec+(M*dec).t()-(M*dec)*(M*dec).t())
            #wdec=wdec-(M*wdec)
            M=torch.zeros(n,n).to(torch.float64)
           # N=torch.ones(n,n).to(torch.float64)
            for k_ in range(n):
                if cover[i][k_].sum()!=0 and dec.t()[k_].sum()!=0:
                    M.t()[k_]=1
                   # N.t()[k_]=0
            c+=1
    return torch.log1p(cover)

In [5]:
"""Loading the collection of strongly regular graphs

https://users.cecs.anu.edu.au/~bdm/data/graphs.html

sr251256.g6 (15 graphs)
sr261034.g6 (10 graphs)
sr281264.g6 (4 graphs)
sr291467.g6 (41 graphs)
sr351668.g6 (3854 graphs)
sr351899.g6 (227 graphs)
sr361446.g6 (180 graphs)
sr361566.g6 (32548 graphs)
sr371889some.g6 (6760 graphs)
sr401224.g6 (28 graphs)
sr65321516some.g6 (32 graphs)
"""
dataset=load_sr_dataset("sr371889some.g6")

In [6]:
num_of_graphs=len(dataset)
num_nodes=dataset[0][1]
num_edges=len(dataset[0][0][0])

In [7]:
Distinguished_graphs=[]
for graph in range(num_of_graphs):
    gr=dataset[graph]
    
    Ad_mat=torch.zeros(num_nodes,num_nodes)
    
    for edge in range(num_edges):
        Ad_mat[gr[0][0][edge]][gr[0][1][edge]]=1
    
    Image=image(Ad_mat,num_nodes,5)
    su=torch.sum(Image,0)

    output_snn_beta=0.01*mon_op(mon_op(0.01*su.t(),0.01*su),su.t())
    #output_snn_beta=mon_op(output_snn_beta,output_snn_beta)
    
    t=torch.mean(output_snn_beta)
    s=torch.var(output_snn_beta)
    s1=torch.var(torch.sum(torch.eye(num_nodes)*output_snn_beta,0))
    t1=torch.mean(torch.sum(torch.eye(num_nodes)*output_snn_beta,0))
    det_su=torch.linalg.det(output_snn_beta)
    mm=torch.min(output_snn_beta)
    if (t,s,t1,s1) not in Distinguished_graphs:
        Distinguished_graphs.append((mm,det_su,t,s,t1,s1))

In [8]:
"""Number of graphs that model can not distinguish"""
print(num_of_graphs-len(Distinguished_graphs))


0


In [None]:
Distinguished_graphs

In [10]:
from math import isfinite, isinf

def all_finite(tuples_list):
    # True if no NaN/±inf anywhere
    return all(isfinite(x) for tup in tuples_list for x in tup)

def none_infinite(tuples_list):
    # True if no ±inf (but allows NaN). Use this if you only care about infinity.
    return not any(isinf(x) for tup in tuples_list for x in tup)

In [11]:
all_finite(Distinguished_graphs)

True