Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QLoRA #478

Merged
merged 34 commits into from
Mar 22, 2024
Merged

QLoRA #478

merged 34 commits into from
Mar 22, 2024

Conversation

rohan-varma
Copy link
Member

@rohan-varma rohan-varma commented Mar 9, 2024

Context

  • This PR enables e2e support for training with QLoRA technique, in which we implement ideas from https://arxiv.org/abs/2305.14314 specifically for single device LoRA finetuning use cases.

Changelog

I've broken down the set of changes into different components since this is quite a large change:

LoRA related changes
  • I added 2 new ops via torch__dispatch which represent a minimal set of changes needed to enable our usage of NF4Tensor as a parameter in LoRALinear. These are a clone implementation to make copies of the tensor (needed to register the Nf4tensor as parameter) as well as an inplace copy method which is needed when loading state_dict.
  • Inplace copy works by creating a reference nf4 tensor from the incoming bf16 tensor. Then, we set all attributes for the destination tensor to attributes for the bf16 tensor, and test that this indeed performs the proper copy.
LoRA related changes
    • Separate creating the base weight and bias into a helper function
  • QLoRA does not support biases at the moment
  • In forward, call regular F.linear or NF4's linear operator accordingly
  • Add quantize_base and plumbing throughout the LoRA builders.
Checkpointing changes
  • For load, we currently support taking in a bf16 checkpoint and quantizing it on the fly to NF4tensor.
  • For loading checkpoint, to_copy dispatch implementation takes care of the above.
  • For saving checkpoint, when saving full weights, and supporting LoRA weight merge, we implement inplace add and sub torch dispatch operators. This ensures that we use the bf16 restored weight when adding / subtracting the LoRA adapter deltas.
  • We add a model-level state_dict post hook to reparametrize the model as bf16 and offload the bf16 parameters to CPU on the fly so as to not increase peak memory usage while checkpointing.
Caveats and follow ups needed
  • The nf4 dispatch ops should live in AO and be thoroughly tested there. We'll upstream as part of follow up work
  • Applying QLoRA to output proj results in higher memory usage, due to NF4tensor high reserved memory issue when working with very large tensors. So, not enabling quantize_base for output_proj of model yet, and explicitly passing in quantize_base=False
  • Checkpoints saved are all in bf16. Need to investigate if there is any use in just saving the nf4 tensors themselves, maybe for efficiency reasons when loading checkpoint back in to torchtune for continued training.
  • Docstrings and tutorials need to be updated appropriately. Will do this in follow up PRs.
  • FP32 compute dtype support.

Test plan

Bunch of different testing to do so I've broken it down again.

Unittests
  • LoRALinear tests are updated for quantize_base=True cases
  • LoRALinear save/load are explicitly tested in QLoRA case
  • LoRA builder tests are updated
  • Model returned by lora_llama2 with quantize_base=True is unittested for save/load with expected dtypes, ensuring forwarad parity after save / load.
Checkpointing

Save checkpoint with weight merging: tune lora_finetune_single_device --config recipes/configs/alpaca_llama2_lora_finetune_single_device.yaml --override model_checkpoint=/home/rvarm1/local/dev/assets/llama2-bf16-latest seed=18 tokenizer_checkpoint=/home/rvarm1/local/dev/assets/tokenizer.model output_dir=/tmp/lora_debug device=cuda batch_size=2 enable_activation_checkpointing=True full_bf16=True quantize_base=True max_steps_per_epoch=1 &> out

Results in Model checkpoint of size 12852 MB saved to /tmp/lora_debug/model_0.ckpt.

Load that back in for continued training works: tune lora_finetune_single_device --config recipes/configs/alpaca_llama2_lora_finetune_single_device.yaml --override model_checkpoint=/home/rvarm1/local/dev/assets/llama2-bf16-latest seed=18 tokenizer_checkpoint=/home/rvarm1/local/dev/assets/tokenizer.model output_dir=/tmp/lora_debug device=cuda batch_size=2 enable_activation_checkpointing=True full_bf16=True quantize_base=True max_steps_per_epoch=1 model_checkpoint=/tmp/lora_debug/model_0.ckpt.

Memory analysis
  • Baseline LoRA that FT's in ~16GB: tune lora_finetune_single_device --config recipes/configs/alpaca_llama2_lora_finetune_single_device.yaml --override model_checkpoint=/home/rvarm1/local/dev/assets/llama2-bf16-latest seed=18 tokenizer_checkpoint=/home/rvarm1/local/dev/assets/tokenizer.model output_dir=/tmp/lora_debug device=cuda batch_size=2 enable_activation_checkpointing=True full_bf16=True
  • setting quantize_base in the config to True and same cmd, nvidia-smi reports ~9.8 GB peak RAM w/bf6 as our compute dtype. Will probably increase on T4's since fp32 needs to be our compute dtype. This is ~35% memory savings.
  • Update: memory numbers with newly added memory logging -
Memory Stats:, GPU peak memory allocation: 5.77175808 GB,
    GPU peak memory reserved: 9.294577664 GB,
    "GPU peak memory active: 5.77175808 GB"
