In [1]:
import torch
import torch.nn.functional as F 

import numpy as np
from scipy.stats import truncnorm 

In [2]:
def l2_normalize(x):
    return x / (torch.sqrt(torch.sum(x**2.)) + 1e-9)

def trunc(shape):
    return torch.from_numpy(truncnorm.rvs(0.5, 1, size=shape)).float()

def linear_lipschitz(w, power_iters=5):
    rand_x = trunc(w.shape[1]).type_as(w)
    for _ in range(power_iters):
        x = l2_normalize(rand_x)
        x_p = F.linear(x, w) 
        rand_x = F.linear(x_p, w.T)

    lc = torch.sqrt(torch.abs(torch.sum(w @ x)) / (torch.abs(torch.sum(x)) + 1e-9)).data.cpu().item()
    return lc

def conv_lipschitz(w, in_channels, stride=1, padding=0, power_iters=5):
    rand_x = trunc((1, in_channels, 32, 32)).type_as(w)
    for _ in range(power_iters):
        x = l2_normalize(rand_x)
        x_p = F.conv2d(x, w, 
                       stride=stride, 
                       padding=padding) 
        rand_x = F.conv_transpose2d(x_p, w, 
                                    stride=stride, 
                                    padding=padding)

    Wx = F.conv2d(rand_x, w, 
                  stride=stride, padding=padding)
    lc = torch.sqrt(torch.abs(torch.sum(Wx**2.)) / 
                    (torch.abs(torch.sum(rand_x**2.)) + 1e-9)).data.cpu().item()
    return lc

In [4]:
weight = torch.load('../trained_models/CIFAR10_cifar_model_large_Unstructured_BiC_FullClaim_0to0.05_accelerate/prune/0--k-0.0500_trainer-bilevel_epochs-100_arch-cifar_model_large/checkpoint/model_best.pth.tar', map_location='cpu')

In [8]:
for k, v in weight['state_dict'].items():
    if "popup" in k:
        if 'conv' in k or 'linear' in k or 'features' in k or 'classifier' in k:
            w = weight['state_dict'][k.replace('popup_scores', 'weight')]
            m = torch.clamp(weight['state_dict'][k], 0, 0.05)
            final = w * m
            if 'conv' in k or 'features' in k:
                lc = conv_lipschitz(final, v.shape[1], power_iters=10)
                lc_m = conv_lipschitz(m, v.shape[1], power_iters=10)
                lc_org = conv_lipschitz(w, v.shape[1], power_iters=10)
            else:
                lc = linear_lipschitz(final)
                lc_m = linear_lipschitz(m)
                lc_org = linear_lipschitz(w)

            print (k, lc, lc_m ,lc_org)

conv1.popup_scores 0.7940506339073181 0.2443625032901764 21.589733123779297
conv2.popup_scores 0.5392088890075684 0.6842237710952759 19.000612258911133
conv3.popup_scores 0.3050970435142517 0.9473257660865784 11.75218677520752
conv4.popup_scores 0.42373237013816833 1.821757435798645 12.645577430725098
linear1.popup_scores 0.2104976773262024 0.708825409412384 1.2051116228103638
linear2.popup_scores 0.32855159044265747 1.1262460947036743 0.3829747438430786
linear3.popup_scores 0.4583682715892792 0.33438408374786377 2.0023016929626465


In [12]:
for k, v in weight['state_dict'].items():
    if len(weight['state_dict'][k].shape) > 1:
        if 'conv' in k or 'linear' in k or 'features' in k or 'classifier' in k:
            w = weight['state_dict'][k]
            if 'conv' in k or 'features' in k:
                lc_org = conv_lipschitz(w, v.shape[1], power_iters=10)
            else:
                lc_org = linear_lipschitz(w)

            print (k,lc_org)

conv1.weight 8.175479888916016
conv2.weight 11.387417793273926
conv3.weight 11.928955078125
conv4.weight 18.98225975036621
linear1.weight 1.4400274753570557
linear2.weight 1.990193486213684
linear3.weight 0.001821956830099225


In [None]:
conv1.popup_scores 4.75579309463501 8.144059181213379
conv2.popup_scores 2.5422472953796387 6.285704612731934
conv3.popup_scores 1.517775297164917 7.100604057312012
conv4.popup_scores 1.4837596416473389 9.892857551574707
linear1.popup_scores 0.3069261610507965 0.9017798900604248
linear2.popup_scores 0.5078356862068176 1.025170087814331
linear3.popup_scores 0.6284152865409851 1.3841636180877686

In [None]:
conv1.popup_scores 4.960141181945801 9.786927223205566
conv2.popup_scores 2.6566529273986816 6.438645362854004
conv3.popup_scores 1.6294360160827637 7.351409912109375
conv4.popup_scores 1.5126166343688965 9.763284683227539
linear1.popup_scores 0.3041444420814514 0.8431172370910645
linear2.popup_scores 0.5116637349128723 0.9854820370674133
linear3.popup_scores 0.5906422734260559 1.1100908517837524

In [None]:
conv1.popup_scores 0.8496606945991516 12.465849876403809
conv2.popup_scores 0.7736638784408569 11.91736888885498
conv3.popup_scores 0.5257560014724731 9.201114654541016
conv4.popup_scores 0.769010066986084 9.800241470336914
linear1.popup_scores 0.2583720088005066 1.010354995727539
linear2.popup_scores 0.3878075182437897 0.6782602071762085
linear3.popup_scores 0.49336615204811096 1.8147735595703125

In [None]:
conv1.popup_scores 0.7940255403518677 21.616945266723633
conv2.popup_scores 0.5391982197761536 18.976423263549805
conv3.popup_scores 0.30509668588638306 11.75127124786377
conv4.popup_scores 0.42373186349868774 12.634675979614258
linear1.popup_scores 0.21049615740776062 1.2048084735870361
linear2.popup_scores 0.3285246193408966 0.4244447946548462
linear3.popup_scores 0.45802968740463257 2.0037381649017334