-
Notifications
You must be signed in to change notification settings - Fork 834
/
Copy pathmerge_lora.py
75 lines (63 loc) · 1.96 KB
/
merge_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved.
"""
Merge base model and lora model into a full model.
"""
import sys
import os
sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0])))
from dataclasses import dataclass, field
from transformers import HfArgumentParser
from typing import Optional
from lmflow.args import (
ModelArguments,
AutoArguments,
)
from lmflow.models.auto_model import AutoModel
@dataclass
class MergeLoraArguments:
device: str = field(
default='cpu',
metadata={
"help": "device to merge model on",
},
)
ds_config: str = field(
default='configs/ds_config_eval.json',
metadata={
"help": "deepspeed config file path",
},
)
output_model_path: Optional[str] = field(
default=None,
metadata={
"help": "output merged full model path"
},
)
local_rank: Optional[int] = field(
default=-1,
metadata={
"help": "local rank for deepspeed",
},
)
def main():
parser = HfArgumentParser((ModelArguments, MergeLoraArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, merge_lora_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, merge_lora_args = parser.parse_args_into_dataclasses()
if merge_lora_args.device == 'gpu':
raise NotImplementedError('Merging LoRA weight using GPU not supported yet. Please use cpu.')
model_args.use_lora = True
model = AutoModel.get_model(
model_args,
tune_strategy='none',
device=merge_lora_args.device,
ds_config=merge_lora_args.ds_config
)
model.activate_model_for_inference()
model.merge_lora_weights()
model.save(merge_lora_args.output_model_path, save_full_model=True)
if __name__ == '__main__':
main()