|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +import re |
14 | 15 | import warnings
|
15 | 16 |
|
16 | 17 | import paddle
|
17 | 18 | from paddle import pir
|
18 | 19 | from paddle.base import core, framework
|
| 20 | +from paddle.base.dygraph import base as imperative_base |
19 | 21 | from paddle.base.framework import Variable, in_dynamic_or_pir_mode, in_pir_mode
|
20 | 22 | from paddle.base.libpaddle import DataType
|
21 | 23 | from paddle.optimizer.adamw import AdamW
|
@@ -583,3 +585,171 @@ def adamw_custom(
|
583 | 585 | moment2[:] = mom2
|
584 | 586 | beta1_pow[:], beta2_pow[:] = beta1 * beta1_pow[:], beta2 * beta2_pow[:]
|
585 | 587 | return
|
| 588 | + |
| 589 | + |
| 590 | +class AdamWLoRAPro(AdamW): |
| 591 | + def __init__(self, scaling_factor=2.0, x_mode="zero", *args, **kwargs): |
| 592 | + super().__init__(*args, **kwargs) |
| 593 | + assert scaling_factor is not None |
| 594 | + if x_mode not in ["zero", "sylvester", "symmetry"]: |
| 595 | + raise ValueError( |
| 596 | + f"Invalid x_mode value: {x_mode}, " f"mode should be in ['zero', 'sylvester', 'symmetry']" |
| 597 | + ) |
| 598 | + self.scaling_factor = scaling_factor |
| 599 | + self.x_mode = x_mode |
| 600 | + |
| 601 | + def _solve_sylvester(self, A, B, C, X=None): |
| 602 | + if A.dtype in [paddle.bfloat16, paddle.float16]: |
| 603 | + A = A.to("float32") |
| 604 | + B = B.to("float32") |
| 605 | + C = C.to("float32") |
| 606 | + B = -B |
| 607 | + m = tuple(B.shape)[-1] |
| 608 | + n = tuple(A.shape)[-1] |
| 609 | + R, U = paddle.linalg.eig(x=A) |
| 610 | + S, V = paddle.linalg.eig(x=B) |
| 611 | + |
| 612 | + CV = C @ V |
| 613 | + |
| 614 | + U_real, U_imag = paddle.real(U), paddle.imag(U) |
| 615 | + CV_real, CV_imag = paddle.real(CV), paddle.imag(CV) |
| 616 | + |
| 617 | + n_dim = U_real.shape[0] |
| 618 | + |
| 619 | + block_top = paddle.concat([U_real, -U_imag], axis=1) # (n, 2n) |
| 620 | + block_bot = paddle.concat([U_imag, U_real], axis=1) # (n, 2n) |
| 621 | + A_block = paddle.concat([block_top, block_bot], axis=0) # (2n, 2n) |
| 622 | + B_block = paddle.concat([CV_real, CV_imag], axis=0) # (2n, m) |
| 623 | + |
| 624 | + F_block = paddle.linalg.solve(A_block, B_block) # [F_real; F_imag] |
| 625 | + |
| 626 | + F_real = F_block[:n_dim, :] |
| 627 | + F_imag = F_block[n_dim:, :] |
| 628 | + F = paddle.complex(F_real, F_imag) |
| 629 | + |
| 630 | + W = R[..., :, None] - S[..., None, :] |
| 631 | + Y = F / W |
| 632 | + try: |
| 633 | + V_inv = paddle.linalg.inv(V) |
| 634 | + except RuntimeError: |
| 635 | + # Add regularization to handle singular matrices |
| 636 | + epsilon = 1e-6 * paddle.mean(paddle.abs(V)) |
| 637 | + V_reg = V + epsilon * paddle.eye(V.shape[-1]) |
| 638 | + V_inv = paddle.linalg.inv(V_reg) |
| 639 | + X = U[..., :n, :n] @ Y[..., :n, :m] @ V_inv[..., :m, :m] |
| 640 | + |
| 641 | + if all(paddle.isreal(x.flatten()[0]) for x in [A, B, C]): |
| 642 | + return paddle.real(X) |
| 643 | + else: |
| 644 | + return X |
| 645 | + |
| 646 | + @imperative_base.no_grad |
| 647 | + @framework.non_static_only |
| 648 | + def step(self) -> None: |
| 649 | + """ |
| 650 | + Execute the optimizer and update parameters once. |
| 651 | +
|
| 652 | + Returns: |
| 653 | + None |
| 654 | +
|
| 655 | + Examples: |
| 656 | + .. code-block:: python |
| 657 | +
|
| 658 | + >>> import paddle |
| 659 | +
|
| 660 | + >>> a = paddle.rand([2,13], dtype="float32") |
| 661 | + >>> linear = paddle.nn.Linear(13, 5) |
| 662 | + >>> # This can be any optimizer supported by dygraph. |
| 663 | + >>> opt = paddle.optimizer.AdamW(learning_rate = 0.01, |
| 664 | + ... parameters = linear.parameters()) |
| 665 | + >>> out = linear(a) |
| 666 | + >>> out.backward() |
| 667 | + >>> opt.step() |
| 668 | + >>> opt.clear_grad() |
| 669 | + """ |
| 670 | + if paddle.base.dygraph.base.in_to_static_mode(): |
| 671 | + self._declarative_step() |
| 672 | + return |
| 673 | + |
| 674 | + if not isinstance(self._parameter_list[0], dict): |
| 675 | + param_id_to_idx = {id(param): idx for idx, param in enumerate(self._parameter_list)} |
| 676 | + |
| 677 | + lora_params = {} |
| 678 | + for idx, param in enumerate(self._parameter_list): |
| 679 | + name = getattr(param, "name", f"param_{idx}") |
| 680 | + match = re.match(r"lo_ra_linear_(\d+)\.w_(\d+)", name) |
| 681 | + if match: |
| 682 | + layer_num = int(match.group(1)) |
| 683 | + weight_type = match.group(2) |
| 684 | + if layer_num not in lora_params: |
| 685 | + lora_params[layer_num] = {} |
| 686 | + lora_params[layer_num][weight_type] = param |
| 687 | + |
| 688 | + for layer_num, weights in lora_params.items(): |
| 689 | + if "1" in weights and "2" in weights: |
| 690 | + param_B = weights["1"] |
| 691 | + param_A = weights["2"] |
| 692 | + |
| 693 | + idx_B = param_id_to_idx[id(param_B)] |
| 694 | + idx_A = param_id_to_idx[id(param_A)] |
| 695 | + |
| 696 | + if param_A._grad_ivar() is not None and param_B._grad_ivar() is not None: |
| 697 | + A = param_A.detach() |
| 698 | + B = param_B.detach() |
| 699 | + grad_A = param_A._grad_ivar() |
| 700 | + grad_B = param_B._grad_ivar() |
| 701 | + |
| 702 | + delta = 1e-08 |
| 703 | + AA_T = A @ A.T |
| 704 | + B_TB = B.T @ B |
| 705 | + AA_T_inv = paddle.linalg.pinv(AA_T + delta * paddle.eye(num_rows=AA_T.shape[0])) |
| 706 | + B_TB_inv = paddle.linalg.pinv(B_TB + delta * paddle.eye(num_rows=B_TB.shape[0])) |
| 707 | + |
| 708 | + if self.x_mode == "sylvester": |
| 709 | + X = self._solve_sylvester( |
| 710 | + B_TB, AA_T, -(1 / self.scaling_factor**2) * B_TB_inv @ grad_A @ A.T |
| 711 | + ) |
| 712 | + elif self.x_mode == "symmetry": |
| 713 | + X = -0.5 * (1 / self.scaling_factor**2) * B_TB_inv @ B.T @ grad_B @ AA_T |
| 714 | + else: # zero mode |
| 715 | + X = paddle.zeros(shape=(B_TB_inv.shape[0], B_TB_inv.shape[0])) |
| 716 | + |
| 717 | + X = X.clone().detach().cast(A.dtype) |
| 718 | + |
| 719 | + new_grad_A = (1 / self.scaling_factor**2) * B_TB_inv @ grad_A + X @ A |
| 720 | + new_grad_B = (1 / self.scaling_factor**2) * ( |
| 721 | + (paddle.eye(num_rows=B.shape[0]) - B @ B_TB_inv @ B.T) @ grad_B @ AA_T_inv |
| 722 | + ) - B @ X |
| 723 | + |
| 724 | + self._parameter_list[idx_A]._grad_ivar()[:] = new_grad_A |
| 725 | + self._parameter_list[idx_B]._grad_ivar()[:] = new_grad_B |
| 726 | + |
| 727 | + params_grads = [] |
| 728 | + for param in self._parameter_list: |
| 729 | + if param.stop_gradient: |
| 730 | + continue |
| 731 | + if param._grad_ivar() is not None: |
| 732 | + grad_var = param._grad_ivar() |
| 733 | + if framework.in_dygraph_mode(): |
| 734 | + if ( |
| 735 | + hasattr(grad_var, "is_selected_rows") |
| 736 | + and grad_var.is_selected_rows() |
| 737 | + and self.regularization is not None |
| 738 | + ): |
| 739 | + raise RuntimeError( |
| 740 | + "AdamW don't support weight_decay with sparse parameters, please set it to None." |
| 741 | + ) |
| 742 | + else: |
| 743 | + if ( |
| 744 | + hasattr(grad_var, "_is_sparse") |
| 745 | + and grad_var._is_sparse() |
| 746 | + and self.regularization is not None |
| 747 | + ): |
| 748 | + raise RuntimeError( |
| 749 | + "AdamW don't support weight_decay with sparse parameters, please set it to None." |
| 750 | + ) |
| 751 | + params_grads.append((param, grad_var)) |
| 752 | + |
| 753 | + self._apply_optimize(loss=None, startup_program=None, params_grads=params_grads) |
| 754 | + else: |
| 755 | + raise NotImplementedError("AdamWLoRAPro does not support parameter groups") |
0 commit comments