In [1]:
import os.path as osp
import torch
from torch.nn import Linear

import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric import utils
from torch_geometric.nn import Sequential
from torch_geometric.data import Data

import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from utils.eval_metrics import *

torch.manual_seed(1)
torch.cuda.manual_seed(1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
dataset = 'Cora'
path = osp.join('..', 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]

delta = 0.85
edge_index, edge_weight = utils.get_laplacian(data.edge_index, data.edge_weight, normalization='sym')
L = utils.to_dense_adj(edge_index, edge_attr=edge_weight)
A = torch.eye(data.num_nodes) - delta*L
data.edge_index, data.edge_weight = utils.dense_to_sparse(A)

In [3]:
EPS = 1e-15

def just_balance_pool(x, adj, s, mask=None, normalize=True):
    r"""The Just Balance pooling operator from the `"Simplifying Clustering with 
    Graph Neural Networks" <https://arxiv.org/abs/2207.08779>`_ paper
   
    Args:
        x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}` 
            with batch-size :math:`B`, (maximum) number of nodes :math:`N` 
            for each graph, and feature dimension :math:`F`.
        adj (Tensor): Symmetrically normalized adjacency tensor
            :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
        s (Tensor): Assignment tensor :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}` 
            with number of clusters :math:`C`. The softmax does not have to be 
            applied beforehand, since it is executed within this method.
        mask (BoolTensor, optional): Mask matrix
            :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
            the valid nodes for each graph. (default: :obj:`None`)

    :rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`,
        :class:`Tensor`)
    """

    x = x.unsqueeze(0) if x.dim() == 2 else x
    adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
    s = s.unsqueeze(0) if s.dim() == 2 else s

    (batch_size, num_nodes, _), k = x.size(), s.size(-1)

    s = torch.softmax(s, dim=-1)

    if mask is not None:
        mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)
        x, s = x * mask, s * mask

    out = torch.matmul(s.transpose(1, 2), x)
    out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)
    
    ss = torch.matmul(s.transpose(1, 2), s)
    ss_sqrt = torch.sqrt(ss + EPS)
    loss = torch.mean(-_rank3_trace(ss_sqrt))
    if normalize:
        loss = loss / torch.sqrt(torch.tensor(num_nodes * k))

    ind = torch.arange(k, device=out_adj.device)
    out_adj[:, ind, ind] = 0
    d = torch.einsum('ijk->ij', out_adj)
    d = torch.sqrt(d)[:, None] + EPS
    out_adj = (out_adj / d) / d.transpose(1, 2)

    return out, out_adj, loss


def _rank3_trace(x):
    return torch.einsum('ijj->i', x)

In [4]:
class Net(torch.nn.Module):
    def __init__(self, 
                 mp_units,
                 mp_act,
                 in_channels, 
                 n_clusters, 
                 mlp_units=[],
                 mlp_act="Identity"):
        super().__init__()
        
        mp_act = getattr(torch.nn, mp_act)(inplace=True)
        mlp_act = getattr(torch.nn, mlp_act)(inplace=True)
        
        mp = [
            (GCNConv(in_channels, mp_units[0], normalize=False, cached=False), 'x, edge_index, edge_weight -> x'),
            mp_act
        ]
        for i in range(len(mp_units)-1):
            mp.append((GCNConv(mp_units[i], mp_units[i+1], normalize=False, cached=False), 'x, edge_index, edge_weight -> x'))
            mp.append(mp_act)
        self.mp = Sequential('x, edge_index, edge_weight', mp)
        out_chan = mp_units[-1]
        
        self.mlp = torch.nn.Sequential()
        for units in mlp_units:
            self.mlp.append(Linear(out_chan, units))
            out_chan = units
            self.mlp.append(mlp_act)
        self.mlp.append(Linear(out_chan, n_clusters))
        

    def forward(self, x, edge_index, edge_weight):
        x = self.mp(x, edge_index, edge_weight)
        s = self.mlp(x)
        adj = utils.to_dense_adj(edge_index, edge_attr=edge_weight)
        x_pooled, adj_pooled, b_loss = just_balance_pool(x, adj, s)
        
        return torch.softmax(s, dim=-1), b_loss, x_pooled, adj_pooled

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)
model = Net([64]*10, "ReLU", dataset.num_features, dataset.num_classes, [16], "ReLU").to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


def train():
    model.train()
    optimizer.zero_grad()
    _, loss, _, _ = model(data.x, data.edge_index, data.edge_weight)
    loss.backward()
    optimizer.step()
    return loss.item()


@torch.no_grad()
def test():
    model.eval()
    clust, _, _, _ = model(data.x, data.edge_index, data.edge_weight)
    return eval_metrics(data.y.cpu(), clust.max(1)[1].cpu())

patience = 50
best_nmi = 0
for epoch in range(1, 1001):
    train_loss = train()
    acc, nmi = test()
    if epoch % 50 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {train_loss:.4f}, NMI: {nmi:.3f}, ACC: {acc:.3f}')

