@@ -55,6 +55,7 @@ def test_mapping():
55
55
debug = debug ,
56
56
min_block_size = min_block_size ,
57
57
immutable_weights = False ,
58
+ offload_module_to_cpu = False ,
58
59
)
59
60
settings = trt_gm ._run_on_acc_0 .settings
60
61
runtime = trt .Runtime (TRT_LOGGER )
@@ -106,6 +107,7 @@ def test_refit_one_engine_with_weightmap():
106
107
debug = debug ,
107
108
min_block_size = min_block_size ,
108
109
immutable_weights = False ,
110
+ offload_module_to_cpu = False ,
109
111
)
110
112
111
113
new_trt_gm = refit_module_weights (
@@ -155,6 +157,7 @@ def test_refit_one_engine_no_map_with_weightmap():
155
157
debug = debug ,
156
158
min_block_size = min_block_size ,
157
159
immutable_weights = False ,
160
+ offload_module_to_cpu = False ,
158
161
)
159
162
160
163
trt_gm ._run_on_acc_0 .weight_name_map = None
@@ -205,6 +208,7 @@ def test_refit_one_engine_with_wrong_weightmap():
205
208
debug = debug ,
206
209
min_block_size = min_block_size ,
207
210
immutable_weights = False ,
211
+ offload_module_to_cpu = False ,
208
212
)
209
213
# Manually Deleted all batch norm layer. This suppose to fail the fast refit
210
214
trt_gm ._run_on_acc_0 .weight_name_map = {
@@ -262,6 +266,7 @@ def test_refit_one_engine_bert_with_weightmap():
262
266
debug = debug ,
263
267
min_block_size = min_block_size ,
264
268
immutable_weights = False ,
269
+ offload_module_to_cpu = False ,
265
270
)
266
271
267
272
new_trt_gm = refit_module_weights (
@@ -294,7 +299,7 @@ def test_refit_one_engine_bert_with_weightmap():
294
299
"TorchScript Frontend is not available" ,
295
300
)
296
301
@pytest .mark .unit
297
- def test_refit_one_engine_inline_runtime__with_weightmap ():
302
+ def test_refit_one_engine_inline_runtime_with_weightmap ():
298
303
trt_ep_path = os .path .join (tempfile .gettempdir (), "compiled.ep" )
299
304
model = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
300
305
model2 = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
@@ -315,6 +320,7 @@ def test_refit_one_engine_inline_runtime__with_weightmap():
315
320
debug = debug ,
316
321
min_block_size = min_block_size ,
317
322
immutable_weights = False ,
323
+ offload_module_to_cpu = False ,
318
324
)
319
325
torchtrt .save (trt_gm , trt_ep_path )
320
326
trt_gm = torch .export .load (trt_ep_path )
@@ -360,6 +366,7 @@ def test_refit_one_engine_python_runtime_with_weightmap():
360
366
debug = debug ,
361
367
min_block_size = min_block_size ,
362
368
immutable_weights = False ,
369
+ offload_module_to_cpu = False ,
363
370
)
364
371
365
372
new_trt_gm = refit_module_weights (
@@ -431,6 +438,7 @@ def forward(self, x):
431
438
immutable_weights = False ,
432
439
torch_executed_ops = torch_executed_ops ,
433
440
reuse_cached_engines = False ,
441
+ offload_module_to_cpu = False ,
434
442
)
435
443
436
444
new_trt_gm = refit_module_weights (
@@ -479,6 +487,7 @@ def test_refit_one_engine_without_weightmap():
479
487
debug = debug ,
480
488
min_block_size = min_block_size ,
481
489
immutable_weights = False ,
490
+ offload_module_to_cpu = False ,
482
491
)
483
492
484
493
new_trt_gm = refit_module_weights (
@@ -530,6 +539,7 @@ def test_refit_one_engine_bert_without_weightmap():
530
539
debug = debug ,
531
540
min_block_size = min_block_size ,
532
541
immutable_weights = False ,
542
+ offload_module_to_cpu = False ,
533
543
)
534
544
535
545
new_trt_gm = refit_module_weights (
@@ -583,6 +593,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap():
583
593
debug = debug ,
584
594
min_block_size = min_block_size ,
585
595
immutable_weights = False ,
596
+ offload_module_to_cpu = False ,
586
597
)
587
598
torchtrt .save (trt_gm , trt_ep_path )
588
599
trt_gm = torch .export .load (trt_ep_path )
@@ -628,6 +639,7 @@ def test_refit_one_engine_python_runtime_without_weightmap():
628
639
debug = debug ,
629
640
min_block_size = min_block_size ,
630
641
immutable_weights = False ,
642
+ offload_module_to_cpu = False ,
631
643
)
632
644
633
645
new_trt_gm = refit_module_weights (
@@ -699,6 +711,7 @@ def forward(self, x):
699
711
immutable_weights = False ,
700
712
torch_executed_ops = torch_executed_ops ,
701
713
reuse_cached_engines = False ,
714
+ offload_module_to_cpu = False ,
702
715
)
703
716
704
717
new_trt_gm = refit_module_weights (
0 commit comments