-
Notifications
You must be signed in to change notification settings - Fork 986
Introduce LoraConfig #17458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce LoraConfig #17458
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -31,12 +31,8 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| checkpoint_path = self.llm_config.base.checkpoint | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| params_path = self.llm_config.base.params | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Adapter checkpoint and config. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| adapter_checkpoint_path = self.llm_config.base.adapter_checkpoint | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| adapter_config_path = self.llm_config.base.adapter_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert (adapter_checkpoint_path is None and adapter_config_path is None) or ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| adapter_checkpoint_path is not None and adapter_config_path is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), "Both adapter_checkpoint_path and adapter_config_path must be specified or neither must be specified." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # LoRA adapter configuration. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lora_config = self.llm_config.base.lora_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.use_kv_cache = self.llm_config.model.use_kv_cache | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -69,10 +65,18 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with open(params_path, "r") as f: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| params = json.loads(f.read()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Get adapter checkpoint and config. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Get adapter checkpoint. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| adapter_checkpoint = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| adapter_config = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if adapter_checkpoint_path: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if lora_config: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Resolve LoRA params from adapter_config JSON if not already set. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if lora_config.adapter_config and lora_config.lora_rank == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with open(lora_config.adapter_config, "r") as f: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cfg = json.load(f) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lora_config.lora_rank = cfg["r"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lora_config.lora_alpha = cfg["lora_alpha"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lora_config.target_modules = cfg["target_modules"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+71
to
+77
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Resolve LoRA params from adapter_config JSON if not already set. | |
| if lora_config.adapter_config and lora_config.lora_rank == 0: | |
| with open(lora_config.adapter_config, "r") as f: | |
| cfg = json.load(f) | |
| lora_config.lora_rank = cfg["r"] | |
| lora_config.lora_alpha = cfg["lora_alpha"] | |
| lora_config.target_modules = cfg["target_modules"] | |
| # Resolve LoRA params from adapter_config JSON if provided, without | |
| # relying on a specific sentinel value for lora_rank. | |
| if lora_config.adapter_config: | |
| with open(lora_config.adapter_config, "r") as f: | |
| cfg = json.load(f) | |
| if not lora_config.lora_rank: | |
| lora_config.lora_rank = cfg["r"] | |
| if not lora_config.lora_alpha: | |
| lora_config.lora_alpha = cfg["lora_alpha"] | |
| if not lora_config.target_modules: | |
| lora_config.target_modules = cfg["target_modules"] |
Copilot
AI
Feb 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing error handling when loading adapter_config JSON. If the JSON file is malformed or missing required keys like 'r', 'lora_alpha', or 'target_modules', this will raise a KeyError without a clear error message. Consider wrapping this in a try-except block with a more informative error message, or add validation after loading to check that all required fields are present.
| cfg = json.load(f) | |
| lora_config.lora_rank = cfg["r"] | |
| lora_config.lora_alpha = cfg["lora_alpha"] | |
| lora_config.target_modules = cfg["target_modules"] | |
| try: | |
| cfg = json.load(f) | |
| except json.JSONDecodeError as e: | |
| raise ValueError( | |
| f"Failed to parse LoRA adapter config JSON file " | |
| f"'{lora_config.adapter_config}': {e}" | |
| ) from e | |
| try: | |
| lora_config.lora_rank = cfg["r"] | |
| lora_config.lora_alpha = cfg["lora_alpha"] | |
| lora_config.target_modules = cfg["target_modules"] | |
| except KeyError as e: | |
| missing_key = e.args[0] if e.args else "unknown" | |
| raise ValueError( | |
| "Missing required key " | |
| f"'{missing_key}' in LoRA adapter config JSON file " | |
| f"'{lora_config.adapter_config}'. Expected keys: " | |
| "'r', 'lora_alpha', 'target_modules'." | |
| ) from e |
Copilot
AI
Feb 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing validation for adapter_checkpoint path. When lora_config is provided, adapter_checkpoint_path could be invalid (file doesn't exist, path is empty, etc.), but there's no explicit check before attempting to load it. The code will only fail with an unclear error when torch.load or the safetensors loader is called. Consider adding early validation to check that the file exists and is readable, with a clear error message.
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -62,6 +62,36 @@ class PreqMode(str, Enum): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| preq_8da4w_out_8da8w = "8da4w_output_8da8w" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @dataclass | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class LoraConfig: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """LoRA adapter configuration. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Can be created in two ways: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 1. From an adapter_config JSON file: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| LoraConfig( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| adapter_checkpoint="/path/to/adapter.safetensors", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| adapter_config="/path/to/adapter_config.json", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Note: user is responsible for parsing the config and | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ensure it doesn't conflict with any explicit values. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+76
to
+77
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Note: user is responsible for parsing the config and | |
| ensure it doesn't conflict with any explicit values. | |
| Note: when adapter_config is provided and lora_rank is left at its | |
| default value (0), values loaded from the JSON (such as lora_rank, | |
| lora_alpha, and target_modules) will overwrite any explicit values | |
| passed to this dataclass; no conflict checking is performed. |
Copilot
AI
Feb 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The LoraConfig dataclass lacks validation to ensure that either adapter_config is provided OR all explicit values (lora_rank, lora_alpha, target_modules) are provided. According to the documentation, these are two mutually exclusive ways to configure LoRA, but there's no post_init method to validate this constraint. This could lead to invalid configurations where adapter_checkpoint is provided but neither adapter_config nor the explicit parameters are set, resulting in runtime errors when the config is used in model.py (lines 72-77).
| def __post_init__(self) -> None: | |
| """ | |
| Validate that LoRA configuration is provided in exactly one of the two | |
| supported ways: | |
| 1. Via adapter_config JSON (adapter_config is not None), in which case | |
| explicit LoRA parameters must not be set. | |
| 2. Via explicit values (adapter_config is None), in which case all of | |
| lora_rank, lora_alpha, and target_modules must be provided. | |
| """ | |
| has_adapter_config = self.adapter_config is not None | |
| has_explicit_params = ( | |
| self.lora_rank != 0 | |
| or self.lora_alpha != 0 | |
| or bool(self.target_modules) | |
| ) | |
| # Enforce mutual exclusivity between adapter_config and explicit params. | |
| if has_adapter_config and has_explicit_params: | |
| raise ValueError( | |
| "LoraConfig must be configured either with 'adapter_config' or " | |
| "with explicit parameters ('lora_rank', 'lora_alpha', " | |
| "'target_modules'), but not both." | |
| ) | |
| # If no adapter_config is provided, require all explicit parameters. | |
| if not has_adapter_config: | |
| if self.lora_rank <= 0 or self.lora_alpha <= 0 or not self.target_modules: | |
| raise ValueError( | |
| "Invalid LoraConfig: when 'adapter_config' is not provided, " | |
| "'lora_rank' and 'lora_alpha' must be positive and " | |
| "'target_modules' must be a non-empty list." | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Directly mutating dataclass fields can lead to unexpected behavior, especially when using OmegaConf which creates structured configs. Instead of mutating lora_config in-place, the config values should be validated and populated during initialization or in a post_init hook. This mutation breaks the principle that configs should be immutable after creation and can cause issues if the same config object is reused or if OmegaConf features like interpolation or merging are used.