You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Our current delayed scaling API askd the user to call the sync_float8_amax_and_scale_history after each backward and before the optimizer step. This does not work on the first iteration if activation checkpointing is on, because the first backward calls the first forward, and an exception is thrown.
For now can workaround with the config override, but we need a better API design since we need to support AC. Example override:
Root Cause (first observed failure):
[0]:
time : 2024-05-28_10:24:53
host : devgpu003.cco3.facebook.com
rank : 1 (local_rank: 1)
exitcode : 1 (pid: 3789795)
error_file: /tmp/torchelastic_9ib8xc3f/none_wmccvyr7/attempt_0/1/error.json
traceback : Traceback (most recent call last):
File "/data/users/vasiliy/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/data/users/vasiliy/torchtitan/train.py", line 315, in main
loss.backward()
File "/data/users/vasiliy/pytorch/torch/_tensor.py", line 523, in backward
torch.autograd.backward(
File "/data/users/vasiliy/pytorch/torch/autograd/__init__.py", line 284, in backward
_engine_run_backward(
File "/data/users/vasiliy/pytorch/torch/autograd/graph.py", line 767, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/vasiliy/pytorch/torch/autograd/function.py", line 302, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File "/data/users/vasiliy/float8_experimental/float8_experimental/float8_linear.py", line 122, in backward
fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY = ctx.saved_tensors
^^^^^^^^^^^^^^^^^
File "/data/users/vasiliy/pytorch/torch/utils/checkpoint.py", line 1115, in unpack_hook
frame.recompute_fn(*args)
File "/data/users/vasiliy/pytorch/torch/utils/checkpoint.py", line 1399, in recompute_fn
fn(*args, **kwargs)
File "/data/users/vasiliy/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/vasiliy/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/vasiliy/torchtitan/torchtitan/models/llama/model.py", line 317, in forward
h = x + self.attention(self.attention_norm(x), freqs_cis)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/vasiliy/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/vasiliy/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/vasiliy/torchtitan/torchtitan/models/llama/model.py", line 186, in forward
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
^^^^^^^^^^
File "/data/users/vasiliy/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/vasiliy/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/vasiliy/float8_experimental/float8_experimental/float8_linear.py", line 383, in forward
self.float8_pre_forward(x)
File "/data/users/vasiliy/float8_experimental/float8_experimental/float8_linear.py", line 361, in float8_pre_forward
raise AssertionError(
AssertionError: amaxes and scales not synced, please call `sync_float8_amax_and_scale_history` before forward
Our current delayed scaling API askd the user to call the
sync_float8_amax_and_scale_history
after each backward and before the optimizer step. This does not work on the first iteration if activation checkpointing is on, because the first backward calls the first forward, and an exception is thrown.For now can workaround with the config override, but we need a better API design since we need to support AC. Example override:
example trace:
copied from pytorch-labs/float8_experimental#267
The text was updated successfully, but these errors were encountered: