Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

float8 delayed scaling safety logic currently doesn't work with activation checkpointing #570

Open
vkuzo opened this issue Jul 30, 2024 · 2 comments
Labels

Comments

@vkuzo
Copy link
Contributor

vkuzo commented Jul 30, 2024

Our current delayed scaling API askd the user to call the sync_float8_amax_and_scale_history after each backward and before the optimizer step. This does not work on the first iteration if activation checkpointing is on, because the first backward calls the first forward, and an exception is thrown.

For now can workaround with the config override, but we need a better API design since we need to support AC. Example override:

config = Float8LinearConfig(
    cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
    cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
    cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
    enable_amax_init=False,
    enable_pre_and_post_forward=False,
)
convert_to_float8_training(model, config=config, ...)

example trace:

Root Cause (first observed failure):
[0]:
  time      : 2024-05-28_10:24:53
  host      : devgpu003.cco3.facebook.com
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 3789795)
  error_file: /tmp/torchelastic_9ib8xc3f/none_wmccvyr7/attempt_0/1/error.json
  traceback : Traceback (most recent call last):
    File "/data/users/vasiliy/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
      return f(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/torchtitan/train.py", line 315, in main
      loss.backward()
    File "/data/users/vasiliy/pytorch/torch/_tensor.py", line 523, in backward
      torch.autograd.backward(
    File "/data/users/vasiliy/pytorch/torch/autograd/__init__.py", line 284, in backward
      _engine_run_backward(
    File "/data/users/vasiliy/pytorch/torch/autograd/graph.py", line 767, in _engine_run_backward
      return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/pytorch/torch/autograd/function.py", line 302, in apply
      return user_fn(self, *args)
             ^^^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/float8_experimental/float8_experimental/float8_linear.py", line 122, in backward
      fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY = ctx.saved_tensors
                                                                ^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/pytorch/torch/utils/checkpoint.py", line 1115, in unpack_hook
      frame.recompute_fn(*args)
    File "/data/users/vasiliy/pytorch/torch/utils/checkpoint.py", line 1399, in recompute_fn
      fn(*args, **kwargs)
    File "/data/users/vasiliy/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
      return forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/torchtitan/torchtitan/models/llama/model.py", line 317, in forward
      h = x + self.attention(self.attention_norm(x), freqs_cis)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
      return forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/torchtitan/torchtitan/models/llama/model.py", line 186, in forward
      xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
                   ^^^^^^^^^^
    File "/data/users/vasiliy/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
      return forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/vasiliy/float8_experimental/float8_experimental/float8_linear.py", line 383, in forward
      self.float8_pre_forward(x)
    File "/data/users/vasiliy/float8_experimental/float8_experimental/float8_linear.py", line 361, in float8_pre_forward
      raise AssertionError(
  AssertionError: amaxes and scales not synced, please call `sync_float8_amax_and_scale_history` before forward

copied from pytorch-labs/float8_experimental#267

@yh8899
Copy link

yh8899 commented Nov 5, 2024

@vkuzo Hi, what do you mean "config override" ? Can you give same example?

@vkuzo
Copy link
Contributor Author

vkuzo commented Nov 5, 2024

sure, updated the issue summary to include example code for setting the config to avoid this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants