Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,18 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
input_specs_val = (
self.lower_setting.input_specs
if self.lower_setting.input_specs
else InputTensorSpec.from_tensors(input)
else (
InputTensorSpec.from_tensors_with_dynamic_batch_size(
input,
(
0,
self.lower_setting.max_batch_size,
self.lower_setting.max_batch_size,
),
)
if self.lower_setting.explicit_batch_dimension
else InputTensorSpec.from_tensors(input)
)
)

# Prepare algorithm selector and timing_cache for TRTInterpreter
Expand Down
7 changes: 7 additions & 0 deletions py/torch_tensorrt/fx/lower_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ class LowerSetting(LowerSettingBasic):
save_timing_cache: Save updated timing cache data into timing cache file if the timing
cache file is provided.
cuda_graph_batch_size (int): Cuda graph batch size, default to be -1.
preset_lowerer (str): when specified, use a preset logic to build the
instance of Lowerer. Refer to
`caffe2.torch.fb.model_transform.fx2trt.presets.LowererPresetsManager` on
how presets are applied. Refer to
`caffe2.torch.fb.model_transform.fx2trt.presets.ESUHMLowererPreset` on how
to add a preset.
"""

input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
Expand All @@ -79,3 +85,4 @@ class LowerSetting(LowerSettingBasic):
timing_cache_prefix: str = ""
save_timing_cache: bool = False
cuda_graph_batch_size: int = -1
preset_lowerer: str = ""
4 changes: 3 additions & 1 deletion py/torch_tensorrt/fx/passes/lower_basic_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def skip_folding_quant_dequant(node: torch.fx.Node):
return True
return False

const_split_mod = split_const_subgraphs(traced_mod, skip_folding_quant_dequant)
const_split_mod = split_const_subgraphs(
traced_mod, skip_folding_quant_dequant, device_for_folded_attrs="cuda"
)
const_split_mod.run_folding()
return const_split_mod

Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2576,5 +2576,6 @@ def test_all_acc_ops_registered(self):
acc_ops.new_ones,
acc_ops.einsum,
acc_ops.as_strided,
acc_ops.var,
},
)
15 changes: 15 additions & 0 deletions py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2864,3 +2864,18 @@ def as_strided(*, input, size, stride, storage_offset=0):
return torch.as_strided(
input=input, size=size, stride=stride, storage_offset=storage_offset
)


@register_acc_op_mapping(op_and_target=("call_function", torch.var))
@register_acc_op_mapping(
op_and_target=("call_method", "var"),
arg_replacement_tuples=[
("input", "input"),
("dim", "dim"),
("unbiased", "unbiased"),
("keepdim", "keepdim"),
],
)
@register_acc_op
def var(*, input, dim, unbiased, keepdim=False):
return torch.var(input=input, dim=dim, unbiased=unbiased, keepdim=keepdim)