Skip to content

Commit 3bd44ca

Browse files
committedMar 23, 2025
Changed the Refitting test to disable CPU offload
1 parent 76e2564 commit 3bd44ca

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed
 

‎tests/py/dynamo/models/test_model_refit.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def test_mapping():
5555
debug=debug,
5656
min_block_size=min_block_size,
5757
immutable_weights=False,
58+
offload_module_to_cpu=False,
5859
)
5960
settings = trt_gm._run_on_acc_0.settings
6061
runtime = trt.Runtime(TRT_LOGGER)
@@ -106,6 +107,7 @@ def test_refit_one_engine_with_weightmap():
106107
debug=debug,
107108
min_block_size=min_block_size,
108109
immutable_weights=False,
110+
offload_module_to_cpu=False,
109111
)
110112

111113
new_trt_gm = refit_module_weights(
@@ -155,6 +157,7 @@ def test_refit_one_engine_no_map_with_weightmap():
155157
debug=debug,
156158
min_block_size=min_block_size,
157159
immutable_weights=False,
160+
offload_module_to_cpu=False,
158161
)
159162

160163
trt_gm._run_on_acc_0.weight_name_map = None
@@ -205,6 +208,7 @@ def test_refit_one_engine_with_wrong_weightmap():
205208
debug=debug,
206209
min_block_size=min_block_size,
207210
immutable_weights=False,
211+
offload_module_to_cpu=False,
208212
)
209213
# Manually Deleted all batch norm layer. This suppose to fail the fast refit
210214
trt_gm._run_on_acc_0.weight_name_map = {
@@ -262,6 +266,7 @@ def test_refit_one_engine_bert_with_weightmap():
262266
debug=debug,
263267
min_block_size=min_block_size,
264268
immutable_weights=False,
269+
offload_module_to_cpu=False,
265270
)
266271

267272
new_trt_gm = refit_module_weights(
@@ -294,7 +299,7 @@ def test_refit_one_engine_bert_with_weightmap():
294299
"TorchScript Frontend is not available",
295300
)
296301
@pytest.mark.unit
297-
def test_refit_one_engine_inline_runtime__with_weightmap():
302+
def test_refit_one_engine_inline_runtime_with_weightmap():
298303
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
299304
model = models.resnet18(pretrained=False).eval().to("cuda")
300305
model2 = models.resnet18(pretrained=True).eval().to("cuda")
@@ -315,6 +320,7 @@ def test_refit_one_engine_inline_runtime__with_weightmap():
315320
debug=debug,
316321
min_block_size=min_block_size,
317322
immutable_weights=False,
323+
offload_module_to_cpu=False,
318324
)
319325
torchtrt.save(trt_gm, trt_ep_path)
320326
trt_gm = torch.export.load(trt_ep_path)
@@ -360,6 +366,7 @@ def test_refit_one_engine_python_runtime_with_weightmap():
360366
debug=debug,
361367
min_block_size=min_block_size,
362368
immutable_weights=False,
369+
offload_module_to_cpu=False,
363370
)
364371

365372
new_trt_gm = refit_module_weights(
@@ -431,6 +438,7 @@ def forward(self, x):
431438
immutable_weights=False,
432439
torch_executed_ops=torch_executed_ops,
433440
reuse_cached_engines=False,
441+
offload_module_to_cpu=False,
434442
)
435443

436444
new_trt_gm = refit_module_weights(
@@ -479,6 +487,7 @@ def test_refit_one_engine_without_weightmap():
479487
debug=debug,
480488
min_block_size=min_block_size,
481489
immutable_weights=False,
490+
offload_module_to_cpu=False,
482491
)
483492

484493
new_trt_gm = refit_module_weights(
@@ -530,6 +539,7 @@ def test_refit_one_engine_bert_without_weightmap():
530539
debug=debug,
531540
min_block_size=min_block_size,
532541
immutable_weights=False,
542+
offload_module_to_cpu=False,
533543
)
534544

535545
new_trt_gm = refit_module_weights(
@@ -583,6 +593,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap():
583593
debug=debug,
584594
min_block_size=min_block_size,
585595
immutable_weights=False,
596+
offload_module_to_cpu=False,
586597
)
587598
torchtrt.save(trt_gm, trt_ep_path)
588599
trt_gm = torch.export.load(trt_ep_path)
@@ -628,6 +639,7 @@ def test_refit_one_engine_python_runtime_without_weightmap():
628639
debug=debug,
629640
min_block_size=min_block_size,
630641
immutable_weights=False,
642+
offload_module_to_cpu=False,
631643
)
632644

633645
new_trt_gm = refit_module_weights(
@@ -699,6 +711,7 @@ def forward(self, x):
699711
immutable_weights=False,
700712
torch_executed_ops=torch_executed_ops,
701713
reuse_cached_engines=False,
714+
offload_module_to_cpu=False,
702715
)
703716

704717
new_trt_gm = refit_module_weights(

0 commit comments

Comments
 (0)
Failed to load comments.