# Torch to ONNX Conversion

This notebook converts Torch to ONNX models.

In [8]:
%matplotlib inline
import torch
import torch.nn as nn
import onnx
import dill
# from onnx2pytorch import convert
import os
import matplotlib.pyplot as plt  # Add this line
import onnxruntime as ort
import numpy as np
import copy


In [9]:
import sys
sys.path.append('/Users/billb/github/nnUNet-Adjustment')


In [10]:
import nnunetv2.training.nnUNetTrainer

In [11]:
import nnunetv2.training.nnUNetTrainer.nnUNetTrainer_LowDoseContrastSim

In [12]:
# Check if running on Mac OS
is_mac = os.name == 'posix' and os.uname().sysname == 'Darwin'
print('posix' if os.name == 'posix' else 'not posix')
print('mac' if is_mac else 'not mac')


posix
mac


In [13]:
# Set up paths
is_mac = os.name == 'posix' and os.uname().sysname == 'Darwin'
netModelPath = "/Volumes/X10Pro/AWIBuffer/NetModels/" if is_mac else '/mnt/SliskiDrive/AWI/AWIBuffer/NetModels/' # '/Volumes/X10Pro/AWIBuffer/' 
dataPath = "/Volumes/X10Pro/AWIBuffer/Angiostore/" if is_mac else '/mnt/SliskiDrive/AWI/AWIBuffer/Angiostore/' # '/Volumes/X10Pro/AWIBuffer/' 
# "~/Projects/AWI/NetExploration/"

In [32]:
# Set up device
gpuDevice = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
cpuDevice = torch.device('cpu')
gpuDevice = cpuDevice

print(f"Using device: {gpuDevice}") 


Using device: cpu


## Test Model with Random Input

Import a dill PyTorch model

In [None]:
torchModelPath = netModelPath + "PlainConvUNet-nnUNetPlans_2d-reduced3-lowdosesim-DC_and_CE_loss-w-1-20-20-dill.pth"
# torchModelPath = netModelPath + "PlainConvUNet-nnUNetPlans_2d-reduced3-lowdosesim-DC_and_CE_loss-w-1-20-20-dill.pth"
# torchModelPath =  "/Volumes/X10Pro/AWIBuffer/UXlstmBot-nnUNetPlans_2d-reduced3-DC_and_CE_loss-w-1-20-40-dill.pth"

In [None]:
torchModelPath

In [None]:
# Check if exists file at path torchModelPath
if not os.path.exists(torchModelPath):
    raise FileNotFoundError(f"Model file not found at path: {torchModelPath}")

In [None]:
# Load the model if its PyTorch
model = torch.load(torchModelPath,map_location=gpuDevice, weights_only=False)

In [33]:
# Load the model if its onnx
# onnxPath = rootPath + "UMambaBot-plans_unet_edge8_epochs250_2d-DC_and_CE_loss-w-1-20-20.onnx"
onnxPath = netModelPath + "ResidualEncoderUNet-nnUNetPlans_2d-reduced3-DC_and_CE_loss-w-1-20-20-lowdosesim.onnx"

from onnx2torch import convert
model = convert(onnxPath)





In [34]:
# Convert to float32
model.to(torch.float32)

