In [30]:
import torch

In [31]:
W = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=torch.float)

In [32]:
U, S, V = torch.svd(W)

In [33]:
W, U, S, V

(tensor([[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.],
         [10., 11., 12.]]),
 tensor([[-0.1409,  0.8247,  0.4541],
         [-0.3439,  0.4263, -0.3773],
         [-0.5470,  0.0278, -0.6078],
         [-0.7501, -0.3706,  0.5310]]),
 tensor([2.5462e+01, 1.2907e+00, 2.7206e-07]),
 tensor([[-0.5045, -0.7608,  0.4082],
         [-0.5745, -0.0571, -0.8165],
         [-0.6445,  0.6465,  0.4082]]))

In [34]:
V.size()

torch.Size([3, 3])

In [35]:
r = 2

In [36]:
B = U[:, :r] * torch.sqrt(S[:r])
B

tensor([[-0.7109,  0.9369],
        [-1.7356,  0.4843],
        [-2.7603,  0.0316],
        [-3.7850, -0.4211]])

In [37]:
A = torch.sqrt(S[:r]).unsqueeze(1) * V.t()[:r, :]
A

tensor([[-2.5459, -2.8990, -3.2522],
        [-0.8643, -0.0649,  0.7345]])

In [38]:
B @ A, W

(tensor([[ 1.0000,  2.0000,  3.0000],
         [ 4.0000,  5.0000,  6.0000],
         [ 7.0000,  8.0000,  9.0000],
         [10.0000, 11.0000, 12.0000]]),
 tensor([[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.],
         [10., 11., 12.]]))

In [39]:
(U * torch.sqrt(S)) @ (torch.sqrt(S).unsqueeze(1) * V.t())

tensor([[ 1.0000,  2.0000,  3.0000],
        [ 4.0000,  5.0000,  6.0000],
        [ 7.0000,  8.0000,  9.0000],
        [10.0000, 11.0000, 12.0000]])

In [40]:
r = 2
W_reconstructed = torch.mm(torch.mm(U[:, :r], torch.diag(S[:r])), V[:, :r].t())
W_reconstructed

tensor([[ 1.0000,  2.0000,  3.0000],
        [ 4.0000,  5.0000,  6.0000],
        [ 7.0000,  8.0000,  9.0000],
        [10.0000, 11.0000, 12.0000]])

In [41]:
torch.linalg.matrix_rank(W)

tensor(2)

In [42]:
conv = torch.nn.Conv2d(3, 5, 3)

In [43]:
conv.weight.size()

torch.Size([5, 3, 3, 3])

In [44]:
conv.bias

Parameter containing:
tensor([ 0.0293,  0.0547,  0.1739,  0.1672, -0.1361], requires_grad=True)

In [45]:
import torch
import torch.nn as nn

# 自定义的loraconv2d层
class LoRaConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(LoRaConv2d, self).__init__()
        # 自定义loraconv2d的实现

    def forward(self, x):
        # 自定义loraconv2d的前向传播逻辑
        pass

# 递归替换conv2d层为loraconv2d层
def replace_conv2d_with_loraconv2d(module):
    for name, child in module.named_children():
        if isinstance(child, nn.Conv2d):
            # 创建一个新的loraconv2d层，使用与原conv2d层相同的参数
            new_layer = LoRaConv2d(child.in_channels, child.out_channels, child.kernel_size, child.stride, child.padding)
            # 将新的loraconv2d层替换原conv2d层
            setattr(module, name, new_layer)
        else:
            # 递归替换子模块的conv2d层
            replace_conv2d_with_loraconv2d(child)

# 创建一个示例的nn.Sequential模型
model = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(128 * 7 * 7, 10)
)

# 打印原始模型
print(model)

# 替换所有conv2d层为loraconv2d层
replace_conv2d_with_loraconv2d(model)

# 打印替换后的模型
print(model)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU()
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): ReLU()
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Flatten(start_dim=1, end_dim=-1)
  (7): Linear(in_features=6272, out_features=10, bias=True)
)
Sequential(
  (0): LoRaConv2d()
  (1): ReLU()
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): LoRaConv2d()
  (4): ReLU()
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Flatten(start_dim=1, end_dim=-1)
  (7): Linear(in_features=6272, out_features=10, bias=True)
)


In [46]:
from models.conv import conv
from models.utils import replace_model

