In [1]:
import torch
import torch.nn as nn
import torch.onnx


def convert_onnx(dummy_input, model: nn.Module, save_path: str = "./model.onnx"):
    model.eval()

    torch.onnx.export(
        model,  # model being run
        dummy_input,  # model input (or a tuple for multiple inputs)
        save_path,  # where to save the model
        export_params=True,  # store the trained parameter weights inside the model file
        opset_version=15,  # the ONNX version to export the model to
        do_constant_folding=True,  # whether to execute constant folding for optimization
        input_names=["X", "H"],  # the model's input names
        output_names=["Score"],  # the model's output names
        dynamic_axes={
            "X": {0: "v_num", 1: "v_channel"},  # variable length axes
            "H": {0: "v_num", 1: "e_num"},
            "Score": {0: "v_num"},
        },
    )

In [None]:
from ..models import HGNNPSchedulabilityPredictor

net = HGNNPSchedulabilityPredictor(10, 20, 10, use_bn=True)
print(net)

SchedulabilityPredictor(
  (hgconv1): HGConv(
    (theta): Linear(in_features=10, out_features=20, bias=True)
    (v2e_msg_pass): V2EMsgPass(
      (dropout_layer): Dropout(p=0.5, inplace=False)
    )
    (e2v_msg_pass): E2VMsgPass(
      (dropout_layer): Dropout(p=0.5, inplace=False)
    )
    (bn): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): ReLU(inplace=True)
    (drop): Dropout(p=0.5, inplace=False)
  )
  (hgconv2): HGConv(
    (theta): Linear(in_features=20, out_features=10, bias=True)
    (v2e_msg_pass): V2EMsgPass(
      (dropout_layer): Dropout(p=0.5, inplace=False)
    )
    (e2v_msg_pass): E2VMsgPass(
      (dropout_layer): Dropout(p=0.5, inplace=False)
    )
    (bn): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): ReLU(inplace=True)
    (drop): Dropout(p=0.5, inplace=False)
  )
  (v2e_msg_pass): V2EMsgPass(
    (dropout_layer): Dropout(p=0.0, inplace=False)
  )
  (predictor): Sequentia

In [3]:
X = torch.randn(5, 10, requires_grad=True)
H = torch.randn(5, 4, requires_grad=True)

convert_onnx((X, H), net, "./net2.pth")

  diag_matrix = x.unsqueeze(1) * torch.eye(len(x)).to(x.device)


In [4]:
# from torch.quantization import quantize_dynamic

# quantized_net = quantize_dynamic(model=net)
# print(quantized_net)
# torch.save(quantized_net, "./quantized_net.pth")