In [21]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing, dense_diff_pool, GCNConv, DenseSAGEConv, global_mean_pool
from torch_geometric.utils import to_dense_adj, remove_self_loops, subgraph, k_hop_subgraph
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader

In [18]:
dataset = torch.load('../../data/MLgSA/train_loader.pt')

In [49]:
data = next(iter(dataset))
data

DataBatch(x=[140, 1], edge_index=[2, 900], y=[1], pos=[140, 3], batch=[140], ptr=[2])

In [50]:
class BasicMessagePassing(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(BasicMessagePassing, self).__init__()
        self.conv = GCNConv(in_channels, out_channels)

    def forward(self, x, edge_index):
        return F.relu(self.conv(x, edge_index))

class RankNodesByTotalVariation(nn.Module):
    def __init__(self, in_channels):
        super(RankNodesByTotalVariation, self).__init__()
        self.in_channels = in_channels

    def forward(self, x, edge_index):
        num_nodes = x.size(0)
        # Create sparse adjacency matrix
        adj = torch.sparse_coo_tensor(edge_index, torch.ones(edge_index.size(1)), (num_nodes, num_nodes))
        
        # Perform sparse matrix multiplication to get sum of inputs
        sum_inputs = torch.sparse.mm(adj, x)
        
        # Get the count of inputs (degree of each node)
        count_inputs = torch.sparse.sum(adj, dim=1).to_dense().view(-1, 1).clamp(min=1)
        
        # Calculate the mean inputs
        mean_inputs = sum_inputs / count_inputs
        
        # Calculate total variation
        total_variation = torch.norm(x - mean_inputs, dim=-1)
        return total_variation

In [51]:
class GNNWithDenseDiffPool(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_clusters):
        super(GNNWithDenseDiffPool, self).__init__()
        self.conv1 = BasicMessagePassing(in_channels, hidden_channels)
        self.rank_nodes = RankNodesByTotalVariation(hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, num_clusters)
        self.lin1 = nn.Linear(hidden_channels, hidden_channels)
        self.lin2 = nn.Linear(hidden_channels, num_clusters)
        self.lin3 = nn.Linear(hidden_channels*num_clusters, 1)
        
    def forward(self, data):
        x, edge_index, batch, pos = data.x.view((-1, 1)), data.edge_index, data.batch, data.pos
        x = self.conv1(x, edge_index)

        # Rank nodes by total variation
        total_variation = self.rank_nodes(x, edge_index)

        # Select top 50 to 100 nodes with highest total variation values
        num_top_nodes = min(100, x.size(0))
        top_k_nodes = torch.argsort(total_variation, descending=True)[:num_top_nodes]
        del total_variation

        # Find k-hop subgraphs for top nodes
        k = 5  # Define k for k-hop neighborhood
        sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph(top_k_nodes, k, edge_index, relabel_nodes=True)
        sub_x = x[sub_nodes]
        sub_pos = pos[sub_nodes]

        # Apply linear transformation if needed
        if sub_x.size(0) != x.size(0):
            sub_x = F.relu(self.lin1(sub_x))

        # Convert edge_index to dense adjacency matrix
        adj = to_dense_adj(sub_edge_index, max_num_nodes=sub_x.size(0))

        # Apply DiffPool
        s = F.relu(self.conv2(sub_x, sub_edge_index))
        s = self.conv3(s, sub_edge_index)
        sub_x = F.relu(self.conv2(sub_x, sub_edge_index))
        x, adj, _, _ = dense_diff_pool(sub_x, adj, s)

        x = self.lin3(x.view((1,-1)))
        print(x.shape, x)

        # Cleanup memory
        del sub_nodes, sub_edge_index, sub_x, adj, s
        torch.cuda.empty_cache()

        return F.log_softmax(x.view(-1), dim=-1)

In [52]:
class GNNWithDenseDiffPool(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_clusters):
        super(GNNWithDenseDiffPool, self).__init__()
        self.conv1 = BasicMessagePassing(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin1 = nn.Linear(hidden_channels, 125)
        self.lin2 = nn.Linear(125, hidden_channels)
        self.lin3 = nn.Linear(hidden_channels, 2)
        
    def forward(self, data, batch=None):
        x, edge_index, batch, pos = data.x.view((-1, 1)), data.edge_index, data.batch, data.pos
        x = self.conv1(x, edge_index)
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        x1 = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x1) + x)
        x = F.relu(self.lin3(x))
        out = global_mean_pool(x, batch)
        print("output ", out)
        return out


In [53]:
# Define model parameters
in_channels = 1
hidden_channels = 4
out_channels = 2
num_clusters = 3  # Number of clusters for DiffPool

model = GNNWithDenseDiffPool(in_channels, hidden_channels, out_channels, num_clusters)
loss_fn = nn.CrossEntropyLoss()
# Training the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
print('training')
optimizer.zero_grad()
out = model(data)
print(out.clone().detach())
#y = torch.tensor(data.y, dtype=torch.float)
loss = loss_fn(out, data.y)
loss.backward()
optimizer.step()
print(loss.detach())


training
output  tensor([[0.3849, 0.0000]], grad_fn=<DivBackward0>)
tensor([[0.3849, 0.0000]])
tensor(0.5191)


In [25]:
data.y.to(torch.float)

tensor([1.])

In [20]:
x.to(torch.long).dtype

torch.int64

In [2]:
import torch
torch.cuda.is_available()

False