# Torch to ONNX Conversion

This notebook converts Torch to ONNX models.

In [1]:
import torch
import torch.nn as nn
import onnx
import dill
from onnx2torch import convert
import os
import matplotlib.pyplot as plt  # Add this line

In [2]:
import sys
sys.path.append('/Users/billb/github/nnUNet-Adjustment')


In [3]:
import nnunetv2.training.nnUNetTrainer

In [None]:
# Check if running on Mac OS
is_mac = os.name == 'posix' and os.uname().sysname == 'Darwin'
print('posix' if os.name == 'posix' else 'not posix')
print('mac' if is_mac else 'not mac')


In [5]:
# Set up paths
is_mac = os.name == 'posix' and os.uname().sysname == 'Darwin'
rootPath = "~/Projects/AWI/NetExploration/" if is_mac else '/mnt/SliskiDrive/AWI/AWIBuffer/' # '/Volumes/Crucial X8/AWIBuffer/'

In [None]:
rootPath

In [None]:
# Set up device
gpuDevice = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(f"Using device: {gpuDevice}") 

## Test Model with Random Input

Import a dill PyTorch model

In [11]:
# torchModelPath = rootPath + "PlainConvUNet-nnUNetPlans_2d-DC_and_CE_loss-w-1-20-20-dill.pth"
torchModelPath =  "/Users/billb/Projects/AWI/NetExploration/UXlstmBot-nnUNetPlans_2d-reduced3-DC_and_CE_loss-w-1-20-40-dill.pth"

In [None]:
torchModelPath

In [13]:
# Check if exists file at path torchModelPath
if not os.path.exists(torchModelPath):
    raise FileNotFoundError(f"Model file not found at path: {torchModelPath}")

In [None]:
# Load the model
model = torch.load(torchModelPath,map_location=gpuDevice)

In [None]:

model.eval()

# Test the model with a random input
random_tensor = torch.randn(1, 5, 512, 512, device=gpuDevice, dtype=torch.float32)
print("Input tensor shape:", random_tensor.shape)

In [None]:
with torch.inference_mode():
    output = model(random_tensor)

print("Output tensor shape:", output.shape)


## Test Model with HDF5 Input

In [None]:
import platform

system = platform.system()
if "Darwin" in system:
    if os.path.isdir("/Volumes/Crucial X8"):
        dataDir = "/Volumes/Crucial X8/AWIBuffer"
    else:
        dataDir = "/Users/billb/Projects/AWI/NetExploration"
elif "Linux" in system:
    dataDir = "/mnt/SliskiDrive/AWI/AWIBuffer"
else:
    dataDir = None  # or some default path

angiogramH5Path = dataDir + "/AngiogramsDistilledUInt8List.h5"
angiogramH5Path

In [None]:
import h5py

# Open the HDF5 file and print all dataset keys
with h5py.File(angiogramH5Path, 'r') as f:
    # Get all keys at root level
    keys = list(f.keys())
    print("Dataset keys in HDF5 file:")
    for key in keys:
        print(f"- {key}")


In [None]:
# Load first angiogram from HDF5 file
import random
with h5py.File(angiogramH5Path, 'r') as f:
    # Get first key
    hdfKey = random.choice(keys)
    print(f"Loading dataset: {hdfKey}")
    # Load data into tensor
    agram = torch.from_numpy(f[hdfKey][:]).float()
    print(f"Loaded tensor shape: {agram.shape}")


In [None]:
#Display the 30th frame of the angiogram
plt.imshow(agram[30], cmap='gray')
plt.colorbar()
plt.show()


In [None]:
# Normalize angiogram by subtracting mean and dividing by standard deviation
xagram = (agram - agram.mean()) / agram.std()
print(f"Normalized tensor shape: {xagram.shape}")


In [None]:
# Create input tensor with 5 consecutive frames centered around frame 30
start_idx = 28  # 30-2 to get 2 frames before
end_idx = 33    # 30+3 to get 2 frames after (exclusive)
z = xagram[start_idx:end_idx].unsqueeze(0)  # Add batch dimension
print(f"Input tensor shape: {z.shape}")


In [23]:
z = z.to(gpuDevice)

In [None]:
y=model(z)
y.shape

In [None]:
# Apply softmax along dimension 1 (second dimension) which has size 3
y = torch.nn.functional.softmax(y, dim=1)
print(f"Output tensor shape after softmax: {y.shape}")


In [None]:
# Display the 3rd channel (index 2) of the output
plt.imshow(y[0, 2].cpu().detach().numpy(), cmap='gray')
plt.colorbar()
plt.title('Output Channel 3')
plt.show()


In [None]:
# Calculate number of valid frame groups (each group has 5 consecutive frames)
num_frames = xagram.shape[0]
num_groups = num_frames - 4  # Each group needs 5 frames

# Create tensor to hold all valid frame groups
z5 = torch.zeros((num_groups, 5, 512, 512))

# Fill z5 with overlapping groups of 5 consecutive frames
for i in range(num_groups):
    z5[i] = xagram[i:i+5]

print(f"Shape of tensor containing all valid 5-frame groups: {z5.shape}")


In [None]:
# Feed z5 into the model and get the output
y5 = model(z5.to(gpuDevice))
y5.shape

In [None]:
# Apply softmax along dimension 1 (second dimension) which has size 3
ys5 = torch.nn.functional.softmax(y5, dim=1)
print(f"Output tensor shape after softmax: {ys5.shape}")


In [None]:
# Display the 3rd channel (index 2) of batch member 35
plt.imshow(ys5[35, 2].cpu().detach().numpy(), cmap='gray')
plt.colorbar()
plt.title('Output Channel 3 - Batch 35')
plt.show()


## Export to ONNX

In [31]:
# Export model back to ONNX
onnxOutputPath = torchModelPath.replace(".pth", ".onnx")


In [None]:
onnxOutputPath

In [None]:

# Move both model and input tensor to CPU for export
# model_for_export = modelPerOnnx.to(gpuDevice)
# input_for_export = z5.to(gpuDevice)

# with torch.inference_mode():
#     torch.onnx.export(modelPerOnnx,
#                      random_tensor,
#                      onnxOutputPath, 
#                      export_params=True,
#                      opset_version=18, 
#                      do_constant_folding=True,
#                      verbose=True,
#                      input_names=['input'],
#                      output_names=['output'], 
#                      dynamic_axes={'input': {0: 'batch_size'}, 
#                                  'output': {0: 'batch_size'}}, 
#                      training=torch.onnx.TrainingMode.EVAL)

# with torch.inference_mode():
#     torch.onnx.export(
#     model,
#     random_tensor,
#     onnxOutputPath,
#     export_params=True,
#     opset_version=14,
#     do_constant_folding=True,
#     input_names=['input'],
#     output_names=['output'],
#     dynamic_axes={
#         'input': {0: 'batch_size'},  # First dimension is batch size
#         'output': {0: 'batch_size'}
#     }
# )

with torch.inference_mode():
    torch.onnx.export(
    model,
    z,
    onnxOutputPath,
    export_params=True,
    opset_version=18,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    },
    keep_initializers_as_inputs=True,  # This can help with some batch dimension issues
    do_constant_folding=True
)
