In [1]:
# https://github.com/cccntu/minLoRA/blob/main/demo.ipynb

## Minilora Network

In [2]:
import math
from functools import partial  # 用于固定某些函数的参数，从而创建一个新的函数。这个新函数会记住被固定的参数，并在调用时使用这些固定参数。

import torch
import torch.nn.utils.parametrize as parametrize
from torch import nn

### lora class

In [3]:
class LoRAParametrization(nn.Module):
    def __init__(self, fan_in, fan_out, fan_in_fan_out=False, rank=4, lora_dropout_p=0.0, lora_alpha=1):
        super().__init__()
        # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
        # otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings
        self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
        self.lora_A = nn.Parameter(torch.zeros(self.swap((rank, fan_in))))
        self.lora_B = nn.Parameter(torch.zeros(self.swap((fan_out, rank))))
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        self.lora_alpha, self.rank = lora_alpha, rank
        self.scaling = lora_alpha / rank
        self.lora_dropout = nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x
        self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x
        self.register_buffer("lora_dropout_mask", torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype))
        self.forward_fn = self.lora_forward

    def _dropout(self, A):
        # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
        return A * self.lora_dropout(self.lora_dropout_mask)

    def lora_forward(self, X):
        return X + torch.matmul(*self.swap((self.lora_B, self.dropout_fn(self.lora_A)))).view(X.shape) * self.scaling

    def forward(self, X):
        return self.forward_fn(X)

    def disable_lora(self):
        self.forward_fn = lambda x: x  # 输出等于输入x，这种设置通常用于临时禁用lora，表示不会在调用lora_forward,model(x)会使用原线性层计算。

    def enable_lora(self):
        self.forward_fn = self.lora_forward

    @classmethod  # 类方法装饰器，它们不绑定到实例上，而是绑定到类本身。
    def from_linear(cls, layer, rank=4, lora_dropout_p=0.0, lora_alpha=1):
        fan_out, fan_in = layer.weight.shape  # (5，7),从layer.weight属性中提出fan_in和fan_out。
        return cls(  # cls是class缩写，表示当前类本身。这里cls被用来调用构造函数__init__，返回一个新的MyLoRA实例
            fan_in, fan_out, fan_in_fan_out=False, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha
        )

    @classmethod
    def from_conv2d(cls, layer, rank=4, lora_dropout_p=0.0, lora_alpha=1):
        fan_out, fan_in = layer.weight.view(layer.weight.shape[0], -1).shape
        return cls(
            fan_in, fan_out, fan_in_fan_out=False, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha
        )

    @classmethod
    def from_embedding(cls, layer, rank=4, lora_dropout_p=0.0, lora_alpha=1):
        fan_in, fan_out = layer.weight.shape
        return cls(
            fan_in, fan_out, fan_in_fan_out=True, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha
        )

### lora config

In [25]:
default_lora_config = {  # specify which layers to add lora to, by default only add to linear layers
    nn.Linear: {  # lora_config的key = '<class 'torch.nn.modules.linear.Linear'>'
        # 表示单个参数参数化方法
        "weight": partial(LoRAParametrization.from_linear, rank=4),  # key='weight', value=partial固定部分参数的单层LoRA layer
        # 可以顺序应用多个参数化方法，继续对应DoRA
        # "weight": partial(MultiplyByTwoParametrization.from_linear, rank=3),
    },
}

### lora functions

In [6]:
'''
    lora_config:  {<class 'torch.nn.modules.linear.Linear'>: {'weight': functools.partial(<bound method 
    LoRAParametrization.from_linear of <class '__main__.LoRAParametrization'>>, rank=4)}}
    type(layer):  <class 'torch.nn.modules.linear.Linear'>
    ----------
    contained
    type(layer):  <class 'torch.nn.modules.linear.Linear'>
    ----------
    contained
    type(layer):  <class 'torch.nn.modules.container.Sequential'>
    ----------
'''

