Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi GPU can not save adapter_model.bin #483

Open
ricksun2023 opened this issue May 27, 2023 · 6 comments
Open

Multi GPU can not save adapter_model.bin #483

ricksun2023 opened this issue May 27, 2023 · 6 comments

Comments

@ricksun2023
Copy link

Hi, I'm trying to run finetune.py by 6 GPUs:
WORLD_SIZE=6 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 torchrun --nproc_per_node=6 --master_port=1234 finetune.py
--base_model='./llama-7b-hf'
--num_epochs=3
--cutoff_len=512
--group_by_length
--lora_target_modules='[q_proj,k_proj,v_proj,o_proj]'
--lora_r=16
--micro_batch_size=64
--batch_size=384

And I commented L263~L269 in the finetune.py, based on: #446 (comment)

And I got the following issues:
{'loss': 0.8493, 'learning_rate': 1.8384401114206126e-05, 'epoch': 2.88}
{'loss': 0.8812, 'learning_rate': 1.0027855153203342e-05, 'epoch': 2.94}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 459/459 [2:04:05<00:00, 12.51s/it]
The intermediate checkpoints of PEFT may not be saved correctly, using TrainerCallback to save adapter_model.bin in corresponding folders, here are some examples huggingface/peft#96
The intermediate checkpoints of PEFT may not be saved correctly, using TrainerCallback to save adapter_model.bin in corresponding folders, here are some examples huggingface/peft#96
The intermediate checkpoints of PEFT may not be saved correctly, using TrainerCallback to save adapter_model.bin in corresponding folders, here are some examples huggingface/peft#96
The intermediate checkpoints of PEFT may not be saved correctly, using TrainerCallback to save adapter_model.bin in corresponding folders, here are some examples huggingface/peft#96
The intermediate checkpoints of PEFT may not be saved correctly, using TrainerCallback to save adapter_model.bin in corresponding folders, here are some examples huggingface/peft#96
The intermediate checkpoints of PEFT may not be saved correctly, using TrainerCallback to save adapter_model.bin in corresponding folders, here are some examples huggingface/peft#96
Traceback (most recent call last):
File "/home/rick/llm/llm-training-data/github/alpaca-lora/finetune.py", line 283, in
Traceback (most recent call last):
File "/home/rick/llm/llm-training-data/github/alpaca-lora/finetune.py", line 283, in
Traceback (most recent call last):
File "/home/rick/llm/llm-training-data/github/alpaca-lora/finetune.py", line 283, in
Traceback (most recent call last):
File "/home/rick/llm/llm-training-data/github/alpaca-lora/finetune.py", line 283, in
Traceback (most recent call last):
File "/home/rick/llm/llm-training-data/github/alpaca-lora/finetune.py", line 283, in
fire.Fire(train)
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
fire.Fire(train)
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
fire.Fire(train)
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
fire.Fire(train) File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire

