In [1]:
import torch
from torch import nn
from torchinfo import summary
import numpy as np

import tensorrt as trt

In [2]:
import pycuda.driver as cuda

In [3]:
def engine_build_from_onnx(onnx_mdl):
    EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
    builder = trt.Builder(TRT_LOGGER)
    config = builder.create_builder_config()
    # config.set_flag(trt.BuilderFlag.FP16)
    config.set_flag(trt.BuilderFlag.TF32)
    config.max_workspace_size = 1 * (1 << 50)  # the maximum size that any layer in the network can use

    network = builder.create_network(EXPLICIT_BATCH)
    parser = trt.OnnxParser(network, TRT_LOGGER)
    # Load the Onnx model and parse it in order to populate the TensorRT network.
    success = parser.parse_from_file(onnx_mdl)

    for idx in range(parser.num_errors):
        print(parser.get_error(idx))

    if not success:
        return None

    return builder.build_engine(network, config)


def mem_allocation(engine):
    # Determine dimensions and create page-locked memory buffers (i.e. won't be swapped to disk) to hold host inputs/outputs.

    in_sz = trt.volume(engine.get_binding_shape(0)) * engine.max_batch_size
    h_input = cuda.pagelocked_empty(in_sz, dtype="float32")

    out_sz = trt.volume(engine.get_binding_shape(1)) * engine.max_batch_size
    h_output = cuda.pagelocked_empty(out_sz, dtype="float32")

    # Allocate device memory for inputs and outputs.
    d_input = cuda.mem_alloc(h_input.nbytes)
    d_output = cuda.mem_alloc(h_output.nbytes)

    # Create a stream in which to copy inputs/outputs and run inference.
    stream = cuda.Stream()

    return h_input, h_output, d_input, d_output, stream

In [4]:
class ReconSmallPhaseModel(nn.Module):
    def __init__(self, nconv: int = 32):
        super(ReconSmallPhaseModel, self).__init__()
        self.nconv = nconv

        self.encoder = nn.Sequential(  # Appears sequential has similar functionality as TF avoiding need for separate model definition and activ
            *self.down_block(1, self.nconv),
            *self.down_block(self.nconv, self.nconv * 2),
            *self.down_block(self.nconv * 2, self.nconv * 4),
            *self.down_block(self.nconv * 4, self.nconv * 8),
            *self.down_block(self.nconv * 8, self.nconv * 16),
            *self.down_block(self.nconv * 16, self.nconv * 32),
        )

        # amplitude model
        # self.decoder1 = nn.Sequential(
        #    *self.up_block(self.nconv * 32, self.nconv * 16),
        #    *self.up_block(self.nconv * 16, self.nconv * 8),
        #    *self.up_block(self.nconv * 8, self.nconv * 4),
        #   *self.up_block(self.nconv * 4, self.nconv * 2),
        #    *self.up_block(self.nconv * 2, self.nconv * 1),
        #    *self.up_block(self.nconv * 1, 16),
        #    nn.Conv2d(16 , 1, 3, stride=1, padding=(1,1)),
        #    nn.Tanh()
        # )

        # phase model
        self.decoder2 = nn.Sequential(
            *self.up_block(self.nconv * 32, self.nconv * 16),  # 16
            *self.up_block(self.nconv * 16, self.nconv * 8),  # 32
            *self.up_block(self.nconv * 8, self.nconv * 4),  # 64
            *self.up_block(self.nconv * 4, self.nconv * 2),  # 128
            # *self.up_block(self.nconv * 2, self.nconv * 1),
            # *self.up_block(self.nconv * 1, 16),
            nn.Conv2d(self.nconv * 2, 1, 3, stride=1, padding=(1, 1)),
            nn.Tanh(),
        )

    def down_block(self, filters_in, filters_out):
        block = [
            nn.Conv2d(
                in_channels=filters_in,
                out_channels=filters_out,
                kernel_size=3,
                stride=1,
                padding=(1, 1),
            ),
            nn.ReLU(),
            nn.Conv2d(filters_out, filters_out, 3, stride=1, padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
        ]
        return block

    def up_block(self, filters_in, filters_out):
        block = [
            nn.Conv2d(filters_in, filters_out, 3, stride=1, padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(filters_out, filters_out, 3, stride=1, padding=(1, 1)),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode="bilinear"),
        ]
        return block

    def forward(self, x):
        with torch.cuda.amp.autocast():
            x1 = self.encoder(x)
            # amp = self.decoder1(x1)
            ph = self.decoder2(x1)

            # Restore -pi to pi range
            ph = (
                ph * np.pi
            )  # Using tanh activation (-1 to 1) for phase so multiply by pi

        return ph


In [5]:
bsz = 8

In [6]:
base_path = "/home/beams/SKANDEL/beamtime_data/sector26_02_28_23/ptychonn_02_28_23/iteration_03_02_04_00/"

In [7]:
#model_path = "/home/beams/SKANDEL/code/anakha_ptychoNN-test/models_11_22/best_model_reduced_model.pth"
model_path = f"{base_path}/best_model.pth"
model = ReconSmallPhaseModel()
   
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
#summary(model, (1, 1, 512, 512))

dummy_input = torch.randn(bsz, 1, 512, 512)  # batchsize , 1, h, w
torch.onnx.export(
    model,
    dummy_input,
    f"{base_path}/best_model_bsz_{bsz}.onnx",
    opset_version=13,
)

  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


In [7]:
data = np.load("/home/beams/SKANDEL/beamtime_data/sector26_02_28_23/Training5/scan168.npz")

In [8]:
engine = engine_build_from_onnx("/home/beams/SKANDEL/beamtime_data/sector26_02_28_23/ptychonn_02_28_23/iteration_03_01_13_19/best_model_bsz_8.onnx")

  config.max_workspace_size = 1 * (1 << 50)  # the maximum size that any layer in the network can use
  return builder.build_engine(network, config)
