Skip to content

Commit d67eaa4

Browse files
authored
LoRA-Pro Implemention (#10146)
* lorapro提交 * Fix issues in LoraPro tests and configuration * Fix loRA-Pro test files bugs * lorapro pr修改 * Trigger CI re-run
1 parent e5858b1 commit d67eaa4

File tree

11 files changed

+637
-0
lines changed

11 files changed

+637
-0
lines changed

llm/docs/finetune.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ python ./predict/reft_predictor.py \
236236
- `strategy_type`: 长序列扩展策略的类型,默认为 None。
237237
- `strategy_name`: 长序列扩展策略的具体名称,默认为 None。
238238
- `rope_scaling_factor`: 应用 RoPE 扩展策略时的缩放因子。
239+
- `lorapro`: 是否开启 LoRA-Pro 策略。
239240
</div>
240241

241242
<summary>&emsp; 数据参数(DataArgument)</summary><div>

llm/run_finetune.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# import inspect
1515
import json
1616
import logging
17+
import math
1718
import os
1819
import sys
1920
from functools import partial
@@ -76,6 +77,7 @@
7677
init_chat_template,
7778
)
7879
from paddlenlp.utils.log import logger
80+
from paddlenlp.utils.optimizer import AdamWLoRAPro
7981
from paddlenlp.utils.tools import get_env_device
8082

8183
# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
@@ -447,6 +449,15 @@ def compute_metrics_do_generation(eval_preds):
447449
)
448450
trainable_parameters = [p for p in model.parameters() if not p.stop_gradient]
449451
trainer.set_optimizer_grouped_parameters(trainable_parameters)
452+
if model_args.lorapro:
453+
optimizer = AdamWLoRAPro(
454+
learning_rate=training_args.learning_rate,
455+
parameters=trainable_parameters,
456+
weight_decay=training_args.weight_decay,
457+
scaling_factor=model_args.lorapro_scaling_factor,
458+
x_mode=model_args.lorapro_x_mode,
459+
)
460+
trainer.optimizer = optimizer
450461

451462
# Train
452463
if training_args.do_train:
@@ -560,7 +571,13 @@ def create_peft_model(model_args, reft_args, training_args, dtype, model_config,
560571
use_quick_lora=model_args.use_quick_lora,
561572
lora_use_mixer=model_args.lora_use_mixer,
562573
use_mora=model_args.use_mora,
574+
lorapro=model_args.lorapro,
563575
)
576+
if model_args.lorapro:
577+
if model_args.rslora:
578+
model_args.lorapro_scaling_factor = lora_config.lora_alpha / math.sqrt(lora_config.r)
579+
else:
580+
model_args.lorapro_scaling_factor = lora_config.lora_alpha / lora_config.r
564581
model = LoRAModel(model, lora_config)
565582
else:
566583
model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path)

paddlenlp/peft/lora/lora_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class LoRAConfig:
9494
default=False,
9595
metadata={"help": "Whether to use mos lora."},
9696
)
97+
lorapro: bool = field(default=False, metadata={"help": "Whether to use LoRA-PRO"})
9798

9899
def __post_init__(self):
99100
if self.use_quick_lora and self.lora_dropout > 0:

