In [1]:
import os
import torch
import networks.torch_vgg as vgg
import torch.nn as nn
import numpy as np
import math
import copy

In [2]:
from main import validate
from datasets import dataprep
from gsp_model import GSP_Model


In [3]:
class Args:
    arch = 'vgg19_bn'
    dataset='cifar10'
    workers = 4
    epochs=160
    start_epoch=0
    batch_size = 128
    lr = 0.1
    momentum = 0.9
    weight_decay=1e-4
    print_freq = 50
    resume = False
    evaluate = False
    pretrained = False
    half = False
    exp_name = 'gsp_test'
    
    gpu=None
    logdir = '/logdir'
    gsp_training = True 
    gsp_sps = 0.8
    scheduled_sps_run = True
    proj_filters = False
    proj_model = False
    gsp_int = 150
    gsp_start_ep = -1
    finetune = False
    finetune_sps = 0.9
    
    filelogger = None

global args, best_acc1
args = Args

In [4]:
args.resume = "/private/home/riohib/explore/gsp_cifar/cifar/model_best.pth.tar"

In [6]:
# Load Model
if args.dataset == 'cifar10': num_classes = 10
if args.dataset == 'cifar100': num_classes = 100

if 'vgg' in args.arch: model = vgg.__dict__["vgg19_bn"](num_classes=num_classes)
model = torch.nn.DataParallel(model)

model.cuda()

train_loader, val_loader = dataprep.get_data_loaders(dataset=args.dataset, args=args)

Files already downloaded and verified


In [7]:
# model_gsp = GSP_Model(model)
# model_gsp.logger = args.filelogger # Initiate Logger
criterion = nn.CrossEntropyLoss().cuda()

optimizer = torch.optim.SGD(model.parameters(), args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

In [8]:
# optionally resume from a checkpoint
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch'] if not args.finetune else 0
best_acc1 = checkpoint['best_acc1'] if not args.finetune else 0
model.load_state_dict(checkpoint['state_dict'])


<All keys matched successfully>

In [9]:
fft_model = copy.deepcopy(model)

In [10]:
validate(val_loader, model, criterion, args)


 Validation Acc@1: 93.090 



93.09

In [11]:
validate(val_loader, fft_model, criterion, args)


 Validation Acc@1: 93.090 



93.09

### Transforms

In [11]:
params_d = dict()
names = list()
for name, params in model.named_parameters():
    names.append(name)
    params_d[name] = params

In [13]:
names[4]

'module.features.3.weight'

In [28]:
# params_d[names[4]]

In [19]:
param_fft = torch.fft.fftn(params_d[names[4]])

In [27]:
# torch.isclose(torch.fft.ifftn(param_fft).real, params_d[names[4]])

In [34]:
module_d = dict()
for name, module in model.named_modules():
    if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
        module_d[name] = module.weight.data

In [46]:
weight_tensor = module_d[name].flatten()
numelem = weight_tensor.shape[0]
topk_weights = numelem * 0.1
print(topk_weights)

512.0


In [25]:
module_d[name].shape

torch.Size([10, 512])

In [28]:
outt = threshold_tensor(module_d[name], sparsity=0.9)
outt.shape

torch.Size([10, 512])

In [12]:
def threshold_tensor(in_tensor, sparsity=0.9):
    t_shape = in_tensor.shape
    tensor = in_tensor.flatten()
    w_sps_num =  len(tensor) * sparsity
    sorted_weights, _ = torch.sort(tensor.abs())
    threshold = sorted_weights[:math.ceil(w_sps_num)+1][-1]
    sps_tensor = torch.where(abs(tensor) < threshold, torch.tensor(0.0, device=tensor.device), tensor)
    out_tensor = sps_tensor.reshape(t_shape)
    return out_tensor

def get_mask(in_tensor, sparsity=0.9):
    out_tensor = threshold_tensor(in_tensor, sparsity=sparsity)
    mask = out_tensor > 0.0   
    return mask 

In [50]:
# def model_fft_ifft(model):
fft_module = dict()
ifft_module = dict()
sps_fft = dict()
sps_abs_fft = dict()
sps = 0.5

for name, module in model.named_modules():
    if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
        fft_module[name] = torch.fft.fft(module.weight.data) # apply FFT
        sps_abs_fft[name] = torch.abs(fft_module[name]) # get absolute FFT vales (make real)
        sps_mask = get_mask(sps_abs_fft[name], sparsity=sps) # use the normalized real values for getting the topk mask
        sps_fft_tensor = fft_module[name] * sps_mask # but mask the actual complex FFT values topk
        ifft_module[name] = torch.fft.ifft(sps_fft_tensor).real # IFFT the sparse complex topk and save only the real values


In [51]:
for name, module in fft_model.named_modules():
    if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
        module.weight.data = ifft_module[name]

In [52]:
model.module.classifier[0].weight.data

tensor([[ 6.1433e-03,  2.6787e-03,  8.7375e-03,  ..., -7.8123e-03,
          1.2930e-02, -8.0516e-03],
        [ 1.6031e-03, -1.8802e-03, -2.0471e-03,  ..., -1.6106e-03,
         -1.5061e-03, -7.2228e-04],
        [-1.6352e-03, -2.9582e-03,  4.8631e-04,  ...,  3.9759e-03,
         -5.9812e-03,  1.5399e-02],
        ...,
        [ 3.8525e-03, -2.2186e-03,  8.3077e-03,  ...,  1.6121e-04,
          4.8145e-04,  3.8776e-03],
        [-1.0152e-03, -1.0338e-03,  1.7625e-04,  ...,  5.7871e-04,
         -2.7955e-04, -1.6501e-05],
        [-6.7887e-04,  2.7948e-04, -2.5118e-04,  ..., -1.1105e-03,
          2.9323e-05, -4.7149e-04]], device='cuda:0')

In [53]:
fft_model.module.classifier[0].weight.data

tensor([[ 0.0061,  0.0026,  0.0087,  ..., -0.0077,  0.0130, -0.0080],
        [ 0.0012, -0.0015, -0.0011,  ..., -0.0014, -0.0015, -0.0004],
        [-0.0016, -0.0030,  0.0005,  ...,  0.0040, -0.0060,  0.0154],
        ...,
        [ 0.0038, -0.0022,  0.0083,  ...,  0.0001,  0.0006,  0.0038],
        [-0.0004, -0.0002, -0.0004,  ..., -0.0003, -0.0004, -0.0003],
        [-0.0010, -0.0004, -0.0009,  ..., -0.0010, -0.0003, -0.0009]],
       device='cuda:0')

In [54]:
validate(val_loader, fft_model, criterion, args)


 Validation Acc@1: 90.680 



90.68

In [55]:
validate(val_loader, model, criterion, args)


 Validation Acc@1: 93.090 



93.09