"\n    lora_config:  {<class 'torch.nn.modules.linear.Linear'>: {'weight': functools.partial(<bound method \n    LoRAParametrization.from_linear of <class '__main__.LoRAParametrization'>>, rank=4)}}\n    type(layer):  <class 'torch.nn.modules.linear.Linear'>\n    ----------\n    contained\n    type(layer):  <class 'torch.nn.modules.linear.Linear'>\n    ----------\n    contained\n    type(layer):  <class 'torch.nn.modules.container.Sequential'>\n    ----------\n"

In [27]:
'''
simple example: torch.nn.utils.parametrize.register_parametrization
    output: 原始参数(weight或bias)会被替换为一个通过指定参数模块生成的参数。
    Linear(
      (weight): ParametrizationList(
        (0): MyParametrization()
      )
      (bias): Parameter containing: [torch.FloatTensor of size 5]
    )
'''
linear = nn.Linear(5, 5)
print(linear)
class LowRankParametrization(nn.Module):
    def __init__(self, original_weight, rank=4):
        super().__init__()
        self.rank = rank
        self.U = nn.Parameter(torch.randn(original_weight.size(0), rank))
        self.V = nn.Parameter(torch.randn(rank, original_weight.size(1)))

    def forward(self, x):
        return self.U @ self.V

# 注册低秩参数化
'''
    torch.nn.utils.parametrize.register_parametrization函数用于在模型的参数上注册新的参数化方法。
    这个功能允许你在现有参数layer.weight上应用一些变换LoRAParametrization，特别适用于LoRA
'''
parametrize.register_parametrization(linear, 'weight', LowRankParametrization(linear.weight))
# 可以顺序应用多个参数化方法，继续加就行 <--对应DoRA
# 定义第二个参数化方法
class MultiplyByTwoParametrization(nn.Module):
    def __init__(self, original_weight, rank=4):
        super().__init__()
        self.rank = rank
        self.U = nn.Parameter(torch.randn(original_weight.size(0), rank))
        self.V = nn.Parameter(torch.randn(rank, original_weight.size(1)))
    def forward(self, x):
        return self.U @ self.V
parametrize.register_parametrization(linear, 'weight', MultiplyByTwoParametrization(linear.weight, rank=3))
    
# 打印线性层，查看参数化后的结果
print(linear)
'''
output:
    Linear(in_features=5, out_features=5, bias=True)  # 原始linear层
    -------------------------------------------------
    ParametrizedLinear(                          # 替换后的参数化线性层para linear
      in_features=5, out_features=5, bias=True   # 这表示layer原始参数original weight
      (parametrizations): ModuleDict(            # parametrizations表示应用参数化方法，新模型参数会存储在ModuleDict中，ModuleDict是一个module容器，它像一个dict一样工作。
        (weight): ParametrizationList(           # 这表示weight原始参数现在被替换/应用了ParametrizationList中一个或多个参数化方法.
          (0): LowRankParametrization()          # (0)表示ParametrizationList的第一个参数化方法。
        # (1): MultiplyByTwoParametrization()    # 顺序应用：当ParametrizationList存储多个参数化方法时，所有方法会按顺序应用到weight参数上。
        )                                        
      )
    )
'''

Linear(in_features=5, out_features=5, bias=True)
ParametrizedLinear(
  in_features=5, out_features=5, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LowRankParametrization()
      (1): MultiplyByTwoParametrization()
    )
  )
)


'\noutput:\n    Linear(in_features=5, out_features=5, bias=True)  # 原始linear层\n    -------------------------------------------------\n    ParametrizedLinear(                          # 替换后的参数化线性层para linear\n      in_features=5, out_features=5, bias=True   # 这表示layer原始参数original weight\n      (parametrizations): ModuleDict(            # parametrizations表示应用参数化方法，新模型参数会存储在ModuleDict中，ModuleDict是一个module容器，它像一个dict一样工作。\n        (weight): ParametrizationList(           # 这表示weight原始参数现在被替换/应用了ParametrizationList中一个或多个参数化方法.\n          (0): LowRankParametrization()          # (0)表示ParametrizationList的第一个参数化方法。\n        )                                        # 顺序应用：当ParametrizationList存储多个参数化方法时，所有方法会按顺序应用到weight参数上。\n      )\n    )\n'

