In [None]:
import torch
from torch_geometric.data import Data
import networkx as nx
from torch_geometric.utils import to_undirected

In [None]:
"""Code of this function is from the repository of the official code
of the papers Weisfeiler and Lehman Go Cellular: CW Networks (NeurIPS 2021)
and Weisfeiler and Lehman Go Topological: Message Passing Simplicial Networks (ICML 2021)"""
def load_sr_dataset(path):
    """Load the Strongly Regular Graph Dataset from the supplied path."""
    nx_graphs = nx.read_graph6(path)
    graphs = list()
    for nx_graph in nx_graphs:
        n = nx_graph.number_of_nodes()
        edge_index = to_undirected(torch.tensor(list(nx_graph.edges()), dtype=torch.long).transpose(1,0))
        graphs.append((edge_index, n))
    return graphs

In [None]:
"""Monoidal operation $\circ$"""
def mon_op(A,B):
    return A+B+torch.mm(A,B)

In [None]:
"""Computing Image for all nodes"""
def image(X,n,m):
    cover=torch.zeros(n,n,n)
    for i in range(n):
        
        dec=torch.clone(X)
        cover[i].t()[i]=X.t()[i]
        dec[i]=0
        dec.t()[i]=0
        M=torch.zeros(n,n)
        N=torch.ones(n,n)
        for k in range(n):
            if cover[i][k].sum()!=0:
                M.t()[k]=1
                N.t()[k]=0
        c=0
            #M.sum()!=0  c<m 
        while M.sum()!=0:
            cover[i]=mon_op((M*dec)-(((M*dec)*((M*dec).t()))),cover[i])
            dec=dec-(M*dec)
            M=torch.zeros(n,n)
            N=torch.ones(n,n)
            for k_ in range(n):
                if cover[i][k_].sum()!=0 and dec.t()[k_].sum()!=0:
                    M.t()[k_]=1
                    N.t()[k_]=0
            c+=1
    return cover

In [None]:
"""Loading the collection of strongly regular graphs

https://users.cecs.anu.edu.au/~bdm/data/graphs.html

sr251256.g6 (15 graphs)
sr261034.g6 (10 graphs)
sr281264.g6 (4 graphs)
sr291467.g6 (41 graphs)
sr351668.g6 (3854 graphs)
sr351899.g6 (227 graphs)
sr361446.g6 (180 graphs)
sr361566.g6 (32548 graphs)
sr371889some.g6 (6760 graphs)
sr401224.g6 (28 graphs)
sr65321516some.g6 (32 graphs)
"""
dataset=load_sr_dataset("sr291467.g6")

In [None]:
num_of_graphs=len(dataset)
num_nodes=dataset[0][1]
num_edges=len(dataset[0][0][0])

In [None]:
Distinguished_graphs=[]
for graph in range(num_of_graphs):
    gr=dataset[graph]
    
    Ad_mat=torch.zeros(num_nodes,num_nodes)
    
    for edge in range(num_edges):
        Ad_mat[gr[0][0][edge]][gr[0][1][edge]]=1
    
    Image=image(Ad_mat,num_nodes,5)
    su=torch.sum(Image,0)

    output_snn_beta=mon_op(mon_op(su.t(),su),su.t())
    
    t=torch.mean(output_snn_beta)
    s=torch.var(output_snn_beta)
    s1=torch.var(torch.sum(torch.eye(num_nodes)*output_snn_beta,0))
    t1=torch.mean(torch.sum(torch.eye(num_nodes)*output_snn_beta,0))
    
    if (t,s,t1,s1) not in Distinguished_graphs:
        Distinguished_graphs.append((t,s,t1,s1))

In [None]:
"""Number of graphs that model can not distinguish"""
print(num_of_graphs-len(Distinguished_graphs))