Correctness comparison to LoRA -
  • Loss curves are on par (visually)
  • LoRA bf16
image - QLoRA - image

Eval results:

  • Trained a QLoRA model e2e: tune lora_finetune_single_device --config recipes/configs/alpaca_llama2_lora_finetune_single_device.yaml dtype=bf16 device=cuda with quantize_base=True. Ran eval with tune eval --config torchtune/_cli/eval_configs/default_eval_config.yaml model_checkpoint=/home/rvarm1/local/dev/assets/qlora_trained.pt, results: truthfulqa_mc2: {'acc,none': 0.491684249068877, 'acc_stderr,none': 0.014723538004649834, 'alias': 'truthfulqa_mc2'}
  • Followed similar process for LoRA without quantization, but ran LoRA distributed recipe to get results faster. tune --nproc_per_node 8 lora_finetune_distributed --config recipes/configs/alpaca_llama2_lora_finetune_distributed.yaml dtype=bf16 batch_size=32 enable_activation_checkpointing=True &> out_dist &. eval result: truthfulqa_mc2: {'acc,none': 0.46393356726865514, 'acc_stderr,none': 0.014448827209276354, 'alias': 'truthfulqa_mc2'}
  • These results are comparable to the lit-gpt QLoRA results https://lightning.ai/pages/community/lora-insights/#toc12 (see "LoRA Hyperparameter Tuning Part 3: Changing Alpha". highest result is for r=256, alpha=32 with 0.5055 on truthfulqa_mc2), with batch_size=128 and micro batch size = 1
  • Running with lit-gpt's best QLoRA config (r=256, a=32, effective batch size of 128 with micro batch size=1) - results in eval results of truthfulqa_mc2: {'acc,none': 0.48860782692137417, 'acc_stderr,none': 0.014737964298503407, 'alias': 'truthfulqa_mc2'}, compared to their 0.5055.

@rohan-varma rohan-varma marked this pull request as draft March 9, 2024 21:12
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 9, 2024
Copy link

netlify bot commented Mar 9, 2024

Deploy Preview for torchtune-preview ready!

Name Link
🔨 Latest commit f8cf66d
🔍 Latest deploy log https://app.netlify.com/sites/torchtune-preview/deploys/65fddb9acafda000088e1d9f
😎 Deploy Preview https://deploy-preview-478--torchtune-preview.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

hook_name: str


def _register_lora_hooks(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice general APIs. If we scrap the LoRA merge state dict hooks do we still intend to keep the design here the same? (No strong preference from my end, mainly just curious)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, just scrapping these for now per our discussion and we can re-introduce if needed.

torchtune/modules/peft/lora.py Show resolved Hide resolved
Copy link

pytorch-bot bot commented Mar 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/478

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f8cf66d with merge base ba7289b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@rohan-varma rohan-varma changed the title [WIP] QLoRA QLoRA Mar 14, 2024
@rohan-varma rohan-varma marked this pull request as ready for review March 14, 2024 14:59
lora_attn_modules: ['q_proj', 'v_proj']
apply_lora_to_mlp: False
apply_lora_to_output: False
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a conscious change to lora single device?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for testing. For QLoRA we may want to offer a different config entirely

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a bit more detail to explain why this is changed:

We currently only support NF4 linears inside of our LoRALinears. From @rohan-varma and my research there is some ambiguity around whether QLoRA actually applies the quantization to all linear layers or just to the ones wrapped in LoRA.

The conclusion we reached was that the original QLoRA repo only applies NF4 quantization to base model weights inside of LoRA layers (so any other linear layers in the model are not quantized, minimal script to repro this is here). But for lit-gpt they actually replace every linear layer (ref from deep inside Fabric).

Since we are using lit-gpt as our comparison, we ideally want to set all linear layers to use NF4 to match their scheme. In our case that also means setting them all to use LoRA.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to what @ebsmothers. Due to this, it might make sense to offer a QLoRA specific config cc @kartikayk @ebsmothers . LMK your thoughts

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to this, it might make sense to offer a QLoRA specific config

This makes sense

Sorry, I didn't follow the comments above. So we're saying that the original paper does NOT quantize any base model weights other than within the LoRA layer? Is there a reason for this? This is a bit surprising to me since it's different from my understanding. But I'm also not fully ramped up on QLORA. Anyways, seems like we're applying it to ALL linear layers but then do this by replacing the layer with LoRA layer? That seems a bit counterintuitive to me? Why does quanitization need to be tied to the LoRA layer . I would imagine we can just do a loop through all of the modules and then do a to_nf4 for the linear layers (caveat - I haven't looked at the code below)?

Comment on lines 17 to 19
_component_: torchtune.utils.FullModelTorchTuneCheckpointer
checkpoint_dir: /home/rvarm1/local/dev/assets/
checkpoint_files: [llama2-bf16-latest.pt]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend just testing this as a regular user since the consolidated.00.pth file is bf16. The MetaCheckpointer should work OOTB for you