In [97]:
def apply_lora(layer, register=True, merge=False, lora_config=default_lora_config): # layer=simple model
    """add lora parametrization to a layer, designed to be used with model.apply"""
    print('Original layer: ', type(layer))
    if register:
        if type(layer) in lora_config:
            for attr_name, parametrization in lora_config[type(layer)].items():  # items函数以list形式返回(key,value)元组列表。
                # torch.nn.utils.parametrize.register_parametrization函数用于在模型的参数上注册新的参数化方法。
                # 这个功能允许你在现有参数layer.weight上应用一些变换LoRAParametrization，特别适用于LoRA
                parametrize.register_parametrization(layer, attr_name, parametrization(layer))  # LoRAParametrization
                print('LoRA Layer: ', type(layer))
                print('-'*20)
    else:  # this will remove all parametrizations, use with caution
        if hasattr(layer, "parametrizations"):
            for attr_name in layer.parametrizations.keys():
                parametrize.remove_parametrizations(layer, attr_name, leave_parametrized=merge)

# simple model将lora.linear的参数传递给apply_lora函数。
def add_lora(model, lora_config=default_lora_config):  # lora_config是一个dict，key=nn.Linear; value={'weight':lora layer}
    """add lora parametrization to all layers in a model. Calling it twice will add lora twice"""
    print('lora_config: ', lora_config)
    model.apply(partial(apply_lora, lora_config=lora_config))  # model每一层应用--固定参数的apply_lora函数


def add_lora_by_name(model, target_module_names, lora_config=default_lora_config):
    """Add LoRA parameterization to specific layers in a model by names"""
    for name, layer in model.named_modules():
        if any([m in name for m in target_module_names]):
            add_lora(layer, lora_config=lora_config)


def merge_lora(model):
    """merge lora parametrization to all layers in a model. This will remove all parametrization"""
    model.apply(partial(apply_lora, register=False, merge=True))


def remove_lora(model):
    """remove lora parametrization to all layers in a model. This will remove all parametrization"""
    model.apply(partial(apply_lora, register=False, merge=False))

### utils

In [48]:
def apply_to_lora(fn):
    """apply a function to LoRAParametrization layers, designed to be used with model.apply"""

    def apply_fn(layer):
        if isinstance(layer, LoRAParametrization):
            fn(layer)

    return apply_fn

In [58]:
enable_lora = lambda model: model.apply(apply_to_lora(lambda x: x.enable_lora()))
disable_lora = lambda model: model.apply(apply_to_lora(lambda x: x.disable_lora()))

In [177]:
def name_is_lora(name):
    return (
        len(name.split(".")) >= 4
        and (name.split(".")[-4]) == "parametrizations"
        and name.split(".")[-1] in ["lora_A", "lora_B"]
    )


def name_is_bias(name):
    return name.split(".")[-1] == "bias"


def get_params_by_name(model, print_shapes=False, name_filter=None):
    for n, p in model.named_parameters():
        if name_filter is None or name_filter(n):
            if print_shapes:
                print(n, p.shape)
            yield p


def get_lora_params(model, print_shapes=False):
    return get_params_by_name(model, print_shapes=print_shapes, name_filter=name_is_lora)


def get_bias_params(model, print_shapes=False):
    return get_params_by_name(model, print_shapes=print_shapes, name_filter=name_is_bias)


def get_lora_state_dict(model):
    return {k: v for k, v in model.state_dict().items() if name_is_lora(k)}

In [162]:
def _prepare_for_multiple_lora(lora_layer):
    lora_layer.lora_As = []
    lora_layer.lora_Bs = []


def _append_lora(lora_layer):
    lora_layer.lora_As.append(nn.Parameter(lora_layer.lora_A.clone()))
    lora_layer.lora_Bs.append(nn.Parameter(lora_layer.lora_B.clone()))


def load_multiple_lora(model, lora_state_dicts):
    model.apply(apply_to_lora(_prepare_for_multiple_lora))
    for state_dict in lora_state_dicts:
        _ = model.load_state_dict(state_dict, strict=False)
        model.apply(apply_to_lora(_append_lora))
    return model

