In [1]:
import torch
import torch.nn as nn


class LinearLoRALayer(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        rank: int,
        lora_alpha: float,
        merge: bool = False,
        dropout: float = 0,
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.rank = rank
        self.merge = merge

        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

        # linear.weight shape is (out_features, in_features)
        self.linear = nn.Linear(in_features, out_features)

        if self.rank > 0:
            self.scale = lora_alpha / rank

            self.lora_a = nn.Parameter(torch.zeros(out_features, rank))
            nn.init.kaiming_normal_(self.lora_a, 0.01)

            self.lora_b = nn.Parameter(torch.zeros(rank, in_features))

            # 设置不可训练
            self.linear.weight.requires_grad = False
            self.linear.bias.requires_grad = False

        if merge:
            self.merge_weight()

    def merge_weight(
        self,
    ) -> None:
        if self.rank > 0 and self.merge:
            self.linear.weight.data += self.scale * (self.lora_a @ self.lora_b)

    def unmerge_weight(
        self,
    ) -> None:
        if self.rank > 0:
            self.linear.weight.data -= self.scale * (self.lora_a @ self.lora_b)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        if self.rank > 0 and not self.merge:
            output = self.linear(X) + self.scale * (X @ (self.lora_a @ self.lora_b).T)
        elif self.rank > 0 and self.merge:
            output = self.linear(X)
        else:
            output = self.linear(X)
        return output

In [2]:
batch_size = 32
seq_len = 128
in_features = 768
out_features = 512
rank = 8
lora_alpha = 16
dropout = 0.1

x = torch.randn(batch_size, seq_len, in_features)
lora_layer = LinearLoRALayer(
    in_features,
    out_features,
    rank,
    lora_alpha,
    merge=False,
    dropout=dropout,
)
output = lora_layer(x)
print(output.shape)  # (32, 128, 512)

torch.Size([32, 128, 512])


In [3]:
lora_layer.merge_weight()
output_after_merge = lora_layer(x)
lora_layer.unmerge_weight()
output_after_unmerge = lora_layer(x)

print(
    "Max difference after merge/unmerge cycle:",
    torch.max(torch.abs(output_after_merge - output_after_unmerge)).item(),
)

Max difference after merge/unmerge cycle: 0.0


In [4]:
lora_layer_merged = LinearLoRALayer(
    in_features=in_features,
    out_features=out_features,
    rank=rank,
    lora_alpha=lora_alpha,
    dropout=dropout,
    merge=True,
)
output_merged = lora_layer_merged(x)
print(f"Output shape (merged): {output_merged.shape}")

Output shape (merged): torch.Size([32, 128, 512])
