In [1]:
import torch.nn.functional as F
import torch.nn as nn
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

import time
import numpy as np
import math
import pandas as pd
from matplotlib import pyplot as plt

import tensorflow as tf

import resnet

In [2]:
im_size = 224

def filter_input_sizes(model):
    param_dict = {}
    size_dict = {}
    i = 0
    for (name, param) in model.named_parameters():
        if 'conv' in name:
            param_clone = param.permute(2, 3, 1, 0).contiguous().clone()
            param_clone = param_clone.detach().cpu().numpy()
            param_dict[name] = tf.convert_to_tensor(param_clone)
            if i < 12:
                size_dict[name] = im_size
            elif i >= 12 and i < 22:
                size_dict[name] = im_size//2
            elif i>=22 and i < 32:
                size_dict[name] = im_size//4
            elif i>=32 and i < 42:
                size_dict[name] = im_size//8
            elif i>=42 and i < 52:
                size_dict[name] = im_size//16
            elif i>=52:
                size_dict[name] = im_size//32
            i = i + 1
    return param_dict, size_dict

def Clip_OperatorNorm(conv, inp_shape, clip_to):
    conv_shape = conv.get_shape().as_list()
    conv_tr = tf.cast(tf.transpose(conv, perm=[2, 3, 0, 1]), tf.complex64)
    padding = tf.constant([[0, 0], [0, 0],
                         [0, inp_shape[0] - conv_shape[0]],
                         [0, inp_shape[1] - conv_shape[1]]])
    transform_coeff = tf.signal.fft2d(tf.pad(conv_tr, padding))
    D, U, V = tf.linalg.svd(tf.transpose(transform_coeff, perm = [2, 3, 0, 1]))
    
    norm = tf.reduce_max(D)
    D_clipped = tf.cast(tf.minimum(D, clip_to), tf.complex64)
    clipped_coeff = tf.matmul(U, tf.matmul(tf.linalg.diag(D_clipped),
                                         V, adjoint_b=True))
    clipped_conv_padded = tf.math.real(tf.signal.ifft2d(
        tf.transpose(clipped_coeff, perm=[2, 3, 0, 1])))
    return tf.slice(tf.transpose(clipped_conv_padded, perm=[2, 3, 0, 1]),
                  [0] * len(conv_shape), conv_shape), norm

def clip_all_convs(param_dict, size_dict):
    for name, conv_filter in param_dict.items():
        inp_shape = (size_dict[name], size_dict[name])
        Clip_OperatorNorm(conv_filter, inp_shape, 1.0)

In [3]:
def SVD_Conv_Tensor(conv, inp_shape):
    conv_tr = tf.cast(tf.transpose(conv, perm=[2, 3, 0, 1]), tf.complex64)
    conv_shape = conv.get_shape().as_list()
    padding = tf.constant([[0, 0], [0, 0],
                         [0, inp_shape[0] - conv_shape[0]],
                         [0, inp_shape[1] - conv_shape[1]]])
    transform_coeff = tf.signal.fft2d(tf.pad(conv_tr, padding))
    
    transform_matrix = tf.transpose(transform_coeff, perm = [2, 3, 0, 1])
    singular_values = tf.linalg.norm(transform_matrix, axis=(2, 3))
    return singular_values
        
def calc_all_convs(param_dict, size_dict):
    for name, conv_filter in param_dict.items():
        inp_shape = (size_dict[name], size_dict[name])
        SVD_Conv_Tensor(conv_filter, inp_shape)

In [4]:
model = torch.nn.DataParallel(resnet.__dict__['resnet32']())
# model = models.resnet34(pretrained=False)
model = model.cuda()

train_dataset = datasets.FakeData(size=10000, image_size=(3, im_size, im_size), 
                                  num_classes=10, transform=transforms.ToTensor())

with torch.no_grad():
    param_dict, size_dict = filter_input_sizes(model)

In [5]:
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=32, shuffle=True,
    num_workers=8, pin_memory=False)
optimizer = torch.optim.SGD(model.parameters(), 0.1,
                            momentum=0.9,
                            weight_decay=0.)
criterion = nn.CrossEntropyLoss().cuda()

In [9]:
from spectral_norm import ConvFilterNorm

class SpectralModule(nn.Module):
    def __init__(self, model):
        super(SpectralModule, self).__init__()
        self.spectra = nn.ModuleList()
        for name, param in model.named_parameters():
            if len(param.shape) > 2:
                self.spectra.append(ConvFilterNorm(param))

    def forward(self):
        sigma_list = []
        for s in self.spectra:
            sigma_list.append(s())
        sigma_arr = torch.Tensor(sigma_list)
        return sigma_arr

start_time = time.time()
spectral_module = SpectralModule(model)
for i, (X, y) in enumerate(train_loader):
    X = X.cuda()
    y = y.cuda()
    
    output = model(X)
    ce_loss = criterion(output, y)
    all_sigma = spectral_module()
    spectral_loss = all_sigma.sum()#spectralnorm_sum(model, spectral_dict)
    loss = ce_loss + 8e-4*spectral_loss

    # compute gradient and do SGD step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
print('{:.4f}'.format(time.time() - start_time))

54.3857


In [7]:
start_time = time.time()
for i, (X, y) in enumerate(train_loader):
    X = X.cuda()
    y = y.cuda()
    
    output = model(X)
    loss = criterion(output, y)

    # compute gradient and do SGD step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
print('{:.4f}'.format(time.time() - start_time))

49.4059


In [8]:
start_time = time.time()
for i, (X, y) in enumerate(train_loader):
    X = X.cuda()
    y = y.cuda()
    
    output = model(X)
    loss = criterion(output, y)

    # compute gradient and do SGD step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if i % 100 == 0:
        clip_all_convs(param_dict, size_dict)
print('{:.4f}'.format(time.time() - start_time))

156.7245
