Skip to content

FLUX FP8 Training Error #3111

@faruknane

Description

@faruknane

Hi,

I can't train the FLUX model in float8 using torch ao. The sanity check (inference) is done successfully. That means I am able to do 25 step inference in float8 and generate a good quality image. However, when it comes to the training / the backward process, the code gives me error. I can't fix it. Could you please identify the issue here?

A small code script regarding how the model is defined:

def LoadPipeline(dtype = torch.bfloat16):
    
    pipeline = FluxKontextPipeline.from_pretrained("/home/cropy/flux_kontext", 
                                            local_files_only=True,
                                            # quantization_config=pipeline_quant_config,
                                            torch_dtype=torch.bfloat16)

    pipeline.to("cuda")


    # quantize_(
    #     pipeline.transformer,
    #     float8_dynamic_activation_float8_weight(),
    # )

    quantize_(
        pipeline.vae,
        float8_dynamic_activation_float8_weight(),
    )
    quantize_(
        pipeline.text_encoder,
        float8_dynamic_activation_float8_weight(),
    )
    quantize_(
        pipeline.text_encoder_2,
        float8_dynamic_activation_float8_weight(),
    )
    
    pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune", fullgraph=True)
    pipeline.text_encoder = torch.compile(pipeline.text_encoder, mode="max-autotune", fullgraph=True)
    pipeline.text_encoder_2 = torch.compile(pipeline.text_encoder_2, mode="max-autotune", fullgraph=True)


    return pipeline

----- My Model -----
self.pipeline = LoadPipeline(self.target_dtype)
self.model = self.pipeline.transformer.to(self.target_dtype)


train_fp8_config = Float8LinearConfig.from_recipe_name("rowwise") # tried tensorwise as well
convert_to_float8_training(self.model, config=train_fp8_config, module_filter_fn=module_filter_fn)

The error:

