Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions .ci/scripts/test_lora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ HF_ADAPTER_PATH=$(
--files "adapter_config.json" "adapter_model.safetensors"
)

# Set environment variables for OmegaConf interpolation in yaml.
export LORA_ADAPTER_CHECKPOINT="${HF_ADAPTER_PATH}/adapter_model.safetensors"
export LORA_ADAPTER_CONFIG="${HF_ADAPTER_PATH}/adapter_config.json"

### SINGLE LORA PTE ###
# Export LoRA PTE file.
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
--config examples/models/qwen3/config/qwen3_xnnpack.yaml \
+base.adapter_checkpoint="${HF_ADAPTER_PATH}/adapter_model.safetensors" \
+base.adapter_config="${HF_ADAPTER_PATH}/adapter_config.json" \
--config examples/models/qwen3/config/qwen3_xnnpack_lora.yaml \
+export.output_name="qwen_lora_math_full.pte"

# Capture the path of the downloaded qwen artifacts
Expand Down Expand Up @@ -93,9 +95,7 @@ fi
### PROGRAM DATA SEPARATION ###
# Export LoRA PTE, LoRA PTD, foundation PTD file.
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
--config examples/models/qwen3/config/qwen3_xnnpack.yaml \
+base.adapter_checkpoint="${HF_ADAPTER_PATH}/adapter_model.safetensors" \
+base.adapter_config="${HF_ADAPTER_PATH}/adapter_config.json" \
--config examples/models/qwen3/config/qwen3_xnnpack_lora.yaml \
+export.output_name="qwen_lora_math.pte" \
+export.foundation_weights_file="qwen_foundation.ptd" \
+export.lora_weights_file="qwen_lora_math.ptd"
Expand All @@ -108,7 +108,7 @@ cmake-out/examples/models/llama/llama_main --model_path=qwen_lora_math.pte --dat
NOW=$(date +"%H:%M:%S")
echo "Finished at ${NOW}"

RESULT=$(cat result.txt)
RESULT=$(cat result2.txt)
if [[ "${RESULT}" == "${EXPECTED_PREFIX}"* ]]; then
echo "Expected result prefix: ${EXPECTED_PREFIX}"
echo "Actual result: ${RESULT}"
Expand Down Expand Up @@ -143,18 +143,19 @@ So, 15% of 80 is equal to (80 * 15) / 100 = 1200 / 100 = 12.
The answer is: 12<|im_end|>"

# Export Quantized PTE, PTD file, no LoRA.
# override base.lora_config=null to avoid creating a lora model
# and loading lora weights.
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
--config examples/models/qwen3/config/qwen3_xnnpack.yaml \
--config examples/models/qwen3/config/qwen3_xnnpack_lora.yaml \
base.lora_config=null \
+export.output_name="qwen_q.pte" \
+export.foundation_weights_file="qwen_foundation_q.ptd" \
+quantization.qmode="8da4w" \
+quantization.group_size=32

# Export Quantized LoRA PTE, LoRA PTD, foundation PTD file.
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
--config examples/models/qwen3/config/qwen3_xnnpack.yaml \
+base.adapter_checkpoint="${HF_ADAPTER_PATH}/adapter_model.safetensors" \
+base.adapter_config="${HF_ADAPTER_PATH}/adapter_config.json" \
--config examples/models/qwen3/config/qwen3_xnnpack_lora.yaml \
+export.output_name="qwen_lora_math_q.pte" \
+export.foundation_weights_file="qwen_foundation_lora_q.ptd" \
+export.lora_weights_file="qwen_lora_math_q.ptd" \
Expand Down
49 changes: 20 additions & 29 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,8 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
checkpoint_path = self.llm_config.base.checkpoint
params_path = self.llm_config.base.params

# Adapter checkpoint and config.
adapter_checkpoint_path = self.llm_config.base.adapter_checkpoint
adapter_config_path = self.llm_config.base.adapter_config
assert (adapter_checkpoint_path is None and adapter_config_path is None) or (
adapter_checkpoint_path is not None and adapter_config_path is not None
), "Both adapter_checkpoint_path and adapter_config_path must be specified or neither must be specified."
# LoRA adapter configuration.
lora_config = self.llm_config.base.lora_config

