In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
import argparse

import torch
import torch.nn as nn

from datautils import *
from database import *
from modelutils import *
from quant import *
import time
from timm.optim import Lamb

In [3]:
dataloader, testloader = get_loaders(
    "imagenet", path="",
    batchsize=-1, workers=8,
    nsamples=1024, seed=0,
    noaug=False
)
get_model, test, run = get_functions("rn50")
modelp = get_model()
model_orig = get_model()



In [4]:
db = SparsityDatabase("unstr", "rn50", prefix='', dev='cpu')
modelp = modelp.to('cpu')
layersp = find_layers(modelp)
with open("rn50_unstr_400x_dp.txt", 'r') as f:
    config = {}
    for l in f.readlines():
        level, name = l.strip().split(' ')
        config[name] = level 
db.stitch(layersp, config)
modelp = modelp.to(DEV)
layersp = find_layers(modelp)

In [5]:
total_nz = 0

for n, p in modelp.named_parameters():
    if "weight" not in n:
        continue
    print(n, (p != 0).sum().item() / p.numel())
    total_nz += (p != 0).sum().item()

conv1.weight 0.28241921768707484
bn1.weight 1.0
layer1.0.conv1.weight 0.59033203125
layer1.0.bn1.weight 1.0
layer1.0.conv2.weight 0.1500922309027778
layer1.0.bn2.weight 1.0
layer1.0.conv3.weight 0.13507080078125
layer1.0.bn3.weight 1.0
layer1.0.downsample.0.weight 0.38739013671875
layer1.0.downsample.1.weight 1.0
layer1.1.conv1.weight 0.228759765625
layer1.1.bn1.weight 1.0
layer1.1.conv2.weight 0.1350640190972222
layer1.1.bn2.weight 1.0
layer1.1.conv3.weight 0.28240966796875
layer1.1.bn3.weight 1.0
layer1.2.conv1.weight 0.12152099609375
layer1.2.bn1.weight 1.0
layer1.2.conv2.weight 0.1852756076388889
layer1.2.bn2.weight 1.0
layer1.2.conv3.weight 0.28240966796875
layer1.2.bn3.weight 1.0
layer2.0.conv1.weight 0.166748046875
layer2.0.bn1.weight 1.0
layer2.0.conv2.weight 0.1500922309027778
layer2.0.bn2.weight 1.0
layer2.0.conv3.weight 0.254180908203125
layer2.0.bn3.weight 1.0
layer2.0.downsample.0.weight 0.15009307861328125
layer2.0.downsample.1.weight 1.0
layer2.1.conv1.weight 0.150085449

In [6]:
test(modelp, testloader)

Evaluating ...
72.63


In [7]:
handles = []

def add_batch(layer, inp, out):
    layer.batches = [(inp[0].detach(), out.detach())]
    X = inp[0].detach().float()
    #print(X.shape)
    #assert X.shape[2] == 1
    # TODO: unfold
    #X = X.permute(0, 2, 3, 1)
    #X = X.reshape(-1, X.shape[-1])
    if isinstance(layer, nn.Conv2d):
        unfold = nn.Unfold(
            layer.kernel_size,
            dilation=layer.dilation,
            padding=layer.padding,
            stride=layer.stride
        )
        X = unfold(X)
        X = X.permute([1, 0, 2])
        X = X.flatten(1)
    layer.XX += X.matmul(X.T)

for n, m in model_orig.named_modules():
    if type(m) == nn.Conv2d:
        Wf = m.weight.flatten(1)
        m.XX = torch.zeros(Wf.shape[1], Wf.shape[1], device=m.weight.device)
        handles.append(m.register_forward_hook(add_batch))
        
for i in range(10):
    for j, batch in enumerate(dataloader):
        print(i, j)
        with torch.no_grad():
            run(model_orig, batch)
        
for h in handles:
    h.remove()

0 0
0 1
0 2
0 3
0 4
0 5
0 6
0 7
1 0
1 1
1 2
1 3
1 4
1 5
1 6
1 7
2 0
2 1
2 2
2 3
2 4
2 5
2 6
2 7
3 0
3 1
3 2
3 3
3 4
3 5
3 6
3 7
4 0
4 1
4 2
4 3
4 4
4 5
4 6
4 7
5 0
5 1
5 2
5 3
5 4
5 5
5 6
5 7
6 0
6 1
6 2
6 3
6 4
6 5
6 6
6 7
7 0
7 1
7 2
7 3
7 4
7 5
7 6
7 7
8 0
8 1
8 2
8 3
8 4
8 5
8 6
8 7
9 0
9 1
9 2
9 3
9 4
9 5
9 6
9 7


In [10]:
def find_other2(A, W, nnz, Z, U, print_sc=None, debug=False, reg=0, rho_start=0.03, iters=5, prune_iters=2):
    XX = A.T.matmul(A)
    norm2 = torch.diag(XX).sqrt() + 1e-8
    An = A / norm2
    XX = An.T.matmul(An)
    XX += torch.diag(torch.ones_like(XX.diag())) * XX.diag().mean() * reg
    
    #norm2 = torch.ones_like(norm2)
    Wnn = W# * norm2.unsqueeze(1)
    rho = 1
    XY = An.T.matmul(Wnn)
    XXinv = torch.inverse(XX + torch.eye(XX.shape[1], device=XX.device)*rho)
    XXinv2 = torch.inverse(XX + torch.eye(XX.shape[1], device=XX.device)*rho_start)
    U = U * norm2.unsqueeze(1)
    Z = Z * norm2.unsqueeze(1)
    
    #B = torch.linalg.solve(XX, XY)
    B = XXinv2.matmul(XY + rho_start*(Z-U))
    
    #U = torch.zeros_like(B)
    
    #Z = B
    
    bsparsity = min(0.99, 1 - nnz/B.numel())
    #print("bs", bsparsity)


    for itt in range(iters):
        if itt < prune_iters:
            cur_sparsity = bsparsity# - bsparsity * (1 - (itt + 1) / iterative_prune) ** 3
            thres = (B+U).abs().flatten().sort()[0][int(B.numel() * cur_sparsity)]
            mask = ((B+U).abs() > thres)
            del thres

        Z = (B + U) * mask    

        U = U + (B - Z)    

        B = XXinv.matmul(XY + rho*(Z-U))
        #B = torch.linalg.solve(XX + torch.eye(XX.shape[1], device=XX.device)*rho, XY + rho*(Z-U))
        if debug:
            print(itt, cur_sparsity, (Z != 0).sum().item() / Z.numel())
            print_sc(A.matmul(B / norm2.unsqueeze(1)))
            print_sc(A.matmul(Z / norm2.unsqueeze(1)))
            print(((An != 0).sum() + (Z != 0).sum()) / W.numel())
            print("-------")
    if debug:
        print("opt end")

    return Z / norm2.unsqueeze(1), U / norm2.unsqueeze(1)    
    
def mag_prune(W, sp=0.6):
    thres = (W).abs().flatten().sort()[0][int(W.numel() * sp)]
    mask = ((W).abs() > thres)
    return W * mask

def ent(p):
    return -(p * np.log2(p) + (1-p) * np.log2(1-p))

def factorizeT(W, XX, asp=0.16, sp=0.4, iters=40):
    #W = lx.weight.detach().T.float()
    nza = int(W.shape[0]**2 * asp)
    nzb = int(W.numel() * sp - nza)
    
    Az = torch.eye(W.shape[0], device=W.device)
    Au = torch.zeros_like(Az)
    norm = XX.diag().sqrt().unsqueeze(1) + 1e-8
    norm = torch.ones_like(norm)
       
    Wn = W * norm
       
    Bz = mag_prune(Wn, (1 - nzb/2/W.numel()))
    Bu = torch.zeros_like(Bz)
    
    for itt in range(iters):
        #if itt < 10:
        #    rho_start = 0.0
        #elif itt < 15:
        #    rho_start = 0.00
        #else:
        #    rho_start = 0.1
        rho_start = min(1.0, itt / (iters-3))**3
        Az, Au = (x.T for x in find_other2(Bz.T, Wn.T, nza, Az.T, Au.T, reg=1e-2, debug=False, rho_start=rho_start))
                
        Bz, Bu = find_other2(Az, Wn, nzb, Bz, Bu, reg=1e-2, debug=False, rho_start=rho_start)
    
    #print(((Az != 0).sum() + (Bz != 0).sum()).item() / W.numel(), (Az != 0).sum().item() / Az.numel(),
    #      (Bz != 0).sum().item() / Bz.numel(), Az.shape, Bz.shape,
    #     (Az.numel()*ent((Az != 0).sum().item() / Az.numel()) + Bz.numel()*ent((Bz != 0).sum().item() / Bz.numel())) / W.numel(), 
    #    ent(0.4), ent(0.5))
    return ((Az / norm).matmul(Bz)).T, Bz.T, (Az / norm).T


