In [1]:
import numpy as np
import json
import gzip
from scipy.sparse import coo_matrix
import pandas as pd
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class DEHNNLayer(nn.Module):
    def __init__(self, node_in_features, edge_in_features):
        super(DEHNNLayer, self).__init__()
        self.node_mlp1 = nn.Linear(edge_in_features, edge_in_features)
        self.edge_mlp2 = nn.Linear(node_in_features, node_in_features)
        self.edge_mlp3 = nn.Linear(2 * node_in_features, 2 * node_in_features)

        self.node_to_virtual_mlp = nn.Linear(node_in_features, node_in_features)
        self.virtual_to_higher_virtual_mlp = nn.Linear(node_in_features, edge_in_features)
        self.higher_virtual_to_virtual_mlp = nn.Linear(edge_in_features, edge_in_features)
        self.virtual_to_node_mlp = nn.Linear(edge_in_features, edge_in_features)

        # Learnable defaults for missing driver or sink
        self.default_driver = nn.Parameter(torch.zeros(node_in_features))
        self.default_sink_agg = nn.Parameter(torch.zeros(node_in_features))
        self.default_edge_agg = nn.Parameter(torch.zeros(edge_in_features))
        self.default_virtual_node = nn.Parameter(torch.zeros(node_in_features))
        self.higher_virtual_node = nn.Parameter(torch.zeros(node_in_features))

    def forward(self, node_features, edge_features, hypergraph):
        # Node update
        updated_node_features = {}
        for node in hypergraph.nodes:
            incident_edges = hypergraph.get_incident_edges(node)
            if incident_edges:
                agg_features = torch.sum(torch.stack([self.node_mlp1(edge_features[edge]) for edge in incident_edges]), dim=0)
            else:
                agg_features = self.default_edge_agg  # Fallback for isolated nodes
            updated_node_features[node] = agg_features

        # Edge update
        updated_edge_features = {}
        for edge in hypergraph.edges:
            driver, sinks = hypergraph.get_driver_and_sinks(edge)

            # Handle missing driver
            driver_feature = node_features[driver] if driver is not None else self.default_driver

            # Handle missing sinks
            if sinks:
                sink_agg = torch.sum(torch.stack([self.edge_mlp2(node_features[sink]) for sink in sinks]), dim=0)
            else:
                sink_agg = self.default_sink_agg

            # Concatenate and update
            concatenated = torch.cat([driver_feature, sink_agg])
            updated_edge_features[edge] = self.edge_mlp3(concatenated)
        
        virtual_node_agg = {}
        for virtual_node in range(hypergraph.num_virtual_nodes):
            assigned_nodes = [node for node in hypergraph.nodes if hypergraph.get_virtual_node(node) == virtual_node]
            if assigned_nodes:
                agg_features = torch.sum(torch.stack([self.node_to_virtual_mlp(node_features[node]) for node in assigned_nodes]), dim=0)
            else:
                agg_features = self.default_virtual_node
            virtual_node_agg[virtual_node] = agg_features

        higher_virtual_feature = torch.sum(
            torch.stack([self.virtual_to_higher_virtual_mlp(virtual_node_agg[vn]) for vn in virtual_node_agg]), dim=0
        )

        propagated_virtual_node_features = {}
        for virtual_node in range(hypergraph.num_virtual_nodes):
            propagated_virtual_node_features[virtual_node] = self.higher_virtual_to_virtual_mlp(higher_virtual_feature)

        for node in hypergraph.nodes:
            virtual_node = hypergraph.get_virtual_node(node)
            propagated_feature = self.virtual_to_node_mlp(propagated_virtual_node_features[virtual_node])
            updated_node_features[node] += propagated_feature  # Add propagated feature to node

        return updated_node_features, updated_edge_features


