/
run_sft.py
152 lines (131 loc) · 4.22 KB
/
run_sft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from dataclasses import dataclass, field
import os
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import (
AutoTokenizer,
HfArgumentParser,
TrainingArguments,
BitsAndBytesConfig,
AutoModelForCausalLM,
)
from peft import LoraConfig
from trl import (
setup_chat_format,
SFTTrainer,
)
tqdm.pandas()
@dataclass
class ScriptArguments:
dataset_path: str = field(
default=None,
metadata={
"help": "Path to the dataset, should be /opt/ml/input/data/train_dataset.json"
},
)
model_id: str = field(
default=None, metadata={"help": "Model ID to use for SFT training"}
)
max_seq_length: int = field(
default=512, metadata={"help": "The maximum sequence length for SFT Trainer"}
)
use_qlora: bool = field(default=False, metadata={"help": "Whether to use QLORA"})
merge_adapters: bool = field(
metadata={"help": "Wether to merge weights for LoRA."},
default=False,
)
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, TrainingArguments))
script_args, training_args = parser.parse_args_into_dataclasses()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
################
# Dataset
################
dataset = load_dataset(
"json",
data_files=script_args.dataset_path,
split="train",
)
################
# Model & Tokenizer
################
torch_dtype = torch.bfloat16 if training_args.bf16 else torch.float32
if script_args.use_qlora:
print("Using QLoRA")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
)
else:
quantization_config = None
model = AutoModelForCausalLM.from_pretrained(
script_args.model_id,
device_map="auto",
attn_implementation="flash_attention_2",
torch_dtype=torch_dtype,
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_id, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
################
# PEFT
################
# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(
lora_alpha=8,
lora_dropout=0.05,
r=16,
bias="none",
target_modules="all-linear",
task_type="CAUSAL_LM",
)
################
# Training
################
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
peft_config=peft_config,
max_seq_length=script_args.max_seq_length,
tokenizer=tokenizer,
packing=True,
dataset_kwargs={
"add_special_tokens": False, # We template with special tokens
"append_concat_token": False, # No need to add additional separator token
},
)
trainer.train()
##########################
# SAVE MODEL FOR SAGEMAKER
##########################
sagemaker_save_dir = "/opt/ml/model"
trainer.tokenizer.save_pretrained(sagemaker_save_dir)
if script_args.merge_adapters:
# merge adapter weights with base model and save
# save int 4 model
trainer.model.save_pretrained(training_args.output_dir)
trainer.tokenizer.save_pretrained(training_args.output_dir)
# clear memory
del model
del trainer
torch.cuda.empty_cache()
from peft import AutoPeftModelForCausalLM
# list file in output_dir
print(os.listdir(training_args.output_dir))
# load PEFT model in fp16
model = AutoPeftModelForCausalLM.from_pretrained(
training_args.output_dir,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
)
# Merge LoRA and base model and save
model = model.merge_and_unload()
model.save_pretrained(
sagemaker_save_dir, safe_serialization=True, max_shard_size="2GB"
)
else:
trainer.model.save_pretrained(sagemaker_save_dir, safe_serialization=True)