In [20]:
import torch
import torch.nn.functional as F 

import numpy as np
from scipy.stats import truncnorm 

In [28]:
def l2_normalize(x):
    return x / (torch.sqrt(torch.sum(x**2.)) + 1e-9)

def trunc(shape):
    return torch.from_numpy(truncnorm.rvs(0.5, 1, size=shape)).float()

def linear_lipschitz(w, power_iters=5):
    rand_x = trunc(w.shape[1]).type_as(w)
    for _ in range(power_iters):
        x = l2_normalize(rand_x)
        x_p = F.linear(x, w) 
        rand_x = F.linear(x_p, w.T)

    lc = torch.sqrt(torch.abs(torch.sum(w @ x)) / (torch.abs(torch.sum(x)) + 1e-9)).data.cpu().item()
    return lc

def conv_lipschitz(w, in_channels, stride=1, padding=0, power_iters=5):
    rand_x = trunc((1, in_channels, 32, 32)).type_as(w)
    for _ in range(power_iters):
        x = l2_normalize(rand_x)
        x_p = F.conv2d(x, w, 
                       stride=stride, 
                       padding=padding) 
        rand_x = F.conv_transpose2d(x_p, w, 
                                    stride=stride, 
                                    padding=padding)

    Wx = F.conv2d(rand_x, w, 
                  stride=stride, padding=padding)
    lc = torch.sqrt(torch.abs(torch.sum(Wx**2.)) / 
                    (torch.abs(torch.sum(rand_x**2.)) + 1e-9)).data.cpu().item()
    return lc

In [71]:
weight = torch.load('../trained_models/CIFAR10_VGG19_Hydra/pruned_model_best.pth.tar', map_location='cpu')

In [72]:
for k, v in weight['state_dict'].items():
    if "popup" in k:
        if 'conv' in k or 'linear' in k or 'features' in k or 'classifier' in k:
            w = weight['state_dict'][k.replace('popup_scores', 'weight')]
            m = weight['state_dict'][k]
            final = w * m
            if 'conv' in k or 'features' in k:
                lc = conv_lipschitz(final, v.shape[1], power_iters=10)
                lc_org = conv_lipschitz(w, v.shape[1], power_iters=10)
            else:
                lc = linear_lipschitz(final)
                lc_org = linear_lipschitz(w)

            print (k, lc, lc_org)

features.0.popup_scores 31.61339569091797 6.366943836212158
features.3.popup_scores 2.0482726097106934 4.719913959503174
features.7.popup_scores 1.0981258153915405 5.389538288116455
features.10.popup_scores 0.5951802730560303 6.607568740844727
features.14.popup_scores 0.4345148801803589 4.906081676483154
features.17.popup_scores 0.3194398880004883 6.869846820831299
features.20.popup_scores 0.2904692590236664 7.108527183532715
features.24.popup_scores 0.2080521583557129 4.668399810791016
features.27.popup_scores 0.08053360134363174 5.3613996505737305
features.30.popup_scores 0.05451308190822601 5.0163350105285645
features.34.popup_scores 0.020527759566903114 2.9826128482818604
features.37.popup_scores 0.01119234599173069 2.5878677368164062
features.40.popup_scores 0.01152096502482891 3.3716065883636475
classifier.0.popup_scores 0.2551078796386719 0.7238613367080688
classifier.2.popup_scores 3.691840410232544 1.3103415966033936
classifier.4.popup_scores 5.800979137420654 0.55902099609375

In [68]:
for k, v in weight['state_dict'].items():
    if len(weight['state_dict'][k].shape) > 1:
        if 'conv' in k or 'linear' in k or 'features' in k or 'classifier' in k:
            w = weight['state_dict'][k]
            if 'conv' in k or 'features' in k:
                lc_org = conv_lipschitz(w, v.shape[1], power_iters=10)
            else:
                lc_org = linear_lipschitz(w)

            print (k,lc_org)

features.0.weight 3.622633457183838
features.3.weight 3.627657890319824
features.7.weight 3.4148292541503906
features.10.weight 4.630601406097412
features.14.weight 3.592601776123047
features.17.weight 4.500280857086182
features.20.weight 4.9865498542785645
features.24.weight 3.8369691371917725
features.27.weight 3.7561752796173096
features.30.weight 3.9328665733337402
features.34.weight 3.051018714904785
features.37.weight 3.801638126373291
features.40.weight 6.44829797744751
classifier.0.weight 0.896481454372406
classifier.2.weight 1.234708547592163
classifier.4.weight 0.041216135025024414