class DEHNN(nn.Module):
    def __init__(self, num_layers, node_in_features, edge_in_features):
        super(DEHNN, self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList()
        
        # Create multiple layers for DEHNN
        for i in range(num_layers):
            self.layers.append(DEHNNLayer(node_in_features, edge_in_features))
            node_in_features, edge_in_features = edge_in_features, node_in_features
            edge_in_features *= 2

        edge_in_features  = int(edge_in_features / 2)
        # Final output layer for node classification (binary classification for congestion)
        self.output_layer = nn.Linear(node_in_features, 2)  # Output 2 classes: congested or not congested

    def forward(self, node_features, edge_features, hypergraph):
        # Pass through each layer
        for layer in self.layers:
            node_features, edge_features = layer(node_features, edge_features, hypergraph)
        
        # Output prediction for nodes
        final_node_features = torch.stack([node_features[node] for node in hypergraph.nodes], dim=0)
        output = self.output_layer(final_node_features)
        output = F.softmax(output, dim=1)
        return output


# Example hypergraph representation class (simplified)
class Hypergraph:
    def __init__(self, nodes, edges, driver_sink_map, node_to_virtual_map, num_virtual_nodes):
        self.nodes = nodes
        self.edges = edges
        self.driver_sink_map = driver_sink_map
        self.node_to_virtual_map = node_to_virtual_map
        self.num_virtual_nodes = num_virtual_nodes

    def get_incident_edges(self, node):
        return [edge for edge in self.edges if node in self.driver_sink_map[edge][1] or node == self.driver_sink_map[edge][0]]

    def get_driver_and_sinks(self, edge):
        return self.driver_sink_map[edge]
    
    def get_virtual_node(self, node):
        return self.node_to_virtual_map[node]

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

clean_data_dir = '../../data/chips/clean_data/'

with open(clean_data_dir + '1.driver_sink_map.pkl', 'rb') as f:
    driver_sink_map = pickle.load(f)

with open(clean_data_dir + '1.node_features.pkl', 'rb') as f:
    node_features = pickle.load(f)

with open(clean_data_dir + '1.net_features.pkl', 'rb') as f:
    edge_features = pickle.load(f)

with open(clean_data_dir + '1.congestion.pkl', 'rb') as f:
    congestion = pickle.load(f)

partition = np.load(clean_data_dir + '1.partition.npy')

node_features = {k: torch.tensor(v).float().to(device) for k, v in node_features.items()}
edge_features = {k: torch.tensor(v).float().to(device) for k, v in edge_features.items()}

nodes = list(range(len(node_features)))
edges = list(range(len(edge_features)))
hypergraph = Hypergraph(nodes, edges, driver_sink_map, partition, 2)

In [4]:
# Initialize DE-HNN model
model = DEHNN(num_layers=2, node_in_features=14, edge_in_features=1).to(device)
epochs = 10

# Optimizer and Loss Function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()  # Cross-entropy loss for classification

# Training Loop (example)
model.train()
for epoch in range(epochs):
    optimizer.zero_grad()
    
    # Forward pass
    output = model(node_features, edge_features, hypergraph)
    # print(output)
    
    # Dummy target for illustration (binary labels for each node: 0 for not congested, 1 for congested)
    target = torch.tensor(list(congestion.values())).to(device)
    
    # print(target)
    # Compute loss
    loss = criterion(output, target)
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    # Print loss
    # if (epoch + 1) % 10 == 0:
    #     print(f'Epoch [{epoch+1}/100], Loss: {loss.item():.4f}')
    print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

Epoch [1/10], Loss: 0.4112
Epoch [2/10], Loss: 0.4112
Epoch [3/10], Loss: 0.4112
Epoch [4/10], Loss: 0.4112
Epoch [5/10], Loss: 0.4112
Epoch [6/10], Loss: 0.4112
Epoch [7/10], Loss: 0.4112
Epoch [8/10], Loss: 0.4112
Epoch [9/10], Loss: 0.4112
Epoch [10/10], Loss: 0.4112


In [5]:
model.layers

ModuleList(
  (0): DEHNNLayer(
    (node_mlp1): Linear(in_features=1, out_features=1, bias=True)
    (edge_mlp2): Linear(in_features=14, out_features=14, bias=True)
    (edge_mlp3): Linear(in_features=28, out_features=28, bias=True)
    (node_to_virtual_mlp): Linear(in_features=14, out_features=14, bias=True)
    (virtual_to_higher_virtual_mlp): Linear(in_features=14, out_features=1, bias=True)
    (higher_virtual_to_virtual_mlp): Linear(in_features=1, out_features=1, bias=True)
    (virtual_to_node_mlp): Linear(in_features=1, out_features=1, bias=True)
  )
  (1): DEHNNLayer(
    (node_mlp1): Linear(in_features=28, out_features=28, bias=True)
    (edge_mlp2): Linear(in_features=1, out_features=1, bias=True)
    (edge_mlp3): Linear(in_features=2, out_features=2, bias=True)
    (node_to_virtual_mlp): Linear(in_features=1, out_features=1, bias=True)
    (virtual_to_higher_virtual_mlp): Linear(in_features=1, out_features=28, bias=True)
    (higher_virtual_to_virtual_mlp): Linear(in_feature

In [35]:
test_output = model(node_features, edge_features, hypergraph)

In [36]:
out = test_output.detach().cpu().numpy()
out = np.array([np.argmax(i) for i in out])
out

array([0, 0, 0, ..., 0, 0, 0], dtype=int64)

In [37]:
np.mean(np.array(list(congestion.values())) == out)

0.9020748987854251

In [None]:
np.array(list(congestion.values())).mean()

0.0979251012145749

In [None]:
out