6
6
7
7
import numpy as np
8
8
import torch
9
+ import torch_tensorrt
9
10
from torch_tensorrt ._Device import Device
10
11
from torch_tensorrt .dynamo import _defaults
11
12
from torch_tensorrt .dynamo ._compiler import compile as dynamo_compile
@@ -61,6 +62,7 @@ def __init__(
61
62
* ,
62
63
device : Optional [Union [Device , torch .device , str ]] = _defaults .DEVICE ,
63
64
use_python_runtime : bool = _defaults .USE_PYTHON_RUNTIME ,
65
+ enable_cuda_graph : bool = True ,
64
66
immutable_weights : bool = False ,
65
67
strict : bool = True ,
66
68
allow_complex_guards_as_runtime_asserts : bool = False ,
@@ -127,6 +129,7 @@ def __init__(
127
129
self .arg_inputs : tuple [Any , ...] = tuple ()
128
130
self .kwarg_inputs : dict [str , Any ] = {}
129
131
self .additional_settings = kwargs
132
+ self .enable_cuda_graph = enable_cuda_graph
130
133
self .strict = strict
131
134
self .allow_complex_guards_as_runtime_asserts = (
132
135
allow_complex_guards_as_runtime_asserts
@@ -142,7 +145,11 @@ def __init__(
142
145
self .run_info : Optional [tuple [Any , ...]] = None
143
146
self .state_dict_metadata : dict [str , torch .Size ] = {}
144
147
self ._store_state_dict_metadata ()
145
-
148
+ self .enable_weight_streaming = (
149
+ kwargs ["enable_weight_streaming" ]
150
+ if "enable_weight_streaming" in kwargs
151
+ else False
152
+ )
146
153
cls = self .__class__
147
154
self .__class__ = type (
148
155
self .original_model .__class__ .__name__ ,
@@ -193,7 +200,7 @@ def forward(a, b, c=0, d=0):
193
200
194
201
self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
195
202
196
- def _get_total_dynamic_shapes (self ) -> Union [ dict [str , Any ], None ] :
203
+ def _get_total_dynamic_shapes (self ) -> dict [str , Any ] | None :
197
204
if not self .arg_dynamic_shapes and not self .kwarg_dynamic_shapes :
198
205
return None
199
206
total_dynamic_shape = {}
@@ -266,15 +273,17 @@ def refit_gm(self) -> None:
266
273
MutableTorchTensorRTModule automatically catches weight value updates and call this function to refit the module.
267
274
If it fails to catch the changes, please call this function manually to update the TRT graph module.
268
275
"""
269
- self . original_model . to ( to_torch_device ( self . trt_device ))
276
+
270
277
if self .exp_program is None :
278
+ self .original_model .to (to_torch_device (self .trt_device ))
271
279
self .exp_program = self .get_exported_program ()
272
280
else :
273
281
self .exp_program ._state_dict = (
274
282
MutableTorchTensorRTModule ._transform_state_dict (
275
283
self .original_model .state_dict ()
276
284
)
277
285
)
286
+ self .exp_program .module ().to (to_torch_device (self .trt_device ))
278
287
self .gm = refit_module_weights (
279
288
self .gm ,
280
289
self .exp_program ,
@@ -284,7 +293,7 @@ def refit_gm(self) -> None:
284
293
in_place = True ,
285
294
)
286
295
287
- self .original_model .cpu ( )
296
+ self .original_model .to ( "cpu" )
288
297
torch .cuda .empty_cache ()
289
298
290
299
def get_exported_program (self ) -> torch .export .ExportedProgram :
@@ -324,8 +333,15 @@ def compile(self) -> None:
324
333
use_python_runtime = self .use_python_runtime ,
325
334
** self .additional_settings ,
326
335
)
327
- self .original_model .cpu ( )
336
+ self .original_model .to ( "cpu" )
328
337
torch .cuda .empty_cache ()
338
+ # torch_tensorrt.runtime.set_cudagraphs_mode(self.enable_cuda_graph)
339
+ # if self.enable_cuda_graph:
340
+ # self.gm = torch_tensorrt.runtime.enable_cudagraphs(self.gm)
341
+ if self .enable_weight_streaming :
342
+ self .weight_streaming_ctx = torch_tensorrt .runtime .weight_streaming (self .gm )
343
+ requested_budget = int (16 * 2 << 20 )
344
+ self .weight_streaming_ctx .device_budget = requested_budget
329
345
330
346
def _validate_inputs (self , * args : Any , ** kwargs : Any ) -> None :
331
347
@@ -446,14 +462,21 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
446
462
self ._store_state_dict_metadata ()
447
463
self .refit_state .set_state (RefitFlag .LIVE )
448
464
465
+ # weight_streaming_ctx = self.weight_streaming_ctx if self.enable_weight_streaming else None
449
466
result = self .gm (* args , ** kwargs )
450
467
# Storing inputs and outputs for verification when the state is unknown
451
468
self .run_info = (args , kwargs , result )
452
469
return result
453
470
454
- def to (self , device : str ) -> None :
455
- logger .warning ("Original PyTorch model is moved. CPU offload may failed." )
456
- self .original_model .to (device )
471
+ def to (self , * args : Any , ** kwargs : Any ) -> None :
472
+ logger .warning (
473
+ "Trying to move the original PyTorch model. This will cause CPU offloading failing and increase GPU memory usage."
474
+ + "If this is absolute necessary, please call module.pytorch_model.to(...)"
475
+ )
476
+
477
+ @property
478
+ def device (self ) -> torch .device :
479
+ return to_torch_device (self .trt_device )
457
480
458
481
def __deepcopy__ (self , memo : Any ) -> Any :
459
482
cls = self .__class__
0 commit comments