In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse
import os
import shutil
import time

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import resnet
import numpy as np

from datetime import datetime

import sys 
sys.path.append('..')
from utils_gsp.logger import Logger
from utils_gsp import sps_tools
from gsp_model import GSP_Model

from torch.utils.tensorboard import SummaryWriter

from main import train, validate, accuracy, save_checkpoint, setup_experiment, gsp_sparse_training, AverageMeter

model_names = sorted(name for name in resnet.__dict__
    if name.islower() and not name.startswith("__")
                     and name.startswith("resnet")
                     and callable(resnet.__dict__[name]))

print(model_names)

['resnet110', 'resnet1202', 'resnet20', 'resnet32', 'resnet44', 'resnet56']
['resnet110', 'resnet1202', 'resnet20', 'resnet32', 'resnet44', 'resnet56']


In [3]:
class Args:
    arch = 'resnet20'
    workers = 4
    epochs=160
    start_epoch=0
    batch_size = 128
    lr = 0.1
    momentum = 0.9
    weight_decay=1e-4
    print_freq = 50
    resume = False
    evaluate = False
    pretrained = False
    half = False
    exp_name = 'gsp_test'

    gpu=None
    logdir = '/logdir'
    gsp_training = True 
    gsp_sps = 0.8
    gsp_int = 150
    gsp_start_ep = -1
    finetuning = False
    finetune_sps = 0.9


global args, best_acc1
args = Args
# writer = SummaryWriter()

In [4]:
args.resume = './results/gspS80/model_best.pth.tar' # LOAD MODEL

In [31]:
# torch.manual_seed(0)
# Setup the experiment
flogger = setup_experiment(args)
args.logger.log_cmd_arguments(args)

model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
model.cuda()

# optionally resume from a checkpoint
if args.resume:
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_acc1 = checkpoint['best_acc1']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})"
                .format(args.evaluate, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

cudnn.benchmark = True

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])

train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, 4),
        transforms.ToTensor(),
        normalize,
    ]), download=True),
    batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=128, shuffle=False,
    num_workers=args.workers, pin_memory=True)


# ----------------------- Make a GSP Model -----------------------
model_gsp = GSP_Model(model)

print(f"The sparsity of the model is: {model_gsp.get_model_sps():.2f}")
args.writer = SummaryWriter(log_dir=f'results/{args.exp_name}/runs/{datetime.now().strftime("%m-%d_%H:%M")}')

# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()

optimizer = torch.optim.SGD(model.parameters(), args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

# ============================ Setup GSP model ============================
if args.gsp_training:
    gsp_sparse_training(model_gsp, train_loader, args)
    flogger.info(15*"*" + " Model will be trained with GSP Sparsity!! " + 15*"*" )

# ============== PRUNE the model and Register Mask ==============
if args.finetuning:
    flogger.info(15*"*" + " Model will be finetuned!! " + 15*"*")
    sps_tools.prune_with_sps(model_gsp.model.module, sparsity = args.finetune_sps)
    masks_d, masks_l = sps_tools.get_conv_linear_mask(model_gsp.model.module)
    model_gsp.register_pre_hook_mask(masks_d) # This for forward pre hook mask registration
    # model_gsp.register_hook_mask(model.module, masks_l) # Does not work with DDP


# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
#                 milestones=[80, 120], last_epoch=args.start_epoch - 1)

['resnet110', 'resnet1202', 'resnet20', 'resnet32', 'resnet44', 'resnet56']

 input logdir: ./results/gsp_test/logdir 


 RETURNED FILE LOGGER! 

=> loading checkpoint './results/gspS80/model_best.pth.tar'
=> loaded checkpoint 'False' (epoch 185)
Files already downloaded and verified
The sparsity of the model is: 3.82
Current Epoch: 185


In [32]:

train(train_loader, model_gsp, criterion, optimizer, 0, args, gsp_mode=False)
# validate(val_loader, model, criterion, args)



Applying GSP!! GSP_Mode: True
Applying GSP!! GSP_Mode: True
Applying GSP!! GSP_Mode: True


In [26]:
model_gsp.get_model_sps()

90.20384419523546

In [25]:
# model_gsp.force_apply_gsp(0.80)
model_gsp.prune_and_mask_model(sps=0.9)

Pruning with threshold : 0.05073028802871704 for layer module.linear.bias


In [None]:
sps_tools.prune_with_sps(model, 0.9)
masks_d ,_ = sps_tools.get_conv_linear_mask(model)


In [19]:
for name, module in model_gsp.model.named_modules():
    if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
        module.weight.data = module.weight.data * model_gsp.masks[module]