def factorizef(W, XX, asp=0.16, sp=0.4, iters=200, l_prev=None):
    s_time = time.time()
    if W.shape[0] >= W.shape[1]:
        return factorizeT(W.T, XX, sp=sp, asp=asp, iters=iters)
    
    nza = int(W.shape[0]**2 * asp)
    nzb = int(W.numel() * sp - nza)
    norm = XX.diag().sqrt() + 1e-8
    norm = torch.ones_like(norm)

    Wn = W * norm
    
    Az = torch.eye(W.shape[0], device=W.device)
    Au = torch.zeros_like(Az)

    Bz = mag_prune(Wn, (1 - nzb/2/W.numel()))
    Bu = torch.zeros_like(Bz)
    
    for itt in range(iters):
        #if itt < 10:
        #    rho_start = 0.0
        #elif itt < 15:
        #    rho_start = 0.00
        #else:
        #    rho_start = 0.1
            
        rho_start = min(1.0, itt / (iters-3))**3
        Az, Au = (x.T for x in find_other2(Bz.T, Wn.T, nza, Az.T, Au.T, reg=1e-2, debug=False, rho_start=rho_start))
                
        Bz, Bu = find_other2(Az, Wn, nzb, Bz, Bu, reg=1e-2, debug=False, rho_start=rho_start)
        
        #print(itt, time.time() - s_time, end =" ") 
        #print_scores(Az.matmul(Bz / norm))
        
        
    #print(((Az != 0).sum() + (Bz != 0).sum()).item() / W.numel(), (Az != 0).sum().item() / Az.numel(),
    #      (Bz != 0).sum().item() / Bz.numel(), Az.shape, Bz.shape,
    #     (Az.numel()*ent((Az != 0).sum().item() / Az.numel()) + Bz.numel()*ent((Bz != 0).sum().item() / Bz.numel())) / W.numel(), 
    #    ent(0.4), ent(0.5))
    return Az.matmul(Bz / norm), Az, Bz / norm

def finalize(XXb, W, Ab, Bb):
    fsparsity = 1 - (Ab != 0).sum() / Ab.numel()
    mask = (Ab != 0).T

    XX = Bb.matmul(XXb).matmul(Bb.T)
    XY = Bb.matmul(XXb).matmul(W.detach().float().T)

    norm2 = torch.diag(XX).sqrt() + 1e-8
    XX = XX / norm2 / norm2.unsqueeze(1)
    XY = XY / norm2.unsqueeze(1)
    Ax = (Ab * norm2).T.clone()
    #Ax = torch.linalg.solve(XX, XY)

    rho = 1
    XXinv = torch.inverse(XX + torch.eye(XX.shape[1], device=XX.device)*rho)
    U = torch.zeros_like(Ax)
    for itt in range(200):
        #if itt < 150:
        #    cur_sparsity = fsparsity - fsparsity * (1 - (itt + 1) / 150) ** 3
        #    thres = (Ax+U).abs().flatten().sort()[0][int(Ax.numel() * cur_sparsity)]
        #    mask = ((Ax+U).abs() > thres)
        #    del thres

        
        Z = (Ax + U) * mask    

        U = U + (Ax - Z)    

        Ax = XXinv.matmul(XY + rho*(Z-U))

    Ac = Z.T / norm2
    return Ac

def find_a(B, Za, Ua, rho, D, Q, E, R, XX, W):
    F = rho*(Za-Ua) + XX.matmul(W).matmul(B.T)
    
    right = Q.T.matmul(F).matmul(R)
    
    div = D.unsqueeze(1).matmul(E.unsqueeze(0)) + rho
    QAR = right / div
    
    A3 = Q.matmul(QAR).matmul(R.T)
    return A3