self.use_kv_cache = self.llm_config.model.use_kv_cache
self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache
Expand Down Expand Up @@ -69,10 +65,18 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
with open(params_path, "r") as f:
params = json.loads(f.read())

# Get adapter checkpoint and config.
# Get adapter checkpoint.
adapter_checkpoint = {}
adapter_config = {}
if adapter_checkpoint_path:
if lora_config:
# Resolve LoRA params from adapter_config JSON if not already set.
if lora_config.adapter_config and lora_config.lora_rank == 0:
with open(lora_config.adapter_config, "r") as f:
cfg = json.load(f)
lora_config.lora_rank = cfg["r"]
lora_config.lora_alpha = cfg["lora_alpha"]
lora_config.target_modules = cfg["target_modules"]
Comment on lines +72 to +77
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Directly mutating dataclass fields can lead to unexpected behavior, especially when using OmegaConf which creates structured configs. Instead of mutating lora_config in-place, the config values should be validated and populated during initialization or in a post_init hook. This mutation breaks the principle that configs should be immutable after creation and can cause issues if the same config object is reused or if OmegaConf features like interpolation or merging are used.

Copilot uses AI. Check for mistakes.
Comment on lines +71 to +77
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition 'lora_config.lora_rank == 0' is used to decide whether to parse adapter_config JSON. However, 0 is the default value for lora_rank, which creates an ambiguous situation: it's unclear whether lora_rank is 0 because it wasn't set (and should be loaded from JSON) or because the user explicitly wants a rank of 0. A rank of 0 would be invalid for LoRA (since it means no low-rank adaptation). Consider using None as the default value for lora_rank instead of 0, or checking if adapter_config is provided and any of the LoRA parameters are still at their defaults.

Suggested change
# Resolve LoRA params from adapter_config JSON if not already set.
if lora_config.adapter_config and lora_config.lora_rank == 0:
with open(lora_config.adapter_config, "r") as f:
cfg = json.load(f)
lora_config.lora_rank = cfg["r"]
lora_config.lora_alpha = cfg["lora_alpha"]
lora_config.target_modules = cfg["target_modules"]
# Resolve LoRA params from adapter_config JSON if provided, without
# relying on a specific sentinel value for lora_rank.
if lora_config.adapter_config:
with open(lora_config.adapter_config, "r") as f:
cfg = json.load(f)
if not lora_config.lora_rank:
lora_config.lora_rank = cfg["r"]
if not lora_config.lora_alpha:
lora_config.lora_alpha = cfg["lora_alpha"]
if not lora_config.target_modules:
lora_config.target_modules = cfg["target_modules"]

Copilot uses AI. Check for mistakes.
Comment on lines +74 to +77
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing error handling when loading adapter_config JSON. If the JSON file is malformed or missing required keys like 'r', 'lora_alpha', or 'target_modules', this will raise a KeyError without a clear error message. Consider wrapping this in a try-except block with a more informative error message, or add validation after loading to check that all required fields are present.

Suggested change
cfg = json.load(f)
lora_config.lora_rank = cfg["r"]
lora_config.lora_alpha = cfg["lora_alpha"]
lora_config.target_modules = cfg["target_modules"]
try:
cfg = json.load(f)
except json.JSONDecodeError as e:
raise ValueError(
f"Failed to parse LoRA adapter config JSON file "
f"'{lora_config.adapter_config}': {e}"
) from e
try:
lora_config.lora_rank = cfg["r"]
lora_config.lora_alpha = cfg["lora_alpha"]
lora_config.target_modules = cfg["target_modules"]
except KeyError as e:
missing_key = e.args[0] if e.args else "unknown"
raise ValueError(
"Missing required key "
f"'{missing_key}' in LoRA adapter config JSON file "
f"'{lora_config.adapter_config}'. Expected keys: "
"'r', 'lora_alpha', 'target_modules'."
) from e

Copilot uses AI. Check for mistakes.

