import numpy as np
import onnxruntime as rt
import tensorrt as trt
import pycuda.driver as cuda


class HostDeviceMem:
	def __init__(self, host_mem, device_mem):
		self.host = host_mem
		self.device = device_mem


def get_engine(model_fn):
	TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
	with trt.Builder(TRT_LOGGER) as builder:
		with builder.create_network(0) as network:
			with builder.create_builder_config() as config:
				with trt.OnnxParser(network, TRT_LOGGER) as parser:
					with trt.Runtime(TRT_LOGGER) as runtime:
						config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)

						# Parse ONNX model
						with open(model_fn, 'rb') as model:
							parser.parse(model.read())

						serialized_engine = builder.build_serialized_network(network, config)
						engine = runtime.deserialize_cuda_engine(serialized_engine)
						return engine


def allocate_buffers(engine):
	inputs = []
	outputs = []
	bindings = []
	stream = cuda.Stream()

	for i in range(engine.num_io_tensors):
		tensor_name = engine.get_tensor_name(i)
		size = trt.volume(engine.get_tensor_shape(tensor_name))
		dtype = trt.nptype(engine.get_tensor_dtype(tensor_name))

		# Allocate host and device buffers
		host_mem = cuda.pagelocked_empty(size, dtype) # page-locked memory buffer (won't swap to disk)
		device_mem = cuda.mem_alloc(host_mem.nbytes)

		# Append the device buffer address to device bindings.
		# When cast to int, it's a linear index into the context's memory (like memory address).
		bindings.append(int(device_mem))

		# Append to the appropriate input/output list.
		if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
			inputs.append(HostDeviceMem(host_mem, device_mem))
		else:
			outputs.append(HostDeviceMem(host_mem, device_mem))

	return (inputs, outputs, bindings, stream)


def do_inference(engine, context, bindings, inputs, outputs, stream):
	for cur_input in inputs:
		cuda.memcpy_htod_async(cur_input.device, cur_input.host, stream)

	for i in range(engine.num_io_tensors):
		context.set_tensor_address(engine.get_tensor_name(i), bindings[i])

	context.execute_async_v3(stream_handle=stream.handle)

	for cur_output in outputs:
		cuda.memcpy_dtoh_async(cur_output.host, cur_output.device, stream)
	stream.synchronize()
	
	return [out.host for out in outputs]


if __name__ == '__main__':
	model_fn = 'model.onnx'
	input_np = np.random.rand(1, 3, 8, 8).astype(np.float32)

	cuda.init()
	dev = cuda.Device(0)
	ctx = dev.make_context()

	with get_engine(model_fn) as engine:
		inputs, outputs, bindings, stream = allocate_buffers(engine)
		np.copyto(inputs[0].host, input_np.flatten())
		with engine.create_execution_context() as context:
			output = do_inference(engine, context, bindings, inputs, outputs, stream)
			trt_output = output[0].reshape((1, 8, 16, 16))

	ctx.pop()

	ort_model = rt.InferenceSession(model_fn, providers=rt.get_available_providers())
	ort_output = ort_model.run(['output'], {'input' : input_np})[0]

	print(np.allclose(ort_output, trt_output, rtol=1e-4, atol=1e-4))
	np.save('input.npy', input_np)
	np.save('output_ort.npy', ort_output)
	np.save('output_trt.npy', trt_output)