File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
fire.Fire(train)
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
component, remaining_args = _CallAndUpdateTrace(
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
component, remaining_args = _CallAndUpdateTrace(
component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace

component, remaining_args = _CallAndUpdateTrace(  File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire

File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire
component = fn(*varargs, **kwargs)component, remaining_args = _CallAndUpdateTrace(

File "/home/rick/llm/llm-training-data/github/alpaca-lora/finetune.py", line 273, in train
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)component = fn(*varargs, **kwargs)

File "/home/rick/llm/llm-training-data/github/alpaca-lora/finetune.py", line 273, in train
File "/home/rick/llm/llm-training-data/github/alpaca-lora/finetune.py", line 273, in train
component, remaining_args = _CallAndUpdateTrace(
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
trainer.train(resume_from_checkpoint=resume_from_checkpoint)trainer.train(resume_from_checkpoint=resume_from_checkpoint)

  File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/transformers/trainer.py", line 1696, in train

File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/transformers/trainer.py", line 1696, in train
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/transformers/trainer.py", line 1696, in train
component = fn(*varargs, **kwargs)
File "/home/rick/llm/llm-training-data/github/alpaca-lora/finetune.py", line 273, in train
component = fn(*varargs, **kwargs)
File "/home/rick/llm/llm-training-data/github/alpaca-lora/finetune.py", line 273, in train
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/transformers/trainer.py", line 1696, in train
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/transformers/trainer.py", line 1696, in train
return inner_training_loop(
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/transformers/trainer.py", line 2094, in _inner_training_loop
return inner_training_loop(
return inner_training_loop( File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/transformers/trainer.py", line 2094, in _inner_training_loop

File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/transformers/trainer.py", line 2094, in _inner_training_loop
return inner_training_loop(
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/transformers/trainer.py", line 2094, in _inner_training_loop
return inner_training_loop(
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/transformers/trainer.py", line 2094, in _inner_training_loop
self._load_best_model()
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/transformers/trainer.py", line 2291, in _load_best_model
self._load_best_model()
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/transformers/trainer.py", line 2291, in _load_best_model
self._load_best_model()
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/transformers/trainer.py", line 2291, in _load_best_model
Traceback (most recent call last):
self._load_best_model()
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/transformers/trainer.py", line 2291, in _load_best_model
self._load_best_model()
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/transformers/trainer.py", line 2291, in _load_best_model
File "/home/rick/llm/llm-training-data/github/alpaca-lora/finetune.py", line 283, in
fire.Fire(train)
self._issue_warnings_after_load(load_result)
UnboundLocalError: local variable 'load_result' referenced before assignment
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
self._issue_warnings_after_load(load_result)
self._issue_warnings_after_load(load_result)
UnboundLocalError: local variable 'load_result' referenced before assignment
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire
component, remaining_args = _CallAndUpdateTrace(
UnboundLocalError: local variable 'load_result' referenced before assignment
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/home/rick/llm/llm-training-data/github/alpaca-lora/finetune.py", line 273, in train
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
self._issue_warnings_after_load(load_result)
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/transformers/trainer.py", line 1696, in train
return inner_training_loop(
UnboundLocalError: local variable 'load_result' referenced before assignment
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/transformers/trainer.py", line 2094, in _inner_training_loop
self._load_best_model()
File "/home/rick/anaconda3/envs/alpaca-lora/lib/python3.10/site-packages/transformers/trainer.py", line 2291, in _load_best_model
self._issue_warnings_after_load(load_result)
self._issue_warnings_after_load(load_result)
UnboundLocalError: local variable 'load_result' referenced before assignment
UnboundLocalError: local variable 'load_result' referenced before assignment

@ricksun2023
Copy link
Author

  • transformers 4.30.0.dev0
  • peft 0.4.0.dev0
  • bitsandbytes 0.39.0

@maekawataiki
Copy link

This is due to change in transformers. huggingface/transformers@357f281

You may downgrade to transformers <= 4.29.2 or follow steps in huggingface/peft#96. (if you are trying 4-bit QLoRA then you should follow the step below since it is not in 4.29.2 huggingface/transformers@9d73b92)

I've added following snippet to finetune.py and it worked.

from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.trainer_callback import TrainerCallback

class SavePeftModelCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        checkpoint_folder = os.path.join(
            args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
        )       

        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)

        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)
        return control

....

    trainer = transformers.Trainer(
        ...
        callbacks=[SavePeftModelCallback]
    )

With this code, the checkpoint directory look like below. It seems the implementation in main branch of transformers only accepts adapter_model.bin but not pytorch_model.bin for LoRA at this moment (2023/05/28).

checkpoint-400/
checkpoint-400/adapter_model/
checkpoint-400/adapter_model/adapter_model.bin
checkpoint-400/adapter_model/adapter_config.json
checkpoint-400/trainer_state.json
checkpoint-400/scaler.pt
checkpoint-400/scheduler.pt
checkpoint-400/optimizer.pt
checkpoint-400/rng_state.pth
checkpoint-400/training_args.bin

@KKcorps
Copy link

KKcorps commented May 28, 2023

With this I always fine adapter_model.bin to be 443 bytes

@ricksun2023
Copy link
Author

ricksun2023 commented May 29, 2023

This is due to change in transformers. huggingface/transformers@357f281

You may downgrade to transformers <= 4.29.2 or follow steps in huggingface/peft#96. (if you are trying 4-bit QLoRA then you should follow the step below since it is not in 4.29.2 huggingface/transformers@9d73b92)

I've added following snippet to finetune.py and it worked.

from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.trainer_callback import TrainerCallback

class SavePeftModelCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        checkpoint_folder = os.path.join(
            args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
        )       

        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)

        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)
        return control

....

    trainer = transformers.Trainer(
        ...
        callbacks=[SavePeftModelCallback]
    )

With this code, the checkpoint directory look like below. It seems the implementation in main branch of transformers only accepts adapter_model.bin but not pytorch_model.bin for LoRA at this moment (2023/05/28).

checkpoint-400/
checkpoint-400/adapter_model/
checkpoint-400/adapter_model/adapter_model.bin
checkpoint-400/adapter_model/adapter_config.json
checkpoint-400/trainer_state.json
checkpoint-400/scaler.pt
checkpoint-400/scheduler.pt
checkpoint-400/optimizer.pt
checkpoint-400/rng_state.pth
checkpoint-400/training_args.bin

@maekawataiki Thank you very much for your help. It works for me now!

From the main branch of https://github.com/tloen/alpaca-lora, I did the following changes:

  1. Downgraded to transformers == 4.29.2
  2. Added a revised snippet based on @maekawataiki version to finetune.py
  3. Commented L263~L269 in finetune.py
  4. Executed command: WORLD_SIZE=6 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 torchrun --nproc_per_node=6 --master_port=1234 finetune.py --other-parameters

Finally I got the files, and they works well.
-rw-rw-r-- 1 rick rick 393 May 29 03:01 adapter_config.json
-rw-rw-r-- 1 rick rick 67201357 May 29 03:01 adapter_model.bin
drwxrwxr-x 3 rick rick 4096 May 29 02:26 checkpoint-200/
drwxrwxr-x 3 rick rick 4096 May 29 02:43 checkpoint-400/
drwxrwxr-x 3 rick rick 4096 May 29 03:00 checkpoint-600/

BTW, to fix the issue "FileNotFoundError: [Errno 2] No such file or directory: 'xxx/checkpoint-200/pytorch_model.bin" in mulit GPUs runtime, I revised the snippet a little bit:

    try:
        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)
    except Exception as e:
        print(f"remove {pytorch_model_path} failed.")

Thanks

@diaojunxian
Copy link

diaojunxian commented May 29, 2023

With this I always fine adapter_model.bin to be 443 bytes

Me too, did you solve it? @KKcorps

@maekawataiki
Copy link

443 bytes adapter_model.bin is addressed in #446 and #334
It seems weights saved through model.save_pretrained() is not saved correctly probably since series of commits around huggingface/peft@c21afbe in PEFT. (Weights are saved correctly with huggingface/peft@644d68e)
Interestingly, weights saved through default checkpoints (pytorch_model.bin) was right LoRA weight.
As mentioned by @ricksun2023 , commenting out following line will solve the problem.

    old_state_dict = model.state_dict
    model.state_dict = (
        lambda self, *_, **__: get_peft_model_state_dict(
            self, old_state_dict()
        )
    ).__get__(model, type(model))

I got following result

drwxr-xr-x 3 root root    6144 May 29 17:10 checkpoint-1400
drwxr-xr-x 3 root root    6144 May 29 16:54 checkpoint-1200
drwxr-xr-x 3 root root    6144 May 29 16:38 checkpoint-1000
-rw-r--r-- 1 root root 8412081 May 29 17:18 adapter_model.bin
-rw-r--r-- 1 root root     339 May 29 17:18 adapter_config.json

Inside checkpoint checkpoint-1000/adapter_model

-rw-r--r-- 1 root root     339 May 29 16:38 adapter_config.json
-rw-r--r-- 1 root root 8412081 May 29 16:38 adapter_model.bin

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants