-
Notifications
You must be signed in to change notification settings - Fork 18
basic knot merging #140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
basic knot merging #140
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a new Knot merging transform for LoRA experts, computing SVD components, storing them locally, and merging expert weights uniformly via TIES.
- Added
KnotMerge
andKnotMergeConfig
to perform SVD-based merges. - Updated
TiesMerge
to factor out parameter-merging logic intomerge_param
. - Added a
test_knot_merge
unit test to validate the new transform.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
tests/test_library_transforms.py | Added test_knot_merge to verify the KnotMerge flow. |
mttl/models/library/library_transforms.py | Implemented KnotMerge transform and refactored TiesMerge. |
Comments suppressed due to low confidence (3)
mttl/models/library/library_transforms.py:11
- Add
import torch
at the top of this file so that all references totorch.save
,torch.load
, and other torch APIs resolve correctly.
from typing import Dict, List, Union
mttl/models/library/library_transforms.py:361
- [nitpick] Rename the variable
ties_mergert
toties_merger
to fix the typo and clarify its purpose.
ties_mergert = TiesMerge()
tests/test_library_transforms.py:96
- [nitpick] Consider adding assertions to verify that the merged weights themselves match expected values (e.g., compare against a manual U @ final_param calculation) to improve test coverage.
assert len(merged_layers) == len(exp.expert_weights.keys()) == 1
@LibraryTransform.register("weighted_knot_merge", KnotMergeConfig) | ||
class KnotMerge(LibraryTransform): | ||
""" | ||
Computes a weighted KnoT merge for LoRA ezperts as in https://arxiv.org/pdf/2410.19735 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct the typo ezperts
to experts
in the docstring.
Computes a weighted KnoT merge for LoRA ezperts as in https://arxiv.org/pdf/2410.19735 | |
Computes a weighted KnoT merge for LoRA experts as in https://arxiv.org/pdf/2410.19735 |
Copilot uses AI. Check for mistakes.
used += keep_mask.sum().item() | ||
else: | ||
# sign majority vote | ||
sign_per_dim = expert_weights.sign().sum(0, keepdim=True).sign() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The second assignment to sign_per_dim
overrides the first; remove the redundant line or clarify which operation is intended.
sign_per_dim = expert_weights.sign().sum(0, keepdim=True).sign() |
Copilot uses AI. Check for mistakes.
Uh oh!
There was an error while loading. Please reload this page.