Skip to content

Commit

Permalink
update peft config.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Apr 21, 2023
1 parent 20244f3 commit 633e376
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 17 deletions.
2 changes: 1 addition & 1 deletion textgen/chatglm/chatglm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def train_model(
if os.path.exists(checkpoint_name):
logger.info(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
self.model = set_peft_model_state_dict(self.model, adapters_weights)
set_peft_model_state_dict(self.model, adapters_weights)
else:
logger.warning(f"Checkpoint {checkpoint_name} not found")

Expand Down
4 changes: 1 addition & 3 deletions textgen/config/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,6 @@ class ChatGlmArgs(ModelArgs):
lora_dropout = 0.05
lora_target_modules = ["query_key_value"]
lora_bias = "none"
only_lora_state_dict: bool = False
num_train_epochs = 1
max_steps = -1
per_device_train_batch_size = 2
Expand Down Expand Up @@ -425,11 +424,10 @@ class LlamaArgs(ModelArgs):
use_lora: bool = True
lora_bin_name: str = field(default="adapter_model.bin")
lora_r: int = 8
lora_alpha = 16
lora_alpha = 32
lora_dropout = 0.05
lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
lora_bias = "none"
only_lora_state_dict: bool = True
num_train_epochs = 3
max_steps = -1
per_device_train_batch_size = 2
Expand Down
14 changes: 1 addition & 13 deletions textgen/llama/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,8 @@ def __init__(
if self.args.use_lora:
self.load_lora()

# unwind broken decapoda-research config
self.tokenizer.padding_side = "left"
self.tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
self.model.config.pad_token_id = 0 # unk
self.model.config.bos_token_id = 1
self.model.config.eos_token_id = 2

def train_model(
self,
Expand Down Expand Up @@ -228,7 +224,7 @@ def train_model(
if os.path.exists(checkpoint_name):
logger.info(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
self.model = set_peft_model_state_dict(self.model, adapters_weights)
set_peft_model_state_dict(self.model, adapters_weights)
else:
logger.warning(f"Checkpoint {checkpoint_name} not found")

Expand Down Expand Up @@ -294,14 +290,6 @@ def train_model(
data_collator=data_collator,
)

if self.args.only_lora_state_dict:
old_state_dict = self.model.state_dict
self.model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(
self, old_state_dict()
)
).__get__(self.model, type(self.model))

if self.args.enable_torch_compile:
if torch.__version__ >= "2" and sys.platform != "win32":
self.model = torch.compile(self.model)
Expand Down

0 comments on commit 633e376

Please sign in to comment.