In [163]:
def _select_lora(lora_layer, index):
    lora_layer.lora_A = lora_layer.lora_As[index]
    lora_layer.lora_B = lora_layer.lora_Bs[index]


def select_lora(model, index):
    model.apply(apply_to_lora(lambda x: _select_lora(x, index)))
    return model

## simle model

In [66]:
_ = torch.set_grad_enabled(False)

In [118]:
# a simple model
model = torch.nn.Sequential(
    torch.nn.Linear(in_features=5, out_features=7),
    torch.nn.Linear(in_features=7, out_features=3),
)
x = torch.randn(1,5)
y = model(x)

In [119]:
print(x)
print(y)
y0 = y

tensor([[ 0.4970,  0.2374, -1.5260, -2.1606, -1.0945]])
tensor([[0.3338, 0.0372, 0.1265]])


In [120]:
type(model)

torch.nn.modules.container.Sequential

In [121]:
print(model)

Sequential(
  (0): Linear(in_features=5, out_features=7, bias=True)
  (1): Linear(in_features=7, out_features=3, bias=True)
)


In [122]:
model.state_dict()

OrderedDict([('0.weight',
              tensor([[-0.4462,  0.3822,  0.1326, -0.3324, -0.1871],
                      [-0.2870,  0.3693, -0.2978, -0.3866, -0.1442],
                      [-0.2644,  0.4316,  0.1074, -0.4259,  0.2431],
                      [ 0.1036, -0.2520, -0.0315,  0.2129,  0.3285],
                      [ 0.3182,  0.1956, -0.4419, -0.3272, -0.2717],
                      [-0.4340,  0.4016, -0.2010, -0.1053,  0.2090],
                      [-0.2288, -0.0553,  0.1590, -0.4251, -0.3897]])),
             ('0.bias',
              tensor([ 0.3816, -0.2553,  0.2329,  0.1657, -0.1987, -0.1720,  0.2264])),
             ('1.weight',
              tensor([[-0.1907, -0.2413, -0.1805, -0.1045,  0.3615,  0.3016,  0.3322],
                      [ 0.3732, -0.0422,  0.0975,  0.3448,  0.0249,  0.1269, -0.1858],
                      [-0.3176, -0.0966,  0.0756, -0.0021, -0.0807, -0.0337,  0.3398]])),
             ('1.bias', tensor([-0.1578,  0.0465,  0.2190]))])

In [123]:
for name, layer in model.named_modules():
    print(f"Layer name: {name}, Layer type: {type(layer)}")

Layer name: , Layer type: <class 'torch.nn.modules.container.Sequential'>
Layer name: 0, Layer type: <class 'torch.nn.modules.linear.Linear'>
Layer name: 1, Layer type: <class 'torch.nn.modules.linear.Linear'>


## add lora model

In [124]:
# add lora to the model
# because B is initialized to 0, the output is the same as before
add_lora(model)
y = model(x)
assert torch.allclose(y, y0)  # 用于比较两个tensor是否近似相等的函数，相近返回true，否则是false。

lora_config:  {<class 'torch.nn.modules.linear.Linear'>: {'weight': functools.partial(<bound method LoRAParametrization.from_linear of <class '__main__.LoRAParametrization'>>, rank=3)}}
Original layer:  <class 'torch.nn.modules.linear.Linear'>
LoRA Layer:  <class 'torch.nn.utils.parametrize.ParametrizedLinear'>
--------------------
Original layer:  <class 'torch.nn.modules.linear.Linear'>
LoRA Layer:  <class 'torch.nn.utils.parametrize.ParametrizedLinear'>
--------------------
Original layer:  <class 'torch.nn.modules.container.Sequential'>


In [125]:
print(model)

Sequential(
  (0): ParametrizedLinear(
    in_features=5, out_features=7, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParametrization()
      )
    )
  )
  (1): ParametrizedLinear(
    in_features=7, out_features=3, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParametrization()
      )
    )
  )
)


In [126]:
model.state_dict()

