Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4430,7 +4430,15 @@ def run(runner, args, original_dir=None):
fullgraph=args.nopython,
mode=args.inductor_compile_mode,
)
runner.model_iter_fn = baseline_ctx(runner.model_iter_fn)
model_iter_fn = baseline_ctx(runner.model_iter_fn)

# needed to avoid error that causes inconsistent timing due to:
# Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards
def model_iter_fn_and_mark_step(*args, **kwargs):
torch.compiler.cudagraph_mark_step_begin()
model_iter_fn(*args, **kwargs)

runner.model_iter_fn = model_iter_fn_and_mark_step
optimize_ctx = torchao_optimize_ctx(args.quantization)
else:
optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
Expand Down
45 changes: 24 additions & 21 deletions userbenchmark/dynamo/dynamobench/torchao_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,48 +4,51 @@


def setup_baseline():
torch._dynamo.epilogue_fusion = False
from torchao.quantization.utils import recommended_inductor_config_setter

recommended_inductor_config_setter()
torch._dynamo.config.automatic_dynamic_shapes = False
torch._dynamo.config.force_parameter_static_shapes = False
torch._dynamo.config.cache_size_limit = 10000
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True


def torchao_optimize_ctx(quantization: str):
import torchao
from torchao.quantization.quant_api import (
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
autoquant,
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
quantize_,
)
from torchao.utils import unwrap_tensor_subclass

def inner(model_iter_fn: Callable):
def _torchao_apply(module: torch.nn.Module, example_inputs: Any):
if getattr(module, "_quantized", None) is None:
if quantization == "int8dynamic":
change_linear_weights_to_int8_dqtensors(module)
quantize_(
module,
int8_dynamic_activation_int8_weight(),
set_inductor_config=False,
)
elif quantization == "int8weightonly":
change_linear_weights_to_int8_woqtensors(module)
quantize_(module, int8_weight_only(), set_inductor_config=False)
elif quantization == "int4weightonly":
change_linear_weights_to_int4_woqtensors(module)
elif quantization == "autoquant":
torchao.autoquant(module, error_on_unseen=False)
quantize_(module, int4_weight_only(), set_inductor_config=False)
if quantization == "autoquant":
autoquant(module, error_on_unseen=False, set_inductor_config=False)
if isinstance(example_inputs, dict):
module(**example_inputs)
else:
module(*example_inputs)
from torchao.quantization.autoquant import AUTOQUANT_CACHE

assert (
len(AUTOQUANT_CACHE) > 0
), f"Err: found no autoquantizable layers in model {type(module)}, stopping autoquantization"
elif quantization == "noquant":
pass
if len(AUTOQUANT_CACHE) == 0:
raise Exception( # noqa: TRY002`
"NotAutoquantizable"
f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run"
)
else:
raise AssertionError(
f"Unsupposed quantization mode {quantization}."
)
unwrap_tensor_subclass(module)
setattr(module, "_quantized", True) # noqa: B010
model_iter_fn(module, example_inputs)

Expand Down