In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.nn.functional import conv2d

import numpy as np
import math
import os
import time

import tensorflow as tf

resnet18 = models.resnet18(pretrained=True).cuda()

def singular_values_tf(conv, inp_shape):
    start_time = time.time()
    conv_tr = tf.cast(tf.transpose(conv, perm=[2, 3, 0, 1]), tf.complex64)
    conv_shape = conv.get_shape().as_list()

    padding = tf.constant([[0, 0], [0, 0],
                         [0, inp_shape[0] - conv_shape[0]],
                         [0, inp_shape[1] - conv_shape[1]]])
    transform_coeff = tf.signal.fft2d(tf.pad(conv_tr, padding))

    transform_coeff_perm = tf.transpose(transform_coeff, perm = [2, 3, 0, 1])
    norms = tf.norm(transform_coeff_perm, ord=2, axis=(2, 3))
    max_norm = tf.math.reduce_max(norms)
    total_time = time.time() - start_time
    return total_time, max_norm

def singular_values_np(conv, inp_shape):
    start_time = time.time()
    transform_coeff = np.fft.fft2(conv, inp_shape, axes=[0, 1])
    
    norms = np.linalg.norm(transform_coeff, ord=2, axis=(2, 3))
    total_time = time.time() - start_time
    return total_time, np.amax(norms)

def l2_normalize(tensor, eps=1e-12):
    norm = float(torch.sqrt(torch.sum(tensor * tensor)))
    norm = max(norm, eps)
    ans = tensor / norm
    return ans

def real_spectral_norm(conv_filter, shape):
    start_time = time.time()
    H, W = shape
    c_out = conv_filter.shape[0]
    c_in = conv_filter.shape[1]
    pad_size = (conv_filter.shape[2] - 1)//2
    u = l2_normalize(conv_filter.new_empty(1, c_out, H, W).normal_(0, 1))
    v = l2_normalize(conv_filter.new_empty(1, c_in, H, W).normal_(0, 1))
    for _ in range(50):
        v.data = l2_normalize(F.conv_transpose2d(u.data, conv_filter, padding=pad_size))
        u.data = l2_normalize(F.conv2d(v, conv_filter, padding=pad_size))
    sigma = torch.sum(u * conv2d(v, conv_filter, padding=pad_size))
    total_time = time.time() - start_time
    return total_time, sigma

def our_bounds_np(conv_filter):
    start_time = time.time()
    out_ch, in_ch, h, w = conv_filter.shape
    
    permute1 = np.transpose(conv_filter, axes=[0, 2, 1, 3])
    matrix1 = np.reshape(permute1, [out_ch*h, in_ch*w])
    norm1 = math.sqrt(h*w)*np.linalg.norm(matrix1, ord=2, axis=(0, 1))

    permute2 = np.transpose(conv_filter, axes=[0, 3, 1, 2])
    matrix2 = np.reshape(permute2, [out_ch*w, in_ch*h])
    norm2 = math.sqrt(h*w)*np.linalg.norm(matrix2, ord=2, axis=(0, 1))

    permute3 = conv_filter
    matrix3 = np.reshape(permute3, [out_ch, in_ch*h*w])
    norm3 = math.sqrt(h*w)*np.linalg.norm(matrix3, ord=2, axis=(0, 1))

    permute4 = np.transpose(conv_filter, axes=[0, 2, 3, 1])
    matrix4 = np.reshape(permute4, [out_ch*h*w, in_ch])
    norm4 = math.sqrt(h*w)*np.linalg.norm(matrix4, ord=2, axis=(0, 1))
    
    norm_tensor = np.stack([norm1, norm2, norm3, norm4], axis=0)
    min_norm = np.amin(norm_tensor)
    total_time = time.time() - start_time
    return total_time, min_norm

def our_bounds_tf(conv_filter):
    start_time = time.time()
    out_ch, in_ch, h, w = conv_filter.shape
    
    permute1 = tf.transpose(conv_filter, perm=[0, 2, 1, 3])
    matrix1 = tf.reshape(permute1, [out_ch*h, in_ch*w])
    norm1 = math.sqrt(h*w)*tf.norm(matrix1, ord=2, axis=(0, 1))

    permute2 = tf.transpose(conv_filter, perm=[0, 3, 1, 2])
    matrix2 = tf.reshape(permute2, [out_ch*w, in_ch*h])
    norm2 = math.sqrt(h*w)*tf.norm(matrix2, ord=2, axis=(0, 1))

    permute3 = conv_filter
    matrix3 = tf.reshape(permute3, [out_ch, in_ch*h*w])
    norm3 = math.sqrt(h*w)*tf.norm(matrix3, ord=2, axis=(0, 1))

    permute4 = tf.transpose(conv_filter, perm=[0, 2, 3, 1])
    matrix4 = tf.reshape(permute4, [out_ch*h*w, in_ch])
    norm4 = math.sqrt(h*w)*tf.norm(matrix4, ord=2, axis=(0, 1))
    
    norm_tensor = tf.stack([norm1, norm2, norm3, norm4], axis=0)
    min_norm = tf.reduce_min(norm_tensor)
    total_time = time.time() - start_time
    return total_time, min_norm

