In [2]:
import torch
import torch.nn as nn

from functools import partial

from lora import *

In [3]:
# a simple model
model = torch.nn.Sequential()
model.add_module('dumb_module', torch.nn.Sequential())
model.dumb_module.add_module('linear', torch.nn.Linear(5, 10))
model.dumb_module.add_module('relu', torch.nn.ReLU())
model.dumb_module.add_module('Q', torch.nn.Linear(10, 10))
model.add_module('linear', torch.nn.Linear(10, 5))
model.add_module('Q', torch.nn.Linear(5, 7))
model.add_module('K', torch.nn.Linear(7, 3))
model.add_module('A', torch.nn.Linear(3, 3))

x = torch.randn(1, 5)
y = model(x)
print(y)
Y0 = y

tensor([[-0.1226, -0.5784, -0.1457]], grad_fn=<AddmmBackward0>)


In [4]:
for name, module in model.named_modules():
    print(name)


dumb_module
dumb_module.linear
dumb_module.relu
dumb_module.Q
linear
Q
K
A


In [5]:
def construct_lorafa_config(model, rank):
    """Each layer has a unique name within module hierarchy, so we can identify 
    them for lora parametrization"""
    config = {}
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any([name.split('.')[-1] == n for n in ['Q', 'K']]):
            config[name] = {
                nn.Linear: {
                    "weight": partial(
                        LoRAFAParametrization.from_linear,
                        rank=rank, 
                        init_method="svd", # set svd as initiazliation method
                        original_weights=module.weight # pass weights for svd init
                    ),
                }
            }
    return config

lorafa_config = construct_lorafa_config(model, rank=2)
list(lorafa_config.keys())

['dumb_module.Q', 'Q', 'K']

In [6]:
add_lora_by_layer_names(model, lorafa_config)
y = model(x)
print(y)

tensor([[-0.1112, -0.6242, -0.0769]], grad_fn=<AddmmBackward0>)


In [7]:
# list of parameters generated by lora
list(get_lora_state_dict(model).keys())

['dumb_module.Q.parametrizations.weight.0.lora_A',
 'dumb_module.Q.parametrizations.weight.0.lora_B',
 'Q.parametrizations.weight.0.lora_A',
 'Q.parametrizations.weight.0.lora_B',
 'K.parametrizations.weight.0.lora_A',
 'K.parametrizations.weight.0.lora_B']

In [15]:
# Note: for some reason get_lora_named_parameters(model) returns all parameters with requires_grad=False
# need to figure out why
# but this works
for params in get_lora_params(model): # check that A are frozen
    print(params.shape, params.requires_grad)

torch.Size([2, 10]) False
torch.Size([10, 2]) True
torch.Size([2, 7]) False
torch.Size([5, 2]) True
torch.Size([2, 3]) False
torch.Size([7, 2]) True
