In [None]:
import numpy as np
import json
import gzip
from scipy.sparse import coo_matrix
import pandas as pd
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F

from scipy.stats import binned_statistic_2d
from collections import defaultdict

from numpy.linalg import eig
from torch_geometric.utils import (get_laplacian, to_scipy_sparse_matrix, to_undirected, to_dense_adj)
from scipy.sparse.linalg import eigsh

import pymetis

## Processing Raw Data


In [None]:
raw_data_dir = '../data/DigIC_dataset/'
clean_data_dir = '../data/clean_data/'
design = 'xbar'
n_variants = 13

sample_names = []
corresponding_design = []
corresponding_variant = []
for idx in range(n_variants):
    sample_name = raw_data_dir + design + '/' + str(idx + 1) + '/'
    sample_names.append(sample_name)
    corresponding_design.append(design)
    corresponding_variant.append(idx + 1)

In [None]:
cells_fn = raw_data_dir + 'cells.json.gz'
with gzip.open(cells_fn, 'r') as fin:
    cell_data = json.load(fin)

In [None]:
widths = []
heights = []
for idx in range(len(cell_data)):
    width = cell_data[idx]['width']
    height = cell_data[idx]['height']
    widths.append(width)
    heights.append(height)

In [None]:
cells_fn = raw_data_dir + 'cells.json.gz'
with gzip.open(cells_fn, 'r') as fin:
    cell_data = json.load(fin)

widths = []
heights = []
for idx in range(len(cell_data)):
    width = cell_data[idx]['width']
    height = cell_data[idx]['height']
    widths.append(width)
    heights.append(height)

widths = np.array(widths)
heights = np.array(heights)

min_cell_width = np.min(widths)
max_cell_width = np.max(widths)
min_cell_height = np.min(heights)
max_cell_height = np.max(heights)

# Scale all widths and heights of each cell type from 0 to 1
widths = (widths - min_cell_width) / (max_cell_width - min_cell_width)
heights = (heights - min_cell_height) / (max_cell_height - min_cell_height)

# For each cell map the input and output pins
cell_to_edge_dict = {item['id']:{inner_item['id']: inner_item['dir'] for inner_item in item['terms']} for item in cell_data}