Epoch: 050, Loss: -0.3780, NMI: 0.000, ACC: 0.302
Epoch: 100, Loss: -0.3782, NMI: 0.000, ACC: 0.302
Epoch: 150, Loss: -0.3909, NMI: 0.000, ACC: 0.302
Epoch: 200, Loss: -0.3949, NMI: 0.000, ACC: 0.302
Epoch: 250, Loss: -0.3972, NMI: 0.000, ACC: 0.302
Epoch: 300, Loss: -0.4010, NMI: 0.000, ACC: 0.302
Epoch: 350, Loss: -0.4392, NMI: 0.037, ACC: 0.322
Epoch: 400, Loss: -0.4850, NMI: 0.207, ACC: 0.403
Epoch: 450, Loss: -0.4946, NMI: 0.230, ACC: 0.403
Epoch: 500, Loss: -0.5558, NMI: 0.253, ACC: 0.390
Epoch: 550, Loss: -0.5759, NMI: 0.255, ACC: 0.391
Epoch: 600, Loss: -0.6155, NMI: 0.354, ACC: 0.469
Epoch: 650, Loss: -0.6356, NMI: 0.359, ACC: 0.473
Epoch: 700, Loss: -0.6375, NMI: 0.361, ACC: 0.480
Epoch: 750, Loss: -0.6398, NMI: 0.361, ACC: 0.479
Epoch: 800, Loss: -0.6412, NMI: 0.359, ACC: 0.475
Epoch: 850, Loss: -0.6442, NMI: 0.353, ACC: 0.467
Epoch: 900, Loss: -0.6472, NMI: 0.352, ACC: 0.464
Epoch: 950, Loss: -0.7276, NMI: 0.341, ACC: 0.462
Epoch: 1000, Loss: -0.9002, NMI: 0.336, ACC: 0.468

In [6]:
import os
import sys

from torch_geometric.data import Data

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

module_path = os.path.abspath(os.path.join('../gmn_config/'))
if module_path not in sys.path:
    sys.path.append(module_path)

from gmn_config.graph_utils import *

from gmn_config.evaluation import compute_similarity, auc
from gmn_config.loss import pairwise_loss, triplet_loss
from gmn_config.gmn_utils import *
from gmn_config.configure_cosine import *

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

# Print configure
config = get_default_config()
for k, v in config.items():
    print("%s= %s" % (k, v))

# Set random seeds
seed = config["seed"]
random.seed(seed)
np.random.seed(seed + 1)
torch.manual_seed(seed + 2)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
torch.autograd.set_detect_anomaly(True)


training_set, validation_set = build_datasets(config)

if config["training"]["mode"] == "pair":
    training_data_iter = training_set.pairs(config["training"]["batch_size"])
    first_batch_graphs, _ = next(training_data_iter)
else:
    training_data_iter = training_set.triplets(config["training"]["batch_size"])
    first_batch_graphs = next(training_data_iter)

node_feature_dim = first_batch_graphs.node_features.shape[-1]
edge_feature_dim = first_batch_graphs.edge_features.shape[-1]

gmn, optimizer = build_model(config, node_feature_dim, edge_feature_dim)
gmn.load_state_dict(torch.load("../gmn_config/model64new.pth"))
gmn.to(device)
gmn.eval()

encoder= {'node_hidden_sizes': [64], 'node_feature_dim': 1, 'edge_hidden_sizes': [4]}
aggregator= {'node_hidden_sizes': [128], 'graph_transform_sizes': [128], 'input_size': [64], 'gated': True, 'aggregation_type': 'sum'}
graph_embedding_net= {'node_state_dim': 64, 'edge_state_dim': 4, 'edge_hidden_sizes': [128, 128], 'node_hidden_sizes': [128], 'n_prop_layers': 5, 'share_prop_params': True, 'edge_net_init_scale': 0.1, 'node_update_type': 'gru', 'use_reverse_direction': True, 'reverse_dir_param_different': False, 'layer_norm': False, 'prop_type': 'matching'}
graph_matching_net= {'node_state_dim': 64, 'edge_state_dim': 4, 'edge_hidden_sizes': [128, 128], 'node_hidden_sizes': [128], 'n_prop_layers': 5, 'share_prop_params': True, 'edge_net_init_scale': 0.1, 'node_update_type': 'gru', 'use_reverse_direction': True, 'reverse_dir_param_different': False, 'layer_norm': False, 'prop_type': 'matching', 'similarity': 'dotproduct'}
model_type= matching
data= {'problem': 'graph_edit_distance', 'dat

GraphMatchingNet(
  (_encoder): GraphEncoder(
    (MLP1): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
    )
    (MLP2): Sequential(
      (0): Linear(in_features=4, out_features=4, bias=True)
    )
  )
  (_aggregator): GraphAggregator(
    (MLP1): Sequential(
      (0): Linear(in_features=64, out_features=256, bias=True)
    )
    (MLP2): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (_prop_layers): ModuleList(
    (0-4): 5 x GraphPropMatchingLayer(
      (_message_net): Sequential(
        (0): Linear(in_features=132, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=128, bias=True)
      )
      (_reverse_message_net): Sequential(
        (0): Linear(in_features=132, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=128, bias=True)
      )
      (GRU): GRU(192, 64)
    )
  )
)

In [7]:
_, _, x_pool, adj_pool = model(data.x, data.edge_index, data.edge_weight)
edge_index_pool = utils.dense_to_sparse(adj_pool)[0]
clustered_data = torch_geometric.data.Data(x=x_pool[0], edge_index=edge_index_pool)
original_data = Data(x=reduce_dimensions(data.x.numpy()), edge_index=data.edge_index)
sim = similarity(gmn, config, original_data, clustered_data)
print(sim)

tensor(0.3520, grad_fn=<SelectBackward0>)
