In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
import onnx
import onnxruntime as ort
import numpy as np
from model import CNNBiLSTMClassifier

ModuleNotFoundError: No module named 'model'

# Load Model

In [3]:
class CNNBiLSTMClassifier(nn.Module):
    def __init__(self, num_classes=11):
        super(CNNBiLSTMClassifier, self).__init__()

        self.cnn = nn.Sequential(
            nn.Conv1d(4, 32, kernel_size=5, padding=2),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(2)
        )

        self.lstm = nn.LSTM(input_size=64, hidden_size=32, batch_first=True, bidirectional=True)
        
        self.bn = nn.BatchNorm1d(32 * 2)  # Match hidden size
        
        self.classifier = nn.Sequential(
            nn.Linear(32 * 2, 32),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(32, num_classes)
        )

    def forward(self, x):
        x = self.cnn(x)  # → (B, 64, T)
        x = x.permute(0, 2, 1)  # (B, T, 64)

        _, (hn, _) = self.lstm(x)  # hn: (2, B, 32) if bidirectional=True
        hn = torch.cat((hn[0], hn[1]), dim=1)  # (B, 64)

        hn = self.bn(hn)  # BatchNorm on full concatenated hidden state
        return self.classifier(hn)

In [12]:
model_path = "models/CNNBiLSTM_0.0.1/model.pth"
torch_model = CNNBiLSTMClassifier()
torch_model.load_state_dict(torch.load(model_path, map_location="cpu"))
torch_model.eval()

CNNBiLSTMClassifier(
  (cnn): Sequential(
    (0): Conv1d(4, 32, kernel_size=(5,), stride=(1,), padding=(2,))
    (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv1d(32, 64, kernel_size=(5,), stride=(1,), padding=(2,))
    (5): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (lstm): LSTM(64, 32, batch_first=True, bidirectional=True)
  (bn): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (classifier): Sequential(
    (0): Linear(in_features=64, out_features=32, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=32, out_features=11, bias=True)
  )
)

In [9]:
dummy_input = torch.randn(1, 4, 500)
torch.onnx.export(
    model,
    dummy_input,
    "cnn_bilstm.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {2: "sequence"}}  # Make time axis dynamic
)

In [13]:
# ----- Prepare input -----
input_tensor = torch.randn(1, 4, 500)
input_numpy = input_tensor.numpy()

# ----- Run PyTorch inference -----
with torch.no_grad():
    torch_output = torch_model(input_tensor).numpy()

# ----- Load ONNX model -----
onnx_model = onnx.load("cnn_bilstm.onnx")
onnx.checker.check_model(onnx_model)  # Optional: validate ONNX model

# ----- Run ONNX Runtime inference -----
ort_session = ort.InferenceSession("cnn_bilstm.onnx")
input_name = ort_session.get_inputs()[0].name
onnx_output = ort_session.run(None, {input_name: input_numpy})[0]

In [14]:
# ----- Compare outputs -----
print("Torch output:", torch_output)
print("ONNX output:", onnx_output)

# ----- Check closeness -----
if np.allclose(torch_output, onnx_output, atol=1e-5):
    print("✅ Outputs match! ONNX model is accurate.")
else:
    print("❌ Outputs differ! Something might be wrong.")

Torch output: [[-15.183826    8.439772    6.972706   -5.57649   -13.262698  -11.6480665
  -11.400796  -26.552748  -20.910198  -16.212461   -8.480648 ]]
ONNX output: [[-15.183824    8.439774    6.972702   -5.5764937 -13.262694  -11.648068
  -11.400796  -26.55274   -20.910196  -16.212458   -8.480652 ]]
✅ Outputs match! ONNX model is accurate.


In [16]:
for i in onnx_model.graph.initializer:
    print(f"{i.name}: {onnx.TensorProto.DataType.Name(i.data_type)}")

bn.weight: FLOAT
bn.bias: FLOAT
bn.running_mean: FLOAT
bn.running_var: FLOAT
classifier.0.weight: FLOAT
classifier.0.bias: FLOAT
classifier.3.weight: FLOAT
classifier.3.bias: FLOAT
onnx::Conv_190: FLOAT
onnx::Conv_191: FLOAT
onnx::Conv_193: FLOAT
onnx::Conv_194: FLOAT
onnx::LSTM_237: FLOAT
onnx::LSTM_238: FLOAT
onnx::LSTM_239: FLOAT