In [None]:
for sample in range(n_variants):
    folder = sample_names[sample]
    design = corresponding_design[sample]
    instances_nets_fn = folder + design + '.json.gz'

    print('--------------------------------------------------')
    print('Folder:', folder)
    print('Design:', design)
    print('Instances & nets info:', instances_nets_fn)

    with gzip.open(instances_nets_fn, 'r') as fin:
        instances_nets_data = json.load(fin)

    instances = instances_nets_data['instances']
    nets = instances_nets_data['nets']

    inst_to_cell = {item['id']:item['cell'] for item in instances}

    num_instances = len(instances)
    num_nets = len(nets)

    print('Number of instances:', num_instances)
    print('Number of nets:', num_nets)

    xloc_list = [instances[idx]['xloc'] for idx in range(num_instances)]
    yloc_list = [instances[idx]['yloc'] for idx in range(num_instances)]
    cell = [instances[idx]['cell'] for idx in range(num_instances)]
    cell_width = [widths[cell[idx]] for idx in range(num_instances)]
    cell_height = [heights[cell[idx]] for idx in range(num_instances)]
    orient = [instances[idx]['orient'] for idx in range(num_instances)]
    
    x_min = min(xloc_list)
    x_max = max(xloc_list)
    y_min = min(yloc_list)
    y_max = max(yloc_list)

    X = np.expand_dims(np.array(xloc_list), axis = 1)
    Y = np.expand_dims(np.array(yloc_list), axis = 1)
    X = (X - x_min) / (x_max - x_min)
    Y = (Y - y_min) / (y_max - y_min)

    cell = np.expand_dims(np.array(cell), axis = 1)
    cell_width = np.expand_dims(np.array(cell_width), axis = 1)
    cell_height = np.expand_dims(np.array(cell_height), axis = 1)
    orient = np.expand_dims(np.array(orient), axis = 1)

    instance_features = np.concatenate((X, Y, cell, cell_width, cell_height, orient), axis = 1)
    # print(instance_features)

    connection_fn = folder + design + '_connectivity.npz'
    connection_data = np.load(connection_fn)
    print('Connection info:', connection_fn)
    
    # get the direction of each edge between inst and net
    dirs = []
    edge_t = connection_data['data']
    instance_idx = connection_data['row']
    
    for idx in range(len(instance_idx)):
        inst = instance_idx[idx]
        cell = inst_to_cell[inst]
        edge_dict = cell_to_edge_dict[cell]
        t = edge_t[idx]
        direction = edge_dict[t]
        dirs.append(direction)

    dirs = np.array(dirs)

    driver_sink_map = defaultdict(lambda: (None, []))

    # Extract unique nodes and edges
    nodes = list(set(connection_data['row']))
    edges = list(set(connection_data['col']))

    # Populate driver_sink_map
    for node, edge, direction in zip(connection_data['row'], connection_data['col'], dirs):
        if direction == 1:  # Driver
            driver_sink_map[edge] = (node, driver_sink_map[edge][1])
        elif direction == 0:  # Sink
            driver_sink_map[edge][1].append(node)

    # Convert to standard dictionary
    driver_sink_map = dict(driver_sink_map)
    print(driver_sink_map)

    net_features = {}
    for k, v in driver_sink_map.items():
        if v[0]:
            net_features[k] = [len(v[1]) + 1]
        else:
            net_features[k] = [len(v[1])]

    instance_idx = connection_data['row']
    net_idx = connection_data['col']
    net_idx += num_instances

    v1 = torch.unsqueeze(torch.Tensor(np.concatenate([instance_idx, net_idx], axis = 0)).long(), dim = 1)
    v2 = torch.unsqueeze(torch.Tensor(np.concatenate([net_idx, instance_idx], axis = 0)).long(), dim = 1)
    undir_edge_index = torch.transpose(torch.cat([v1, v2], dim = 1), 0, 1)

    L = to_scipy_sparse_matrix(
        *get_laplacian(undir_edge_index, normalization = "sym", num_nodes = num_instances + num_nets)
    )
    evals, evects = eigsh(L, k = 10, which='SM')
    print(evects.shape)

    node_features = {}
    for i in range(num_instances):
        node_features[i] = np.concatenate([instance_features[i, 2:], evects[i]])

    print(net_features)
    print(node_features)

    congestion_fn = folder + design + '_congestion.npz'
    congestion_data = np.load(congestion_fn)
    print('Congestion info:', congestion_fn)

    congestion_data_demand = congestion_data['demand']
    congestion_data_capacity = congestion_data['capacity']

    num_layers = len(list(congestion_data['layerList']))
    print('Number of layers:', num_layers)
    print('Layers:', list(congestion_data['layerList']))

    ybl = congestion_data['yBoundaryList']
    xbl = congestion_data['xBoundaryList']

    all_demand = []
    all_capacity = []

    for layer in list(congestion_data['layerList']):
        # print('Layer', layer, ':')
        lyr = list(congestion_data['layerList']).index(layer)

        # Binned statistics 2D
        ret = binned_statistic_2d(xloc_list, yloc_list, None, 'count', bins = [xbl[1:], ybl[1:]], expand_binnumbers = True)

        i_list = np.array([ret.binnumber[0, idx] - 1 for idx in range(num_instances)])
        j_list = np.array([ret.binnumber[1, idx] - 1 for idx in range(num_instances)])

        # Get demand and capacity
        demand_list = congestion_data_demand[lyr, i_list, j_list].flatten()
        capacity_list = congestion_data_capacity[lyr, i_list, j_list].flatten()

        demand_list = np.array(demand_list)
        capacity_list = np.array(capacity_list)

        all_demand.append(np.expand_dims(demand_list, axis = 1))
        all_capacity.append(np.expand_dims(capacity_list, axis = 1))

        average_demand = np.mean(demand_list)
        average_capacity = np.mean(capacity_list)
        average_diff = np.mean(capacity_list - demand_list)
        count_congestions = np.sum(demand_list > capacity_list)

    demand = np.concatenate(all_demand, axis = 1).sum(axis=1)
    capacity = np.concatenate(all_capacity, axis = 1).sum(axis=1)

    congestion_actual = {}
    for i in range(len(node_features)):
        congestion_actual[i] = int(((capacity[i] * 0.9) - demand[i]) < 0)

    with open(f'{clean_data_dir}{sample+1}.driver_sink_map.pkl', 'wb') as f:
        pickle.dump(driver_sink_map, f)
    
    with open(f'{clean_data_dir}{sample+1}.node_features.pkl', 'wb') as f:
        pickle.dump(node_features, f)

    with open(f'{clean_data_dir}{sample+1}.net_features.pkl', 'wb') as f:
        pickle.dump(net_features, f)

    with open(f'{clean_data_dir}{sample+1}.congestion.pkl', 'wb') as f:
        pickle.dump(congestion_actual, f)

