Skip to content

Commit

Permalink
cherry pick of #2832
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed May 18, 2024
1 parent 722457b commit 7fafea4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
12 changes: 7 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set

import numpy as np
import tensorrt as trt
import torch
import torch.fx
from torch.fx.node import _get_qualified_name
Expand All @@ -25,7 +26,6 @@
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.logging import TRT_LOGGER

import tensorrt as trt
from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -313,8 +313,10 @@ def run(
)
timing_cache = self._create_timing_cache(builder_config, existing_cache)

engine = self.builder.build_serialized_network(self.ctx.net, builder_config)
assert engine
serialized_engine = self.builder.build_serialized_network(
self.ctx.net, builder_config
)
assert serialized_engine

serialized_cache = (
bytearray(timing_cache.serialize())
Expand All @@ -324,10 +326,10 @@ def run(
_LOGGER.info(
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
)
_LOGGER.info(f"TRT Engine uses: {engine.nbytes} bytes of Memory")
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")

return TRTInterpreterResult(
engine, self._input_names, self._output_names, serialized_cache
serialized_engine, self._input_names, self._output_names, serialized_cache
)

def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc]

def __init__(
self,
engine: trt.ICudaEngine,
engine: bytes,
input_names: Optional[List[str]] = None,
output_names: Optional[List[str]] = None,
target_device: Device = Device._current_device(),
Expand Down Expand Up @@ -60,9 +60,9 @@ def _initialize(self) -> None:
self.engine = runtime.deserialize_cuda_engine(self.engine)
self.context = self.engine.create_execution_context()

assert (
self.engine.num_io_tensors // self.engine.num_optimization_profiles
) == (len(self.input_names) + len(self.output_names))
assert self.engine.num_io_tensors == (
len(self.input_names) + len(self.output_names)
)

self.input_dtypes = [
dtype._from(self.engine.get_tensor_dtype(input_name))
Expand Down

0 comments on commit 7fafea4

Please sign in to comment.