In [1]:
import torch
import torch.nn as nn

In [9]:
class LoraLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0):
        super(LoraLinear, self).__init__()
        assert r <= min(in_features, out_features), f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
        
        self.r = r
        self.linear = nn.Linear(in_features, out_features, bias)
        self.lora_down = nn.Linear(in_features, r, bias=False)
        self.dropout = nn.Dropout(dropout_p)
        self.lora_up = nn.Linear(r, out_features, bias=False)
        self.scale = scale
        self.selector = nn.Identity()
        
        nn.init.normal_(self.lora_down.weight, std=1/r)
        nn.init.zeros_(self.lora_up.weight)
        
    def forward(self, input):
        return (
            self.linear(input)
            + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
            * self.scale
        )
    
    def realize_as_lora(self):
        return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
    
    def set_selector_from_diag(self, diag: torch.Tensor):
        assert diag.shape == (self.r, ) # diag是个一维向量，长度为r
        self.selector = nn.Linear(self.r, self.r, bias=False)
        self.selector.weight.data = torch.diag(diag) # 把selector的权重初始化为了一个对角矩阵？
        self.selector.weight.data = self.selector.weight.data.to(
            self.lora_up.weight.device
        ).to(self.lora_up.weight.dtype)
        

In [19]:
class LoraConv2d(nn.Module):
    
    def __init__(self, in_channels: int, out_channels: int, kernel_size, stride=1, 
                 padding=0, dilation=1, groups: int=1, bias: bool=True, r: int=4, 
                 dropout_p: float=0.1, scale: float=1.0,):
        super(LoraConv2d, self).__init__() #单继承情况下不用指明当前类和当前对象
        assert r <= min(in_channels, out_channels), f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}"
        
        self.r = r
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, groups=groups, bias=bias)
        
        self.lora_down = nn.Conv2d(in_channels=in_channels, out_channels=r, kernel_size=kernel_size, stride=stride,
                                  padding=padding, dilation=dilation, groups=groups, bias=False)
        self.dropout = nn.Dropout(dropout_p)
        self.lora_up = nn.Conv2d(in_channels=r, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        
        self.selector = nn.Identity()
        self.scale = scale
        
        nn.init.normal_(self.lora_down.weight, std=1/r)
        nn.init.zeros_(self.lora_up.weight)
        
    def forward(self, input):
        return (
            self.conv(input)
            + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
            * self.scale
        )
    
    def realize_as_lora(self):
        return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
    
    def set_selector_from_diag(self, diag: torch.Tensor):
        assert diag.shape == (self.r,)
        self.selector = nn.Conv2d(in_channels=self.r, out_channels=self.r, kernel_size=1, stride=1, padding=0, bias=False)
        self.selector.weight.data = torch.diag(diag)
        
        self.selector.weight.data = self.selector.weight.data.to(
            self.lora_up.weight.device
        ).to(self.lora_up.weight.dtype)

In [18]:
conv = nn.Conv2d(in_channels=8, out_channels=32, kernel_size=3, stride=1, padding=0, bias=False)
conv.weight.data.shape

torch.Size([32, 8, 3, 3])

In [30]:
from typing import List, Type, Optional, Set
DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}

In [31]:
def _find_modules_v2(model, ancestor_class: Optional[Set[str]] = None, search_class: List[Type[nn.Module]] = [nn.Linear],
                    exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraLinear, LoraConv2d],): # Type动态创建类，也就是返回一个类，在3.10版本中用type()代替
    if ancestor_class is not None:
        ancestors = (
            module for module in model.modules() if module.__class__ in ancestor_class
        )
    else:
        ancestors = [module for module in model.modules()]
        
    for ancestor in ancestors:
        for fullname, module in ancestor.named_modules():
            if any([isinstance(module, _class) for _class in search_class]):
                *path, name = fullname.split(".")
                parent = ancestor
                while path:
                    parent = parent.get_submodule(path.pop(0))
                if exclude_children_of and any([isinstance(parent, _class) for _class in exclude_children_of]):
                    continue
                yield parent, name, module


In [33]:
def inject_trainable_lora(model:nn.Module, 
                         target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
                         r: int=4,
                         loras=None, # path to lora .pt
                         verbose:bool=False,
                         dropout_p:float=0.0,
                         scale:float=1.0,):
    require_grad_params = []
    names = []
    
    if loras is not None:
        loras = torch.load(loras)
        
    for _module, name, _child_module in _find_modules_v2(model, target_replace_module, search_class=[nn.Linear]):
        weight = _child_module.weight
        bias = _child_module.bias
        if verbose:
            print("LoRA Injection : injecting lora into ", name)
            print("LoRA Injection : weight shape", weight.shape)
        _tmp = LoraLinear(_child_module.in_features, _child_module.out_features,_child_module.bias is not None, r=r,dropout_p=dropout_p, scale=scale)
        # 这里是用原来linear中的参数替换掉新lora_linear中的参数
        _tmp.linear.weight = weight
        if bias is not None:
            _tmp.linear.bias = bias
            
        # switch the module
        _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
        _module._modules[name] = _tmp
        
        require_grad_params.append(_module._modules[name].lora_up.parameters())
        require_grad_params.append(_module._modules[name].lora_down.parameters())
        
        if loras is not None:
            _module._modules[name].lora_up.weight = loras.pop(0)
            _module._modules[name].lora_down.weight = loras.pop(0)
            
        _module._modules[name].lora_up.weight.requires_grad = True
        _module._modules[name].lora_down.weight.requires_grad = True
        
    return require_grad_params, names
            

In [26]:
conv = nn.Conv2d(1,3,1,1,1)
for module in conv.modules():
    print(module.__class__)

<class 'torch.nn.modules.conv.Conv2d'>