OrderedDict([('0.bias',
              tensor([ 0.3816, -0.2553,  0.2329,  0.1657, -0.1987, -0.1720,  0.2264])),
             ('0.parametrizations.weight.original',
              tensor([[-0.4462,  0.3822,  0.1326, -0.3324, -0.1871],
                      [-0.2870,  0.3693, -0.2978, -0.3866, -0.1442],
                      [-0.2644,  0.4316,  0.1074, -0.4259,  0.2431],
                      [ 0.1036, -0.2520, -0.0315,  0.2129,  0.3285],
                      [ 0.3182,  0.1956, -0.4419, -0.3272, -0.2717],
                      [-0.4340,  0.4016, -0.2010, -0.1053,  0.2090],
                      [-0.2288, -0.0553,  0.1590, -0.4251, -0.3897]])),
             ('0.parametrizations.weight.0.lora_A',
              tensor([[ 0.0139, -0.2692,  0.2622, -0.1078,  0.1925],
                      [ 0.2196,  0.3292,  0.3107, -0.2908, -0.0264],
                      [ 0.1487, -0.2220,  0.0665, -0.0764,  0.3098]])),
             ('0.parametrizations.weight.0.lora_B',
              tensor([[0., 0., 0.],


In [127]:
for name, layer in model.named_modules():
    print(f"Layer name: {name}, Layer type: {type(layer)}")

Layer name: , Layer type: <class 'torch.nn.modules.container.Sequential'>
Layer name: 0, Layer type: <class 'torch.nn.utils.parametrize.ParametrizedLinear'>
Layer name: 0.parametrizations, Layer type: <class 'torch.nn.modules.container.ModuleDict'>
Layer name: 0.parametrizations.weight, Layer type: <class 'torch.nn.utils.parametrize.ParametrizationList'>
Layer name: 0.parametrizations.weight.0, Layer type: <class '__main__.LoRAParametrization'>
Layer name: 1, Layer type: <class 'torch.nn.utils.parametrize.ParametrizedLinear'>
Layer name: 1.parametrizations, Layer type: <class 'torch.nn.modules.container.ModuleDict'>
Layer name: 1.parametrizations.weight, Layer type: <class 'torch.nn.utils.parametrize.ParametrizationList'>
Layer name: 1.parametrizations.weight.0, Layer type: <class '__main__.LoRAParametrization'>


### initialize lora_B

In [128]:
# to make the output different, we need to initialize B to something non-zero
model.apply(apply_to_lora(lambda x: torch.nn.init.ones_(x.lora_B)))
y = model(x)
print(y)
assert not torch.allclose(y, y0)  # 没有返回表示true，即y ≠ y0
y1 = y

tensor([[ 0.2669, -0.0787,  0.1015]])


### disable lora

In [129]:
# now let's try to disable lora, the output is the same as before lora is added.
disable_lora(model)
y = model(x)
assert torch.allclose(y, y0)

In [130]:
x

tensor([[ 0.4970,  0.2374, -1.5260, -2.1606, -1.0945]])

In [131]:
y

tensor([[0.3338, 0.0372, 0.1265]])

In [132]:
print(model)

Sequential(
  (0): ParametrizedLinear(
    in_features=5, out_features=7, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParametrization()
      )
    )
  )
  (1): ParametrizedLinear(
    in_features=7, out_features=3, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParametrization()
      )
    )
  )
)


In [133]:
for name, layer in model.named_modules():
    print(f"Layer name: {name}, Layer type: {type(layer)}")

Layer name: , Layer type: <class 'torch.nn.modules.container.Sequential'>
Layer name: 0, Layer type: <class 'torch.nn.utils.parametrize.ParametrizedLinear'>
Layer name: 0.parametrizations, Layer type: <class 'torch.nn.modules.container.ModuleDict'>
Layer name: 0.parametrizations.weight, Layer type: <class 'torch.nn.utils.parametrize.ParametrizationList'>
Layer name: 0.parametrizations.weight.0, Layer type: <class '__main__.LoRAParametrization'>
Layer name: 1, Layer type: <class 'torch.nn.utils.parametrize.ParametrizedLinear'>
Layer name: 1.parametrizations, Layer type: <class 'torch.nn.modules.container.ModuleDict'>
Layer name: 1.parametrizations.weight, Layer type: <class 'torch.nn.utils.parametrize.ParametrizationList'>
Layer name: 1.parametrizations.weight.0, Layer type: <class '__main__.LoRAParametrization'>


### enable lora

In [134]:
# enable lora again
enable_lora(model)
y = model(x)
assert torch.allclose(y, y1)

### save lora

In [135]:
# let's save the state dict for later use
state_dict_to_save = get_lora_state_dict(model)
state_dict_to_save.keys()

0.bias
0.parametrizations.weight.original
0.parametrizations.weight.0.lora_A
0.parametrizations.weight.0.lora_B
0.parametrizations.weight.0.lora_dropout_mask
1.bias
1.parametrizations.weight.original
1.parametrizations.weight.0.lora_A
1.parametrizations.weight.0.lora_B
1.parametrizations.weight.0.lora_dropout_mask


dict_keys(['0.parametrizations.weight.0.lora_A', '0.parametrizations.weight.0.lora_B', '1.parametrizations.weight.0.lora_A', '1.parametrizations.weight.0.lora_B'])

In [136]:
print(model)

Sequential(
  (0): ParametrizedLinear(
    in_features=5, out_features=7, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParametrization()
      )
    )
  )
  (1): ParametrizedLinear(
    in_features=7, out_features=3, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParametrization()
      )
    )
  )
)