def our_bounds_ch(conv_filter, num_iters=50):
    start_time = time.time()
    out_ch, in_ch, h, w = conv_filter.shape
        
    permute1 = torch.transpose(conv_filter, 1, 2)
    matrix1 = permute1.reshape(out_ch*h, in_ch*w)
    u1 = torch.randn(matrix1.shape[1], device='cuda', requires_grad=False)
    v1 = torch.randn(matrix1.shape[0], device='cuda', requires_grad=False)

    permute2 = torch.transpose(conv_filter, 1, 3)
    matrix2 = permute2.reshape(out_ch*w, in_ch*h)
    u2 = torch.randn(matrix2.shape[1], device='cuda', requires_grad=False)
    v2 = torch.randn(matrix2.shape[0], device='cuda', requires_grad=False)

    permute3 = conv_filter
    matrix3 = permute3.reshape(out_ch, in_ch*h*w)
    u3 = torch.randn(matrix3.shape[1], device='cuda', requires_grad=False)
    v3 = torch.randn(matrix3.shape[0], device='cuda', requires_grad=False)

    permute4 = torch.transpose(conv_filter, 0, 1)
    matrix4 = permute4.reshape(in_ch, out_ch*h*w)
    u4 = torch.randn(matrix4.shape[1], device='cuda', requires_grad=False)
    v4 = torch.randn(matrix4.shape[0], device='cuda', requires_grad=False)
    
    for i in range(num_iters):
        v1.data = F.normalize(torch.mv(matrix1.data, u1.data), dim=0)
        u1.data = F.normalize(torch.mv(torch.t(matrix1.data), v1.data), dim=0)

        v2.data = F.normalize(torch.mv(matrix2.data, u2.data), dim=0)
        u2.data = F.normalize(torch.mv(torch.t(matrix2.data), v2.data), dim=0)

        v3.data = F.normalize(torch.mv(matrix3.data, u3.data), dim=0)
        u3.data = F.normalize(torch.mv(torch.t(matrix3.data), v3.data), dim=0)

        v4.data = F.normalize(torch.mv(matrix4.data, u4.data), dim=0)
        u4.data = F.normalize(torch.mv(torch.t(matrix4.data), v4.data), dim=0)

    sigma1 = torch.mv(v1.unsqueeze(0), torch.mv(matrix1, u1))
    sigma2 = torch.mv(v2.unsqueeze(0), torch.mv(matrix2, u2))
    sigma3 = torch.mv(v3.unsqueeze(0), torch.mv(matrix3, u3)) 
    sigma4 = torch.mv(v4.unsqueeze(0), torch.mv(matrix4, u4)) 

    min_norm = math.sqrt(h*w)*(torch.min(torch.min(torch.min(sigma1, sigma2), sigma3), sigma4)).item()
    total_time = time.time() - start_time
    return total_time, min_norm