adapter_checkpoint_path = lora_config.adapter_checkpoint
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing validation for adapter_checkpoint path. When lora_config is provided, adapter_checkpoint_path could be invalid (file doesn't exist, path is empty, etc.), but there's no explicit check before attempting to load it. The code will only fail with an unclear error when torch.load or the safetensors loader is called. Consider adding early validation to check that the file exists and is readable, with a clear error message.

Copilot uses AI. Check for mistakes.
if adapter_checkpoint_path.endswith(".pt"):
adapter_checkpoint = torch.load(
adapter_checkpoint_path, map_location=device, mmap=True
Expand All @@ -92,22 +96,6 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
raise ValueError(
f"Unsupported adapter checkpoint format: {adapter_checkpoint_path}"
)

with open(adapter_config_path, "r") as f:
adapter_config_full = json.loads(f.read())
if (
"r" not in adapter_config_full
or "lora_alpha" not in adapter_config_full
or "target_modules" not in adapter_config_full
):
raise ValueError(
"Adapter config must contain r, lora_alpha, and target_modules."
)
adapter_config = {
"r": adapter_config_full["r"],
"lora_alpha": adapter_config_full["lora_alpha"],
"target_modules": adapter_config_full["target_modules"],
}
checkpoint.update(adapter_checkpoint)

output_prune_map = None
Expand All @@ -133,8 +121,10 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
input_prune_map=input_prune_map,
output_prune_map=output_prune_map,
enable_dynamic_shape=self.enable_dynamic_shape,
r=lora_config.lora_rank if lora_config else None,
lora_alpha=lora_config.lora_alpha if lora_config else None,
target_modules=lora_config.target_modules if lora_config else None,
**params,
**adapter_config,
)

if model_args.use_scaled_rope:
Expand Down Expand Up @@ -356,9 +346,10 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):