### remove lora

In [137]:
remove_lora(model)  # remove 'parametrizations'

Original layer:  <class '__main__.LoRAParametrization'>
Original layer:  <class 'torch.nn.utils.parametrize.ParametrizationList'>
Original layer:  <class 'torch.nn.modules.container.ModuleDict'>
Original layer:  <class 'torch.nn.utils.parametrize.ParametrizedLinear'>
Original layer:  <class '__main__.LoRAParametrization'>
Original layer:  <class 'torch.nn.utils.parametrize.ParametrizationList'>
Original layer:  <class 'torch.nn.modules.container.ModuleDict'>
Original layer:  <class 'torch.nn.utils.parametrize.ParametrizedLinear'>
Original layer:  <class 'torch.nn.modules.container.Sequential'>


In [138]:
state_dict_of_model = get_lora_state_dict(model)
state_dict_of_model.keys()

0.bias
0.weight
1.bias
1.weight


dict_keys([])

In [139]:
print(model)

Sequential(
  (0): Linear(in_features=5, out_features=7, bias=True)
  (1): Linear(in_features=7, out_features=3, bias=True)
)


### load lora back

In [140]:
# let's try to load the lora back
# first we need to add lora to the model
add_lora(model)
# then we can load the parameters
# strict=False is needed because we are loading a subset of the parameters
_ = model.load_state_dict(state_dict_to_save, strict=False)
y = model(x)
assert torch.allclose(y, y1)

lora_config:  {<class 'torch.nn.modules.linear.Linear'>: {'weight': functools.partial(<bound method LoRAParametrization.from_linear of <class '__main__.LoRAParametrization'>>, rank=3)}}
Original layer:  <class 'torch.nn.modules.linear.Linear'>
LoRA Layer:  <class 'torch.nn.utils.parametrize.ParametrizedLinear'>
--------------------
Original layer:  <class 'torch.nn.modules.linear.Linear'>
LoRA Layer:  <class 'torch.nn.utils.parametrize.ParametrizedLinear'>
--------------------
Original layer:  <class 'torch.nn.modules.container.Sequential'>


### merge lora

In [142]:
# we can merge it to make it a normal linear layer, so there is no overhead for inference.
merge_lora(model)
y = model(x)
assert torch.allclose(y, y1)

Original layer:  <class '__main__.LoRAParametrization'>
Original layer:  <class 'torch.nn.utils.parametrize.ParametrizationList'>
Original layer:  <class 'torch.nn.modules.container.ModuleDict'>
Original layer:  <class 'torch.nn.utils.parametrize.ParametrizedLinear'>
Original layer:  <class '__main__.LoRAParametrization'>
Original layer:  <class 'torch.nn.utils.parametrize.ParametrizationList'>
Original layer:  <class 'torch.nn.modules.container.ModuleDict'>
Original layer:  <class 'torch.nn.utils.parametrize.ParametrizedLinear'>
Original layer:  <class 'torch.nn.modules.container.Sequential'>