paddlenlp/peft/lora/lora_layers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
pissa: bool = False,
6666
lora_use_mixer: bool = False,
6767
use_mora: bool = False,
68+
lorapro: bool = False,
6869
mp_moe: bool = False,
6970
is_distributed: bool = False,
7071
**kwargs
@@ -84,6 +85,7 @@ def __init__(
8485
self.merged = False
8586
self.pissa = pissa
8687
self.lora_use_mixer = lora_use_mixer
88+
self.lorapro = lorapro
8789

8890
# Actual trainable parameters
8991
if use_mora: # reset the rank and create high rank matrix

paddlenlp/peft/lora/lora_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ def _find_and_replace_module(self, model, module_name, lora_config):
488488
use_mora=lora_config.use_mora,
489489
mp_moe=getattr(module.weight, "mp_moe", False),
490490
is_distributed=getattr(module.weight, "is_distributed", False),
491+
lorapro=lora_config.lorapro,
491492
)
492493
elif isinstance(module, nn.Conv2D):
493494
lora_module = LoRAConv2D(

paddlenlp/trainer/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2119,6 +2119,7 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
21192119
optimizer_kwargs.update(adam_kwargs)
21202120
else:
21212121
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
2122+
21222123
return optimizer_cls, optimizer_kwargs
21232124

21242125
def create_scheduler(self, num_training_steps: int):

paddlenlp/trl/model_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,17 @@ class ModelConfig:
8686
use_mora: bool = field(
8787
default=False, metadata={"help": "Whether to use MoRA: https://arxiv.org/pdf/2405.12130.pdf"}
8888
)
89+
lorapro: bool = field(
90+
default=False, metadata={"help": "Whether to use LoRA-Pro: https://arxiv.org/pdf/2407.18242"}
91+
)
92+
lorapro_x_mode: str = field(
93+
default="zero",
94+
metadata={"help": "X mode for AdamWLoRAPro optimizer (zero, sylvester, symmetry)."},
95+
)
96+
lorapro_scaling_factor: float = field(
97+
default=2.0,
98+
metadata={"help": "Scaling factor for AdamWLoRAPro optimizer."},
99+
)
89100

90101
# vera related parameters
91102
vera: bool = field(default=False, metadata={"help": "Whether to use vera technique"})

paddlenlp/utils/optimizer.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import re
1415
import warnings
1516

1617
import paddle
1718
from paddle import pir
1819
from paddle.base import core, framework
20+
from paddle.base.dygraph import base as imperative_base
1921
from paddle.base.framework import Variable, in_dynamic_or_pir_mode, in_pir_mode
2022
from paddle.base.libpaddle import DataType
2123
from paddle.optimizer.adamw import AdamW
@@ -583,3 +585,171 @@ def adamw_custom(
583585
moment2[:] = mom2
584586
beta1_pow[:], beta2_pow[:] = beta1 * beta1_pow[:], beta2 * beta2_pow[:]
585587
return
588+
589+
590+
class AdamWLoRAPro(AdamW):
591+
def __init__(self, scaling_factor=2.0, x_mode="zero", *args, **kwargs):
592+
super().__init__(*args, **kwargs)
593+
assert scaling_factor is not None
594+
if x_mode not in ["zero", "sylvester", "symmetry"]:
595+
raise ValueError(
596+
f"Invalid x_mode value: {x_mode}, " f"mode should be in ['zero', 'sylvester', 'symmetry']"
597+
)
598+
self.scaling_factor = scaling_factor
599+
self.x_mode = x_mode
600+
601+
def _solve_sylvester(self, A, B, C, X=None):
602+
if A.dtype in [paddle.bfloat16, paddle.float16]:
603+
A = A.to("float32")
604+
B = B.to("float32")
605+
C = C.to("float32")
606+
B = -B
607+
m = tuple(B.shape)[-1]
608+
n = tuple(A.shape)[-1]
609+
R, U = paddle.linalg.eig(x=A)
610+
S, V = paddle.linalg.eig(x=B)
611+
612+
CV = C @ V
613+
614+
U_real, U_imag = paddle.real(U), paddle.imag(U)
615+
CV_real, CV_imag = paddle.real(CV), paddle.imag(CV)
616+
617+
n_dim = U_real.shape[0]
618+
619+
block_top = paddle.concat([U_real, -U_imag], axis=1) # (n, 2n)
620+
block_bot = paddle.concat([U_imag, U_real], axis=1) # (n, 2n)
621+
A_block = paddle.concat([block_top, block_bot], axis=0) # (2n, 2n)
622+
B_block = paddle.concat([CV_real, CV_imag], axis=0) # (2n, m)
623+
624+
F_block = paddle.linalg.solve(A_block, B_block) # [F_real; F_imag]
625+
626+
F_real = F_block[:n_dim, :]
627+
F_imag = F_block[n_dim:, :]
628+
F = paddle.complex(F_real, F_imag)
629+
630+
W = R[..., :, None] - S[..., None, :]
631+
Y = F / W
632+
try:
633+
V_inv = paddle.linalg.inv(V)
634+
except RuntimeError:
635+
# Add regularization to handle singular matrices
636+
epsilon = 1e-6 * paddle.mean(paddle.abs(V))
637+
V_reg = V + epsilon * paddle.eye(V.shape[-1])
638+
V_inv = paddle.linalg.inv(V_reg)
639+
X = U[..., :n, :n] @ Y[..., :n, :m] @ V_inv[..., :m, :m]
640+
641+
if all(paddle.isreal(x.flatten()[0]) for x in [A, B, C]):
642+
return paddle.real(X)
643+
else:
644+
return X
645+
646+
@imperative_base.no_grad
647+
@framework.non_static_only
648+
def step(self) -> None:
649+
"""
650+
Execute the optimizer and update parameters once.
651+
652+
Returns:
653+
None
654+
655+
Examples:
656+
.. code-block:: python
657+
658+
>>> import paddle
659+
660+
>>> a = paddle.rand([2,13], dtype="float32")
661+
>>> linear = paddle.nn.Linear(13, 5)
662+
>>> # This can be any optimizer supported by dygraph.
663+
>>> opt = paddle.optimizer.AdamW(learning_rate = 0.01,
664+
... parameters = linear.parameters())
665+
>>> out = linear(a)
666+
>>> out.backward()
667+
>>> opt.step()
668+
>>> opt.clear_grad()
669+
"""
670+
if paddle.base.dygraph.base.in_to_static_mode():
671+
self._declarative_step()
672+
return
673+
674+
if not isinstance(self._parameter_list[0], dict):
675+
param_id_to_idx = {id(param): idx for idx, param in enumerate(self._parameter_list)}
676+
677+
lora_params = {}
678+
for idx, param in enumerate(self._parameter_list):
679+
name = getattr(param, "name", f"param_{idx}")
680+
match = re.match(r"lo_ra_linear_(\d+)\.w_(\d+)", name)
681+
if match:
682+
layer_num = int(match.group(1))
683+
weight_type = match.group(2)
684+
if layer_num not in lora_params:
685+
lora_params[layer_num] = {}
686+
lora_params[layer_num][weight_type] = param
687+
688+
for layer_num, weights in lora_params.items():
689+
if "1" in weights and "2" in weights:
690+
param_B = weights["1"]
691+
param_A = weights["2"]
692+
693+
idx_B = param_id_to_idx[id(param_B)]
694+
idx_A = param_id_to_idx[id(param_A)]
695+
696+
if param_A._grad_ivar() is not None and param_B._grad_ivar() is not None:
697+
A = param_A.detach()
698+
B = param_B.detach()
699+
grad_A = param_A._grad_ivar()
700+
grad_B = param_B._grad_ivar()
701+
702+
delta = 1e-08
703+
AA_T = A @ A.T
704+
B_TB = B.T @ B
705+
AA_T_inv = paddle.linalg.pinv(AA_T + delta * paddle.eye(num_rows=AA_T.shape[0]))
706+
B_TB_inv = paddle.linalg.pinv(B_TB + delta * paddle.eye(num_rows=B_TB.shape[0]))
707+
708+
if self.x_mode == "sylvester":
709+
X = self._solve_sylvester(
710+
B_TB, AA_T, -(1 / self.scaling_factor**2) * B_TB_inv @ grad_A @ A.T
711+
)
712+
elif self.x_mode == "symmetry":
713+
X = -0.5 * (1 / self.scaling_factor**2) * B_TB_inv @ B.T @ grad_B @ AA_T
714+
else: # zero mode
715+
X = paddle.zeros(shape=(B_TB_inv.shape[0], B_TB_inv.shape[0]))
716+
717+
X = X.clone().detach().cast(A.dtype)
718+
719+
new_grad_A = (1 / self.scaling_factor**2) * B_TB_inv @ grad_A + X @ A
720+
new_grad_B = (1 / self.scaling_factor**2) * (
721+
(paddle.eye(num_rows=B.shape[0]) - B @ B_TB_inv @ B.T) @ grad_B @ AA_T_inv
722+
) - B @ X
723+
724+
self._parameter_list[idx_A]._grad_ivar()[:] = new_grad_A
725+
self._parameter_list[idx_B]._grad_ivar()[:] = new_grad_B
726+
727+
params_grads = []
728+
for param in self._parameter_list:
729+
if param.stop_gradient:
730+
continue
731+
if param._grad_ivar() is not None:
732+
grad_var = param._grad_ivar()
733+
if framework.in_dygraph_mode():
734+
if (
735+
hasattr(grad_var, "is_selected_rows")
736+
and grad_var.is_selected_rows()
737+
and self.regularization is not None
738+
):
739+
raise RuntimeError(
740+
"AdamW don't support weight_decay with sparse parameters, please set it to None."
741+
)
742+
else:
743+
if (
744+
hasattr(grad_var, "_is_sparse")
745+
and grad_var._is_sparse()
746+
and self.regularization is not None
747+
):
748+
raise RuntimeError(
749+
"AdamW don't support weight_decay with sparse parameters, please set it to None."
750+
)
751+
params_grads.append((param, grad_var))
752+
753+
self._apply_optimize(loss=None, startup_program=None, params_grads=params_grads)
754+
else:
755+
raise NotImplementedError("AdamWLoRAPro does not support parameter groups")

0 commit comments

Comments
 (0)