E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1] failed while attempting to run meta for aten._scaled_mm.default
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1] Traceback (most recent call last):
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1]   File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2717, in _dispatch_impl
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1]     r = func(*args, **kwargs)
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1]         ^^^^^^^^^^^^^^^^^^^^^
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1]   File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_ops.py", line 829, in __call__
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1]     return self._op(*args, **kwargs)
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1]            ^^^^^^^^^^^^^^^^^^^^^^^^^
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1]   File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_meta_registrations.py", line 6448, in meta_scaled_mm
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1]     torch._check(
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1]   File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/__init__.py", line 1684, in _check
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1]     _check_with(RuntimeError, cond, message)
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1]   File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/__init__.py", line 1666, in _check_with
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1]     raise error_type(message_evaluated)
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1] RuntimeError: self must be row_major, got stride (1, 18432)
/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/autograd/graph.py:829: UserWarning: Error detected in matmul_with_hp_or_float8_argsBackward. Traceback of forward call that caused the error:
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/diffusers/models/transformers/transformer_flux.py", line 537, in forward
    return torch.utils.checkpoint.checkpoint(
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/diffusers/models/transformers/transformer_flux.py", line 481, in _forward
    norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/diffusers/models/normalization.py", line 168, in forward
    emb = self.linear(self.silu(emb))
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torchao/float8/float8_linear.py", line 264, in forward
    output = matmul_with_hp_or_float8_args.apply(
 (Triggered internally at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:122.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/mnt/c/Users/user_1/Desktop/rectifiedflow_main_image_editing/rectifiedflow_main_image_editing_flux/main.py", line 155, in <module>
    trainer.fit(model, train_dataloader, val_dataloader)
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 561, in fit
    call._call_and_handle_interrupt(
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 599, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1012, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1056, in _run_stage
    self.fit_loop.run()
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 216, in run
    self.advance()
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 455, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 150, in run
    self.advance(data_fetcher)
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 320, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 192, in run
    self._optimizer_step(batch_idx, closure)
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 270, in _optimizer_step
    call._call_lightning_module_hook(
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 176, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/core/module.py", line 1302, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/core/optimizer.py", line 154, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 239, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py", line 76, in optimizer_step
    return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/precision.py", line 123, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/optim/lr_scheduler.py", line 133, in wrapper
    return func.__get__(opt, opt.__class__)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/optim/optimizer.py", line 516, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/bitsandbytes/optim/optimizer.py", line 272, in step
    loss = closure()
           ^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/precision.py", line 109, in _wrap_closure
    closure_result = closure()
                     ^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 146, in __call__
    self._result = self.closure(*args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 131, in closure
    step_output = self._step_fn()
                  ^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 319, in _training_step
    training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 328, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 391, in training_step
    return self.lightning_module.training_step(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/user_1/Desktop/rectifiedflow_main_image_editing/rectifiedflow_main_image_editing_flux/models/models.py", line 71, in training_step
    loss = self.flow.loss_fn(self.apply_model, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/user_1/Desktop/rectifiedflow_main_image_editing/rectifiedflow_main_image_editing_flux/modules/diffusion/flows.py", line 46, in loss_fn
    pred = model_fn(xt, t=t, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/user_1/Desktop/rectifiedflow_main_image_editing/rectifiedflow_main_image_editing_flux/models/models.py", line 440, in apply_model
    v_pred = m(hidden_states=hidden_states, timestep=t, return_dict=False, **kwargs2)[0].clone()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/diffusers/models/transformers/transformer_flux.py", line 802, in forward
    encoder_hidden_states, hidden_states = block(
                                           ^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 749, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1871, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1846, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/__init__.py", line 2380, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 2418, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 109, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1199, in aot_module_simplified
    compiled_fn = AOTAutogradCache.load(
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py", line 1140, in load
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1184, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 576, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 836, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 1262, in aot_dispatch_autograd
    fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(
                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 318, in aot_dispatch_autograd_graph
    fx_g = _create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 55, in _create_graph
    fx_g = make_fx(
           ^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 2318, in wrapped
    return make_fx_tracer.trace(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 2250, in trace
    return self._trace_inner(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 2221, in _trace_inner
    t = dispatch_trace(
        ^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1254, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 850, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 703, in flatten_fn
    tree_out = root_fn(*tree_args)
               ^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1312, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
          ^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 720, in inner_fn
    outs = fn(*args)
           ^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 671, in joint_helper
    return _functionalized_f_helper(primals, tangents)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 419, in _functionalized_f_helper
    f_outs = fn(*f_args)
             ^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 286, in inner_fn_with_anomaly
    return inner_fn(*args)
           ^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 271, in inner_fn
    backward_out = torch.autograd.grad(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/autograd/__init__.py", line 452, in grad
    return handle_torch_function(
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/overrides.py", line 1725, in handle_torch_function
    result = mode.__torch_function__(public_api, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1360, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/autograd/__init__.py", line 503, in grad
    result = _engine_run_backward(
             ^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/autograd/graph.py", line 829, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/autograd/function.py", line 311, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torchao/float8/float8_linear.py", line 197, in backward
    grad_weight = torch.mm(
                  ^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torchao/float8/float8_training_tensor.py", line 374, in __torch_dispatch__
    return FLOAT8_OPS_TABLE[func](func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torchao/float8/float8_ops.py", line 385, in float8_mm
    tensor_out = addmm_float8_unwrapped(
                 ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torchao/float8/float8_ops.py", line 72, in addmm_float8_unwrapped
    output = torch._scaled_mm(
             ^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_subclasses/functional_tensor.py", line 511, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
                     ^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/utils/_stats.py", line 28, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1462, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 914, in proxy_call
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/utils/_stats.py", line 28, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1352, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2058, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1487, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2717, in _dispatch_impl
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_meta_registrations.py", line 6448, in meta_scaled_mm
    torch._check(
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/__init__.py", line 1684, in _check
    _check_with(RuntimeError, cond, message)
  File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/__init__.py", line 1666, in _check_with
    raise error_type(message_evaluated)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: self must be row_major, got stride (1, 18432)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

I tried to make every input tensor contiguous for nn.Linear layers using the code below, but it didn't solve the issue.

class SafeFloat8Linear(nn.Module):
    def __init__(self, float8_linear):
        super().__init__()
        self.inner = float8_linear

    def forward(self, x, *args, **kwargs):
        return self.inner(x.contiguous(), *args, **kwargs)


def wrap_float8_linears(module: nn.Module):
    """
    Recursively replace all nn.Linear modules with SafeFloat8Linear-wrapped versions.
    """
    for name, child in list(module.named_children()):
        # If it's a Linear (or a specific float8 Linear class), wrap it
        if isinstance(child, nn.Linear):
            setattr(module, name, SafeFloat8Linear(child))
        else:
            # Recurse into children
            wrap_float8_linears(child)
    return module

Metadata

Metadata

Labels

bugSomething isn't workingfloat8

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions