LoRA fine-tuning with automatic relevance determination

Full rank matrix $V \in \mathbb{R}^{F \times N}$

Factorize as $V' \stackrel{\Delta}{\approx} WH$

Objective function:

$ C(W, H, \lambda) \stackrel{\Delta}{\approx} − log p(W, H, \lambda | V) = \frac{1}{\phi} D_{\beta}(V|WH) + \sum_{k=1}^{K} \frac{1}{\lambda_k} (f(w_k) + f(h_k) + b) + c \log \lambda_k + \text{cst}$

Using L-1 regularization, we define

$f(x) = \| x \|_1$ and $c = F + N + a + 1$


```python
loss = beta_div(Beta,V,W,H,eps_,mask)
cst = (K*C)*(1.0-torch.log(C))
return torch.pow(phi,-1)*loss + (C*torch.sum(torch.log(lambda_ * C))) + cst
```



In [3]:
import torch
import torch.nn as nn
from torch.nn import Module

import math
import numpy as np

In [17]:
class LoRALayer(Module):
    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        init_rank: int,
        alpha: float = 1.,
        a: float = 1.,
        phi: float = 1.,
        dropout: float = 0.,
        # merge_weights: bool = False
    ):
        super().__init__()
        self.init_rank = init_rank
        self.alpha = alpha
        self.scaling = alpha / init_rank
        self.dropout = dropout
        # Initialize A and B matrices
        self.lora_A = nn.Parameter(torch.empty(in_dim, init_rank))
        self.lora_B = nn.Parameter(torch.empty(init_rank, out_dim))
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
        # Initialize constants
        self.register_buffer('c', torch.tensor(((in_dim + out_dim) / 2) + a + 1))
        self.register_buffer('cst', torch.tensor((init_rank * self.c) * (1 - torch.log(self.c))))
        self.register_buffer('phi', torch.tensor(phi))
        # Dropout
        self.dropout = nn.Dropout(p = dropout)

    def forward(self, x):
        x = self.alpha * (self.dropout(x) @ self.lora_A @ self.lora_B)
        return x

class LoRALinear(Module):
    def __init__(self, linear_layer, init_rank, alpha=1, a=1, phi=1, dropout=0.):
        super().__init__()
        self.base_layer = linear_layer
        self.lora = LoRALayer(linear_layer.in_features, linear_layer.out_features, init_rank, alpha, a, phi)

    def forward(self, x):
        return self.base_layer(x) + self.lora(x)

In [None]:
# Functions 
def compute_sparsity_loss(A, B, c, b, cst):
    lambda_k = torch.div(0.5*torch.sum(A ** 2, dim=0) + 0.5*torch.sum(B ** 2, dim=1), c)
    return c * torch.sum(torch.log(lambda_k + b)) + cst

In [None]:
def apply_lora(model, init_rank, alpha=1):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            # Replace linear layers with LoRA layers
            parent_name = '.'.join(name.split('.')[:-1])
            layer_name = name.split('.')[-1]
            parent = model.get_submodule(parent_name)
            lora_layer = LoRALinear(module, init_rank, alpha)
            setattr(parent, layer_name, lora_layer)
    # Freeze base model parameters
    for param in model.parameters():
        param.requires_grad = False
    # Unfreeze LoRA parameters
    for param in model.named_parameters():
        if 'lora' in name:
            param.requires_grad = True
    return model