-
Notifications
You must be signed in to change notification settings - Fork 256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
QLoRA #478
QLoRA #478
Conversation
✅ Deploy Preview for torchtune-preview ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
torchtune/modules/peft/peft_utils.py
Outdated
hook_name: str | ||
|
||
|
||
def _register_lora_hooks( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice general APIs. If we scrap the LoRA merge state dict hooks do we still intend to keep the design here the same? (No strong preference from my end, mainly just curious)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, just scrapping these for now per our discussion and we can re-introduce if needed.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/478
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f8cf66d with merge base ba7289b (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
lora_attn_modules: ['q_proj', 'v_proj'] | ||
apply_lora_to_mlp: False | ||
apply_lora_to_output: False | ||
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a conscious change to lora single device?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for testing. For QLoRA we may want to offer a different config entirely
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a bit more detail to explain why this is changed:
We currently only support NF4 linears inside of our LoRALinears. From @rohan-varma and my research there is some ambiguity around whether QLoRA actually applies the quantization to all linear layers or just to the ones wrapped in LoRA.
The conclusion we reached was that the original QLoRA repo only applies NF4 quantization to base model weights inside of LoRA layers (so any other linear layers in the model are not quantized, minimal script to repro this is here). But for lit-gpt they actually replace every linear layer (ref from deep inside Fabric).
Since we are using lit-gpt as our comparison, we ideally want to set all linear layers to use NF4 to match their scheme. In our case that also means setting them all to use LoRA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to what @ebsmothers. Due to this, it might make sense to offer a QLoRA specific config cc @kartikayk @ebsmothers . LMK your thoughts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to this, it might make sense to offer a QLoRA specific config
This makes sense
Sorry, I didn't follow the comments above. So we're saying that the original paper does NOT quantize any base model weights other than within the LoRA layer? Is there a reason for this? This is a bit surprising to me since it's different from my understanding. But I'm also not fully ramped up on QLORA. Anyways, seems like we're applying it to ALL linear layers but then do this by replacing the layer with LoRA layer? That seems a bit counterintuitive to me? Why does quanitization need to be tied to the LoRA layer . I would imagine we can just do a loop through all of the modules and then do a to_nf4
for the linear layers (caveat - I haven't looked at the code below)?
_component_: torchtune.utils.FullModelTorchTuneCheckpointer | ||
checkpoint_dir: /home/rvarm1/local/dev/assets/ | ||
checkpoint_files: [llama2-bf16-latest.pt] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd recommend just testing this as a regular user since the consolidated.00.pth
file is bf16. The MetaCheckpointer should work OOTB for you
"FrozenNF4Linear", | ||
] | ||
|
||
def reparametrize_as_bf16_state_dict_post_hook( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this in __init__.py
? Can we move this to a utils file or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep - will take care of all code organization, lints, tests etc
A state_dict hook that replaces nf4 tensors with their restored | ||
bf16 weight and optionally offloads the restored weight to CPU. | ||
|
||
This function is meant to be used with PyTorch's ``nn.Module._register_state_dict_hook``, i.e. | ||
>>> m = MyModule() | ||
>>> m._register_state_dict_hook(reparametrize_as_bf16_state_dict_post_hook) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think these hooks are that well understood. When I google these, the first link is to pytorch forums which dissuades their use calling them private APIs. I think this has changed since then. But nontheless I don't think a lot of people udnerstand how these work. This doc string should include very detailed information of when this hook gets called, what impact it has on the state dict etc. Especially since this is a public API available to users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some docs. But actually, we could maybe even make this private to users and just always register it where appropriate.
@@ -34,6 +34,7 @@ def lora_llama2_7b( | |||
lora_rank: int = 8, | |||
lora_alpha: float = 16, | |||
max_batch_size: Optional[int] = None, | |||
quantize_base: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering if we also wanna add a top-level qlora builder (basically just a partial with quantize_base=True
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, shall we add this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My vote would be for yes. I don't see it in the latest version (unless I missed it), can we add it? Also thinking about how best to expose LoRA in our configs (see my other comment), do you think it makes sense to expose e.g. qlora_llama2_7b
there? Tbh I'm a bit torn cause I think it somewhat obscures how easy it is to switch LoRA <-> QLoRA with just the quantize_base
flag. Either way, not a blocker for this PR.
# TODO (rohan-varma): not quantizing output_proj as NF4Tensor quantization of large weight leads to large increase | ||
# in reserved memory. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you give more details on this? I didn't fully understand what was happening here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Essentially, the current implementation in torchAO allocates intermediates when doing the quantization process. For very large tensors, the intermediates actually result in a large memory spike, so we're working around this for now by just not quantizing this layer.
@@ -96,6 +96,7 @@ def __init__(self, cfg: DictConfig) -> None: | |||
self.total_training_steps = 0 | |||
|
|||
self._resume_from_checkpoint = cfg.resume_from_checkpoint | |||
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this! Now only one recipe remains.. (though it's the one that actually needs this feature the least)
*args: Tuple[Any, ...], | ||
offload_to_cpu: bool = True, | ||
**kwargs: Dict[Any, Any], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dumb q but any particular reason offload_to_cpu
is between *args
and **kwargs
? I'm not used to seeing this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I'm not sure of the best way to design this sort of API. This API is meant to be a function that's passed into _register_state_dict_hook
, and the function is invoked with some arguments that we don't use when taking state_dict in nn.Module. Since these are unused arguments but still provided to the function, I need to take them as unused with *args and **kwargs. Though, I want to offer offload_to_cpu
as a keyword arg with a default, so I can't put it before *args. So this is what I was able to come up with, not sure if there's somethign better.
|
||
from torchao.dtypes.nf4tensor import linear_nf4 | ||
from torchtune.modules.low_precision import ( # noqa: F401 | ||
_register_nf4_dispatch_ops, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to understand how this works: so basically we import _register_nf4_dispatch_ops
at the first import of lora.py, which will then ensure any NF4Tensor used henceforth in the same run will have all those overrides in the ops table? Assuming that is the correct understanding, is there a way to do this with less of a hammer? (I don't think it matters for this particular PR, just curious if we try to extend things down the line)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, when this is imported these ops get registered on the NF4Tensor.
is there a way to do this with less of a hammer
Hmm, not sure what would be less of a hammer - I thought this was already pretty light in that we just handle importing something and the registration automatically happens, we don't have to call any specific register_ops
API or anything like that. Did you have any example of what you'd like this to be?
nf4_tensor = args[0][0] | ||
sub_tensor = args[0][1] | ||
assert sub_tensor.dtype == torch.bfloat16 | ||
return to_nf4(nf4_tensor.get_original_weight().sub_(sub_tensor)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe you had a comment yesterday about whether we were doing this in the best way, but I didn't fully understand it. What was that about?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I don't think we need add and sub anymore, at least for now, since when we're merging the lora_state_dict, the state_dict is already in full bf16.
|
||
|
||
@nf4_tensor_impl([torch.ops.aten.add_.Tensor]) | ||
def add_bf16_tensor(func, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will apply for both weight merge and adding bias, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not needed anymore - see other comment
|
||
# Modules from CausalSelfAttention that LoRA can be applied to | ||
LORA_ATTN_MODULES = Literal["q_proj", "k_proj", "v_proj", "output_proj"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove, you're already importing this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, might'be been bad merge
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a really impressive PR. I left a bunch of comments but no huge concerns from my side. As always, make sure to clean up any debug code before landing. Otherwise looks good!
Context
Changelog
I've broken down the set of changes into different components since this is quite a large change:
LoRA related changes
LoRALinear
. These are aclone
implementation to make copies of the tensor (needed to register the Nf4tensor as parameter) as well as an inplace copy method which is needed when loading state_dict.LoRA related changes
quantize_base
and plumbing throughout the LoRA builders.Checkpointing changes
to_copy
dispatch implementation takes care of the above.Caveats and follow ups needed
quantize_base=False
Test plan
Bunch of different testing to do so I've broken it down again.
Unittests
lora_llama2
with quantize_base=True is unittested for save/load with expected dtypes, ensuring forwarad parity after save / load.Checkpointing
Save checkpoint with weight merging:
tune lora_finetune_single_device --config recipes/configs/alpaca_llama2_lora_finetune_single_device.yaml --override model_checkpoint=/home/rvarm1/local/dev/assets/llama2-bf16-latest seed=18 tokenizer_checkpoint=/home/rvarm1/local/dev/assets/tokenizer.model output_dir=/tmp/lora_debug device=cuda batch_size=2 enable_activation_checkpointing=True full_bf16=True quantize_base=True max_steps_per_epoch=1 &> out
Results in
Model checkpoint of size 12852 MB saved to /tmp/lora_debug/model_0.ckpt
.Load that back in for continued training works:
tune lora_finetune_single_device --config recipes/configs/alpaca_llama2_lora_finetune_single_device.yaml --override model_checkpoint=/home/rvarm1/local/dev/assets/llama2-bf16-latest seed=18 tokenizer_checkpoint=/home/rvarm1/local/dev/assets/tokenizer.model output_dir=/tmp/lora_debug device=cuda batch_size=2 enable_activation_checkpointing=True full_bf16=True quantize_base=True max_steps_per_epoch=1 model_checkpoint=/tmp/lora_debug/model_0.ckpt
.Memory analysis
tune lora_finetune_single_device --config recipes/configs/alpaca_llama2_lora_finetune_single_device.yaml --override model_checkpoint=/home/rvarm1/local/dev/assets/llama2-bf16-latest seed=18 tokenizer_checkpoint=/home/rvarm1/local/dev/assets/tokenizer.model output_dir=/tmp/lora_debug device=cuda batch_size=2 enable_activation_checkpointing=True full_bf16=True
Correctness comparison to LoRA -
Eval results:
tune lora_finetune_single_device --config recipes/configs/alpaca_llama2_lora_finetune_single_device.yaml dtype=bf16 device=cuda
withquantize_base=True
. Ran eval withtune eval --config torchtune/_cli/eval_configs/default_eval_config.yaml model_checkpoint=/home/rvarm1/local/dev/assets/qlora_trained.pt
, results:truthfulqa_mc2: {'acc,none': 0.491684249068877, 'acc_stderr,none': 0.014723538004649834, 'alias': 'truthfulqa_mc2'}
tune --nproc_per_node 8 lora_finetune_distributed --config recipes/configs/alpaca_llama2_lora_finetune_distributed.yaml dtype=bf16 batch_size=32 enable_activation_checkpointing=True &> out_dist &
. eval result:truthfulqa_mc2: {'acc,none': 0.46393356726865514, 'acc_stderr,none': 0.014448827209276354, 'alias': 'truthfulqa_mc2'}
r=256, alpha=32
with 0.5055 on truthfulqa_mc2), with batch_size=128 and micro batch size = 1truthfulqa_mc2: {'acc,none': 0.48860782692137417, 'acc_stderr,none': 0.014737964298503407, 'alias': 'truthfulqa_mc2'}
, compared to their 0.5055.