## Creating Partition Data 

In [None]:
data_dir = raw_data_dir + 'xbar/'
n_variants = 13
num_partitions = 2

for i in range(1, n_variants+1):
    connection_data = np.load(f'{data_dir}{i}/xbar_connectivity.npz')

    num_nodes = max(connection_data['row']) + 1
    num_nets = max(connection_data['col']) + 1

    # Convert hypergraph to bipartite graph representation
    adj_list = [[] for _ in range(num_nodes + num_nets)]
    for node, net in zip(connection_data['row'], connection_data['col']):
        adj_list[node].append(num_nodes + net)
        adj_list[num_nodes + net].append(node)

    cuts, membership = pymetis.part_graph(num_partitions, adjacency=adj_list)
    arr = np.array(membership[:num_nodes])

    np.save(f'{clean_data_dir}{i}.partition.npy', arr)

## DEHNN Model Architecture

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DEHNNLayer(nn.Module):
    def __init__(self, node_in_features, edge_in_features):
        super(DEHNNLayer, self).__init__()
        self.node_mlp1 = nn.Sequential(
            nn.Linear(edge_in_features, edge_in_features),
            nn.ReLU()
        )
        self.edge_mlp2 = nn.Sequential(
            nn.Linear(node_in_features, node_in_features),
            nn.ReLU()
        )
        self.edge_mlp3 = nn.Sequential(
            nn.Linear(2 * node_in_features, 2 * node_in_features),
            nn.ReLU()
        )

        self.node_to_virtual_mlp = nn.Sequential(
            nn.Linear(node_in_features, node_in_features),
            nn.ReLU()
        )
        self.virtual_to_higher_virtual_mlp = nn.Sequential(
            nn.Linear(node_in_features, edge_in_features),
            nn.ReLU()
        )
        self.higher_virtual_to_virtual_mlp = nn.Sequential(
            nn.Linear(edge_in_features, edge_in_features),
            nn.ReLU()
        )
        self.virtual_to_node_mlp = nn.Sequential(
            nn.Linear(edge_in_features, edge_in_features),
            nn.ReLU()
        )

        # Learnable defaults for missing driver or sink
        self.default_driver = nn.Parameter(torch.zeros(node_in_features))
        self.default_sink_agg = nn.Parameter(torch.zeros(node_in_features))
        self.default_edge_agg = nn.Parameter(torch.zeros(edge_in_features))
        self.default_virtual_node = nn.Parameter(torch.zeros(node_in_features))

        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize all parameters with Xavier uniform distribution."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, node_features, edge_features, hypergraph):
        # Node update
        updated_node_features = {}
        for node in hypergraph.nodes:
            incident_edges = hypergraph.get_incident_edges(node)
            if incident_edges:
                agg_features = torch.sum(torch.stack([self.node_mlp1(edge_features[edge]) for edge in incident_edges]), dim=0)
            else:
                agg_features = self.default_edge_agg
            updated_node_features[node] = agg_features

        # Edge update
        updated_edge_features = {}
        for edge in hypergraph.edges:
            driver, sinks = hypergraph.get_driver_and_sinks(edge)

            driver_feature = node_features[driver] if driver is not None else self.default_driver

            if sinks:
                sink_agg = torch.sum(torch.stack([self.edge_mlp2(node_features[sink]) for sink in sinks]), dim=0)
            else:
                sink_agg = self.default_sink_agg

            concatenated = torch.cat([driver_feature, sink_agg])
            updated_edge_features[edge] = self.edge_mlp3(concatenated)

        # Virtual node aggregation
        virtual_node_agg = {}
        for virtual_node in range(hypergraph.num_virtual_nodes):
            assigned_nodes = [node for node in hypergraph.nodes if hypergraph.get_virtual_node(node) == virtual_node]
            if assigned_nodes:
                agg_features = torch.sum(torch.stack([self.node_to_virtual_mlp(node_features[node]) for node in assigned_nodes]), dim=0)
            else:
                agg_features = self.default_virtual_node
            virtual_node_agg[virtual_node] = agg_features

        higher_virtual_feature = torch.sum(
            torch.stack([self.virtual_to_higher_virtual_mlp(virtual_node_agg[vn]) for vn in virtual_node_agg]), dim=0
        )

        propagated_virtual_node_features = {}
        for virtual_node in range(hypergraph.num_virtual_nodes):
            propagated_virtual_node_features[virtual_node] = self.higher_virtual_to_virtual_mlp(higher_virtual_feature)

        for node in hypergraph.nodes:
            virtual_node = hypergraph.get_virtual_node(node)
            propagated_feature = self.virtual_to_node_mlp(propagated_virtual_node_features[virtual_node])
            updated_node_features[node] += propagated_feature

        return updated_node_features, updated_edge_features


