In [None]:
import onnx
import torch
import tensorrt as trt
import numpy as np
import os

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

data_dir = os.getcwd() + "/speed_test/gpu/"
input_filename = "inputs_test_scaled_1000000.txt"

data = np.loadtxt(os.path.join(data_dir, input_filename))
inputs_test_tensor = torch.tensor(data, dtype=torch.float32)
inputs_test_tensor = inputs_test_tensor.to(device)

In [None]:
inputs_test_tensor[:5]

In [None]:
# Load model
model_dir = "../models/"
model_name = "NNC2PL"


model_path = os.path.join(model_dir, model_name + ".pth")
model = torch.jit.load(model_path, map_location=device)
model = model.to(device)
model.eval()

In [None]:
# Export onnx model
onnx_model_path = os.path.join(model_dir, model_name + "_1M" + ".onnx")
torch.onnx.export(model, inputs_test_tensor, onnx_model_path, verbose=True)


# # Causes issues, come back to this later
# torch.onnx.export(
#     model,
#     inputs_test_tensor,
#     onnx_model_path,
#     verbose=True,
#     input_names=["input"],
#     output_names=["output"],
#     dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
# )

In [None]:
# BUILD A DYNAMIC TRT ENGINE 

engine_path = os.path.join(model_dir, model_name + "_1M" + ".engine")
# # engine_path = os.path.join(model_dir, model_name + "_1M" + "_FP16" + ".engine") # Uncomment for FP16

logger = trt.Logger(trt.Logger.VERBOSE)
builder = trt.Builder(logger)

network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)

with open(onnx_model_path, 'rb') as model:
    if not parser.parse(model.read()):
        print("Failed to parse the ONNX model")
        for error in range(parser.num_errors):
            print(parser.get_error(error))
        raise RuntimeError("Failed to parse ONNX model.")

config = builder.create_builder_config()

config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 2 << 30)  # 2GB
config.set_flag(trt.BuilderFlag.FP16)  # Uncomment for FP16

# Set optimization profile for dynamic batching
profile = builder.create_optimization_profile() 
min_batch_size = int(inputs_test_tensor.shape[0] * 0.95 / 20)
optim_batch_size = int(inputs_test_tensor.shape[0] / 20)
max_batch_size = int(inputs_test_tensor.shape[0] * 1.05 / 20)
dim = inputs_test_tensor.shape[1]
profile.set_shape("inputs_test_tensor", trt.Dims((min_batch_size, dim)), trt.Dims((optim_batch_size, dim)), trt.Dims((max_batch_size, dim)))
config.add_optimization_profile(profile)

serialized_engine = builder.build_serialized_network(network, config)
with open(engine_path, "wb") as f:
    f.write(serialized_engine)

In [None]:
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
runtime = trt.Runtime(TRT_LOGGER)

def load_engine(engine_path):
    with open(engine_path, "rb") as f:
        engine_data = f.read()
    return runtime.deserialize_cuda_engine(engine_data)

In [None]:
def print_engine_details(engine):
    print("Engine has {} bindings:".format(engine.num_bindings))
    for i in range(engine.num_bindings):
        binding_name = engine.get_binding_name(i)
        binding_shape = engine.get_binding_shape(i)
        binding_dtype = engine.get_binding_dtype(i)
        is_input = engine.binding_is_input(i)
        print("Binding {}: Name = {}, Shape = {}, DataType = {}, {}".format(
            i, binding_name, binding_shape, binding_dtype, "Input" if is_input else "Output"
        ))

In [None]:
print_engine_details(eng)