In [12]:
import torch
import torch.nn as nn
import onnx

# Define the model architecture
class NeuralNet(nn.Module):
    def __init__(self, input_size, output_size):
        super(NeuralNet, self).__init__()
        self.input_size = input_size
        self.fc1 = nn.Linear(input_size, input_size)
        self.fc2 = nn.Linear(input_size, input_size // 16)
        self.outputLayer = nn.Linear(input_size // 16, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, self.input_size)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.outputLayer(x)
        return x

# Initialize the model
model = NeuralNet(2048, 20)

# Load model parameters
model.load_state_dict(torch.load("model_params.pt"))

# Example input tensor
x = torch.randn(1, 2048, requires_grad=False)

# Export the model to ONNX
torch.onnx.export(model, x, "torchToOnnx.onnx", verbose=True, input_names=['input'], output_names=['output'])

# Print the ONNX model's graph
print('Model Graph:\n\n{}'.format(onnx.helper.printable_graph(onnx_model.graph)))


Exported graph: graph(%input : Float(1, 2048, strides=[2048, 1], requires_grad=0, device=cpu),
      %fc1.weight : Float(2048, 2048, strides=[2048, 1], requires_grad=1, device=cpu),
      %fc1.bias : Float(2048, strides=[1], requires_grad=1, device=cpu),
      %fc2.weight : Float(128, 2048, strides=[2048, 1], requires_grad=1, device=cpu),
      %fc2.bias : Float(128, strides=[1], requires_grad=1, device=cpu),
      %outputLayer.weight : Float(20, 128, strides=[128, 1], requires_grad=1, device=cpu),
      %outputLayer.bias : Float(20, strides=[1], requires_grad=1, device=cpu)):
  %/Constant_output_0 : Long(2, strides=[1], device=cpu) = onnx::Constant[value=   -1  2048 [ CPULongType{2} ], onnx_name="/Constant"](), scope: __main__.NeuralNet:: # /tmp/ipykernel_3560254/343124900.py:16:0
  %/Reshape_output_0 : Float(1, 2048, strides=[2048, 1], requires_grad=0, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/Reshape"](%input, %/Constant_output_0), scope: __main__.NeuralNet:: # /tmp/ipyke

  model.load_state_dict(torch.load("model_params.pt"))