In [143]:
print(model)

Sequential(
  (0): Linear(in_features=5, out_features=7, bias=True)
  (1): Linear(in_features=7, out_features=3, bias=True)
)


## Training a model

In [183]:
model = torch.nn.Linear(in_features=5, out_features=3)
print(model)
# step 1: Add LoRA to the model
add_lora(model)
print(model)

# step 2: Collect the parameters, pass them to the optimizer
parameters = [
    {'params': list(get_lora_params(model))}
]
optimizer = torch.optim.AdamW(parameters, lr=1e-3)

# step 3: Train the model
# ...
# Simulate training, update the LoRA parameters
model.apply(apply_to_lora(lambda x: torch.nn.init.normal_(x.lora_A)))
model.apply(apply_to_lora(lambda x: torch.nn.init.normal_(x.lora_B)))

# step 4: export the LoRA parameters
state_dict = model.state_dict()
lora_state_dict = {k: v for k,v in state_dict.items() if name_is_lora(k)}

Linear(in_features=5, out_features=3, bias=True)
lora_config:  {<class 'torch.nn.modules.linear.Linear'>: {'weight': functools.partial(<bound method LoRAParametrization.from_linear of <class '__main__.LoRAParametrization'>>, rank=3)}}
Original layer:  <class 'torch.nn.modules.linear.Linear'>
LoRA Layer:  <class 'torch.nn.utils.parametrize.ParametrizedLinear'>
--------------------
ParametrizedLinear(
  in_features=5, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoRAParametrization()
    )
  )
)


In [184]:
print(model)

ParametrizedLinear(
  in_features=5, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoRAParametrization()
    )
  )
)


## Loading and Inferencing with LoRA

In [185]:
# Step 1: Add LoRA to your model
add_lora(model)

# Step 2: Load the LoRA parameters
_ = model.load_state_dict(lora_state_dict, strict=False)

# Step 3: Merge the LoRA parameters into the model
merge_lora(model)

lora_config:  {<class 'torch.nn.modules.linear.Linear'>: {'weight': functools.partial(<bound method LoRAParametrization.from_linear of <class '__main__.LoRAParametrization'>>, rank=3)}}
Original layer:  <class '__main__.LoRAParametrization'>
Original layer:  <class 'torch.nn.utils.parametrize.ParametrizationList'>
Original layer:  <class 'torch.nn.modules.container.ModuleDict'>
Original layer:  <class 'torch.nn.utils.parametrize.ParametrizedLinear'>
Original layer:  <class '__main__.LoRAParametrization'>
Original layer:  <class 'torch.nn.utils.parametrize.ParametrizationList'>
Original layer:  <class 'torch.nn.modules.container.ModuleDict'>
Original layer:  <class 'torch.nn.utils.parametrize.ParametrizedLinear'>


In [186]:
print(model)

Linear(in_features=5, out_features=3, bias=True)


## Inferencign with multiple LoRA models

In [187]:
# to avoid re-adding LoRA to the model when return the cell, remove lora first
remove_lora(model)
print(model)
# Step 1: Add lora to your model
add_lora(model)

Original layer:  <class 'torch.nn.modules.linear.Linear'>
Linear(in_features=5, out_features=3, bias=True)
lora_config:  {<class 'torch.nn.modules.linear.Linear'>: {'weight': functools.partial(<bound method LoRAParametrization.from_linear of <class '__main__.LoRAParametrization'>>, rank=3)}}
Original layer:  <class 'torch.nn.modules.linear.Linear'>
LoRA Layer:  <class 'torch.nn.utils.parametrize.ParametrizedLinear'>
--------------------


In [188]:
# Step 2: Load the lora parameters
# fake 3 sets of LoRA parameters
lora_state_dict_0 = lora_state_dict
lora_state_dict_1 = {k: torch.ones_like(v) for k,v in lora_state_dict.items()}
lora_state_dict_2 = {k: torch.zeros_like(v) for k,v in lora_state_dict.items()}
lora_state_dicts = [lora_state_dict_0, lora_state_dict_1, lora_state_dict_2]

