In [2]:
import torch
import torch.nn as nn

In [3]:
class LoraLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0):
        super(LoraLinear, self).__init__()
        assert r <= min(in_features, out_features), f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
        
        self.r = r
        self.linear = nn.Linear(in_features, out_features, bias)
        self.lora_down = nn.Linear(in_features, r, bias=False)
        self.dropout = nn.Dropout(dropout_p)
        self.lora_up = nn.Linear(r, out_features, bias=False)
        self.scale = scale
        self.selector = nn.Identity()
        
        nn.init.normal_(self.lora_down.weight, std=1/r)
        nn.init.zeros_(self.lora_up.weight)
        
    def forward(self, input):
        return (
            self.linear(input)
            + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
            * self.scale
        )
    
    def realize_as_lora(self):
        return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
    
    def set_selector_from_diag(self, diag: torch.Tensor):
        assert diag.shape == (self.r, ) # diag是个一维向量，长度为r
        self.selector = nn.Linear(self.r, self.r, bias=False)
        self.selector.weight.data = torch.diag(diag) # 把selector的权重初始化为了一个对角矩阵？
        self.selector.weight.data = self.selector.weight.data.to(
            self.lora_up.weight.device
        ).to(self.lora_up.weight.dtype)
        

In [4]:
class LoraConv2d(nn.Module):
    
    def __init__(self, in_channels: int, out_channels: int, kernel_size, stride=1, 
                 padding=0, dilation=1, groups: int=1, bias: bool=True, r: int=4, 
                 dropout_p: float=0.1, scale: float=1.0,):
        super(LoraConv2d, self).__init__() #单继承情况下不用指明当前类和当前对象
        assert r <= min(in_channels, out_channels), f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}"
        
        self.r = r
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, groups=groups, bias=bias)
        
        self.lora_down = nn.Conv2d(in_channels=in_channels, out_channels=r, kernel_size=kernel_size, stride=stride,
                                  padding=padding, dilation=dilation, groups=groups, bias=False)
        self.dropout = nn.Dropout(dropout_p)
        self.lora_up = nn.Conv2d(in_channels=r, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        
        self.selector = nn.Identity()
        self.scale = scale
        
        nn.init.normal_(self.lora_down.weight, std=1/r)
        nn.init.zeros_(self.lora_up.weight)
        
    def forward(self, input):
        return (
            self.conv(input)
            + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
            * self.scale
        )
    
    def realize_as_lora(self):
        return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
    
    def set_selector_from_diag(self, diag: torch.Tensor):
        assert diag.shape == (self.r,)
        self.selector = nn.Conv2d(in_channels=self.r, out_channels=self.r, kernel_size=1, stride=1, padding=0, bias=False)
        self.selector.weight.data = torch.diag(diag)
        
        self.selector.weight.data = self.selector.weight.data.to(
            self.lora_up.weight.device
        ).to(self.lora_up.weight.dtype)

In [5]:
conv = nn.Conv2d(in_channels=8, out_channels=32, kernel_size=3, stride=1, padding=0, bias=False)
conv.weight.data.shape

In [6]:
from typing import List, Type, Optional, Set
DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}

In [7]:
def _find_modules_v2(model, ancestor_class: Optional[Set[str]] = None, search_class: List[Type[nn.Module]] = [nn.Linear],
                    exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraLinear, LoraConv2d],): # Type动态创建类，也就是返回一个类，在3.10版本中用type()代替
    if ancestor_class is not None:
        ancestors = (
            module for module in model.modules() if module.__class__ in ancestor_class
        )
    else:
        ancestors = [module for module in model.modules()]
        
    for ancestor in ancestors:
        for fullname, module in ancestor.named_modules():
            if any([isinstance(module, _class) for _class in search_class]):
                *path, name = fullname.split(".")
                parent = ancestor
                while path:
                    parent = parent.get_submodule(path.pop(0))
                if exclude_children_of and any([isinstance(parent, _class) for _class in exclude_children_of]):
                    continue
                yield parent, name, module


In [8]:
def inject_trainable_lora(model:nn.Module, 
                         target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
                         r: int=4,
                         loras=None, # path to lora .pt
                         verbose:bool=False,
                         dropout_p:float=0.0,
                         scale:float=1.0,):
    require_grad_params = []
    names = []
    
    if loras is not None:
        loras = torch.load(loras)
        
    for _module, name, _child_module in _find_modules_v2(model, target_replace_module, search_class=[nn.Linear]):
        weight = _child_module.weight
        bias = _child_module.bias
        if verbose:
            print("LoRA Injection : injecting lora into ", name)
            print("LoRA Injection : weight shape", weight.shape)
        _tmp = LoraLinear(_child_module.in_features, _child_module.out_features,_child_module.bias is not None, r=r,dropout_p=dropout_p, scale=scale)
        # 这里是用原来linear中的参数替换掉新lora_linear中的参数
        _tmp.linear.weight = weight
        if bias is not None:
            _tmp.linear.bias = bias
            
        # switch the module
        _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
        _module._modules[name] = _tmp
        
        require_grad_params.append(_module._modules[name].lora_up.parameters())
        require_grad_params.append(_module._modules[name].lora_down.parameters())
        
        if loras is not None:
            _module._modules[name].lora_up.weight = loras.pop(0)
            _module._modules[name].lora_down.weight = loras.pop(0)
            
        _module._modules[name].lora_up.weight.requires_grad = True
        _module._modules[name].lora_down.weight.requires_grad = True
        
    return require_grad_params, names
            

In [2]:
from diffusers import UNet2DConditionModel

In [3]:
# 下载的模型的名字是在缓存中随机值 ～/.cache/huggingface/hub
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet", revision=None)

Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.


In [None]:
for module in unet.modules():
    print(module)
    print("-"*1000)

In [53]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

In [61]:
class LoRALayer():
    def __init__(self, r: int, lora_alpha:int, lora_dropout: float, merge_weights: bool):
        self.r = r
        self.lora_alpha = lora_alpha
        
        if lora_dropout > 0.:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        
        self.merged = False
        self.merge_weights = merge_weights

In [67]:
class Embedding(nn.Embedding, LoRALayer):
    def __init__(self, num_embeddings: int, embedding_dim: int, r: int=0, lora_alpha: int=1, merge_weights: bool=True, **kwargs):
        pass
        

In [62]:
# 采用这种方法构建的LoraLinear up down组件不包含在modules里面
class Linear(nn.Linear, LoRALayer):
    
    def __init__(self, in_features: int, out_features: int, r: int=0, lora_alpha: int=1, lora_dropout: float=0., 
                 fan_in_fan_out: bool=False, merge_weights: bool=True, **kwargs):
        '''
            fan_in_fan_out：为True的话就是XW，False是WX，两种计算方式，结果是相同的
            merge_weights: 这个参数为True的话，计算起来是(W+BA)X这种方式，否则就是WX + BAX这种方法
        '''
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoRALayer.__init__(self, r, lora_alpha, lora_dropout, merge_weights)
        
        self.fan_in_fan_out = fan_in_fan_out
        
        if r > 0:
            self.lora_down = nn.Parameter(self.weight.new_zeros((r, in_features)))
            self.lora_up = nn.Parameter(self.weight.new_zeros((out_features, r)))
            self.scaling = self.lora_alpha / self.r
            
            self.weight.requires_grad = False
        
        self.reset_parameters()
        if fan_in_fan_out:
            # 原本的weight.shape 是 (out,input)， 转置后为(input,out)
            self.weight.data = self.weight.data.T
        
    def reset_parameters(self):
        nn.Linear.reset_parameters(self) # 这里是直接调用了nn.Linear类里面定义的函数，把self当作参数传进去
        if hasattr(self,'lora_down'):
            nn.init.kaiming_normal_(self.lora_down, a=math.sqrt(5))
            nn.init.zeros_(self.lora_up)
    
    def train(self, mode: bool=True):
        def T(w):
            return w.T if self.fan_in_fan_out else w
        nn.Linear.train(self, mode)
        if self.merge_weights and self.merged:
            # 这里是啥意思？
            # 把eval中merge的weight减出来
            if self.r > 0:
                self.weight.data -= T(self.lora_up @ self.lora_down) * self.scaling
            self.merged = False
            
    def eval(self):
        def T(w):
            return w.T if self.fan_in_fan_out else w
        nn.Linear.eval(self)
        if self.merge_weights and not self.merged:
            # merge_weights为True时，通过W+BA把参数融合在一起
            if self.r > 0:
                self.weight.data += T(self.lora_up @ self.lora_down) * self.scaling
            self.merged = True
    
    def forward(self, x:torch.Tensor):
        def T(w):
            return w.T if self.fan_in_fan_out else w
        if self.r > 0 and not self.merged:
            result = F.linear(x, T(self.weight), bias=self.bias)
            if self.r > 0:
                result += (self.lora_dropout(x) @ self.lora_down.T @ self.lora_up.T) * self.scaling
            return result
        else:
            return F.linear(x, T(self.weight), bias=self.bias)
    

