In [7]:
import os
import torch
from os.path import join
from monai.networks.nets import UNet

# Define the model architecture
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm="batch",
).to(device)

# Function to Convert to ONNX 
def Convert_ONNX(): 
    root_dir = "/tmp/tmpk9rec3f2"  # Specify the root directory
    onnx_dir = join(root_dir, "models")  # Directory where you want to save the ONNX model
    os.makedirs(onnx_dir, exist_ok=True)  # Create the directory if it doesn't exist
    
    # set the model to inference mode 
    model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
    model.eval() 

    # Export the model   
    onnx_file_path = join(onnx_dir, "spleen_3d_seg.onnx")
    with open(onnx_file_path, "wb") as f:
        torch.onnx.export(
            model,                                   # model being run 
            torch.randn(1, 1, 64, 64, 64).to(device),# dummy input
            f,                                       # file object to save the model
            export_params=True,                      # store the trained parameter weights inside the model file 
            opset_version=10,                        # the ONNX version to export the model to 
            do_constant_folding=True,                # whether to execute constant folding for optimization 
            input_names=['modelInput'],             # the model's input names 
            output_names=['modelOutput'],           # the model's output names 
            dynamic_axes={'modelInput': {0: 'batch_size'}, 'modelOutput': {0: 'batch_size'}} # variable length axes 
        ) 

    print(f"Model has been converted to ONNX and saved at: {onnx_file_path}")

# Call the function to convert the model to ONNX
Convert_ONNX()


Model has been converted to ONNX and saved at: /tmp/tmpk9rec3f2/models/spleen_3d_seg.onnx


In [8]:
ls -lart /tmp/tmp5dukrc1e/models/

total 18804
drwx------. 4 1001050000 root      118 Mar 28 22:27 [0m[01;34m..[0m/
drwxr-xr-x. 2 1001050000 root       32 Mar 28 23:57 [01;34m.[0m/
-rw-r--r--. 1 1001050000 root 19251671 Mar 29 00:12 spleen_3d_seg.onnx
