# Code to convert Pytorch-UNet trained mode to OpenVINO IR

https://pytorch.org/tutorials/beginner/saving_loading_models.html

In [1]:
from pathlib import Path
import torch
import openvino as ov
from unet import UNet

In [2]:
  import os
  print(os.getcwd())

/home/qwang12/src/Pytorch-UNet-scottwangintel/Pytorch-UNet


In [3]:
# Specify the path to the model file
model_path = './checkpoints-0703/checkpoint_epoch5.pth'
cwd = Path.cwd()
model_path = cwd / model_path

In [4]:
# Load the state dictionary from the file without loading it into the model
state_dict = torch.load(model_path)

# Print all keys in the state dictionary
print("Keys in the loaded state_dict:")
for key in state_dict.keys():
    print(key)

Keys in the loaded state_dict:
inc.double_conv.0.weight
inc.double_conv.1.weight
inc.double_conv.1.bias
inc.double_conv.1.running_mean
inc.double_conv.1.running_var
inc.double_conv.1.num_batches_tracked
inc.double_conv.3.weight
inc.double_conv.4.weight
inc.double_conv.4.bias
inc.double_conv.4.running_mean
inc.double_conv.4.running_var
inc.double_conv.4.num_batches_tracked
down1.maxpool_conv.1.double_conv.0.weight
down1.maxpool_conv.1.double_conv.1.weight
down1.maxpool_conv.1.double_conv.1.bias
down1.maxpool_conv.1.double_conv.1.running_mean
down1.maxpool_conv.1.double_conv.1.running_var
down1.maxpool_conv.1.double_conv.1.num_batches_tracked
down1.maxpool_conv.1.double_conv.3.weight
down1.maxpool_conv.1.double_conv.4.weight
down1.maxpool_conv.1.double_conv.4.bias
down1.maxpool_conv.1.double_conv.4.running_mean
down1.maxpool_conv.1.double_conv.4.running_var
down1.maxpool_conv.1.double_conv.4.num_batches_tracked
down2.maxpool_conv.1.double_conv.0.weight
down2.maxpool_conv.1.double_conv.1.

In [5]:
# Initialize a new UNet model
model = UNet(n_channels=3, n_classes=2, bilinear=False)

# Get the state dictionary of the new model
new_state_dict = model.state_dict()

# Print all keys expected by the UNet model
print("Keys expected by the UNet model:")
for key in new_state_dict.keys():
    print(key)

Keys expected by the UNet model:
inc.double_conv.0.weight
inc.double_conv.1.weight
inc.double_conv.1.bias
inc.double_conv.1.running_mean
inc.double_conv.1.running_var
inc.double_conv.1.num_batches_tracked
inc.double_conv.3.weight
inc.double_conv.4.weight
inc.double_conv.4.bias
inc.double_conv.4.running_mean
inc.double_conv.4.running_var
inc.double_conv.4.num_batches_tracked
down1.maxpool_conv.1.double_conv.0.weight
down1.maxpool_conv.1.double_conv.1.weight
down1.maxpool_conv.1.double_conv.1.bias
down1.maxpool_conv.1.double_conv.1.running_mean
down1.maxpool_conv.1.double_conv.1.running_var
down1.maxpool_conv.1.double_conv.1.num_batches_tracked
down1.maxpool_conv.1.double_conv.3.weight
down1.maxpool_conv.1.double_conv.4.weight
down1.maxpool_conv.1.double_conv.4.bias
down1.maxpool_conv.1.double_conv.4.running_mean
down1.maxpool_conv.1.double_conv.4.running_var
down1.maxpool_conv.1.double_conv.4.num_batches_tracked
down2.maxpool_conv.1.double_conv.0.weight
down2.maxpool_conv.1.double_conv.

In [6]:
def load_model(model_path):
    # Initialize the model
    model = UNet(3, 2, False)
    
    # Load the state dictionary from the file
    state_dict = torch.load(model_path)
    
    # Remove the unexpected key
    state_dict.pop('mask_values', None)  # This removes 'mask_values' if it exists, and does nothing if it doesn't
    
    # Load the state dictionary into the model
    model.load_state_dict(state_dict, strict=False)
    
    # Set the model to evaluation mode
    model.eval()
    
    return model

In [7]:
# Check if the model path exists
if model_path.exists():
    # Load the model
    loaded_model = load_model(model_path)
    print("Model loaded successfully!")
else:
    print(f"Error: The specified model path does not exist: {model_path}")

Model loaded successfully!


In [8]:
ov_model = ov.convert_model(loaded_model, input=[1, 3, 256, 256])


In [9]:
# Save the OpenVINO IR files
ir_model_path = model_path.with_suffix('') / "shape-1-3-256-256"
xml_path = ir_model_path.with_suffix('.xml')
bin_path = ir_model_path.with_suffix('.bin')

In [10]:
print (xml_path)

/home/qwang12/src/Pytorch-UNet-scottwangintel/Pytorch-UNet/checkpoints-0703/checkpoint_epoch5/shape-1-3-256-256.xml


In [11]:
ov.save_model(ov_model, xml_path)

In [12]:
# Read and print the first 10 lines of the XML file
with open(xml_path, 'r') as file:
    for _ in range(10):
        print(file.readline().strip())

<?xml version="1.0"?>
<net name="Model0" version="11">
<layers>
<layer id="0" name="x" type="Parameter" version="opset1">
<data shape="1,3,256,256" element_type="f32" />
<output>
<port id="0" precision="FP32" names="x">
<dim>1</dim>
<dim>3</dim>
<dim>256</dim>