In [47]:
x = torch.rand(2, 3, 32, 32)
model = conv([3],  [64, 128, 256, 512], 10)
print(model(x).size())
# print(model)
replace_model(model, 0.125)
print(model(x).size())
print(model)

torch.Size([2, 10])
conv RANK is: 2
conv RANK is: 13
conv RANK is: 26
conv RANK is: 52
torch.Size([2, 10])
Conv(
  (scaler): Scaler()
  (layer0): Sequential(
    (conv): LoraConv(
      (up_conv): Conv2d(3, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (down_conv): Conv2d(2, 64, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1))
    )
    (scale): Scaler()
    (norm): BatchNorm2d(64, eps=1e-05, momentum=None, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxPool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer1): Sequential(
    (conv): LoraConv(
      (up_conv): Conv2d(64, 13, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (down_conv): Conv2d(13, 128, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1))
    )
    (scale): Scaler()
    (norm): BatchNorm2d(128, eps=1e-05, momentum=None, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxPool): MaxPool2d(kernel_size=2, s

In [48]:
x = model.layer2.conv.up_conv.weight
y = model.layer2.conv.down_conv.weight

In [49]:
x_s, y_s = x.size(), y.size()
x_s, y_s

(torch.Size([26, 128, 3, 3]), torch.Size([256, 26, 1, 1]))

In [50]:
z = y.view(y_s[0], -1) @ x.view(x_s[0], -1)
z.view(y_s[0], x_s[1], x_s[2], x_s[3]).size()
z
z_bias = torch.zeros((y_s[0]))
z_bias

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [51]:
params = model.state_dict()

In [52]:
for k, v in params.items():
    print(k)

layer0.conv.up_conv.weight
layer0.conv.up_conv.bias
layer0.conv.down_conv.weight
layer0.conv.down_conv.bias
layer0.norm.weight
layer0.norm.bias
layer0.norm.running_mean
layer0.norm.running_var
layer0.norm.num_batches_tracked
layer1.conv.up_conv.weight
layer1.conv.up_conv.bias
layer1.conv.down_conv.weight
layer1.conv.down_conv.bias
layer1.norm.weight
layer1.norm.bias
layer1.norm.running_mean
layer1.norm.running_var
layer1.norm.num_batches_tracked
layer2.conv.up_conv.weight
layer2.conv.up_conv.bias
layer2.conv.down_conv.weight
layer2.conv.down_conv.bias
layer2.norm.weight
layer2.norm.bias
layer2.norm.running_mean
layer2.norm.running_var
layer2.norm.num_batches_tracked
layer3.conv.up_conv.weight
layer3.conv.up_conv.bias
layer3.conv.down_conv.weight
layer3.conv.down_conv.bias
layer3.norm.weight
layer3.norm.bias
layer3.norm.running_mean
layer3.norm.running_var
layer3.norm.num_batches_tracked
linear.weight
linear.bias


In [53]:
from collections import OrderedDict
new_params = OrderedDict()

In [54]:
for k, v in params.items():
    if 'conv' in k:
        if 'conv.up_conv.weight' in k:
            # print(k.split('.')[0])
            x = params[k.split('.')[0] + '.conv.up_conv.weight']
            y = params[k.split('.')[0] + '.conv.down_conv.weight']
            x_s, y_s = x.size(), y.size()
            z = y.view(y_s[0], -1) @ x.view(x_s[0], -1)
            z = z.view(y_s[0], x_s[1], x_s[2], x_s[3])
            z_bias = torch.zeros((y_s[0]))

            new_params[k.split('.')[0] + '.conv.weight'] = z
            new_params[k.split('.')[0] + '.conv.bias'] = z_bias
        else:
            continue
    else:
        new_params[k] = v


In [55]:
xx = torch.rand((4, 4))
xx.view(2, 2, 4)
xx.size(0)

4

In [56]:
for k, v in new_params.items():
    print(k)

layer0.conv.weight
layer0.conv.bias
layer0.norm.weight
layer0.norm.bias
layer0.norm.running_mean
layer0.norm.running_var
layer0.norm.num_batches_tracked
layer1.conv.weight
layer1.conv.bias
layer1.norm.weight
layer1.norm.bias
layer1.norm.running_mean
layer1.norm.running_var
layer1.norm.num_batches_tracked
layer2.conv.weight
layer2.conv.bias
layer2.norm.weight
layer2.norm.bias
layer2.norm.running_mean
layer2.norm.running_var
layer2.norm.num_batches_tracked
layer3.conv.weight
layer3.conv.bias
layer3.norm.weight
layer3.norm.bias
layer3.norm.running_mean
layer3.norm.running_var
layer3.norm.num_batches_tracked
linear.weight
linear.bias


In [57]:
new_params['layer0.conv.weight'].size()
nnew_params = OrderedDict()

In [58]:
rate = 0.125
for key, value in new_params.items():
    if 'conv' in key:
        print(key)
        if 'weight' in key:
            cout, cin, m, n = value.size()
            w_2dim = value.view(cout, -1)
            rank = int((cin*cout*m*n*rate) / (cout + cin*m*n))
            U, S, V = torch.svd(w_2dim)

            B = U[:, :rank] * torch.sqrt(S[:rank])
            A = torch.sqrt(S[:rank]).unsqueeze(1) * V[:, :rank].t()

            print(f"{rank}   {B.size()}     {A.size()}")

            up_weight = A.view(A.size(0), -1, m, n)
            up_bias = torch.zeros((up_weight.size(0)))

            down_weight = B.view(B.size(0), B.size(1), 1, 1)
            down_bias = torch.zeros((down_weight.size(0)))


            nnew_params[key.replace('weight', 'up_conv.weight')] = up_weight
            nnew_params[key.replace('weight', 'up_conv.bias')] = up_bias

            nnew_params[key.replace('weight', 'down_conv.weight')] = down_weight
            nnew_params[key.replace('weight', 'down_conv.bias')] = down_bias            
        else:
            continue
    else:
        nnew_params[key] = value
        

layer0.conv.weight
2   torch.Size([64, 2])     torch.Size([2, 27])
layer0.conv.bias
layer1.conv.weight
13   torch.Size([128, 13])     torch.Size([13, 576])
layer1.conv.bias
layer2.conv.weight
26   torch.Size([256, 26])     torch.Size([26, 1152])
layer2.conv.bias
layer3.conv.weight
52   torch.Size([512, 52])     torch.Size([52, 2304])
layer3.conv.bias


In [59]:
for k, v in nnew_params.items():
    print(k)

layer0.conv.up_conv.weight
layer0.conv.up_conv.bias
layer0.conv.down_conv.weight
layer0.conv.down_conv.bias
layer0.norm.weight
layer0.norm.bias
layer0.norm.running_mean
layer0.norm.running_var
layer0.norm.num_batches_tracked
layer1.conv.up_conv.weight
layer1.conv.up_conv.bias
layer1.conv.down_conv.weight
layer1.conv.down_conv.bias
layer1.norm.weight
layer1.norm.bias
layer1.norm.running_mean
layer1.norm.running_var
layer1.norm.num_batches_tracked
layer2.conv.up_conv.weight
layer2.conv.up_conv.bias
layer2.conv.down_conv.weight
layer2.conv.down_conv.bias
layer2.norm.weight
layer2.norm.bias
layer2.norm.running_mean
layer2.norm.running_var
layer2.norm.num_batches_tracked
layer3.conv.up_conv.weight
layer3.conv.up_conv.bias
layer3.conv.down_conv.weight
layer3.conv.down_conv.bias
layer3.norm.weight
layer3.norm.bias
layer3.norm.running_mean
layer3.norm.running_var
layer3.norm.num_batches_tracked
linear.weight
linear.bias


In [60]:
nnew_params['layer1.conv.up_conv.weight'].size()

torch.Size([13, 64, 3, 3])

In [61]:
x = torch.rand(2, 3, 32, 32)
model1 = conv([3],  [64, 128, 256, 512], 10)
print(model1(x).size())
# print(model)
replace_model(model1, 0.125)
# model1

torch.Size([2, 10])
conv RANK is: 2
conv RANK is: 13
conv RANK is: 26
conv RANK is: 52


In [62]:
model1.load_state_dict(nnew_params)

<All keys matched successfully>

In [63]:
import torch
dx = torch.ones((4, 7, 8))
dy = torch.ones((4, 7, 8))

In [79]:
one_params = OrderedDict()
zero_params = OrderedDict()

In [82]:
for k, v in nnew_params.items():
    one_params[k] = torch.ones(v.size())
    zero_params[k] = torch.zeros(v.size())
# one_params

In [85]:
from utils import merge_list_params
two_params = merge_list_params([one_params, zero_params])
# two_params

In [86]:
def loadmodel(model, params):
    model.load_state_dict(params)

In [87]:
loadmodel(model1, two_params)

In [91]:
# model1.layer0.conv.up_conv.weight