Skip to content

Commit

Permalink
Fix inference attention mask (#142)
Browse files Browse the repository at this point in the history
* faster saving & inference

* Update llama.py

* Update save.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update mistral.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* fast inference

* Update llama.py

* Update save.py

* Update llama.py

* Mistral correct RoPE scaling

* Max sequence lengths

* Apache 2

* fast_linear_forward

* Update utils.py

* Update utils.py

* No print

* Update utils.py

* Update utils.py

* inference

* Update llama.py

* Fast inference RoPE

* Update llama.py

* Update llama.py

* RoPE

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* LoRA

* Fast LoRA saving

* Update llama.py

* hidden_states

* q_len == 1

* q_len issue

* Update mistral.py

* Update mistral.py

* incorrect inference

* Update to transformers 4.37

* Graceful FA2 error + torch 2.1.1

* Update mapper.py

* Update pyproject.toml

* Fix saving and bnb-4bit

* Update fast_lora.py

* Update fast_lora.py

* remove patching

* Update llama.py

* Update llama.py

* Update swiglu.py

* Repatch

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update llama.py

* Update fast_lora.py

* Update llama.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update swiglu.py

* Update fast_lora.py

* Update swiglu.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update save.py

* Update fast_lora.py

* Update utils.py

* Update llama.py

* Update fast_lora.py

* Update swiglu.py

* Update save.py

* Update save.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Revert "Update llama.py"

This reverts commit a208ec4.

* Update llama.py

* Works?

* Update pyproject.toml

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Swiglu

* Update swiglu.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update swiglu.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* attention_mask

* Update llama.py

* Update llama.py

* labels

* Update mistral.py

* Update llama.py

* attention mask

* Update save.py

* Update save.py

* Update mistral.py

* attention mask

* Update llama.py

* Update llama.py

* Update mistral.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update dpo.py

* Patch saving

* Update save.py

* Update save.py

* patch_saving_functions

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* print

* Mistral patch

* Update mistral.py

* Update save.py

* saving

* Update llama.py

* Update llama.py
  • Loading branch information
danielhanchen committed Jan 29, 2024
1 parent 206a9b6 commit 0562464
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,10 @@ def LlamaModel_fast_forward(

# Fix up attention mask by setting elements to 0
# Specifically for DPO
if self._has_no_labels and attention_mask is not None:
if self._has_no_labels and attention_mask is not None and \
attention_mask.shape[1] == seq_length:
# Careful for inference the attention_mask is size (1, kv_seq_len)
# Whilst the input_embeds is size (1, 1, 4096)
inputs_requires_grad = inputs_embeds.requires_grad
if inputs_requires_grad: inputs_embeds.requires_grad_(False)
inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2)
Expand Down

0 comments on commit 0562464

Please sign in to comment.