In [6]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing, dense_diff_pool, GCNConv, DenseSAGEConv
from torch_geometric.utils import to_dense_adj, remove_self_loops, subgraph, k_hop_subgraph
from torch_geometric.loader import DataLoader

In [5]:
X_train = torch.load('../../data/train_split.pt')
X_test = torch.load('../../data/test_split.pt')

In [7]:
train_loader = DataLoader(X_train, batch_size=1, shuffle=True)

In [20]:
data = next(iter(train_loader))[0]
print(data)

Data(x=[15357, 1], edge_index=[2, 182982], y=[1], pos=[15357, 3])


In [22]:
class BasicMessagePassing(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(BasicMessagePassing, self).__init__()
        self.conv = GCNConv(in_channels, out_channels)

    def forward(self, x, edge_index):
        return F.relu(self.conv(x, edge_index))


class RankNodesByTotalVariation(nn.Module):
    def __init__(self, in_channels):
        super(RankNodesByTotalVariation, self).__init__()
        self.in_channels = in_channels

    def forward(self, x, edge_index):
        num_nodes = x.size(0)
        adj = torch.sparse_coo_tensor(edge_index, torch.ones(edge_index.size(1)), (num_nodes, num_nodes)).to_dense()

        sum_inputs = torch.matmul(adj, x)
        count_inputs = adj.sum(dim=1, keepdim=True).clamp(min=1)
        mean_inputs = sum_inputs / count_inputs

        total_variation = torch.norm(x - mean_inputs, dim=-1)
        return total_variation

In [12]:
rank_nodes = RankNodesByTotalVariation(1)

In [18]:
var = rank_nodes(data[0].x, data[0].edge_index)

In [19]:
class GNNWithDenseDiffPool(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_clusters):
        super(GNNWithDenseDiffPool, self).__init__()
        self.conv1 = BasicMessagePassing(in_channels, hidden_channels)
        self.rank_nodes = RankNodesByTotalVariation(hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, num_clusters)
        self.conv4 = GCNConv(num_clusters, out_channels)
        self.lin1 = nn.Linear(hidden_channels, hidden_channels)
        self.lin2 = nn.Linear(hidden_channels, num_clusters)
        
    def forward(self, data):
        x, edge_index, batch, pos = data.x, data.edge_index, data.batch, data.pos
        x = self.conv1(x, edge_index)

        # Rank nodes by total variation
        with torch.no_grad():
            total_variation = self.rank_nodes(x, edge_index)

        # Select top 50 to 100 nodes with highest total variation values
        num_top_nodes = min(100, x.size(0))
        top_k_nodes = torch.argsort(total_variation, descending=True)[:num_top_nodes]

        # Find k-hop subgraphs for top nodes
        k = 2  # Define k for k-hop neighborhood
        sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph(top_k_nodes, k, edge_index, relabel_nodes=True)
        sub_x = x[sub_nodes]
        sub_pos = pos[sub_nodes]

        # Apply linear transformation if needed
        if sub_x.size(0) != x.size(0):
            sub_x = F.relu(self.lin1(sub_x))

        # Convert edge_index to dense adjacency matrix
        adj = to_dense_adj(sub_edge_index, max_num_nodes=sub_x.size(0))

        # Apply DiffPool
        s = F.relu(self.conv2(sub_x, sub_edge_index))
        s = self.conv3(s, sub_edge_index)
        sub_x = F.relu(self.conv2(sub_x, sub_edge_index))
        sub_x, adj, _, _ = dense_diff_pool(sub_x, adj, s)

        # Further graph convolution after pooling
        # Convert dense adjacency matrix back to edge_index format
        pooled_edge_index, _ = to_dense_adj(sub_x).nonzero(as_tuple=False).t()
        x = self.conv4(sub_x, pooled_edge_index)

        # Cleanup memory
        del sub_nodes, sub_edge_index, sub_x, adj, s
        torch.cuda.empty_cache()

        return F.log_softmax(x, dim=1)

In [23]:
# Define model parameters
in_channels = 1
hidden_channels = 4
out_channels = 2
num_clusters = 3  # Number of clusters for DiffPool

model = GNNWithDenseDiffPool(in_channels, hidden_channels, out_channels, num_clusters)
loss_fn = nn.CrossEntropyLoss()
# Training the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
print('training')
optimizer.zero_grad()
out = model(data)
print(out.clone().detach())
loss = loss_fn(out, data.y)
loss.backward()
optimizer.step()
print(loss.detach())

model.eval()
_, pred = model(data).max(dim=1)
correct = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()
accuracy = correct / data.test_mask.sum().item()
print('Test Accuracy: {:.4f}'.format(accuracy))

training


RuntimeError: scatter(): Expected dtype int64 for index