diff --git a/Dockerfile b/Dockerfile index 5768865..af73a32 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ -# https://hub.docker.com/layers/winglian/axolotl/main-20240423-py3.11-cu121-2.2.1/images/sha256-fc2b9d2b1e46d6b7c47c28a65d2c1d2c3ae4f032fafef27ffaf6ec63bf442f44?context=explore -FROM --platform=linux/amd64 winglian/axolotl@sha256:e0b5b8a94934aaf183932c66ab3ce3ad822e91e19341ade8dbf9eccd9339d799 +# https://hub.docker.com/layers/winglian/axolotl/main-20240603-py3.11-cu121-2.3.0/images/sha256-e4b898a0f700eb86f9e802bb85c1ec6c509b2dec65d941ad43405fe323865017?context=explore +FROM --platform=linux/amd64 winglian/axolotl@sha256:a66d1469cdad472779f6419ea67d0fbb2cce984244aa86f40c99abaa4a21b3db USER root COPY requirements.txt /tmp/ RUN pip install -U pip wheel setuptools && \ @@ -9,7 +9,7 @@ RUN mkdir -p /packages && \ cd /packages && \ git clone https://github.com/truefoundry/axolotl && \ cd axolotl/ && \ - git checkout 4e8264e937571c53b9dc75345a14d4b9b9d68c4f + git checkout dffcb7adfb42dd3305fcabb0de106d5e2454315e RUN cd /packages/axolotl/ && \ MAX_JOBS=1 NVCC_APPEND_FLAGS="--threads 1" pip install -U --no-build-isolation -e .[flash-attn,mamba-ssm,fused-dense-lib] && \ pip install --no-cache-dir -U -r /tmp/requirements.txt && \ diff --git a/Dockerfile-notebook b/Dockerfile-notebook index a5c18db..f95982f 100644 --- a/Dockerfile-notebook +++ b/Dockerfile-notebook @@ -1,4 +1,4 @@ -FROM truefoundrycloud/jupyter:0.2.17-sudo +FROM truefoundrycloud/jupyter:0.2.19-sudo ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" ENV DEBIAN_FRONTEND=noninteractive USER root @@ -21,7 +21,7 @@ USER jovyan RUN cd /packages && \ git clone https://github.com/truefoundry/axolotl && \ cd axolotl/ && \ - git checkout 4e8264e937571c53b9dc75345a14d4b9b9d68c4f + git checkout dffcb7adfb42dd3305fcabb0de106d5e2454315e RUN cd /packages/axolotl/ && \ MAX_JOBS=1 NVCC_APPEND_FLAGS="--threads 1" pip install -U --no-build-isolation -e .[flash-attn,mamba-ssm,fused-dense-lib] && \ pip install --no-cache-dir -U -r /tmp/llm-finetune/notebook-requirements.txt diff --git a/README.md b/README.md index bfa89c3..8c197ef 100644 --- a/README.md +++ b/README.md @@ -1,301 +1,22 @@ -Axolotl config options - -
- Click to expand all axolotl options - -Just dumping here, because some options are not documented - -``` -cfg.adam_beta1 -cfg.adam_beta2 -cfg.adam_epsilon -cfg.adapter -cfg.auto_resume_from_checkpoints -cfg.axolotl_config_path -cfg.base_model -cfg.base_model_config -cfg.batch_size -cfg.bench_dataset -cfg.bf16 -cfg.bfloat16 -cfg.bnb_config_kwargs -cfg.chat_template -cfg.conversation -cfg.cosine_min_lr_ratio -cfg.dataloader_drop_last -cfg.dataloader_num_workers -cfg.dataloader_pin_memory -cfg.dataloader_prefetch_factor -cfg.dataset_keep_in_memory -cfg.dataset_prepared_path -cfg.dataset_processes -cfg.dataset_shard_idx -cfg.dataset_shard_num -cfg.datasets -cfg.ddp -cfg.ddp_broadcast_buffers -cfg.ddp_bucket_cap_mb -cfg.ddp_timeout -cfg.debug -cfg.deepspeed -cfg.default_system_message -cfg.device -cfg.device_map -cfg.do_bench_eval -cfg.dpo_beta -cfg.dpo_label_smoothing -cfg.eager_attention -cfg.early_stopping_patience -cfg.eval_batch_size -cfg.eval_sample_packing -cfg.eval_steps -cfg.eval_table_max_new_tokens -cfg.eval_table_size -cfg.evals_per_epoch -cfg.evaluation_strategy -cfg.field_input -cfg.field_instruction -cfg.field_output -cfg.field_system -cfg.flash_attention -cfg.flash_attn_cross_entropy -cfg.flash_attn_fuse_mlp -cfg.flash_attn_fuse_qkv -cfg.flash_attn_rms_norm -cfg.flash_optimum -cfg.float16 -cfg.format -cfg.fp16 -cfg.fsdp -cfg.fsdp_config -cfg.gptq -cfg.gptq_disable_exllama -cfg.gpu_memory_limit -cfg.gradient_accumulation_steps -cfg.gradient_checkpointing -cfg.gradient_checkpointing_kwargs -cfg.greater_is_better -cfg.group_by_length -cfg.hf_use_auth_token -cfg.hub_model_id -cfg.hub_strategy -cfg.is_falcon_derived_model -cfg.is_file -cfg.is_llama_derived_model -cfg.is_mistral_derived_model -cfg.is_preprocess -cfg.is_qwen_derived_model -cfg.learning_rate -cfg.load_best_model_at_end -cfg.load_in_4bit -cfg.load_in_8bit -cfg.local_rank -cfg.logging_steps -cfg.lora_alpha -cfg.lora_dropout -cfg.lora_fan_in_fan_out -cfg.lora_model_dir -cfg.lora_modules_to_save -cfg.lora_on_cpu -cfg.lora_r -cfg.lora_target_linear -cfg.lora_target_modules -cfg.loss_watchdog_patience -cfg.loss_watchdog_threshold -cfg.lr_quadratic_warmup -cfg.lr_scheduler -cfg.lr_scheduler_kwargs -cfg.max_grad_norm -cfg.max_memory -cfg.max_packed_sequence_len -cfg.max_steps -cfg.merge_lora -cfg.metric_for_best_model -cfg.micro_batch_size -cfg.mlflow_experiment_name -cfg.model_config -cfg.model_config_type -cfg.model_kwargs -cfg.model_revision -cfg.model_type -cfg.neftune_noise_alpha -cfg.no_input_format -cfg.noisy_embedding_alpha -cfg.num_epochs -cfg.optimizer -cfg.output_dir -cfg.pad_to_sequence_len -cfg.path -cfg.peft -cfg.peft_adapter -cfg.peft_layers_to_transform -cfg.precompute_ref_log_probs -cfg.pretraining_dataset -cfg.push_dataset_to_hub -cfg.push_to_hub_model_id -cfg.read_text -cfg.relora_cpu_offload -cfg.relora_steps -cfg.relora_warmup_steps -cfg.remove_unused_columns -cfg.resize_token_embeddings_to_32x -cfg.resume_from_checkpoint -cfg.rl -cfg.rl_adapter_ref_model -cfg.rope_scaling -cfg.s2_attention -cfg.sample_packing -cfg.sample_packing_eff_est -cfg.save_safetensors -cfg.save_steps -cfg.save_strategy -cfg.save_total_limit -cfg.saves_per_epoch -cfg.sdp_attention -cfg.seed -cfg.sequence_len -cfg.special_tokens -cfg.strict -cfg.system_format -cfg.system_prompt -cfg.test_datasets -cfg.tf32 -cfg.tokenizer_config -cfg.tokenizer_legacy -cfg.tokenizer_type -cfg.tokenizer_use_fast -cfg.tokens -cfg.torch_compile -cfg.torch_compile_backend -cfg.torch_dtype -cfg.torchdistx_path -cfg.total_num_tokens -cfg.total_supervised_tokens -cfg.train_on_inputs -cfg.trust_remote_code -cfg.type -cfg.unfrozen_parameters -cfg.use_mlflow -cfg.use_wandb -cfg.val_set_size -cfg.wandb_name -cfg.wandb_project -cfg.wandb_run_id -cfg.warmup_ratio -cfg.warmup_steps -cfg.weight_decay -cfg.world_size -cfg.xformers_attention -cfg.zero_optimization -``` -
+> [!important] +> Please prefer using commits from [release tags](https://github.com/truefoundry/llm-finetune/releases). `main` branch is work in progress and may have partially working commits. ## LLM Finetuning with Truefoundry + Test QLoRA w/ Deepspeed Stage 2 ``` -#!/bin/bash - -# export CUDA_LAUNCH_BLOCKING=1 -# export NCCL_DEBUG=INFO -# export TORCH_PER_PROCESS_MEMORY_LIMIT=22000 -export CUDA_VISIBLE_DEVICES=0 -export DISABLE_MLFLOW_INTEGRATION=True - -TRAIN_BATCH_SIZE=1 -GRADIENT_ACCUMULATION_STEPS=4 -LORA_R=32 -LORA_ALPHA=64 -TORCH_PER_PROCESS_MEMORY_LIMIT=0.95 -CUDA_VISIBLE_DEVICES=0,1 -TRAIN_DATA="./data/standford_alpaca_train_49k.jsonl" -# TRAIN_DATA="./data/lima_llama2_1k.jsonl" -MAX_STEPS=10 -# MODEL_ID=TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T -# MODEL_ID=cognitivecomputations/Wizard-Vicuna-30B-Uncensored -# MODEL_ID=EleutherAI/pythia-70m -MODEL_ID=NousResearch/Llama-2-7b-chat-hf -# MODEL_ID=NousResearch/Llama-2-13b-chat-hf -# MODEL_ID=mistralai/Mistral-7B-Instruct-v0.2 -# MODEL_ID=NousResearch/Llama-2-70b-chat-hf -# MODEL_ID=mistralai/Mixtral-8x7B-Instruct-v0.1 -# MODEL_ID=stas/tiny-random-llama-2 -# MODEL_ID=microsoft/phi-1_5 -# MODEL_ID=microsoft/phi-2 -# MODEL_ID=Deci/DeciLM-7B -USE_FLASH_ATTENTION=True -GRADIENT_CHECKPOINTING=True -NUM_TRAIN_EPOCHS=3 - - -# --deepspeed ./deepspeed_configs/3_ds_z2_config.json \ -# --deepspeed ./deepspeed_configs/4_ds_z2_offload_optimizer_config.json \ -# --deepspeed ./deepspeed_configs/5_ds_z3_config.json \ -# --deepspeed ./deepspeed_configs/6_ds_z3_offload_param_config.json \ -# --deepspeed ./deepspeed_configs/7_ds_z3_offload_optimizer_config.json \ -# --deepspeed ./deepspeed_configs/8_ds_z3_offload_param_offload_optimizer_config.json \ - -accelerate launch \ ---mixed_precision bf16 \ ---use_deepspeed \ -train.py \ -config-base.yaml \ ---deepspeed ./deepspeed_configs/3_ds_z2_config.json \ ---flash_attention $USE_FLASH_ATTENTION \ ---base_model $MODEL_ID \ ---train_data_uri $TRAIN_DATA \ ---max_steps $MAX_STEPS \ ---val_data_uri None \ ---val_set_size 0.1 \ ---micro_batch_size $TRAIN_BATCH_SIZE \ ---num_epochs $NUM_TRAIN_EPOCHS \ ---gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \ ---gradient_checkpointing $GRADIENT_CHECKPOINTING \ ---learning_rate 0.00001 \ ---output_dir ./outputs \ ---train_on_inputs False \ ---logging_steps 1 \ ---save_strategy steps \ ---save_steps 0.05 \ ---evaluation_strategy steps \ ---eval_steps 0.05 \ ---adapter qlora \ ---lora_target_linear True \ ---lora_r $LORA_R \ ---lora_alpha $LORA_ALPHA \ ---mlfoundry_enable_reporting False \ ---mlfoundry_ml_repo my-ml-repo \ ---mlfoundry_run_name test \ ---mlfoundry_checkpoint_artifact_name chk-test \ ---mlfoundry_log_checkpoints False \ ---resume_from_checkpoint False \ ---cleanup_output_dir_on_start True +./sample_run.sh ``` +--- -- `TORCH_PER_PROCESS_MEMORY_LIMIT` allows limiting the max memory per gpu. Can be a fraction (denoting percentage) or integer (denoting limit in MiB). Useful for testing limited gpu memory scenarios -- CUDA_VISIBLE_DEVICES can be used to control the amount of GPUs -- `--mlfoundry_enable_reporting true/false` toggles reporting metrics, checkpoints and models to mlfoundry -- When you are testing locally, you can set `--cleanup_output_dir_on_start true` if you don't care about checkpoints between runs +TODO: + +- [ ] Setup C/I Tests +- [ ] Track and publish VRAM and Speed benchmarks for popular models and GPUs --- Generally we always try to optimize for memory footprint because that allows higher batch size and more gpu utilization Speedup is second priority but we take what we can easily get - -#### Experimental things we want to try - -- Memory Savings Optimizers - - AnyPrecision Adam: `--optim adamw_anyprecision --optim-args "use_kahan_summation=True, momentum_dtype=bfloat16, variance_dtype=bfloat16"` - - 8-bit Adam: `--optim adamw_bnb_8bit` - - Zero's BF16 optimizer -- torch.compile -> Works in some cases, can speedup training -- Zero++ quantized weights and gradients for faster comm -- Long context - - Sequence Parallelism w/ Deepspeed Ulysses - - LongLora with SSA - - Tricks mentioned in Meta: Effective Long-Context Scaling of Foundation Model - - Quantized Activations? - FP8 training is already a thing - - https://github.com/kaiokendev/alpaca_lora_4bit -- DP + TP + PP aka Megatron - - Difficult to configure, Megatron-Deepspeed provides lower throughput but easier to work with diff --git a/config-base.yaml b/config-base.yaml index 9f0538c..c2a06d9 100644 --- a/config-base.yaml +++ b/config-base.yaml @@ -16,6 +16,7 @@ mlfoundry_ml_repo: null # --------------------- # Auto computed and set by script based on environment and external state # Only edit them if you know what you are doing +chat_template: auto # type: string data_dir: auto # type: string datasets: auto # type: list test_datasets: auto # type: list @@ -29,6 +30,10 @@ load_in_4bit: auto # type: bool lora_modules_to_save: auto # type: list resume_from_checkpoint: auto # type: bool special_tokens: auto # type: dict +unsloth_cross_entropy_loss: auto # type: bool +unsloth_lora_mlp: auto # type: bool +unsloth_lora_qkv: auto # type: bool +unsloth_lora_o: auto # type: bool tf32: auto # type: bool ## Added by TrueFoundry, not native to Axolotl mlfoundry_run_name: auto # type: string @@ -49,8 +54,8 @@ base_model_ignore_patterns: - '*.ot' - '*.tflite' - '*.msgpack' -chat_template: chatml dataset_prepared_path: ./outputs/data/last_run_prepared +dataset_processes: 1 ddp_timeout: 21600 deepspeed: ./deepspeed_configs/3_ds_z2_config.json default_system_message: You are a helpful assistant. Please give a long and detailed answer. @@ -105,3 +110,4 @@ logging_dir: ./tensorboard_logs mlfoundry_log_checkpoints: True use_mflow: False use_wandb: False +use_tensorboard: True diff --git a/custom_prompt_strategies.py b/custom_prompt_strategies.py deleted file mode 100644 index ca0e0d2..0000000 --- a/custom_prompt_strategies.py +++ /dev/null @@ -1,73 +0,0 @@ -import logging -from typing import Any, Dict, Optional - -from axolotl.prompt_strategies.chat_template import ( - ChatTemplatePrompter, - ChatTemplateStrategy, -) -from axolotl.prompt_strategies.sharegpt import ( - ShareGPTPrompterV2, - SimpleShareGPTPromptTokenizingStrategy, -) -from axolotl.utils.chat_templates import chat_templates -from transformers import PreTrainedTokenizer - -logger = logging.getLogger("axolotl") - - -class OpenAIShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy): - """ - Sharegpt strategy that remaps openai chat data to sharegpt format - """ - - def get_conversation_thread(self, prompt): - conversations = prompt["messages"] - role_key = "role" - value_key = "content" - role_map = { - "user": "human", - "human": "human", - "assistant": "gpt", - "gpt": "gpt", - "system": "system", - } - turns = [{"from": role_map[t[role_key]], "value": t[value_key]} for t in conversations] - return turns - - -def load_openai_sharegpt(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): - conversation = ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None - strategy = OpenAIShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2( - conversation=conversation, - ), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - return strategy - - -# TODO (chiragjn): Add proper support for HF Tokenizers based chat templates -# Axolotl has provided an implementation but -# it requires the key to be "conversations" instead of "messages" -# secondly, it does not correctly mask the prompt tokens accounting for system prompt - -# Goal is to have something like follows: -# def load_hf_chat_template(tokenizer: PreTrainedTokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): -# chat_template = ( -# ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else None -# ) -# if chat_template is None: -# if not tokenizer.chat_template: -# logger.warning("No chat template provided and tokenizer also does not have one set. Using default 'chatml'.") -# chat_template = "chatml" -# else: -# chat_template = chat_templates(chat_template) -# strategy = ChatTemplateStrategy( -# ChatTemplatePrompter(tokenizer, chat_template), -# tokenizer, -# cfg.train_on_inputs, -# cfg.sequence_len, -# ) -# return strategy diff --git a/data_utils.py b/data_utils.py index 971714e..cc77170 100644 --- a/data_utils.py +++ b/data_utils.py @@ -25,7 +25,9 @@ class DatasetType(str, enum.Enum): chat = "chat" -def _make_dataset_file_source(path, split="train", dataset_type: DatasetType = DatasetType.completion): +def _make_dataset_file_source( + path, split="train", dataset_type: DatasetType = DatasetType.completion, chat_template: str = "chatml" +): """ Axolotl dynamically loads prompt strategies based on the `type` key The modules are present at axolotl.prompt_strategies.* @@ -49,9 +51,9 @@ def _make_dataset_file_source(path, split="train", dataset_type: DatasetType = D "field_system": "system", "field_instruction": "prompt", "field_output": "completion", - "format": "{instruction} {input} ", - "no_input_format": "{instruction}", - "system_format": "{system}", + "format": "{instruction}\n{input}\n", + "no_input_format": "{instruction}\n", + "system_format": "{system}\n", }, "split": split, } @@ -59,15 +61,21 @@ def _make_dataset_file_source(path, split="train", dataset_type: DatasetType = D return { "path": path, "ds_type": "json", - "type": "custom_prompt_strategies.load_openai_sharegpt", - "conversation": "chatml", + "type": "chat_template", + "chat_template": chat_template, + "field_messages": "messages", + "message_field_role": "role", + "message_field_content": "content", + "roles": {"system": ["system"], "user": ["user", "human"], "assistant": ["assistant"], "tool": ["tool"]}, "split": split, } else: raise ValueError(f"Unsupported dataset type: {dataset_type}") -def dataset_uri_to_axolotl_datasources(uri, download_dir, dataset_type: DatasetType = DatasetType.completion): +def dataset_uri_to_axolotl_datasources( + uri, download_dir, dataset_type: DatasetType = DatasetType.completion, chat_template: str = "chatml" +): # TODO: Add support for HF datasets if uri.startswith("https://"): return [_make_dataset_file_source(path=uri, dataset_type=dataset_type)] diff --git a/finetune.ipynb b/finetune.ipynb index f2bc361..7641003 100644 --- a/finetune.ipynb +++ b/finetune.ipynb @@ -363,7 +363,7 @@ "config-base.yaml \\\n", "--deepspeed ./deepspeed_configs/3_ds_z2_config.json \\\n", "--flash_attention True \\\n", - "--gradient_checkpointing True \\\n", + "--gradient_checkpointing unsloth \\\n", "--base_model {model_id} \\\n", "--output_dir {output_dir} \\\n", "--dataset_type {dataset_type} \\\n", diff --git a/monkey_patch.py b/monkey_patch.py index 492daed..0e55b56 100644 --- a/monkey_patch.py +++ b/monkey_patch.py @@ -46,12 +46,6 @@ class TruefoundryAxolotlConfigWCapabilities(AxolotlConfigWCapabilities, Truefoun return DictDefault(dict(TruefoundryAxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_unset=True))) -def add_custom_prompt_strategies(): - import custom_prompt_strategies - - sys.modules["axolotl.prompt_strategies.custom_prompt_strategies"] = custom_prompt_strategies - - def patched_pretrain_hooks(cfg, trainer): # type: (DictDefault, AxolotlTrainer) -> None # Bad hack because axolotl is not flexible at the moment @@ -128,9 +122,6 @@ def monkey_patch_axolotl_internals(): else: raise ValueError("Did not find `validate_config` on `axolotl.utils.config`. " "This is required") - logger.info("Adding custom data prompt strategies...") - add_custom_prompt_strategies() - if hasattr(axolotl.train, "pretrain_hooks"): logger.info("Patching pretrain_hooks...") axolotl.train.pretrain_hooks = patched_pretrain_hooks diff --git a/requirements.txt b/requirements.txt index e4f2a14..eea8adf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ --extra-index-url https://download.pytorch.org/whl/cu121 -torch==2.2.1+cu121 +torch==2.3.0+cu121 cloud-files==4.15.2 -truefoundry[ml]==0.1.2 +truefoundry[ml]==0.2.4 snowflake-connector-python[pandas]==3.7.0 pyarrow==15.0.0 deepspeed @ git+https://github.com/truefoundry/DeepSpeed@0866580c316963ddda30ffee44de2c3e21129556 +unsloth @ git+https://github.com/unslothai/unsloth@27fa021a7bb959a53667dd4e7cdb9598c207aa0d diff --git a/sample_run.sh b/sample_run.sh new file mode 100755 index 0000000..6936f9a --- /dev/null +++ b/sample_run.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +# --- Environment variables --- +export DISABLE_MLFLOW_INTEGRATION=True +export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" + +## This controls how many GPUs you want to use +export CUDA_VISIBLE_DEVICES=0 +## This controls how much memory to user per gpu +export TORCH_PER_PROCESS_MEMORY_LIMIT=0.99 + +## Add your token for private/gated models +# export HF_TOKEN= + +## Turn these on for debugging +# export CUDA_LAUNCH_BLOCKING=1 +# export NCCL_DEBUG=INFO + +# --- Agruments ---- + +## If to delete outputs/ dir before starting - to start from clean slate +CLEANUP_OUTPUT_DIR_ON_START=True + +## You can logs metrics, checkpoints and final model with TrueFoundry Experiment Tracking +MLFOUNDRY_ENABLE_REPORTING=False +MLFOUNDRY_ML_REPO=llm-finetuning +MLFOUNDRY_RUN_NAME=my-finetuning-run-name + +accelerate launch \ +--mixed_precision bf16 \ +--use_deepspeed \ +train.py \ +config-base.yaml \ +--deepspeed ./deepspeed_configs/3_ds_z2_config.json \ +--flash_attention True \ +--base_model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ +--train_data_uri ./sample_data/chatalpaca-openai-100.jsonl \ +--val_data_uri None \ +--val_set_size 0.1 \ +--dataset_type chat \ +--sequence_len 4096 \ +--max_steps 0 \ +--micro_batch_size 1 \ +--eval_batch_size 1 \ +--num_epochs 1 \ +--gradient_accumulation_steps 4 \ +--gradient_checkpointing unsloth \ +--learning_rate 0.00001 \ +--output_dir ./outputs \ +--train_on_inputs False \ +--logging_steps 1 \ +--save_strategy steps \ +--save_steps 0.2 \ +--evaluation_strategy steps \ +--eval_steps 0.2 \ +--adapter qlora \ +--lora_target_linear True \ +--lora_r 16 \ +--lora_alpha 32 \ +--mlfoundry_enable_reporting $MLFOUNDRY_ENABLE_REPORTING \ +--mlfoundry_ml_repo $MLFOUNDRY_ML_REPO \ +--mlfoundry_run_name $MLFOUNDRY_RUN_NAME \ +--mlfoundry_log_checkpoints True \ +--resume_from_checkpoint True \ +--cleanup_output_dir_on_start $CLEANUP_OUTPUT_DIR_ON_START diff --git a/train.py b/train.py index 06ed58e..9e90357 100644 --- a/train.py +++ b/train.py @@ -15,6 +15,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import barrier, is_main_process, zero_first from axolotl.utils.models import load_tokenizer +from transformers import AutoConfig from transformers.utils import is_torch_bf16_gpu_available, is_torch_tf32_available from checkpoint_utils import cleanup_checkpoints, get_last_checkpoint_for_resume_if_any @@ -48,6 +49,15 @@ DEFAULT_EOS_TOKEN = "" DEFAULT_BOS_TOKEN = "" DEFAULT_UNK_TOKEN = "" +MODEL_TYPE_TO_CHAT_TEMPLATE = { + "llama": "llama3", + "gemma": "gemma", + "cohere": "cohere", + "phi3": "phi_3", + "phi_3": "phi_3", + "phi": "phi_3", + None: "chatml", +} def set_cfg_option_if_auto(cfg, key, value, force=False): @@ -83,6 +93,8 @@ def make_axolotl_config(config_base, kwargs, timestamp=None): if os.path.exists(cfg.output_dir): shutil.rmtree(cfg.output_dir) + model_hf_config = AutoConfig.from_pretrained(cfg.base_model, trust_remote_code=True) + data_dir = os.path.join(os.path.abspath(cfg.output_dir), "data") set_cfg_option_if_auto(cfg, "data_dir", data_dir) cfg.output_dir = os.path.join(os.path.abspath(cfg.output_dir), "model") @@ -145,6 +157,16 @@ def make_axolotl_config(config_base, kwargs, timestamp=None): set_cfg_option_if_auto(cfg, "flash_attn_fuse_mlp", cfg.adapter not in {"qlora", "lora"}) set_cfg_option_if_auto(cfg, "flash_attn_fuse_qkv", cfg.adapter not in {"qlora", "lora"}) + use_unsloth = False # torch.cuda.device_count() == 1 + set_cfg_option_if_auto(cfg, "unsloth_cross_entropy_loss", use_unsloth) + set_cfg_option_if_auto(cfg, "unsloth_lora_mlp", use_unsloth) + set_cfg_option_if_auto(cfg, "unsloth_lora_qkv", use_unsloth) + set_cfg_option_if_auto(cfg, "unsloth_lora_o", use_unsloth) + + model_type = getattr(model_hf_config, "model_type", None) + chat_template = MODEL_TYPE_TO_CHAT_TEMPLATE.get(model_type, "chatml") + set_cfg_option_if_auto(cfg, "chat_template", chat_template) + if cfg.datasets == "auto": if not cfg.train_data_uri: raise ValueError("`train_data_uri` cannot be null when set to `datasets` is set to auto") @@ -152,6 +174,7 @@ def make_axolotl_config(config_base, kwargs, timestamp=None): uri=cfg.train_data_uri, download_dir=cfg.data_dir, dataset_type=cfg.dataset_type, + chat_template=chat_template, ) if cfg.test_datasets == "auto": if cfg.val_data_uri and str(cfg.val_data_uri).lower() != "na": @@ -159,6 +182,7 @@ def make_axolotl_config(config_base, kwargs, timestamp=None): uri=cfg.val_data_uri, download_dir=cfg.data_dir, dataset_type=cfg.dataset_type, + chat_template=chat_template, ) elif cfg.val_set_size: set_cfg_option_if_auto(cfg, "test_datasets", None, force=True) @@ -218,7 +242,7 @@ def train_with_truefoundry(config_base: Path = Path("examples/"), **kwargs): cleanup_checkpoints(output_dir=cfg.output_dir) if cfg.adapter in {"lora", "qlora"}: with temporarily_unset_accelerate_envs(): - axolotl_merge_lora_cli(config=axolotl_config, deepspeed=None, fsdp=None, device_map="auto") + axolotl_merge_lora_cli(config=axolotl_config, device_map="auto") model_dir = os.path.join(model_dir, "merged") model_parent_dir = os.path.dirname(model_dir) # Copy tensorboard logs