# 加载多套参数state_dict
load_multiple_lora(model, lora_state_dicts)

# 检查参数
print(model)

ParametrizedLinear(
  in_features=5, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoRAParametrization()
    )
  )
)


In [189]:
lora_state_dicts

[{'parametrizations.weight.0.lora_A': tensor([[-0.6389,  1.1025, -0.4028,  0.1417,  1.2188],
          [ 1.2653,  0.3596, -0.5528, -0.4428,  0.0385],
          [-1.1018, -1.9272, -1.0012, -0.9899,  2.2904]]),
  'parametrizations.weight.0.lora_B': tensor([[ 1.0520,  0.5619,  1.2419],
          [ 0.0190, -0.1718, -1.1967],
          [-1.5829, -2.0740, -0.3178]])},
 {'parametrizations.weight.0.lora_A': tensor([[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]),
  'parametrizations.weight.0.lora_B': tensor([[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]])},
 {'parametrizations.weight.0.lora_A': tensor([[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]),
  'parametrizations.weight.0.lora_B': tensor([[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]])}]

In [190]:
# step 3: Select which LoRA to use at inference time
Y0 = select_lora(model, 0)(x)  # 令lora_layer.lora_A = index指定的参数
Y1 = select_lora(model, 1)(x)
Y2 = select_lora(model, 2)(x)

In [191]:
Y0, Y1, Y2

(tensor([[-0.0684, -1.2928, -2.5007]]),
 tensor([[-4.2630, -5.1295, -5.3923]]),
 tensor([[-0.2162, -1.0827, -1.3455]]))

### merge lora 

In [193]:
remove_lora(model)
init_state_dict = model.state_dict()
print(model)
print(init_state_dict)

Original layer:  <class '__main__.LoRAParametrization'>
Original layer:  <class 'torch.nn.utils.parametrize.ParametrizationList'>
Original layer:  <class 'torch.nn.modules.container.ModuleDict'>
Original layer:  <class 'torch.nn.utils.parametrize.ParametrizedLinear'>
Linear(in_features=5, out_features=3, bias=True)
OrderedDict([('bias', tensor([-0.2535, -0.3261,  0.3318])), ('weight', tensor([[-0.3439,  0.0394, -0.7497, -0.0762,  1.0139],
        [-0.0388,  0.7641,  0.7504,  0.2345, -0.6696],
        [-0.3932, -0.8996,  1.1107,  0.1917, -0.7682]]))])


In [194]:
# verify that it's the same as if we load the lora parameters one by one
for state_dict in lora_state_dicts:
    remove_lora(model)
    _ = model.load_state_dict(init_state_dict, strict=False)
    add_lora(model)
    _ = model.load_state_dict(state_dict, strict=False)
    merge_lora(model)
    y = model(x)
    print(y)

Original layer:  <class 'torch.nn.modules.linear.Linear'>
lora_config:  {<class 'torch.nn.modules.linear.Linear'>: {'weight': functools.partial(<bound method LoRAParametrization.from_linear of <class '__main__.LoRAParametrization'>>, rank=3)}}
Original layer:  <class 'torch.nn.modules.linear.Linear'>
LoRA Layer:  <class 'torch.nn.utils.parametrize.ParametrizedLinear'>
--------------------
Original layer:  <class '__main__.LoRAParametrization'>
Original layer:  <class 'torch.nn.utils.parametrize.ParametrizationList'>
Original layer:  <class 'torch.nn.modules.container.ModuleDict'>
Original layer:  <class 'torch.nn.utils.parametrize.ParametrizedLinear'>
tensor([[-0.0684, -1.2928, -2.5007]])
Original layer:  <class 'torch.nn.modules.linear.Linear'>
lora_config:  {<class 'torch.nn.modules.linear.Linear'>: {'weight': functools.partial(<bound method LoRAParametrization.from_linear of <class '__main__.LoRAParametrization'>>, rank=3)}}
Original layer:  <class 'torch.nn.modules.linear.Linear'>
