Skip to content

Commit

Permalink
chore: cherry pick of #2832 (#2852)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed May 29, 2024
1 parent 5ef1dec commit 856f33d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
10 changes: 6 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,10 @@ def run(
)
timing_cache = self._create_timing_cache(builder_config, existing_cache)

engine = self.builder.build_serialized_network(self.ctx.net, builder_config)
assert engine
serialized_engine = self.builder.build_serialized_network(
self.ctx.net, builder_config
)
assert serialized_engine

serialized_cache = (
bytearray(timing_cache.serialize())
Expand All @@ -328,10 +330,10 @@ def run(
_LOGGER.info(
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
)
_LOGGER.info(f"TRT Engine uses: {engine.nbytes} bytes of Memory")
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")

return TRTInterpreterResult(
engine, self._input_names, self._output_names, serialized_cache
serialized_engine, self._input_names, self._output_names, serialized_cache
)

def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc]

def __init__(
self,
engine: trt.ICudaEngine,
engine: bytes,
input_names: Optional[List[str]] = None,
output_names: Optional[List[str]] = None,
target_device: Device = Device._current_device(),
Expand Down Expand Up @@ -61,9 +61,9 @@ def _initialize(self) -> None:
self.engine = runtime.deserialize_cuda_engine(self.engine)
self.context = self.engine.create_execution_context()

assert (
self.engine.num_io_tensors // self.engine.num_optimization_profiles
) == (len(self.input_names) + len(self.output_names))
assert self.engine.num_io_tensors == (
len(self.input_names) + len(self.output_names)
)

self.input_dtypes = [
dtype._from(self.engine.get_tensor_dtype(input_name))
Expand Down

0 comments on commit 856f33d

Please sign in to comment.