In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch_geometric.nn import MessagePassing, dense_diff_pool, GCNConv
from torch_geometric.utils import to_dense_adj, remove_self_loops
from torch_geometric.datasets import Planetoid

In [2]:
class BasicMessagePassing(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(BasicMessagePassing, self).__init__(aggr='add')  # Use sum aggregation
        self.lin = nn.Linear(in_channels, out_channels)
        self.update_mlp = nn.Sequential(
            nn.Linear(out_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )

    def forward(self, x, edge_index):
        # x: Node feature matrix of shape [num_nodes, in_channels]
        # edge_index: Graph connectivity matrix of shape [2, num_edges]
        x = self.lin(x)
        return self.propagate(edge_index, x=x)
    
    def message(self, x_j, x_i):
        return x_j - x_i
    
    def update(self, aggr_out, x):
        return self.update_mlp(aggr_out) + x

In [3]:
class RankNodesByVarianceAndProximity(nn.Module):
    def __init__(self, in_channels):
        super(RankNodesByVarianceAndProximity, self).__init__()
        self.in_channels = in_channels

    def forward(self, x, edge_index, pos):
        num_nodes = x.size(0)
        adj = torch.sparse_coo_tensor(edge_index, torch.ones(edge_index.size(1)), (num_nodes, num_nodes)).to_dense()

        # Compute the sum and count of the messages for each target node
        sum_inputs = torch.matmul(adj, x)
        count_inputs = adj.sum(dim=1, keepdim=True).clamp(min=1)

        # Compute the mean of the messages for each target node
        mean_inputs = torch.zeros_like(x)
        variances = torch.zeros(num_nodes, device=x.device)
        batch_size = 1000  # Adjust batch size based on memory constraints
        for i in range(0, num_nodes, batch_size):
            x_batch = x[i:i+batch_size]
            sum_inputs_batch = sum_inputs[i:i+batch_size]
            count_inputs_batch = count_inputs[i:i+batch_size]

            # Compute mean inputs
            mean_inputs[i:i+batch_size] = sum_inputs_batch / count_inputs_batch

            # Compute variance inputs
            sum_sq_diff_batch = torch.norm(x_batch - mean_inputs[i:i+batch_size].unsqueeze(1), dim=-1)**2
            adj_batch = adj[i:i+batch_size, i:i+batch_size]
            mean_sq_diff_batch = (adj_batch.unsqueeze(-1) * sum_sq_diff_batch).sum(dim=1) / count_inputs[i:i+batch_size]
            variances[i:i+batch_size] = mean_sq_diff_batch.sum(dim=1)


        # Enhance ranking with proximity to high-variance nodes
        adjacency_scores = torch.matmul(adj, variances.unsqueeze(1)).squeeze()

        # Compute geometric distances
        pos_diff = pos.unsqueeze(1) - pos.unsqueeze(0)
        dist = torch.norm(pos_diff, dim=-1)

        # Select top-ranked nodes
        top_k = int(len(variances) * 0.5)
        top_k_indices = torch.argsort(variances, descending=True)[:top_k]

        # Enhance ranking based on proximity
        for i in top_k_indices:
            proximity_scores = torch.exp(-dist[i])
            adjacency_scores += proximity_scores * variances[i]

        combined_scores = variances + adjacency_scores
        ranks = torch.argsort(combined_scores, descending=True)
        
        return ranks, combined_scores

In [4]:
class GNNWithDenseDiffPool(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_clusters):
        super(GNNWithDenseDiffPool, self).__init__()
        self.conv1 = BasicMessagePassing(in_channels, hidden_channels)
        self.rank_nodes = RankNodesByVarianceAndProximity(hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, num_clusters)
        self.conv4 = GCNConv(hidden_channels, out_channels)
        self.lin1 = nn.Linear(hidden_channels, hidden_channels)
        self.lin2 = nn.Linear(hidden_channels, num_clusters)
        self._act = nn.ReLU()
        
    def forward(self, data):
        x, edge_index, batch, pos = data.x.view((-1,1)), data.edge_index, data.batch, data.pos
        x = self.conv1(x, edge_index)
        print(x.shape)

        # Rank nodes by variance and proximity
        with torch.no_grad():  # No gradients needed for ranking
            ranks, combined_scores = self.rank_nodes(x, edge_index, pos)

        # Select top-ranked nodes
        top_k_nodes = ranks[:min(int(len(ranks) * 0.1), x.size(0))] 
        print(top_k_nodes.shape)
        
        # Create subgraph mask
        subgraph_mask = torch.zeros(x.size(0), dtype=torch.bool, device=x.device)
        # print(subgraph_mask.shape)
        subgraph_mask[top_k_nodes] = True
        # print(subgraph_mask[edge_index[0]].shape, subgraph_mask[edge_index[1]].shape, edge_index[0][subgraph_mask].shape)
        
        # mask = subgraph_mask[edge_index[0]]  subgraph_mask[edge_index[1]]
        # print(mask.shape)
        subgraph_edge_index = edge_index[:, subgraph_mask]
        print(subgraph_edge_index.shape)
        subgraph_x = x[subgraph_mask]

        # Apply linear transformation if needed
        # if subgraph_x.size(0) != x.size(0):
        #    subgraph_x =self._act(self.lin1(subgraph_x))

        # Convert edge_index to dense adjacency matrix
        adj = to_dense_adj(subgraph_edge_index, max_num_nodes=subgraph_x.size(0))

        # Apply DiffPool
        print(subgraph_x.shape)
        s = self._act(self.conv2(subgraph_x, subgraph_edge_index))
        s = self.conv3(s, subgraph_edge_index)
        x = self._act(self.conv2(subgraph_x, subgraph_edge_index))
        x, adj, _, _ = dense_diff_pool(x, adj, s)

        # Further graph convolution after pooling
        x = self.conv4(x, subgraph_edge_index)

        # Cleanup memory
        del subgraph_mask, subgraph_edge_index, subgraph_x, adj, s
        torch.cuda.empty_cache()

        return F.log_softmax(x, dim=1)

In [5]:
dataset = torch.load('../../data/MLgSA/train_loader.pt')
data = next(iter(dataset))

In [6]:
data

DataBatch(x=[28811], edge_index=[2, 343986], y=[1], pos=[28811, 3], batch=[28811], ptr=[2])

In [7]:
# Define model parameters
in_channels = 1
hidden_channels = 4
out_channels = 2
num_clusters = 10  # Number of clusters for DiffPool

model = GNNWithDenseDiffPool(in_channels, hidden_channels, out_channels, num_clusters)
loss_fn = nn.CrossEntropyLoss()
# Training the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
print('training')
optimizer.zero_grad()
out = model(data)
print(out.clone().detach())
loss = loss_fn(out, data.y)
loss.backward()
optimizer.step()
print(loss.detach())

model.eval()
_, pred = model(data).max(dim=1)
correct = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()
accuracy = correct / data.test_mask.sum().item()
print('Test Accuracy: {:.4f}'.format(accuracy))

training
torch.Size([28811, 4])
torch.Size([2881])


IndexError: The shape of the mask [28811] at index 0 does not match the shape of the indexed tensor [2, 343986] at index 1

In [8]:
32671-30120

2551

In [14]:
nodes = torch.tensor(range(30))
top_k = [1, 3, 5, 7]
nodes_mask = torch.zeros(nodes.size(0), dtype=torch.bool)

In [15]:
nodes_mask[top_k] = True

In [27]:
new_ind = edges[nodes_mask.any(dim=0)]

In [28]:
new_ind

tensor([[[ 0,  1,  0,  2,  1,  2,  0,  2,  0,  3,  2,  3,  4,  3,  4,  5,  3,
           5,  2,  5,  2,  3,  5,  3,  0,  6,  0,  1,  6,  1,  7,  1,  7,  6,
           1,  6,  5,  8,  5,  4,  8,  4,  9,  8,  9, 10,  8, 10,  9,  4,  9,
           8,  4,  8, 11, 10, 11,  8, 10,  8,  9, 10,  9, 12, 10, 12,  0,  3,
           0, 13,  3, 13,  0, 13,  0,  6, 13,  6, 14, 15, 14,  7, 15,  7,  6,
          14,  6,  7, 14,  7, 16, 14, 16,  6, 14,  6,  4, 17,  4,  3],
         [ 1,  0,  2,  0,  2,  1,  2,  0,  3,  0,  3,  2,  3,  4,  5,  4,  5,
           3,  5,  2,  3,  2,  3,  5,  6,  0,  1,  0,  1,  6,  1,  7,  6,  7,
           6,  1,  8,  5,  4,  5,  4,  8,  8,  9, 10,  9, 10,  8,  4,  9,  8,
           9,  8,  4, 10, 11,  8, 11,  8, 10, 10,  9, 12,  9, 12, 10,  3,  0,
          13,  0, 13,  3, 13,  0,  6,  0,  6, 13, 15, 14,  7, 14,  7, 15, 14,
           6,  7,  6,  7, 14, 14, 16,  6, 16,  6, 14, 17,  4,  3,  4]]])