torchtune/models/llama2/_lora_llama2_builders.py Outdated Show resolved Hide resolved
"FrozenNF4Linear",
]

def reparametrize_as_bf16_state_dict_post_hook(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this in __init__.py? Can we move this to a utils file or something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep - will take care of all code organization, lints, tests etc

Comment on lines 23 to 28
A state_dict hook that replaces nf4 tensors with their restored
bf16 weight and optionally offloads the restored weight to CPU.

This function is meant to be used with PyTorch's ``nn.Module._register_state_dict_hook``, i.e.
>>> m = MyModule()
>>> m._register_state_dict_hook(reparametrize_as_bf16_state_dict_post_hook)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these hooks are that well understood. When I google these, the first link is to pytorch forums which dissuades their use calling them private APIs. I think this has changed since then. But nontheless I don't think a lot of people udnerstand how these work. This doc string should include very detailed information of when this hook gets called, what impact it has on the state dict etc. Especially since this is a public API available to users.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some docs. But actually, we could maybe even make this private to users and just always register it where appropriate.

@@ -34,6 +34,7 @@ def lora_llama2_7b(
lora_rank: int = 8,
lora_alpha: float = 16,
max_batch_size: Optional[int] = None,
quantize_base: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we also wanna add a top-level qlora builder (basically just a partial with quantize_base=True)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, shall we add this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My vote would be for yes. I don't see it in the latest version (unless I missed it), can we add it? Also thinking about how best to expose LoRA in our configs (see my other comment), do you think it makes sense to expose e.g. qlora_llama2_7b there? Tbh I'm a bit torn cause I think it somewhat obscures how easy it is to switch LoRA <-> QLoRA with just the quantize_base flag. Either way, not a blocker for this PR.

Comment on lines 321 to 322
# TODO (rohan-varma): not quantizing output_proj as NF4Tensor quantization of large weight leads to large increase
# in reserved memory.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give more details on this? I didn't fully understand what was happening here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Essentially, the current implementation in torchAO allocates intermediates when doing the quantization process. For very large tensors, the intermediates actually result in a large memory spike, so we're working around this for now by just not quantizing this layer.

@@ -96,6 +96,7 @@ def __init__(self, cfg: DictConfig) -> None:
self.total_training_steps = 0

self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! Now only one recipe remains.. (though it's the one that actually needs this feature the least)

Comment on lines +16 to +18
*args: Tuple[Any, ...],
offload_to_cpu: bool = True,
**kwargs: Dict[Any, Any],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb q but any particular reason offload_to_cpu is between *args and **kwargs? I'm not used to seeing this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I'm not sure of the best way to design this sort of API. This API is meant to be a function that's passed into _register_state_dict_hook, and the function is invoked with some arguments that we don't use when taking state_dict in nn.Module. Since these are unused arguments but still provided to the function, I need to take them as unused with *args and **kwargs. Though, I want to offer offload_to_cpu as a keyword arg with a default, so I can't put it before *args. So this is what I was able to come up with, not sure if there's somethign better.


from torchao.dtypes.nf4tensor import linear_nf4
from torchtune.modules.low_precision import ( # noqa: F401
_register_nf4_dispatch_ops,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to understand how this works: so basically we import _register_nf4_dispatch_ops at the first import of lora.py, which will then ensure any NF4Tensor used henceforth in the same run will have all those overrides in the ops table? Assuming that is the correct understanding, is there a way to do this with less of a hammer? (I don't think it matters for this particular PR, just curious if we try to extend things down the line)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, when this is imported these ops get registered on the NF4Tensor.

is there a way to do this with less of a hammer

Hmm, not sure what would be less of a hammer - I thought this was already pretty light in that we just handle importing something and the registration automatically happens, we don't have to call any specific register_ops API or anything like that. Did you have any example of what you'd like this to be?

nf4_tensor = args[0][0]
sub_tensor = args[0][1]
assert sub_tensor.dtype == torch.bfloat16
return to_nf4(nf4_tensor.get_original_weight().sub_(sub_tensor))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you had a comment yesterday about whether we were doing this in the best way, but I didn't fully understand it. What was that about?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I don't think we need add and sub anymore, at least for now, since when we're merging the lora_state_dict, the state_dict is already in full bf16.



@nf4_tensor_impl([torch.ops.aten.add_.Tensor])
def add_bf16_tensor(func, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will apply for both weight merge and adding bias, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not needed anymore - see other comment


# Modules from CausalSelfAttention that LoRA can be applied to
LORA_ATTN_MODULES = Literal["q_proj", "k_proj", "v_proj", "output_proj"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove, you're already importing this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, might'be been bad merge

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really impressive PR. I left a bunch of comments but no huge concerns from my side. As always, make sure to clean up any debug code before landing. Otherwise looks good!

@rohan-varma rohan-varma merged commit 412a4ec into main Mar 22, 2024
22 checks passed
@joecummings joecummings deleted the qlora branch April 11, 2024 15:40
@rohan-varma rohan-varma mentioned this pull request Apr 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants