Skip to content

Commit

Permalink
pass the chat_template value correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Jun 12, 2024
1 parent 063d828 commit 872152d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
6 changes: 4 additions & 2 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,11 @@ def dataset_uri_to_axolotl_datasources(
datasources = []
if os.path.isdir(uri):
for filepath in find_all_jsonl_files(uri):
datasources.append(_make_dataset_file_source(path=filepath, dataset_type=dataset_type))
datasources.append(
_make_dataset_file_source(path=filepath, dataset_type=dataset_type, chat_template=chat_template)
)
else:
datasources = [_make_dataset_file_source(path=uri, dataset_type=dataset_type)]
datasources = [_make_dataset_file_source(path=uri, dataset_type=dataset_type, chat_template=chat_template)]
return datasources
else:
raise ValueError("Unsupported data uri or path does not exist: {uri}")
Expand Down
9 changes: 5 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,10 @@ def make_axolotl_config(config_base, kwargs, timestamp=None):
set_cfg_option_if_auto(cfg, "unsloth_lora_qkv", use_unsloth)
set_cfg_option_if_auto(cfg, "unsloth_lora_o", use_unsloth)

model_type = getattr(model_hf_config, "model_type", None)
chat_template = MODEL_TYPE_TO_CHAT_TEMPLATE.get(model_type, "chatml")
set_cfg_option_if_auto(cfg, "chat_template", chat_template)
if cfg.chat_template == "auto":
model_type = getattr(model_hf_config, "model_type", None)
chat_template = MODEL_TYPE_TO_CHAT_TEMPLATE.get(model_type, "chatml")
set_cfg_option_if_auto(cfg, "chat_template", chat_template)

if cfg.datasets == "auto":
if not cfg.train_data_uri:
Expand All @@ -187,7 +188,7 @@ def make_axolotl_config(config_base, kwargs, timestamp=None):
uri=cfg.train_data_uri,
download_dir=cfg.data_dir,
dataset_type=cfg.dataset_type,
chat_template=chat_template,
chat_template=cfg.chat_template,
)
if cfg.test_datasets == "auto":
if cfg.val_data_uri and str(cfg.val_data_uri).lower() != "na":
Expand Down

0 comments on commit 872152d

Please sign in to comment.