def get_at(XX, W, A, B):
    mask = (A != 0)
    
    norm2 = torch.diag(XX).sqrt() + 1e-8
    XXn = XX / norm2 / norm2.unsqueeze(1)
    
    #XXn += torch.diag(XXn.diag()*0 + 0.01*XXn.diag().mean())
    
    Wn = W * norm2.unsqueeze(1)
    #XY = XY / norm2.unsqueeze(1)
    
    normB = torch.norm(B, dim=1) + 1e-8
    Bn = B / normB.unsqueeze(1)
    BBn = Bn.matmul(Bn.T)
    #BBn += torch.diag(BBn.diag()*0 + 0.01*BBn.diag().mean())
    #print(BBn.diag())
    
    D, Q = torch.linalg.eigh(XXn)
    E, R = torch.linalg.eigh(BBn)
    
    #print(D, E)
    
    Za = A * norm2.unsqueeze(1) * normB
    Ua = torch.zeros_like(Za)
    rho = 1
    
    for itt in range(20):
        A2 = find_a(Bn, Za, Ua, rho, D, Q, E, R, XXn, Wn)
        Wx = (A2 / norm2.unsqueeze(1) / normB).matmul(B)
        #print(itt)
        #print("   errx", (Wx - W).T.matmul(XX).matmul((Wx - W)).diag().sum().item())
        
        Za = (A2 + Ua) * mask
        Ua = Ua + (A2 - Za)
        Wx = (Za / norm2.unsqueeze(1) / normB).matmul(B)
        #print("   errz", (Wx - W).T.matmul(XX).matmul((Wx - W)).diag().sum().item())
    return Za / norm2.unsqueeze(1) / normB

def factorize(XX, W, sp, l_prev=None):
    W = W.detach().float()
    asp = max(0.16, sp/2)
    W2, Ab, Bb = factorizef(W, XX, sp=sp, asp=asp, l_prev=l_prev)
    print("err_prefin", (W2 - W).matmul(XX).matmul((W2 - W).T).diag().sum().item())
    Ac = finalize(XX, W, Ab, Bb)
    W3 = Ac.matmul(Bb)
    assert W3.shape == W.shape
    print("err_fin   ", (W3 - W).matmul(XX).matmul((W3 - W).T).diag().sum().item())
    #fin_b(XX, W, Ac, Bb)
    
    Bc = get_at(XX, W.T, Bb.T, Ac.T).T
    
    W4 = Ac.matmul(Bc)
    assert W3.shape == W.shape
    print("err_fin2   ", (W4 - W).matmul(XX).matmul((W4 - W).T).diag().sum().item())
    
    print("sparsity check", ((Ac != 0).sum() + (Bc != 0).sum()).item() / W3.numel())
    return W4, (Ac.cpu(), Bc.cpu())

In [11]:
sd_pruned = modelp.state_dict()
out_admm = {}

for n, m in model_orig.named_modules():
    if type(m) == nn.Conv2d and m.weight.shape[1] > 3:
        w_pruned = sd_pruned[n+".weight"].flatten(1)
        sparsity = (w_pruned != 0).sum().item() / w_pruned.numel()
        w_orig = m.weight.flatten(1)
        w_admm, facts = factorize(m.XX, w_orig, sparsity)
        e1 = (w_orig - w_pruned).matmul(m.XX).matmul((w_orig - w_pruned).T).diag().sum().item()
        e2 = (w_orig - w_admm).matmul(m.XX).matmul((w_orig - w_admm).T).diag().sum().item()
        print(n, sparsity, m.weight.shape, 
              e1,
              e2,
              "bad" if e1 < e2 else ""
             )
        out_admm[n] = (w_admm.reshape(w_pruned.shape), facts)
        #m.XX = None
        
for n, m in modelp.named_modules():
    if n in out_admm:
        m.weight.data = out_admm[n][0].reshape(m.weight.shape)
        m.weight.facts = out_admm[n][1]

err_prefin 216261.34375
err_fin    18613.828125
err_fin2    15016.3046875
sparsity check 0.58984375
layer1.0.conv1 0.59033203125 torch.Size([64, 64, 1, 1]) 27640.7734375 15016.3046875 
err_prefin 5138185.5
err_fin    833410.3125
err_fin2    566937.25
sparsity check 0.15003797743055555
layer1.0.conv2 0.1500922309027778 torch.Size([64, 64, 3, 3]) 326139.6875 566937.25 bad
err_prefin 4077767.0
err_fin    1004332.4375
err_fin2    930536.375
sparsity check 0.13494873046875
layer1.0.conv3 0.13507080078125 torch.Size([256, 64, 1, 1]) 683674.4375 930536.375 bad
err_prefin 3806934.75
err_fin    424953.0
err_fin2    418320.1875
sparsity check 0.38726806640625
layer1.0.downsample.0 0.38739013671875 torch.Size([256, 64, 1, 1]) 556299.25 418320.1875 
err_prefin 1717065.25
err_fin    506788.0625
err_fin2    268020.375
sparsity check 0.2286376953125
layer1.1.conv1 0.228759765625 torch.Size([64, 256, 1, 1]) 230354.375 268020.375 bad
err_prefin 6621598.0
err_fin    2514085.0
err_fin2    1669178.5
spars

