# VGG-11 test results
### float model: 92.13
### binary model: 89.21

In [None]:
import sys
import os
sys.path.append('./util/')

from get_model import get_model, get_model2
from tools import progress_bar
import torch.nn as nn
from collections import OrderedDict

from math import ceil
import torch
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
%matplotlib inline

# from utils.input_pipeline import get_image_folders
from train_net import train, optimization_step_float, train_eta
from quant import optimization_step_eta, optimization_step

torch.cuda.is_available()
import torch.backends.cudnn as cudnn

print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='../pytorch-tutorial/data', train=True, 
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='../pytorch-tutorial/data', train=False, 
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

torch.backends.cudnn.benchmark = True
LEARNING_RATE = 1e-1  # learning rate for all possible weights
WEIGHT_DECAY = 1e-4  # hyperparameter for quantization

n_epochs = 200 # total number of epochs
m_epochs = 150 # the epoch phase II (turn off relaxation) starts

train_size = len(trainloader.dataset.train_labels)
batch_size = 128
n_batches = int(ceil(train_size/batch_size))
n_validation_batches = 100
# total number of batches in the train set
print "There are ", n_batches, " batches in the train set."

cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        in_channels = 3
        x = cfg[0]
        layers = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(in_channels, x, kernel_size=3, padding=1)),
            ('norm0', nn.BatchNorm2d(x)),
            ('relu0', nn.ReLU(inplace=True))
        ]))
        in_channels = x
        
        index_pool = 0; index_block = 1
        for x in cfg[1:]:
            if x == 'M':
                layers.add_module('pool%d' % index_pool, 
                                  nn.MaxPool2d(kernel_size=2, stride=2))
                index_pool += 1
            else:
                layers.add_module('conv%d' % index_block, 
                                  nn.Conv2d(in_channels, x, kernel_size=3, padding=1)),
                layers.add_module('norm%d' % index_block, 
                                  nn.BatchNorm2d(x)),
                layers.add_module('relu%d' % index_block, 
                                  nn.ReLU(inplace=True))
                in_channels = x
                index_block += 1
#         layers.add_module('avg_pool%d' % index_pool, 
#                           nn.AvgPool2d(kernel_size=1, stride=1))
        return layers
    
    
net = VGG('VGG11') # VGG-11 
thisname = 'vgg11_'

def load_model(net, name=thisname+'float.t7'):
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/' + name)
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['best_acc']
    start_epoch = checkpoint['start_epoch']
    return best_acc, start_epoch

def load_model_quant(net, name):
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/' + name)
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['best_acc']
    start_epoch = checkpoint['start_epoch']
    all_G_kernels = checkpoint['G_kernels']
    return best_acc, start_epoch, all_G_kernels

# float model

In [None]:
use_cuda = True
net.cuda()
cudnn.benchmark = True
model, loss, optimizer = get_model2(net, 
                                   learning_rate=LEARNING_RATE, 
                                   weight_decay=WEIGHT_DECAY)

params_float = [best_acc, start_epoch, thisname+'float.t7']

def optimization_step_fn(model, loss, x_batch, y_batch):
    return optimization_step_float(
        model, loss, x_batch, y_batch, 
        optimizer = optimizer
    )

lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80,140],gamma=0.1)

all_losses = train(
    model, loss, optimization_step_fn,
    trainloader, testloader, 
    params_float,
    threshold = 0.001,
    n_epochs=200, steps_per_epoch=n_batches, n_validation_batches=n_validation_batches,
    lr_scheduler=lr_scheduler
)
# epoch  train_logloss test_logloss train_accuracy test_accuracy     time

# Binary Model

In [None]:
# training binary weight model
modelname = 'bw.t7' 
load_float_model = 1 # load float model as the initialization
eta = 1 # relaxation parameter
eta_rate = 1.04 # growth factor for eta

all_G_kernels = []
    
if load_float_model:
    best_acc, start_epoch = load_model(net, name=thisname+'float.t7')
    use_cuda = True
    net.cuda()
    cudnn.benchmark = True
    model, loss, optimizer = get_model2(net, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    best_acc = 0
    start_epoch = 0
else:
    best_acc, start_epoch, all_G_kernels = load_model_quant(net, name=thisname+modelname)
    use_cuda = True
    net.cuda()
    cudnn.benchmark = True
    model, loss, optimizer = get_model2(net, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        
params_quant = [best_acc, start_epoch, thisname+modelname]
    
if load_float_model:
    all_G_kernels = [
        Variable(kernel.data.clone(), requires_grad=True)
        for kernel in optimizer.param_groups[1]['params']
    ]
    
all_W_kernels = [kernel for kernel in optimizer.param_groups[1]['params']]
    
kernels = [
    {'params': all_G_kernels}
]
    
optimizer_quant = optim.SGD(kernels, lr=0)
    
def optimization_step_fn(model, loss, x_batch, y_batch):
    return optimization_step(
        model, loss, x_batch, y_batch,
        optimizer_list = [optimizer, optimizer_quant]
    )
    
def optimization_step_fn_eta(model, loss, x_batch, y_batch, eta):
    return optimization_step_eta(
        model, loss, x_batch, y_batch,
        optimizer_list = [optimizer, optimizer_quant],
        eta = eta
    )
    
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80,140],gamma=0.1)

all_losses = train_eta(
    model, loss, optimization_step_fn,
    [all_W_kernels, all_G_kernels],
    trainloader, testloader,
    params_quant, optimization_step_fn_eta,
    n_epochs=n_epochs, steps_per_epoch=n_batches, n_validation_batches=n_validation_batches,
    lr_scheduler=lr_scheduler,
    eta=eta, eta_rate=eta_rate, m_epochs=m_epochs
)
# epoch  train_logloss test_logloss train_accuracy test_accuracy     time