In [3]:
import numpy as np
import json
import gzip
from scipy.sparse import coo_matrix
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class DEHNNLayer(nn.Module):
    def __init__(self, node_in_features, edge_in_features):
        super(DEHNNLayer, self).__init__()
        self.node_mlp1 = nn.Linear(edge_in_features, edge_in_features)
        # self.edge_mlp1 = nn.Linear(node_in_features, edge_in_features)
        self.edge_mlp2 = nn.Linear(node_in_features, node_in_features)
        self.edge_mlp3 = nn.Linear(2 * node_in_features, 2 * node_in_features)  # No compression, keeps 2 * out_features

    def forward(self, node_features, edge_features, hypergraph):
        # Node update
        updated_node_features = {}
        for node in hypergraph.nodes:
            incident_edges = hypergraph.get_incident_edges(node)
            agg_features = torch.sum(torch.stack([self.node_mlp1(edge_features[edge]) for edge in incident_edges]), dim=0)
            updated_node_features[node] = agg_features

        # Edge update
        updated_edge_features = {}
        for edge in hypergraph.edges:
            driver, sinks = hypergraph.get_driver_and_sinks(edge)
            sink_agg = torch.sum(torch.stack([self.edge_mlp2(node_features[sink]) for sink in sinks]), dim=0)
            concatenated = torch.cat([node_features[driver], sink_agg])
            updated_edge_features[edge] = self.edge_mlp3(concatenated)

        return updated_node_features, updated_edge_features


class DEHNN(nn.Module):
    def __init__(self, num_layers, node_in_features, edge_in_features):
        super(DEHNN, self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList()
        
        # Create multiple layers for DEHNN
        for i in range(num_layers):
            print(node_in_features, edge_in_features)
            # in_features = node_in_features if i > 0 else edge_in_features
            self.layers.append(DEHNNLayer(node_in_features, edge_in_features))
            node_in_features, edge_in_features = edge_in_features, node_in_features
            edge_in_features *= 2

        print(node_in_features, edge_in_features)
        edge_in_features  = int(edge_in_features / 2)
        # Final output layer for node classification (binary classification for congestion)
        self.output_layer = nn.Linear(node_in_features, 2)  # Output 2 classes: congested or not congested

    def forward(self, node_features, edge_features, hypergraph):
        # Pass through each layer
        for layer in self.layers:
            node_features, edge_features = layer(node_features, edge_features, hypergraph)
        
        # Output prediction for nodes
        final_node_features = torch.stack([node_features[node] for node in hypergraph.nodes], dim=0)
        # print(node_features[0].shape)
        # print(final_node_features.shape)
        output = self.output_layer(final_node_features)
        output = F.softmax(output, dim=1)
        return output


# Example hypergraph representation class (simplified)
class Hypergraph:
    def __init__(self, nodes, edges, driver_sink_map):
        self.nodes = nodes
        self.edges = edges
        self.driver_sink_map = driver_sink_map

    def get_incident_edges(self, node):
        return [edge for edge in self.edges if node in self.driver_sink_map[edge][1] or node == self.driver_sink_map[edge][0]]

    def get_driver_and_sinks(self, edge):
        return self.driver_sink_map[edge]


# Example usage
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Dummy data for illustration
nodes = [0, 1, 2, 3]
edges = [0, 1]
driver_sink_map = {0: (0, [1, 2]), 1: (2, [3])}  # edge 0: driver is 0, sinks are 1, 2
node_features = {i: torch.randn(10).to(device) for i in nodes}
edge_features = {i: torch.randn(15).to(device) for i in edges}

hypergraph = Hypergraph(nodes, edges, driver_sink_map)

# Initialize DE-HNN model
model = DEHNN(num_layers=3, node_in_features=10, edge_in_features=15).to(device)

# Optimizer and Loss Function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()  # Cross-entropy loss for classification

# Training Loop (example)
model.train()
for epoch in range(100):
    optimizer.zero_grad()
    
    # Forward pass
    output = model(node_features, edge_features, hypergraph)
    print(output)
    
    # Dummy target for illustration (binary labels for each node: 0 for not congested, 1 for congested)
    target = torch.randint(0, 2, (len(nodes),)).to(device)
    
    print(target)
    # Compute loss
    loss = criterion(output, target)
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    # Print loss
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/100], Loss: {loss.item():.4f}')


10 15
15 20
20 30
30 40
tensor([[0.5437, 0.4563],
        [0.5437, 0.4563],
        [0.5850, 0.4150],
        [0.5751, 0.4249]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
tensor([1, 1, 1, 0], device='cuda:0')
tensor([[0.5198, 0.4802],
        [0.5198, 0.4802],
        [0.5462, 0.4538],
        [0.5593, 0.4407]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
tensor([0, 1, 1, 0], device='cuda:0')
tensor([[0.4987, 0.5013],
        [0.4987, 0.5013],
        [0.5105, 0.4895],
        [0.5445, 0.4555]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
tensor([0, 0, 0, 0], device='cuda:0')
tensor([[0.4996, 0.5004],
        [0.4996, 0.5004],
        [0.5192, 0.4808],
        [0.5523, 0.4477]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
tensor([1, 1, 0, 0], device='cuda:0')
tensor([[0.5000, 0.5000],
        [0.5000, 0.5000],
        [0.5301, 0.4699],
        [0.5629, 0.4371]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
tensor([1, 0, 1, 0], device='cuda:0')
tensor([[0.4978, 0.5022],
        [0