# Logic Explained GNNs  (toy example)
1. Generate BA-shape-like dataset
2. Run LogDiffPool
3. Get explanations

In [1]:
import sys
sys.path.append('../../')
import torch
import torch_geometric as pyg
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_adj, to_dense_batch, to_undirected
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import matplotlib.pyplot as plt
import seaborn as sns
import torch_explain as te

In [2]:
torch.manual_seed(12345)

dataset = TUDataset(root='data/TUDataset', name='MUTAG')
dataset = dataset.shuffle()
train_dataset = dataset[:150]
test_dataset = dataset[150:]
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

In [3]:
EPS = 1e-15
entropy_loss = lambda alpha: -torch.sum(alpha * torch.log(alpha + EPS))

class LogGNNLayer(torch.nn.Module):
    def __init__(self, num_node_features, hidden_channels, hidden_channels_pool):
        super(LogGNNLayer, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels//2)
        self.pool1 = GCNConv(num_node_features, hidden_channels_pool)
        self.pool2 = GCNConv(hidden_channels_pool, hidden_channels_pool//2)
        self.logpool = te.nn.logic.LogDiffPool()

    def forward(self, x, edge_index, batch_index):
        # 1. Obtain node embeddings 
        e = self.conv1(x, edge_index)
        e = e.relu()
        e = self.conv2(e, edge_index)
        e = e.relu()
        
        # 2. Obtain pooled cluster assignments
        p = self.pool1(x, edge_index)
        p = p.relu()
        p = self.pool2(p, edge_index)
        p = p.relu()

        # 3. LogDiffPool
        x, edge_index, edge_attr, batch_index = self.logpool(e, p, edge_index, batch_index)
        return x, edge_index, edge_attr, batch_index, self.logpool
#         return e, p, x, edge_index, batch_index

class LogGNN(torch.nn.Module):
    def __init__(self, hidden_channels, hidden_channels_pool):
        super(LogGNN, self).__init__()
        torch.manual_seed(12345)
        self.gnn1 = LogGNNLayer(dataset.num_node_features, hidden_channels[0], hidden_channels_pool[0])
        self.gnn2 = LogGNNLayer(hidden_channels[1], hidden_channels[1], hidden_channels_pool[1])
        self.lin = Linear(hidden_channels[1]//2, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings
        x, edge_index, edge_attr, batch_index, ldp1 = self.gnn1(x, edge_index, batch.batch)
        print(x.shape, edge_index.shape, batch_index.shape)
        print(x)
        print(edge_index)
        print(batch_index)
        x, edge_index, edge_attr, batch_index, ldp2 = self.gnn2(x, edge_index, batch_index)
        
        # 2. Apply a final classifier
        x = self.lin(x)
        return x, ldp1, ldp2
    
data = next(iter(train_loader))
model = LogGNN(hidden_channels=[8, 4], hidden_channels_pool=[4, 2])
print(model)
model(data.x, data.edge_index, data)

# m = LogGNNLayer(7, hidden_channels=8, hidden_channels_pool=6)
# e, p, x, edge_index, batch_index = m(data.x, data.edge_index, data.batch)
# print(e.shape, p.shape, edge_index.shape, batch_index.shape)
# adj = to_dense_adj(edge_index, batch_index)
# embed, embfake = to_dense_batch(e, batch_index)
# pool, poolfake = to_dense_batch(p, batch_index)
# print(adj.shape, embed.shape, pool.shape)
# alpha_r = torch.softmax(pool, dim=-1)
# alpha_c = torch.softmax(pool, dim=-2)
# gamma = alpha_r * alpha_c
# out = torch.matmul(gamma.transpose(1, 2), embed)
# out_adj = torch.matmul(torch.matmul(gamma.transpose(1, 2), adj), gamma)
# print(out, out_adj) # (B, C, H)->(B x C, H) (B, C, C)->(B, C, C)
# index = torch.LongTensor([[i, j] for i in range(len(out_adj)) for j in range(len(out_adj))]).T
# batch = index[0]
# batch_mul = index[0] * out_adj.size(-2)
# index = (batch_mul + index[0], batch_mul + index[1])
# edge_index = torch.stack(index, dim=0)
# edge_index = to_undirected(edge_index, None)
# # edge_index
# edge_index

LogGNN(
  (gnn1): LogGNNLayer(
    (conv1): GCNConv(7, 8)
    (conv2): GCNConv(8, 4)
    (pool1): GCNConv(7, 4)
    (pool2): GCNConv(4, 2)
    (logpool): LogDiffPool()
  )
  (gnn2): LogGNNLayer(
    (conv1): GCNConv(4, 4)
    (conv2): GCNConv(4, 2)
    (pool1): GCNConv(4, 2)
    (pool2): GCNConv(2, 1)
    (logpool): LogDiffPool()
  )
  (lin): Linear(in_features=2, out_features=2, bias=True)
)
torch.Size([8, 4]) torch.Size([2, 204]) torch.Size([89])
tensor([[0.1467, 0.1255, 0.0000, 0.0000],
        [0.1608, 0.1329, 0.0000, 0.0000],
        [0.1486, 0.1116, 0.0000, 0.0000],
        [0.1625, 0.1194, 0.0000, 0.0000],
        [0.1470, 0.1256, 0.0000, 0.0000],
        [0.1613, 0.1330, 0.0000, 0.0000],
        [0.1850, 0.1379, 0.0000, 0.0000],
        [0.1999, 0.1459, 0.0000, 0.0000]], grad_fn=<ReshapeAliasBackward0>)
tensor([[ 0,  0,  1,  1,  2,  2,  3,  3,  3,  4,  4,  4,  5,  5,  6,  6,  6,  7,
          7,  8,  8,  8,  9,  9, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 13, 14,
         14, 14

RuntimeError: index 12 is out of bounds for dimension 0 with size 8

In [17]:
import torch
adj = torch.ones(2, 3, 3).reshape(-1, 3)
index = torch.LongTensor([[i, j] for i in range(len(adj)) for j in range(len(adj))]).T
batch = index[0] * adj.size(-2)
index = (batch + index[0], batch + index[1])
index

(tensor([ 0,  0,  0,  0,  0,  0,  7,  7,  7,  7,  7,  7, 14, 14, 14, 14, 14, 14,
         21, 21, 21, 21, 21, 21, 28, 28, 28, 28, 28, 28, 35, 35, 35, 35, 35, 35]),
 tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]))

In [69]:
model = LogGNN(hidden_channels=[8, 4], hidden_channels_pool=[4, 2])
optim = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train_epoch(train_loader):
    model.train()
    avg_loss = 0
    for data in train_loader:
        optim.zero_grad()
        y_pred, ldp1, ldp2 = model(data.x, data.edge_index, data)
        loss = criterion(y_pred, data.y) + entropy_loss(ldp1.alpha_r) + entropy_loss(ldp1.alpha_c) + entropy_loss(ldp2.alpha_r) + entropy_loss(ldp2.alpha_c)
        loss.backward()
        optim.step()
        avg_loss += loss
    return avg_loss, ldp
        
def test_epoch(loader):
    model.eval()
    correct = 0
    for data in loader:
        y_pred, _, _ = model(data.x, data.edge_index, data)
        pred = y_pred.argmax(dim=1)
        correct += int((pred == data.y).sum())
    return correct / len(loader.dataset)

for epoch in range(1000):
    avg_loss, ldp = train_epoch(train_loader)
    train_acc = test_epoch(train_loader)
    test_acc = test_epoch(test_loader)
    if epoch % 100 == 0:
        print(f'Epoch {epoch}: loss={avg_loss}, train_acc={train_acc}, test_acc={test_acc}')

torch.Size([128, 4]) torch.Size([2, 254]) torch.Size([128])


RuntimeError: index 129 is out of bounds for dimension 0 with size 128

In [46]:
ldp.gamma.shape

torch.Size([38, 26, 1])

In [10]:

def combine_random_graphs():
    """
    Combines BA random graph with a small graph with known structure.
    """
    n, m = 15, 6
    np.random.default_rng(42)
    p1 = np.random.randint(2)
    p2 = np.random.randint(3)

    # generate random BA graph
    g1 = random_graphs.barabasi_albert_graph(n, m)
    label = 0
    if p1:
        # generate a small graph with known structure
        if p2 == 0:
            g2 = lattice.grid_2d_graph(3, 3)
        elif p2 == 1:
            g2 = small.house_graph()
        else:
            g2 = classic.wheel_graph(m)

        # merge the two graphs
        g12 = union(g1, g2, rename=('G1', 'G2'))
        a12 = nx.to_numpy_array(g12)
        a12[len(g1) - 1, len(g1) + 1] = 1
        g12 = nx.from_numpy_array(a12)
        g1 = g12
        label = 1
    return g1, label


def generate_dataset_multishape(n_samples):
    """
    Generate a dataset where each sample is a combination of a random BA graph and a small graph with known structure.
    """
    graphs = []
    labels = []
    for i in range(n_samples):
        g, label = combine_random_graphs()
        graphs.append(g)
        labels.append(label)
    return graphs, labels

n_samples = 1000
graph, labels = generate_dataset_multishape(n_samples)

In [19]:
edge_index = list(graph[0].edges)


1

In [23]:
from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='data/TUDataset', name='MUTAG')
data = dataset[0]
data

Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip
Extracting data\TUDataset\MUTAG\MUTAG.zip
Processing...
Done!


Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1])