class DEHNN(nn.Module):
    def __init__(self, num_layers, node_in_features, edge_in_features):
        super(DEHNN, self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList()
        
        for i in range(num_layers):
            self.layers.append(DEHNNLayer(node_in_features, edge_in_features))
            node_in_features, edge_in_features = edge_in_features, node_in_features
            edge_in_features *= 2

        edge_in_features = edge_in_features // 2
        self.output_layer = nn.Sequential(
            nn.Linear(node_in_features, 2)
        )

    def forward(self, node_features, edge_features, hypergraph):
        for layer in self.layers:
            node_features, edge_features = layer(node_features, edge_features, hypergraph)
        
        final_node_features = torch.stack([node_features[node] for node in hypergraph.nodes], dim=0)
        output = self.output_layer(final_node_features)
        return output
    
class Hypergraph:
    def __init__(self, nodes, edges, driver_sink_map, node_to_virtual_map, num_virtual_nodes):
        self.nodes = nodes
        self.edges = edges
        self.driver_sink_map = driver_sink_map
        self.node_to_virtual_map = node_to_virtual_map
        self.num_virtual_nodes = num_virtual_nodes

    def get_incident_edges(self, node):
        return [edge for edge in self.edges if node in self.driver_sink_map[edge][1] or node == self.driver_sink_map[edge][0]]

    def get_driver_and_sinks(self, edge):
        return self.driver_sink_map[edge]
    
    def get_virtual_node(self, node):
        return self.node_to_virtual_map[node]

## Creating Design 1 Train Data and Hypergraph

In [None]:
# loading training data and constructing hypergraph
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

with open(clean_data_dir + '1.driver_sink_map.pkl', 'rb') as f:
    train_driver_sink_map = pickle.load(f)

with open(clean_data_dir + '1.node_features.pkl', 'rb') as f:
    train_node_features = pickle.load(f)

with open(clean_data_dir + '1.net_features.pkl', 'rb') as f:
    train_edge_features = pickle.load(f)

with open(clean_data_dir + '1.congestion.pkl', 'rb') as f:
    train_congestion = pickle.load(f)

train_partition = np.load(clean_data_dir + '1.partition.npy')

train_node_features = {k: torch.tensor(v).float().to(device) for k, v in train_node_features.items()}
train_edge_features = {k: torch.tensor(v).float().to(device) for k, v in train_edge_features.items()}

train_nodes = list(range(len(train_node_features)))
train_edges = list(range(len(train_edge_features)))
train_hypergraph = Hypergraph(train_nodes, train_edges, train_driver_sink_map, train_partition, 2)

## Creating Design 2 Validation Data and Hypergraph

In [None]:
# loading validation data and constructing hypergraph
with open(clean_data_dir + '2.driver_sink_map.pkl', 'rb') as f:
    val_driver_sink_map = pickle.load(f)

with open(clean_data_dir + '2.node_features.pkl', 'rb') as f:
    val_node_features = pickle.load(f)

with open(clean_data_dir + '2.net_features.pkl', 'rb') as f:
    val_edge_features = pickle.load(f)

with open(clean_data_dir + '2.congestion.pkl', 'rb') as f:
    val_congestion = pickle.load(f)

val_partition = np.load(clean_data_dir + '2.partition.npy')

val_node_features = {k: torch.tensor(v).float().to(device) for k, v in val_node_features.items()}
val_edge_features = {k: torch.tensor(v).float().to(device) for k, v in val_edge_features.items()}

val_nodes = list(range(len(val_node_features)))
val_edges = list(range(len(val_edge_features)))
val_hypergraph = Hypergraph(val_nodes, val_edges, val_driver_sink_map, val_partition, 2)
val_targets = val_congestion

## Train on Design 1, Validate on Design 2

In [None]:
model = DEHNN(num_layers=4, node_in_features=14, edge_in_features=1).to(device)
epochs = 20

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss() 

train_node_features = {k: v.to(device) for k, v in train_node_features.items()}
train_edge_features = {k: v.to(device) for k, v in train_edge_features.items()}
train_targets = torch.tensor(list(train_congestion.values())).long().to(device)

val_node_features = {k: v.to(device) for k, v in val_node_features.items()}
val_edge_features = {k: v.to(device) for k, v in val_edge_features.items()}
val_targets = torch.tensor(list(val_congestion.values())).long().to(device)

for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()

    train_output = model(train_node_features, train_edge_features, train_hypergraph)
    
    train_loss = criterion(train_output, train_targets)
    
    train_loss.backward()
    optimizer.step()
    
    model.eval()
    with torch.no_grad():
        val_output = model(val_node_features, val_edge_features, val_hypergraph)
        
        val_loss = criterion(val_output, val_targets)
        
        val_predictions = torch.argmax(val_output, dim=1)
        val_correct = (val_predictions == val_targets).sum().item()
        val_total = len(val_targets)
        val_accuracy = val_correct / val_total
    
    print(f"Epoch [{epoch+1}/{epochs}]")
    print(f"Train Loss: {train_loss.item():.4f}")
    print(f"Validation Loss: {val_loss.item():.4f}, Validation Accuracy: {val_accuracy:.4f}")


## Train across multiple designs

In [None]:
file_indices = range(1, 9)
epochs = 10
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

model.train()
for epoch in range(epochs):
    epoch_loss = 0
    
    for i in file_indices:
        print(f'Chip {i}:')
        
        with open(f'{clean_data_dir}{i}.driver_sink_map.pkl', 'rb') as f:
            driver_sink_map = pickle.load(f)
        
        with open(f'{clean_data_dir}{i}.node_features.pkl', 'rb') as f:
            node_features = pickle.load(f)
        
        with open(f'{clean_data_dir}{i}.net_features.pkl', 'rb') as f:
            edge_features = pickle.load(f)
        
        with open(f'{clean_data_dir}{i}.congestion.pkl', 'rb') as f:
            congestion = pickle.load(f)
        
        partition = np.load(f'{clean_data_dir}{i}.partition.npy')
        
        node_features = {k: torch.tensor(v).float().to(device) for k, v in node_features.items()}
        edge_features = {k: torch.tensor(v).float().to(device) for k, v in edge_features.items()}
        
        nodes = list(range(len(node_features)))
        edges = list(range(len(edge_features)))
        hypergraph = Hypergraph(nodes, edges, driver_sink_map, partition, 2)
        
        output = model(node_features, edge_features, hypergraph)
        
        target = torch.tensor(list(congestion.values())).to(device)
        
        loss = criterion(output, target)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        epoch_loss += loss.item()
        print(f'Epoch [{epoch+1}/10], Loss: {epoch_loss:.4f}')
    
    print(f'Epoch [{epoch+1}/10], Loss: {epoch_loss:.4f}')