In [None]:
class MergedLinear(nn.Linear, LoRALayer):
    def __init__(self, in_features: int, out_features: int, r: int=0, lora_alpha: int=1, 
                 lora_dropout: float=0., enable_lora: List[bool]=[False], fan_in_fan_out: bool=False, 
                 merge_weights: bool=True, **kwargs):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
        assert out_features % len(enable_lora) == 0, "The length of enable_lora must divide out_features"
        self.enable_lora = enable_lora
        self.fan_in_fan_out = fan_in_fan_out
        
        if r > 0 and any(enable_lora):
            # 这里的down和up没看懂是啥意思？
            self.lora_down = nn.Parameter(self.weight.new_zeros((r * sum(enable_lor), in_features)))
            self.lora_up = nn.Parameter(self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r)))
            self.scaling = self.lora_alpha / r
            self.weight.requires_grad = False
            
            self.lora_ind = self.weight.new_zeros(
                (out_features, ), dtype=torch.bool
            ).view(len(enable_lora), -1)
            self.lora_ind[enable_lora, :] = True
            self.lora_ind = self.lora_ind.view(-1)
            

In [65]:
class Conv2d(nn.Conv2d, LoRALayer):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int, r: int = 0, 
                 lora_alpha: int=1, lora_dropout: float=0., merge_weights: bool = True, **kwargs):
        assert type(kernel_size) is int
        nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
        if r > 0:
            self.lora_down = nn.Parameter(
                self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
            )
            self.lora_up = nn.Parameter(
                self.weight.new_zeros((out_channels*kernel_size, r*kernel_size))
            )
            self.scaling = self.lora_alpha / self.r
            self.weight.requires_grad = False
        
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.Conv2d.reset_parameters(self)
        if hasattr(self, 'lora_down'):
            nn.init.kaiming_normal_(self.lora_down, a=math.sqrt(5))
            nn.init.zeros_(self.lora_up)
            
    def train(self, mode: bool=True):
        nn.Conv2d.train(self, mode)
        if self.merge_weights and self.merged:
            self.weight.data -= (self.lora_up @ self.lora_down).view(self.weight.shape) * self.scaling
            self.merged = False
    
    def eval(self):
        nn.Conv2d.eval(self)
        if self.merge_weights and not self.merged:
            self.weight.data += (self.lora_up @ self.lora_down).view(self.weight.shape) * self.scaling
            self.merged = True
    
    def forward(self, x: torch.Tensor):
        if self.r > 0 and self.merged:
            return F.Conv2d(
                x,
                self.weight + (self.lora_up @ self.lora_down).view(self.weight.shape) * self.scaling,
                self.bias, self.stride, self.padding, self.dilation, self.groups
            )
        
        return nn.Conv2d.forward(self, x)
        

In [57]:
lora_fc = Linear(64,32, r=4, fan_in_fan_out=True)
data = torch.randn((8, 64))
output = lora_fc(data)
print(output.shape)

torch.Size([8, 32])
