# Count star structures
... without considering node/edge attributes. Also ignore ducplicate edges and make graph undirected.


In [1]:
"""example code for counting subgraph structures"""
from subgraph_counting import pattern, conversion, datasets
from subgraph_counting.graph import Graph
from torch_geometric.loader import DataLoader
import torch
import os

PLOT_ROOT = '../data/figures'
DATA_ZINC = '../data/datasets/ZINC'
ANNOTATED_DATA_ZINC_DIRECTORY = '../data/datasets/annotated/ZINC'
os.makedirs(ANNOTATED_DATA_ZINC_DIRECTORY, exist_ok=True)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# get ZINC dataset
datasets_pyg = datasets.get_zinc_dataset(DATA_ZINC)

In [3]:
# get ZINC dataset
star_5_ig = pattern.create_star(5)
star_5_pyg = conversion.igraph_to_pyg_graph(star_5_ig)
star_5 = Graph(star_5_pyg, pyg_to_undirected=True)

look_for = {
    'star_5': star_5
}

for dataset_name in datasets_pyg:
    data_list = []
    for i, graph_pyg in enumerate(datasets_pyg[dataset_name]):
        graph = Graph(graph_pyg, pyg_to_undirected=True)
        y = []
        for key in look_for:
            subisomorphisms = graph.count_subisomorphisms(star_5)
            y.append(subisomorphisms)
        del graph_pyg.edge_attr
        graph_pyg.x = torch.ones_like(graph_pyg.x, dtype=torch.float32)
        graph_pyg.y = torch.as_tensor(y)
        data_list.append(graph_pyg)
    torch.save(data_list, f"{ANNOTATED_DATA_ZINC_DIRECTORY}/{dataset_name}.datalist")

with open(f"{ANNOTATED_DATA_ZINC_DIRECTORY}/Annotations.txt", "w") as f:
    f.writelines("Graphs are annotated with the subgraphcount of following graph structures:\n")
    for graph_name in look_for:
        f.writelines(graph_name + "\n")


# Load Data
ZINC Dataset without edge attributes and node attributes equal 1.

In [4]:
datasets_annotated_pyg = datasets.get_annotated_zinc_datalist(ANNOTATED_DATA_ZINC_DIRECTORY)
train_loader = DataLoader(datasets_annotated_pyg['train'])
test_loader = DataLoader(datasets_annotated_pyg['test'])

# Train

In [13]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

device = 'cuda'

node_feature_size = 1
target_features_size = 1

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(node_feature_size, hidden_channels)
        #self.conv2 = GCNConv(hidden_channels, hidden_channels)
        #self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, target_features_size)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        #x = x.relu()
        #x = self.conv2(x, edge_index)
        #x = x.relu()
        #x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.1, training=self.training)
        x = self.lin(x)
        
        return x

model = GCN(hidden_channels=64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.L1Loss()

def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
         out = model(data.x.to(device), data.edge_index.to(device), data.batch.to(device))  # Perform a single forward pass.
         loss = criterion(out, data.y.unsqueeze(0).to(device))  # Compute the loss.
         loss.backward()  # Derive gradients.
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

def test(loader):
    model.eval()
    losses = []
    for data in loader:  # Iterate in batches over the training/test dataset.
        out = model(data.x.to(device), data.edge_index.to(device), data.batch.to(device))  
        loss = criterion(out, data.y.unsqueeze(0).to(device))
        losses.append(loss)
    return torch.mean(torch.as_tensor(losses)).item()

for epoch in range(1, 171):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train MSE: {train_acc:.4f}, Test MSE: {test_acc:.4f}')

Epoch: 001, Train MSE: 0.3126, Test MSE: 0.3459
Epoch: 002, Train MSE: 0.3114, Test MSE: 0.3450
Epoch: 003, Train MSE: 0.3111, Test MSE: 0.3446
Epoch: 004, Train MSE: 0.3116, Test MSE: 0.3452
Epoch: 005, Train MSE: 0.3123, Test MSE: 0.3457
Epoch: 006, Train MSE: 0.3127, Test MSE: 0.3460
Epoch: 007, Train MSE: 0.3118, Test MSE: 0.3452
Epoch: 008, Train MSE: 0.3127, Test MSE: 0.3461
Epoch: 009, Train MSE: 0.3131, Test MSE: 0.3464
Epoch: 010, Train MSE: 0.3178, Test MSE: 0.3505
Epoch: 011, Train MSE: 0.3121, Test MSE: 0.3455
Epoch: 012, Train MSE: 0.3128, Test MSE: 0.3461
Epoch: 013, Train MSE: 0.3128, Test MSE: 0.3461
Epoch: 014, Train MSE: 0.3110, Test MSE: 0.3445
Epoch: 015, Train MSE: 0.3115, Test MSE: 0.3450
Epoch: 016, Train MSE: 0.3121, Test MSE: 0.3455
Epoch: 017, Train MSE: 0.3132, Test MSE: 0.3464
Epoch: 018, Train MSE: 0.3124, Test MSE: 0.3458
Epoch: 019, Train MSE: 0.3113, Test MSE: 0.3448
Epoch: 020, Train MSE: 0.3131, Test MSE: 0.3463
Epoch: 021, Train MSE: 0.3120, Test MSE:

# Explain

from torch_geometric.data import Data
from torch_geometric.explain import Explainer, PGExplainer

dataset = datasets_annotated_pyg['val']
loader = DataLoader(dataset, batch_size=1, shuffle=True)

explainer = Explainer(
    model=model,
    algorithm=PGExplainer(epochs=30, lr=0.003),
    explanation_type='phenomenon',
    edge_mask_type='object',
    model_config=dict(
        mode='regression',
        task_level='graph',
        return_type='raw',
    ),
    # Include only the top 10 most important edges:
    threshold_config=dict(threshold_type='topk', value=10),
)

# PGExplainer needs to be trained separately since it is a parametric
# explainer i.e it uses a neural network to generate explanations:
for epoch in range(30):
    for batch in loader:
        loss = explainer.algorithm.train(
            epoch, model, batch.x, batch.edge_index, target=batch.target)

# Generate the explanation for a particular graph:
explanation = explainer(dataset[0].x, dataset[0].edge_index)
print(explanation.edge_mask)