def conv_power_iteration(conv_filter, u_list=None, v_list=None, num_iters=50):
    start_time = time.time()
    out_ch, in_ch, h, w = conv_filter.shape
    if u_list is None:
        u1 = torch.randn((1, in_ch, 1, w), device='cuda', requires_grad=False)
        u1.data = l2_normalize(u1.data)
        
        u2 = torch.randn((1, in_ch, h, 1), device='cuda', requires_grad=False)
        u2.data = l2_normalize(u2.data)

        u3 = torch.randn((1, in_ch, h, w), device='cuda', requires_grad=False)
        u3.data = l2_normalize(u3.data)

        u4 = torch.randn((out_ch, 1, h, w), device='cuda', requires_grad=False)
        u4.data = l2_normalize(u4.data)
        
    if v_list is None:
        v1 = torch.randn((out_ch, 1, h, 1), device='cuda', requires_grad=False)
        v1.data = l2_normalize(v1.data)
        
        v2 = torch.randn((out_ch, 1, 1, w), device='cuda', requires_grad=False)
        v2.data = l2_normalize(v2.data)

        v3 = torch.randn((out_ch, 1, 1, 1), device='cuda', requires_grad=False)
        v3.data = l2_normalize(v3.data)

        v4 = torch.randn((1, in_ch, 1, 1), device='cuda', requires_grad=False)
        v4.data = l2_normalize(v4.data)

    for i in range(num_iters):
        v1.data = l2_normalize((conv_filter.data*u1.data).sum((1, 3), keepdim=True).data)
        u1.data = l2_normalize((conv_filter.data*v1.data).sum((0, 2), keepdim=True).data)
        
        v2.data = l2_normalize((conv_filter.data*u2.data).sum((1, 2), keepdim=True).data)
        u2.data = l2_normalize((conv_filter.data*v2.data).sum((0, 3), keepdim=True).data)
        
        v3.data = l2_normalize((conv_filter.data*u3.data).sum((1, 2, 3), keepdim=True).data)
        u3.data = l2_normalize((conv_filter.data*v3.data).sum(0, keepdim=True).data)
        
        v4.data = l2_normalize((conv_filter.data*u4.data).sum((0, 2, 3), keepdim=True).data)
        u4.data = l2_normalize((conv_filter.data*v4.data).sum(1, keepdim=True).data)

    sigma1 = torch.sum(conv_filter.data*u1.data*v1.data)
    sigma2 = torch.sum(conv_filter.data*u2.data*v2.data)
    sigma3 = torch.sum(conv_filter.data*u3.data*v3.data)
    sigma4 = torch.sum(conv_filter.data*u4.data*v4.data)

    min_norm = math.sqrt(h*w)*(torch.min(torch.min(torch.min(sigma1, sigma2), sigma3), sigma4)).item()
    total_time = time.time() - start_time
    return total_time, min_norm

In [2]:
for name, param in resnet18.named_parameters():
    if 'conv' in name:
        out_channels, in_channels, H, W = param.shape
        
#         param_clone = param.clone()
        param_tf = tf.convert_to_tensor(param.permute(2, 3, 1, 0).contiguous().clone().detach().cpu().numpy())
    
        if out_channels == 512:
            inp_shape = (28, 28)
        elif out_channels == 256:
            inp_shape = (56, 56)
        elif out_channels == 128:
            inp_shape = (112, 112)
        else:
            inp_shape = (224, 224)

        print(list(param.shape), inp_shape)
        
        our_time, our_bound = our_bounds_ch(param)
        print("our: {:.4f}, {:.4f}".format(our_bound, our_time))

        real_time, real_singular = real_spectral_norm(param, (224, 224))
        print("real: {:.4f}, {:.4f}".format(real_singular, real_time))
        
        tf_time, tf_singular = singular_values_tf(param_tf, inp_shape)
        print("exact tf: {:.4f}, {:.4f}".format(tf_singular, tf_time))
        
        np_time, np_singular = singular_values_np(param_tf.numpy(), inp_shape)
        print("exact np: {:.4f}, {:.4f}".format(np_singular, np_time))

[64, 3, 7, 7] (224, 224)
our: 28.8947, 0.2373
real: 15.8215, 0.0372
Instructions for updating:
This op will be removed after the deprecation date. Please switch to tf.sets.difference().
exact tf: 15.9164, 0.9458
exact np: 15.9164, 0.5979
[64, 64, 3, 3] (224, 224)
our: 9.3270, 0.0344
real: 5.9705, 0.0377
exact tf: 6.0060, 8.3601
exact np: 6.0060, 24.3969
[64, 64, 3, 3] (224, 224)
our: 6.2902, 0.0359
real: 5.3196, 0.0383
exact tf: 5.3414, 9.2406
exact np: 5.3414, 24.3861
[64, 64, 3, 3] (224, 224)
our: 8.7121, 0.0343
real: 6.9658, 0.0372
exact tf: 6.9997, 9.1890
exact np: 6.9997, 24.9708
[64, 64, 3, 3] (224, 224)
our: 5.3942, 0.0346
real: 3.7930, 0.0370
exact tf: 3.8162, 9.1697
exact np: 3.8162, 24.4463
[128, 64, 3, 3] (112, 112)
our: 5.8978, 0.0338
real: 4.6765, 0.0562
exact tf: 4.7070, 3.0651
exact np: 4.7070, 37.0092
[128, 128, 3, 3] (112, 112)
our: 7.2012, 0.0347
real: 5.6888, 0.0888
exact tf: 5.7223, 9.2648
exact np: 5.7223, 31.9692
[128, 128, 3, 3] (112, 112)
our: 6.7763, 0.0351
rea