err_prefin 82400.53125
err_fin    57309.78125
err_fin2    41713.375
sparsity check 0.3486773173014323
layer4.0.conv2 0.34867816501193577 torch.Size([512, 512, 3, 3]) 83596.25 41713.375 
err_prefin 28323.662109375
err_fin    12818.5263671875
err_fin2    12665.55859375
sparsity check 0.47829437255859375
layer4.0.conv3 0.47829627990722656 torch.Size([2048, 512, 1, 1]) 37507.96875 12665.55859375 
err_prefin 49622.78125
err_fin    28954.6171875
err_fin2    28306.537109375
sparsity check 0.3486771583557129
layer4.0.downsample.0 0.3486781120300293 torch.Size([2048, 1024, 1, 1]) 88051.0234375 28306.537109375 
err_prefin 591586.0
err_fin    292809.875
err_fin2    183296.140625
sparsity check 0.3874177932739258
layer4.1.conv1 0.3874197006225586 torch.Size([512, 2048, 1, 1]) 348949.75 183296.140625 
err_prefin 48011.2265625
err_fin    34919.7109375
err_fin2    24890.3515625
sparsity check 0.43046612209743923
layer4.1.conv2 0.4304669698079427 torch.Size([512, 512, 3, 3]) 50902.359375 24890.3515625

In [12]:
test(modelp, testloader)

Evaluating ...
74.44


In [13]:
print('Batchnorm tuning ...')

loss = 0
with torch.no_grad():
    for batch in dataloader:
        loss += run(modelp, batch, loss=True)
print(loss / 1024)

batchnorms = find_layers(modelp, [nn.BatchNorm2d])
for bn in batchnorms.values():
    bn.reset_running_stats()
    bn.momentum = .1
modelp.train()
with torch.no_grad():
    i = 0
    while i < 100:
        for batch in dataloader:
            if i == 100:
                break
            print('%03d' % i)
            run(modelp, batch)
            i += 1
modelp.eval()

loss = 0
with torch.no_grad():
    for batch in dataloader:
        loss += run(modelp, batch, loss=True)
print(loss / 1024)

Batchnorm tuning ...
0.8818617463111877
000
001
002
003
004
005
006
007
008
009
010
011
012
013
014
015
016
017
018
019
020
021
022
023
024
025
026
027
028
029
030
031
032
033
034
035
036
037
038
039
040
041
042
043
044
045
046
047
048
049
050
051
052
053
054
055
056
057
058
059
060
061
062
063
064
065
066
067
068
069
070
071
072
073
074
075
076
077
078
079
080
081
082
083
084
085
086
087
088
089
090
091
092
093
094
095
096
097
098
099
0.8327280580997467


In [14]:
test(modelp, testloader)

Evaluating ...
74.95


In [15]:
total_nz = 0
total = 0

for n, p in modelp.named_parameters():
    if "weight" not in n or "bn" in n:
        continue
    
    if hasattr(p, "facts"):
        ff = (p.facts[0] != 0).sum().item() + (p.facts[1] != 0).sum().item() #(p != 0).sum().item()
        total_nz += ff
        print(n, ff / p.numel(), "ff")
    else:
        total_nz += (p != 0).sum().item()
        print(n, (p != 0).sum().item() / p.numel())
    total += p.numel()
    
total_nz, total

conv1.weight 0.28241921768707484
layer1.0.conv1.weight 0.58984375 ff
layer1.0.conv2.weight 0.15003797743055555 ff
layer1.0.conv3.weight 0.13494873046875 ff
layer1.0.downsample.0.weight 0.38726806640625 ff
layer1.0.downsample.1.weight 1.0
layer1.1.conv1.weight 0.2286376953125 ff
layer1.1.conv2.weight 0.135009765625 ff
layer1.1.conv3.weight 0.28228759765625 ff
layer1.2.conv1.weight 0.12139892578125 ff
layer1.2.conv2.weight 0.18522135416666666 ff
layer1.2.conv3.weight 0.28228759765625 ff
layer2.0.conv1.weight 0.16668701171875 ff
layer2.0.conv2.weight 0.1500786675347222 ff
layer2.0.conv3.weight 0.254150390625 ff
layer2.0.downsample.0.weight 0.15007781982421875 ff
layer2.0.downsample.1.weight 1.0
layer2.1.conv1.weight 0.150054931640625 ff
layer2.1.conv2.weight 0.07975260416666667 ff
layer2.1.conv3.weight 0.0984344482421875 ff
layer2.2.conv1.weight 0.1852569580078125 ff
layer2.2.conv2.weight 0.09846327039930555 ff
layer2.2.conv3.weight 0.3486328125 ff
layer2.3.conv1.weight 0.071746826171875 