embedding_bit_width, embedding_group_size = None, None
if self.llm_config.base.preq_embedding_quantize:
embedding_bit_width, embedding_group_size = (
self.llm_config.base.preq_embedding_quantize.split(",")
)
(
embedding_bit_width,
embedding_group_size,
) = self.llm_config.base.preq_embedding_quantize.split(",")
from .source_transformation.pre_quantization import (
transform_embedding_for_pre_quantization,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ base:
model_class: "qwen3_0_6b"
params: "examples/models/qwen3/config/0_6b_config.json"
metadata: '{"get_bos_id": 151644, "get_eos_ids":[151645]}'
lora_config:
adapter_checkpoint: ${oc.env:LORA_ADAPTER_CHECKPOINT}
adapter_config: ${oc.env:LORA_ADAPTER_CONFIG}

model:
use_kv_cache: True
Expand Down
50 changes: 39 additions & 11 deletions extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,36 @@ class PreqMode(str, Enum):
preq_8da4w_out_8da8w = "8da4w_output_8da8w"


@dataclass
class LoraConfig:
"""LoRA adapter configuration.

Can be created in two ways:

1. From an adapter_config JSON file:
LoraConfig(
adapter_checkpoint="/path/to/adapter.safetensors",
adapter_config="/path/to/adapter_config.json",
)
Note: user is responsible for parsing the config and
ensure it doesn't conflict with any explicit values.
Comment on lines +76 to +77
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation note on lines 76-77 states "user is responsible for parsing the config and ensure it doesn't conflict with any explicit values", but this is misleading. Looking at the implementation in model.py (lines 72-77), the config JSON is only parsed when lora_rank is 0, and the values are directly assigned without any conflict checking. If a user provides both adapter_config and explicit values, the explicit values will be silently overwritten if lora_rank is 0. This behavior should either be prevented through validation, or the documentation should be updated to accurately reflect what happens.

Suggested change
Note: user is responsible for parsing the config and
ensure it doesn't conflict with any explicit values.
Note: when adapter_config is provided and lora_rank is left at its
default value (0), values loaded from the JSON (such as lora_rank,
lora_alpha, and target_modules) will overwrite any explicit values
passed to this dataclass; no conflict checking is performed.

Copilot uses AI. Check for mistakes.

2. With explicit values:
LoraConfig(
adapter_checkpoint="/path/to/adapter.safetensors",
lora_rank=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
)
"""

adapter_checkpoint: str
adapter_config: Optional[str] = None
lora_rank: int = 0
lora_alpha: int = 0
target_modules: List[str] = field(default_factory=list)


Comment on lines +93 to +94
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LoraConfig dataclass lacks validation to ensure that either adapter_config is provided OR all explicit values (lora_rank, lora_alpha, target_modules) are provided. According to the documentation, these are two mutually exclusive ways to configure LoRA, but there's no post_init method to validate this constraint. This could lead to invalid configurations where adapter_checkpoint is provided but neither adapter_config nor the explicit parameters are set, resulting in runtime errors when the config is used in model.py (lines 72-77).

Suggested change
def __post_init__(self) -> None:
"""
Validate that LoRA configuration is provided in exactly one of the two
supported ways:
1. Via adapter_config JSON (adapter_config is not None), in which case
explicit LoRA parameters must not be set.
2. Via explicit values (adapter_config is None), in which case all of
lora_rank, lora_alpha, and target_modules must be provided.
"""
has_adapter_config = self.adapter_config is not None
has_explicit_params = (
self.lora_rank != 0
or self.lora_alpha != 0
or bool(self.target_modules)
)
# Enforce mutual exclusivity between adapter_config and explicit params.
if has_adapter_config and has_explicit_params:
raise ValueError(
"LoraConfig must be configured either with 'adapter_config' or "
"with explicit parameters ('lora_rank', 'lora_alpha', "
"'target_modules'), but not both."
)
# If no adapter_config is provided, require all explicit parameters.
if not has_adapter_config:
if self.lora_rank <= 0 or self.lora_alpha <= 0 or not self.target_modules:
raise ValueError(
"Invalid LoraConfig: when 'adapter_config' is not provided, "
"'lora_rank' and 'lora_alpha' must be positive and "
"'target_modules' must be a non-empty list."
)

Copilot uses AI. Check for mistakes.
@dataclass
class BaseConfig:
"""
Expand All @@ -77,11 +107,7 @@ class BaseConfig:
If left empty, the model will either be initialized with random weights
if it is a Llama model or the weights will be downloaded from HuggingFace
if it is a non-Llama model.
adapter_checkpoint: Path to the adapter.pt file from torchtune. Used if
the model has trained LoRA adapters. Must provide
adapter_config.json.
adapter_config: Path to the adapter_config.json file from torchtune.
Used if the model has trained LoRA adapters. Must provide adapter.pt.
lora_config: LoRA adapter configuration.
tokenizer_path: Path to the tokenizer file.
metadata: Json string containing metadata information.
e.g. '"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"'
Expand All @@ -98,8 +124,7 @@ class BaseConfig:
model_class: ModelType = ModelType.llama3
params: Optional[str] = None
checkpoint: Optional[str] = None
adapter_checkpoint: Optional[str] = None
adapter_config: Optional[str] = None
lora_config: Optional[LoraConfig] = None
tokenizer_path: Optional[str] = None
metadata: Optional[str] = None
use_lora: int = 0
Expand Down Expand Up @@ -536,10 +561,13 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
llm_config.base.params = args.params
if hasattr(args, "checkpoint"):
llm_config.base.checkpoint = args.checkpoint
if hasattr(args, "adapter_checkpoint"):
llm_config.base.adapter_checkpoint = args.adapter_checkpoint
if hasattr(args, "adapter_config"):
llm_config.base.adapter_config = args.adapter_config
if hasattr(args, "adapter_checkpoint") and args.adapter_checkpoint:
if not hasattr(args, "adapter_config") or not args.adapter_config:
raise ValueError("--adapter_checkpoint requires --adapter_config")
llm_config.base.lora_config = LoraConfig(
adapter_checkpoint=args.adapter_checkpoint,
adapter_config=args.adapter_config,
)
if hasattr(args, "tokenizer_path"):
llm_config.base.tokenizer_path = args.tokenizer_path
if hasattr(args, "metadata"):
Expand Down
Loading