In [None]:
import numpy as np
import onnxruntime as rt
import torch
import torch.nn as nn
import wfdb
import scipy as sp

In [None]:
# Load data
ecg = wfdb.rdsamp("../data/ECG/ath_001")
ecg_resampled = sp.signal.resample(ecg[0], 1000, axis=0)
x = np.expand_dims(np.transpose(ecg_resampled), axis=0).astype(np.float32)

# Plaintext models

In [None]:
# 3.1M params
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(12000, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 71)
    
    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = x * x
        x = self.fc2(x)
        x = x * x
        x = self.fc3(x)
        return x

In [None]:
# 4.6M params
class ConvNet(nn.Module):
    def __init__(self, hidden=128, output=71):
        super(ConvNet, self).__init__()
        # Input: (n, 12, 1000)
        self.conv1 = nn.Conv1d(12, 36, kernel_size=1, bias=False)
        
        # Calculate the output length from formula
        # After conv1: (n, 36, 1000)
        self.fc1 = nn.Linear(int(36*1000), hidden)
        self.fc2 = nn.Linear(hidden, output)

    def forward(self, x):
        x = self.conv1(x)  
        x = x * x
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = x * x
        x = self.fc2(x)
        return x

## Export models

In [None]:
torch.manual_seed(0)
dummy_input = torch.randn(1, 12, 1000)

mlp = MLP()
torch.save(mlp, "mlp.pt")
torch.onnx.export(
    mlp,
    dummy_input,
    "mlp.onnx",
    export_params=True,
    input_names=["input"],
    output_names=["output"],
    opset_version=14,
    dynamic_axes={
        "input": {0: "batch_size"},
        "output": {0: "batch_size"},
    },
    keep_initializers_as_inputs=False,
)


convnet = ConvNet()
torch.save(convnet, "convnet.pt")
torch.onnx.export(
    convnet,
    dummy_input,
    "convnet.onnx",
    export_params=True,
    input_names=["input"],
    output_names=["output"],
    opset_version=14,
    dynamic_axes={
        "input": {0: "batch_size"},
        "output": {0: "batch_size"},
    },
    keep_initializers_as_inputs=False,
)

## ONNX plaintext inference

In [None]:
mlp_session = rt.InferenceSession("mlp.onnx")
out_mlp_pt = mlp_session.run(["output"], {"input": x})
out_mlp_pt