(10194502, 25506752)

In [16]:
modelp.layer3[4].conv2

Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

In [17]:
modelp.layer3[4].conv2.weight.facts[0].shape, modelp.layer3[4].conv2.weight.facts[1].shape

(torch.Size([256, 256]), torch.Size([256, 2304]))

In [18]:
f1 = nn.Conv2d(256, 256, 1, bias=False)
f2 = nn.Conv2d(256, 256, 3, padding=1, bias=False)
f1.weight.data = modelp.layer3[4].conv2.weight.facts[0].reshape(f1.weight.shape)
f2.weight.data = modelp.layer3[4].conv2.weight.facts[1].reshape(f2.weight.shape)
f1 = f1.cuda()
f2 = f2.cuda()
f1.weight.device, f2.weight.device

(device(type='cuda', index=0), device(type='cuda', index=0))

In [19]:
xx = torch.randn(10,256,15,15).cuda()
with torch.amp.autocast("cuda"):
    o1 = modelp.layer3[4].conv2(xx)
    ot = f2(xx)
    o2 = f1(ot)
    print((o1 - o2).abs().max())

tensor(0.0020, device='cuda:0', dtype=torch.float16, grad_fn=<MaxBackward1>)


In [20]:
def boo(m, i, o):
    print("boo", i[0].shape)

for n, m in modelp.named_modules():
    if "Bottleneck" in str(type(m)):
        print(n)
        if hasattr(m, "conv1b"):
            m.conv1 = m.conv1b
        ff = m.conv1.weight.facts
        m.conv1b = m.conv1
        m.conv1 = nn.Sequential(
            nn.Conv2d(m.conv1b.in_channels, m.conv1b.out_channels, 1, bias=False),
            nn.Conv2d(m.conv1b.out_channels, m.conv1b.out_channels, 1, bias=False)
        )
        m.conv1[0].weight.data = ff[1].reshape(m.conv1[0].weight.shape)
        m.conv1[1].weight.data = ff[0].reshape(m.conv1[1].weight.shape)
        m.conv1.cuda()
        
        #print(n)
        if hasattr(m, "conv2b"):
            m.conv2 = m.conv2b
        ff = m.conv2.weight.facts
        m.conv2b = m.conv2
        m.conv2 = nn.Sequential(
            nn.Conv2d(m.conv2b.in_channels, m.conv2b.out_channels, 3, padding=1, stride=m.conv2b.stride, bias=False),
            nn.Conv2d(m.conv2b.out_channels, m.conv2b.out_channels, 1, bias=False)
        )
        #m.conv2[0].register_forward_hook(boo)
        m.conv2[0].weight.data = ff[1].reshape(m.conv2[0].weight.shape)
        m.conv2[1].weight.data = ff[0].reshape(m.conv2[1].weight.shape)
        m.conv2.cuda()
        
        if hasattr(m, "conv3b"):
            m.conv3 = m.conv3b
        ff = m.conv3.weight.facts
        m.conv3b = m.conv3
        m.conv3 = nn.Sequential(
            nn.Conv2d(m.conv3b.in_channels, m.conv3b.in_channels, 1, bias=False),
            nn.Conv2d(m.conv3b.in_channels, m.conv3b.out_channels, 1, bias=False)
        )
        m.conv3[0].weight.data = ff[1].reshape(m.conv3[0].weight.shape)
        m.conv3[1].weight.data = ff[0].reshape(m.conv3[1].weight.shape)
        m.conv3.cuda()

layer1.0
layer1.1
layer1.2
layer2.0
layer2.1
layer2.2
layer2.3
layer3.0
layer3.1
layer3.2
layer3.3
layer3.4
layer3.5
layer4.0
layer4.1
layer4.2


In [21]:
test(modelp, testloader)

Evaluating ...
74.95


In [66]:
modelp.layer2[0].conv1

Sequential(
  (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)

In [67]:
modelp.layer2[0].conv2

Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

In [68]:
modelp.layer2[0].conv3

Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)