GraphModule(
  (initializers): Module()
  (Identity_0): OnnxCopyIdentity()
  (Identity_1): OnnxCopyIdentity()
  (Identity_2): OnnxCopyIdentity()
  (Identity_3): OnnxCopyIdentity()
  (encoder/stem/convs/convs/0/all_modules/conv/Cast): OnnxCast()
  (encoder/stem/convs/convs/0/all_modules/conv/Conv): Conv2d(5, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (encoder/stem/convs/convs/0/all_modules/norm/InstanceNormalization): InstanceNorm2d(32, eps=9.999999747378752e-06, momentum=0.1, affine=True, track_running_stats=False)
  (encoder/stem/convs/convs/0/all_modules/nonlin/LeakyRelu): LeakyReLU(negative_slope=0.009999999776482582)
  (encoder/stages/0/blocks/blocks/0/conv1/all_modules/conv/Conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (encoder/stages/0/blocks/blocks/0/conv1/all_modules/norm/InstanceNormalization): InstanceNorm2d(32, eps=9.999999747378752e-06, momentum=0.1, affine=True, track_running_stats=False)
  (encoder/stages/0/blocks/blocks/0/conv1/

In [26]:
model.float()

GraphModule(
  (initializers): Module()
  (Identity_0): OnnxCopyIdentity()
  (Identity_1): OnnxCopyIdentity()
  (Identity_2): OnnxCopyIdentity()
  (Identity_3): OnnxCopyIdentity()
  (encoder/stem/convs/convs/0/all_modules/conv/Cast): OnnxCast()
  (encoder/stem/convs/convs/0/all_modules/conv/Conv): Conv2d(5, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (encoder/stem/convs/convs/0/all_modules/norm/InstanceNormalization): InstanceNorm2d(32, eps=9.999999747378752e-06, momentum=0.1, affine=True, track_running_stats=False)
  (encoder/stem/convs/convs/0/all_modules/nonlin/LeakyRelu): LeakyReLU(negative_slope=0.009999999776482582)
  (encoder/stages/0/blocks/blocks/0/conv1/all_modules/conv/Conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (encoder/stages/0/blocks/blocks/0/conv1/all_modules/norm/InstanceNormalization): InstanceNorm2d(32, eps=9.999999747378752e-06, momentum=0.1, affine=True, track_running_stats=False)
  (encoder/stages/0/blocks/blocks/0/conv1/

In [24]:
type(model)

torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl

In [21]:

model.eval()

# Test the model with a random input
random_tensor = torch.randn(1, 5, 512, 512, device=gpuDevice, dtype=torch.float32)
print("Input tensor shape:", random_tensor.shape)

Input tensor shape: torch.Size([1, 5, 512, 512])


In [35]:
# Fix: Convert ALL model parameters to float32
def convert_model_to_fp32(model):
    """
    Convert all model parameters and buffers to float32
    """
    # Convert model to float32
    model = model.float()
    
    # Also convert all parameters and buffers explicitly
    for param in model.parameters():
        param.data = param.data.float()
    
    for buffer in model.buffers():
        buffer.data = buffer.data.float()
    
    # Set model to eval mode
    model.eval()
    
    return model


In [None]:

# Verify all parameters are float32
print("Checking model parameter types:")
for name, param in model.named_parameters():
    print(f"{name}: {param.dtype}")

# Now test with float32 input
random_tensor = torch.randn(1, 5, 512, 512, device=cpuDevice)
random_tensor = random_tensor.to(torch.float32)  # Explicit conversion

print(f"Input tensor dtype: {random_tensor.dtype}")

with torch.inference_mode():
    output = model(random_tensor)
print(f"Output tensor dtype: {output.dtype}")

In [39]:

# Apply the conversion
model = convert_model_to_fp32(model)


In [38]:
print(f"Output tensor dtype: {random_tensor.dtype}")

Output tensor dtype: torch.float32


In [40]:
with torch.inference_mode():
    output = model(random_tensor)

print("Output tensor shape:", output.shape)


RuntimeError: Input type (c10::Half) and bias type (float) should be the same

In [42]:
# Debug: Check if there are any remaining Half precision tensors
def check_model_dtypes(model):
    """Check all parameter and buffer dtypes in the model"""
    print("=== Model Parameter Types ===")
    half_params = []
    for name, param in model.named_parameters():
        if param.dtype == torch.float16:
            half_params.append(name)
        print(f"{name}: {param.dtype}")
    
    print("\n=== Model Buffer Types ===")
    half_buffers = []
    for name, buffer in model.named_buffers():
        if buffer.dtype == torch.float16:
            half_buffers.append(name)
        print(f"{name}: {buffer.dtype}")
    
    if half_params:
        print(f"\n⚠️  Found Half precision parameters: {half_params}")
    if half_buffers:
        print(f"\n⚠️  Found Half precision buffers: {half_buffers}")
    
    return half_params, half_buffers

# Check current model state
half_params, half_buffers = check_model_dtypes(model)

=== Model Parameter Types ===
encoder/stem/convs/convs/0/all_modules/conv/Conv.weight: torch.float32
encoder/stem/convs/convs/0/all_modules/conv/Conv.bias: torch.float32
encoder/stem/convs/convs/0/all_modules/norm/InstanceNormalization.weight: torch.float32
encoder/stem/convs/convs/0/all_modules/norm/InstanceNormalization.bias: torch.float32
encoder/stages/0/blocks/blocks/0/conv1/all_modules/conv/Conv.weight: torch.float32
encoder/stages/0/blocks/blocks/0/conv1/all_modules/conv/Conv.bias: torch.float32
encoder/stages/0/blocks/blocks/0/conv1/all_modules/norm/InstanceNormalization.weight: torch.float32
encoder/stages/0/blocks/blocks/0/conv1/all_modules/norm/InstanceNormalization.bias: torch.float32
encoder/stages/0/blocks/blocks/0/conv2/all_modules/conv/Conv.weight: torch.float32
encoder/stages/0/blocks/blocks/0/conv2/all_modules/conv/Conv.bias: torch.float32
encoder/stages/0/blocks/blocks/0/conv2/all_modules/norm/InstanceNormalization.weight: torch.float32
encoder/stages/0/blocks/blocks

In [43]:
# Aggressive fix: Force ALL tensors to float32
def force_model_to_fp32(model):
    """Aggressively convert all model components to float32"""
    
    # Convert the model itself
    model = model.float()
    
    # Force convert all parameters
    for name, param in model.named_parameters():
        if param.dtype != torch.float32:
            print(f"Converting parameter {name} from {param.dtype} to float32")
            param.data = param.data.float()
    
    # Force convert all buffers
    for name, buffer in model.named_buffers():
        if buffer.dtype != torch.float32:
            print(f"Converting buffer {name} from {buffer.dtype} to float32")
            buffer.data = buffer.data.float()
    
    # Also check for any nested modules
    for module in model.modules():
        if hasattr(module, 'weight') and module.weight is not None:
            module.weight.data = module.weight.data.float()
        if hasattr(module, 'bias') and module.bias is not None:
            module.bias.data = module.bias.data.float()
    
    return model

# Apply the aggressive conversion
model = force_model_to_fp32(model)

# Verify conversion worked
print("\nAfter conversion:")
check_model_dtypes(model)

# Test with float32 input
random_tensor = torch.randn(1, 5, 512, 512, device=gpuDevice, dtype=torch.float32)
print(f"\nInput tensor dtype: {random_tensor.dtype}")

try:
    with torch.inference_mode():
        output = model(random_tensor)
    print(f"✅ Success! Output dtype: {output.dtype}, shape: {output.shape}")
except RuntimeError as e:
    print(f"❌ Still getting error: {e}")


After conversion:
=== Model Parameter Types ===
encoder/stem/convs/convs/0/all_modules/conv/Conv.weight: torch.float32
encoder/stem/convs/convs/0/all_modules/conv/Conv.bias: torch.float32
encoder/stem/convs/convs/0/all_modules/norm/InstanceNormalization.weight: torch.float32
encoder/stem/convs/convs/0/all_modules/norm/InstanceNormalization.bias: torch.float32
encoder/stages/0/blocks/blocks/0/conv1/all_modules/conv/Conv.weight: torch.float32
encoder/stages/0/blocks/blocks/0/conv1/all_modules/conv/Conv.bias: torch.float32
encoder/stages/0/blocks/blocks/0/conv1/all_modules/norm/InstanceNormalization.weight: torch.float32
encoder/stages/0/blocks/blocks/0/conv1/all_modules/norm/InstanceNormalization.bias: torch.float32
encoder/stages/0/blocks/blocks/0/conv2/all_modules/conv/Conv.weight: torch.float32
encoder/stages/0/blocks/blocks/0/conv2/all_modules/conv/Conv.bias: torch.float32
encoder/stages/0/blocks/blocks/0/conv2/all_modules/norm/InstanceNormalization.weight: torch.float32
encoder/sta

In [41]:
# Alternative: Reload and convert properly from ONNX
from onnx2torch import convert

# Reload the ONNX model
onnx_path = netModelPath + "ResidualEncoderUNet-nnUNetPlans_2d-reduced3-DC_and_CE_loss-w-1-20-20-lowdosesim.onnx"
model = convert(onnx_path)

# Immediately convert to float32 before any operations
model = model.float()

# Move to device as float32
model = model.to(gpuDevice, dtype=torch.float32)

# Set to eval mode
model.eval()

# Test immediately
random_tensor = torch.randn(1, 5, 512, 512, device=gpuDevice, dtype=torch.float32)
with torch.inference_mode():
    output = model(random_tensor)
print(f"Output shape: {output.shape}, dtype: {output.dtype}")

RuntimeError: Input type (c10::Half) and bias type (float) should be the same

In [44]:
# Debug: Add hooks to catch where the Half tensor is coming from
def debug_hook(module, input, output):
    """Hook to catch Half precision tensors"""
    for i, inp in enumerate(input):
        if hasattr(inp, 'dtype') and inp.dtype == torch.float16:
            print(f"❌ Found Half precision input {i} in {module.__class__.__name__}: {inp.dtype}")
    
    if hasattr(output, 'dtype') and output.dtype == torch.float16:
        print(f"❌ Found Half precision output in {module.__class__.__name__}: {output.dtype}")

# Register hooks on all modules
hooks = []
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):  # Focus on Conv2d layers where error occurs
        hook = module.register_forward_hook(debug_hook)
        hooks.append(hook)

# Now run your model
try:
    with torch.inference_mode():
        output = model(random_tensor)
except RuntimeError as e:
    print(f"Caught error: {e}")

# Remove hooks
for hook in hooks:
    hook.remove()

Caught error: Input type (c10::Half) and bias type (float) should be the same


In [47]:
# Nuclear option: Force everything to CPU and float32, then move to device
model = model.cpu().float()

# Ensure ALL parameters are float32
for param in model.parameters():
    param.data = param.data.float()

for buffer in model.buffers():
    buffer.data = buffer.data.float()

# Now move to device
model = model.to(gpuDevice)

# Create input tensor
random_tensor = torch.randn(1, 5, 512, 512, device=gpuDevice, dtype=torch.float32)

# Test
with torch.inference_mode():
    output = model(random_tensor)

RuntimeError: Input type (c10::Half) and bias type (float) should be the same

## Test Model with HDF5 Input

In [None]:
import h5py
angiogramH5Path = dataPath + "WebknossosAngiogramsRevisedUInt8List.h5"
# "~/Projects/AWI/NetExploration/"
# Open the HDF5 file and print all dataset keys
with h5py.File(angiogramH5Path, 'r') as f:
    # Get all keys at root level
    keys = list(f.keys())
    print("Dataset keys in HDF5 file:")
    for key in keys:
        print(f"- {key}")


In [None]:
# Load first angiogram from HDF5 file
import random
with h5py.File(angiogramH5Path, 'r') as f:
    # Get first key
    hdfKey = random.choice(keys)
    print(f"Loading dataset: {hdfKey}")
    # Load data into tensor
    agram = torch.from_numpy(f[hdfKey][:]).float()
    print(f"Loaded tensor shape: {agram.shape}")


In [None]:
#Display the 30th frame of the angiogram
plt.imshow(agram[30], cmap='gray')
plt.colorbar()
plt.show()


In [None]:
# Normalize angiogram by subtracting mean and dividing by standard deviation
xagram = (agram - agram.mean()) / agram.std()
print(f"Normalized tensor shape: {xagram.shape}")


In [None]:
# Create input tensor with 5 consecutive frames centered around frame 30
start_idx = 28  # 30-2 to get 2 frames before
end_idx = 33    # 30+3 to get 2 frames after (exclusive)
z = xagram[start_idx:end_idx].unsqueeze(0)  # Add batch dimension
print(f"Input tensor shape: {z.shape}")


In [None]:
z = z.to(gpuDevice)

In [None]:
y=model(z)
y.shape

In [None]:
# Apply softmax along dimension 1 (second dimension) which has size 3
y = torch.nn.functional.softmax(y, dim=1)
print(f"Output tensor shape after softmax: {y.shape}")


In [None]:
# Display the 3rd channel (index 2) of the output
plt.imshow(y[0, 2].cpu().detach().numpy(), cmap='gray')
plt.colorbar()
plt.title('Output Channel 3')
plt.show()


In [None]:
# Calculate number of valid frame groups (each group has 5 consecutive frames)
num_frames = xagram.shape[0]
num_groups = num_frames - 4  # Each group needs 5 frames

# Create tensor to hold all valid frame groups
z5 = torch.zeros((num_groups, 5, 512, 512))

# Fill z5 with overlapping groups of 5 consecutive frames
for i in range(num_groups):
    z5[i] = xagram[i:i+5]

print(f"Shape of tensor containing all valid 5-frame groups: {z5.shape}")


In [None]:
# Feed z5 into the model and get the output
y5 = model(z5.to(gpuDevice))
y5.shape

In [None]:
# Apply softmax along dimension 1 (second dimension) which has size 3
ys5 = torch.nn.functional.softmax(y5, dim=1)
print(f"Output tensor shape after softmax: {ys5.shape}")


In [None]:
# Display the 3rd channel (index 2) of batch member 35
plt.imshow(ys5[35, 2].cpu().detach().numpy(), cmap='gray')
plt.colorbar()
plt.title('Output Channel 3 - Batch 35')
plt.show()


## Export to ONNX

In [None]:
# Export model back to ONNX
onnxOutputPath = torchModelPath.replace(".pth", "-fp32.onnx")


In [None]:
onnxOutputPath

In [None]:

# Move both model and input tensor to CPU for export
# model_for_export = modelPerOnnx.to(gpuDevice)
# input_for_export = z5.to(gpuDevice)

# with torch.inference_mode():
#     torch.onnx.export(modelPerOnnx,
#                      random_tensor,
#                      onnxOutputPath, 
#                      export_params=True,
#                      opset_version=18, 
#                      do_constant_folding=True,
#                      verbose=True,
#                      input_names=['input'],
#                      output_names=['output'], 
#                      dynamic_axes={'input': {0: 'batch_size'}, 
#                                  'output': {0: 'batch_size'}}, 
#                      training=torch.onnx.TrainingMode.EVAL)

# with torch.inference_mode():
#     torch.onnx.export(
#     model,
#     random_tensor,
#     onnxOutputPath,
#     export_params=True,
#     opset_version=14,
#     do_constant_folding=True,
#     input_names=['input'],
#     output_names=['output'],
#     dynamic_axes={
#         'input': {0: 'batch_size'},  # First dimension is batch size
#         'output': {0: 'batch_size'}
#     }
# )

with torch.inference_mode():
    torch.onnx.export(
    model,
    z,
    onnxOutputPath,
    export_params=True,
    opset_version=18,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    },
    keep_initializers_as_inputs=True,  # This can help with some batch dimension issues
    do_constant_folding=True
)


In [None]:
# Ensure model is in float32 precision
model = model.float()
model.eval()

# Create a sample input tensor in float32
# Using the same shape as z but ensuring float32
dummy_input = torch.randn(1, 5, 512, 512, dtype=torch.float32, device=gpuDevice)

print(f"Model dtype: {next(model.parameters()).dtype}")
print(f"Input dtype: {dummy_input.dtype}")

In [None]:
# Export to ONNX with explicit float32 configuration
onnxOutputPath_fp32 = torchModelPath.replace(".pth", "-float32.onnx")
print(f"Exporting to: {onnxOutputPath_fp32}")

# Ensure model is completely moved to CPU and in float32
model = model.cpu().float()
dummy_input_cpu = dummy_input.cpu().float()

# Export with explicit float32 settings
with torch.no_grad():
    torch.onnx.export(
        model,
        dummy_input_cpu,
        onnxOutputPath_fp32,
        export_params=True,
        opset_version=14,  # Using opset 14 for better compatibility
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={
            'input': {0: 'batch_size'},
            'output': {0: 'batch_size'}
        },
        operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
        verbose=False
    )

print(f"Model exported successfully to {onnxOutputPath_fp32}")

In [None]:
# Verify the exported ONNX model uses float32
onnx_model = onnx.load(onnxOutputPath_fp32)

# Check data types of tensors in the model
print("Checking tensor data types in ONNX model:")
for initializer in onnx_model.graph.initializer:
    tensor_dtype = initializer.data_type
    dtype_str = onnx.helper.tensor_dtype_to_np_dtype(tensor_dtype)
    print(f"  Tensor '{initializer.name[:30]}...': dtype={dtype_str}")
    if len(onnx_model.graph.initializer) > 5:
        print("  ... (showing first 5 tensors)")
        break

# Check input and output types
for input_tensor in onnx_model.graph.input:
    if input_tensor.type.tensor_type.elem_type:
        dtype = onnx.helper.tensor_dtype_to_np_dtype(input_tensor.type.tensor_type.elem_type)
        print(f"\nInput '{input_tensor.name}': dtype={dtype}")

for output_tensor in onnx_model.graph.output:
    if output_tensor.type.tensor_type.elem_type:
        dtype = onnx.helper.tensor_dtype_to_np_dtype(output_tensor.type.tensor_type.elem_type)
        print(f"Output '{output_tensor.name}': dtype={dtype}")

print(f"\nONNX model saved as float32 to: {onnxOutputPath_fp32}")

In [None]:
# Test the exported ONNX model with ONNXRuntime to ensure it works
ort_session = ort.InferenceSession(onnxOutputPath_fp32)

# Prepare test input
test_input = dummy_input_cpu.numpy()

# Run inference
ort_inputs = {ort_session.get_inputs()[0].name:
test_input}
ort_outputs = ort_session.run(None, ort_inputs)

print(f"ONNX Runtime test successful!")
print(f"Input shape: {test_input.shape}, dtype: {test_input.dtype}")
print(f"Output shape: {ort_outputs[0].shape}, dtype: {ort_outputs[0].dtype}")

# Compare with PyTorch output (model is now on CPU)
# cpuDevice = torch.device('cpu')
model.to(cpuDevice)
with torch.no_grad():
    pytorch_output = model(dummy_input_cpu).numpy()

# Check if outputs are close
max_diff = np.max(np.abs(pytorch_output -
ort_outputs[0]))
print(f"\nMax difference between PyTorch and ONNX outputs: {max_diff:.6f}")