Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error message when using ORPO fine-tuning #601

Open
MRQJsfhf opened this issue Jun 6, 2024 · 1 comment
Open

Error message when using ORPO fine-tuning #601

MRQJsfhf opened this issue Jun 6, 2024 · 1 comment

Comments

@MRQJsfhf
Copy link

MRQJsfhf commented Jun 6, 2024

When using ORPO to fine-tune mistral-7b-instruct-v0.3-bnb-4bit, after clicking orpo_trainer.train() to start, the following error message appears:

`--------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In[15], line 1
----> 1 orpo_trainer.train()

File /usr/local/lib/python3.10/site-packages/transformers/trainer.py:1885, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1883 hf_hub_utils.enable_progress_bars()
1884 else:
-> 1885 return inner_training_loop(
1886 args=args,
1887 resume_from_checkpoint=resume_from_checkpoint,
1888 trial=trial,
1889 ignore_keys_for_eval=ignore_keys_for_eval,
1890 )

File :348, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

File /usr/local/lib/python3.10/site-packages/transformers/trainer.py:3238, in Trainer.training_step(self, model, inputs)
3235 return loss_mb.reduce_mean().detach().to(self.args.device)
3237 with self.compute_loss_context_manager():
-> 3238 loss = self.compute_loss(model, inputs)
3240 del inputs
3241 torch.cuda.empty_cache()

File /usr/local/lib/python3.10/site-packages/trl/trainer/orpo_trainer.py:786, in ORPOTrainer.compute_loss(self, model, inputs, return_outputs)
783 compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
785 with compute_loss_context_manager():
--> 786 loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
788 # force log the metrics
789 self.store_metrics(metrics, train_eval="train")

File /usr/local/lib/python3.10/site-packages/trl/trainer/orpo_trainer.py:746, in ORPOTrainer.get_batch_loss_metrics(self, model, batch, train_eval)
737 """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
738 metrics = {}
740 (
741 policy_chosen_logps,
742 policy_rejected_logps,
743 policy_chosen_logits,
744 policy_rejected_logits,
745 policy_nll_loss,
--> 746 ) = self.concatenated_forward(model, batch)
748 losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
749 policy_chosen_logps, policy_rejected_logps
750 )
751 # full ORPO loss

File /usr/local/lib/python3.10/site-packages/trl/trainer/orpo_trainer.py:686, in ORPOTrainer.concatenated_forward(self, model, batch)
676 len_chosen = batch["chosen_labels"].shape[0]
678 model_kwargs = (
679 {
680 "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
(...)
683 else {}
684 )
--> 686 outputs = model(
687 concatenated_batch["concatenated_input_ids"],
688 attention_mask=concatenated_batch["concatenated_attention_mask"],
689 use_cache=False,
690 **model_kwargs,
691 )
692 all_logits = outputs.logits
694 def cross_entropy_loss(logits, labels):

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File /usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py:822, in convert_outputs_to_fp32..forward(*args, **kwargs)
821 def forward(*args, **kwargs):
--> 822 return model_forward(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py:810, in ConvertOutputsToFp32.call(self, *args, **kwargs)
809 def call(self, *args, **kwargs):
--> 810 return convert_to_fp32(self.model_forward(*args, **kwargs))

File /usr/local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator..decorate_autocast(*args, **kwargs)
13 @functools.wraps(func)
14 def decorate_autocast(*args, **kwargs):
15 with autocast_instance:
---> 16 return func(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py:822, in convert_outputs_to_fp32..forward(*args, **kwargs)
821 def forward(*args, **kwargs):
--> 822 return model_forward(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py:810, in ConvertOutputsToFp32.call(self, *args, **kwargs)
809 def call(self, *args, **kwargs):
--> 810 return convert_to_fp32(self.model_forward(*args, **kwargs))

File /usr/local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator..decorate_autocast(*args, **kwargs)
13 @functools.wraps(func)
14 def decorate_autocast(*args, **kwargs):
15 with autocast_instance:
---> 16 return func(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py:822, in convert_outputs_to_fp32..forward(*args, **kwargs)
821 def forward(*args, **kwargs):
--> 822 return model_forward(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py:810, in ConvertOutputsToFp32.call(self, *args, **kwargs)
809 def call(self, *args, **kwargs):
--> 810 return convert_to_fp32(self.model_forward(*args, **kwargs))

File /usr/local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator..decorate_autocast(*args, **kwargs)
13 @functools.wraps(func)
14 def decorate_autocast(*args, **kwargs):
15 with autocast_instance:
---> 16 return func(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/unsloth/models/llama.py:883, in PeftModelForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
870 def PeftModelForCausalLM_fast_forward(
871 self,
872 input_ids=None,
(...)
881 **kwargs,
882 ):
--> 883 return self.base_model(
884 input_ids=input_ids,
885 causal_mask=causal_mask,
886 attention_mask=attention_mask,
887 inputs_embeds=inputs_embeds,
888 labels=labels,
889 output_attentions=output_attentions,
890 output_hidden_states=output_hidden_states,
891 return_dict=return_dict,
892 **kwargs,
893 )

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File /usr/local/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:179, in BaseTuner.forward(self, *args, **kwargs)
178 def forward(self, *args: Any, **kwargs: Any):
--> 179 return self.model.forward(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/site-packages/unsloth/models/mistral.py:213, in MistralForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs)
205 outputs = LlamaModel_fast_forward_inference(
206 self,
207 input_ids,
(...)
210 attention_mask = attention_mask,
211 )
212 else:
--> 213 outputs = self.model(
214 input_ids=input_ids,
215 causal_mask=causal_mask,
216 attention_mask=attention_mask,
217 position_ids=position_ids,
218 past_key_values=past_key_values,
219 inputs_embeds=inputs_embeds,
220 use_cache=use_cache,
221 output_attentions=output_attentions,
222 output_hidden_states=output_hidden_states,
223 return_dict=return_dict,
224 )
225 pass
227 hidden_states = outputs[0]

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File /usr/local/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/site-packages/unsloth/models/llama.py:651, in LlamaModel_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs)
648 past_key_value = past_key_values[idx] if past_key_values is not None else None
650 if offloaded_gradient_checkpointing:
--> 651 hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply(
652 decoder_layer,
653 hidden_states,
654 causal_mask,
655 attention_mask,
656 position_ids,
657 past_key_values,
658 output_attentions,
659 use_cache,
660 )[0]
662 elif gradient_checkpointing:
663 def create_custom_forward(module):

File /usr/local/lib/python3.10/site-packages/torch/autograd/function.py:553, in Function.apply(cls, *args, **kwargs)
550 if not torch._C._are_functorch_transforms_active():
551 # See NOTE: [functorch vjp and autograd interaction]
552 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 553 return super().apply(*args, **kwargs) # type: ignore[misc]
555 if not is_setup_ctx_defined:
556 raise RuntimeError(
557 "In order to use an autograd.Function with functorch transforms "
558 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
559 "staticmethod. For more details, please see "
560 "https://pytorch.org/docs/master/notes/extending.func.html"
561 )

File /usr/local/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py:115, in custom_fwd..decorate_fwd(*args, **kwargs)
113 if cast_inputs is None:
114 args[0]._fwd_used_autocast = torch.is_autocast_enabled()
--> 115 return fwd(*args, **kwargs)
116 else:
117 autocast_context = torch.is_autocast_enabled()

File /usr/local/lib/python3.10/site-packages/unsloth/models/_utils.py:385, in Unsloth_Offloaded_Gradient_Checkpointer.forward(ctx, forward_function, hidden_states, *args)
383 saved_hidden_states = hidden_states.to("cpu", non_blocking = True)
384 with torch.no_grad():
--> 385 output = forward_function(hidden_states, *args)
386 ctx.save_for_backward(saved_hidden_states)
387 ctx.forward_function = forward_function

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File /usr/local/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/site-packages/unsloth/models/llama.py:434, in LlamaDecoderLayer_fast_forward(self, hidden_states, causal_mask, attention_mask, position_ids, past_key_value, output_attentions, use_cache, padding_mask, *args, **kwargs)
432 residual = hidden_states
433 hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
--> 434 hidden_states, self_attn_weights, present_key_value = self.self_attn(
435 hidden_states=hidden_states,
436 causal_mask=causal_mask,
437 attention_mask=attention_mask,
438 position_ids=position_ids,
439 past_key_value=past_key_value,
440 output_attentions=output_attentions,
441 use_cache=use_cache,
442 padding_mask=padding_mask,
443 )
444 hidden_states = residual + hidden_states
446 # Fully Connected

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File /usr/local/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/site-packages/unsloth/models/mistral.py:129, in MistralAttention_fast_forward(self, hidden_states, causal_mask, attention_mask, position_ids, past_key_value, output_attentions, use_cache, padding_mask, *args, **kwargs)
126 pass
127 pass
--> 129 A = xformers_attention(Q, K, V, attn_bias = causal_mask)
130 A = A.view(bsz, q_len, n_heads, head_dim)
132 elif HAS_FLASH_ATTENTION and attention_mask is None:

File /usr/local/lib/python3.10/site-packages/xformers/ops/fmha/init.py:268, in memory_efficient_attention(query, key, value, attn_bias, p, scale, op, output_dtype)
156 def memory_efficient_attention(
157 query: torch.Tensor,
158 key: torch.Tensor,
(...)
165 output_dtype: Optional[torch.dtype] = None,
166 ) -> torch.Tensor:
167 """Implements the memory-efficient attention mechanism following
168 "Self-Attention Does Not Need O(n^2) Memory" <[http://arxiv.org/abs/2112.05682>](http://arxiv.org/abs/2112.05682%3E%60).
169
(...)
266 :return: multi-head attention Tensor with shape [B, Mq, H, Kv]
267 """
--> 268 return _memory_efficient_attention(
269 Inputs(
270 query=query,
271 key=key,
272 value=value,
273 p=p,
274 attn_bias=attn_bias,
275 scale=scale,
276 output_dtype=output_dtype,
277 ),
278 op=op,
279 )

File /usr/local/lib/python3.10/site-packages/xformers/ops/fmha/init.py:387, in _memory_efficient_attention(inp, op)
382 def _memory_efficient_attention(
383 inp: Inputs, op: Optional[AttentionOp] = None
384 ) -> torch.Tensor:
385 # fast-path that doesn't require computing the logsumexp for backward computation
386 if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):
--> 387 return _memory_efficient_attention_forward(
388 inp, op=op[0] if op is not None else None
389 )
391 output_shape = inp.normalize_bmhk()
392 return _fMHA.apply(
393 op, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale
394 ).reshape(output_shape)

File /usr/local/lib/python3.10/site-packages/xformers/ops/fmha/init.py:403, in _memory_efficient_attention_forward(inp, op)
401 output_shape = inp.normalize_bmhk()
402 if op is None:
--> 403 op = _dispatch_fw(inp, False)
404 else:
405 _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)

File /usr/local/lib/python3.10/site-packages/xformers/ops/fmha/dispatch.py:125, in _dispatch_fw(inp, needs_gradient)
116 def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]:
117 """Computes the best operator for forward
118
119 Raises:
(...)
123 AttentionOp: The best operator for the configuration
124 """
--> 125 return _run_priority_list(
126 "memory_efficient_attention_forward",
127 _dispatch_fw_priority_list(inp, needs_gradient),
128 inp,
129 )

File /usr/local/lib/python3.10/site-packages/xformers/ops/fmha/dispatch.py:65, in _run_priority_list(name, priority_list, inp)
63 for op, not_supported in zip(priority_list, not_supported_reasons):
64 msg += "\n" + _format_not_supported_reasons(op, not_supported)
---> 65 raise NotImplementedError(msg)

NotImplementedError: No operator found for memory_efficient_attention_forward with inputs:
query : shape=(8, 569, 8, 4, 128) (torch.bfloat16)
key : shape=(8, 569, 8, 4, 128) (torch.bfloat16)
value : shape=(8, 569, 8, 4, 128) (torch.bfloat16)
attn_bias : <class 'xformers.ops.fmha.attn_bias.LowerTriangularMask'>
p : 0.0
flshattF@0.0.0 is not supported because:
xFormers wasn't build with CUDA support
operator wasn't built - see python -m xformers.info for more info
cutlassF is not supported because:
xFormers wasn't build with CUDA support
operator wasn't built - see python -m xformers.info for more info
smallkF is not supported because:
max(query.shape[-1] != value.shape[-1]) > 32
xFormers wasn't build with CUDA support
dtype=torch.bfloat16 (supported: {torch.float32})
attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMask'>
operator wasn't built - see python -m xformers.info for more info
operator does not support BMGHK format
unsupported embed per head: ### ### 128`

@danielhanchen
Copy link
Contributor

Oh you need to update xformers!
Do pip install --upgrade "xformers<0.0.26" for torch 2.2 or lower, and pip install --upgrade xformers for torch 2.3 and above. If that does not work, try

!pip install -U xformers --index-url https://download.pytorch.org/whl/cu121
!pip install "unsloth[kaggle-new] @ git+https://github.com/unslothai/unsloth.git"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants