In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.utils.data as data
import torch.nn as nn
# import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import utils.utils as util

import numpy as np
import os, time, sys
import argparse

import utils.pg_utils as q
import model as m

In [2]:
def load_data():
    data_dir = 'tiny-224/'
    num_workers = {'train': 2,'val': 0,'test': 0}
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomRotation(20),
            transforms.RandomHorizontalFlip(0.5),
            transforms.ToTensor(),
            transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
        ]),
        'test': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
        ])
    }
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) 
                      for x in ['train', 'val', 'test']}
    dataloaders = {x: data.DataLoader(image_datasets[x], batch_size=batch_size, 
                                      shuffle=True, num_workers=num_workers[x], pin_memory=True)
                      for x in ['train', 'val', 'test']}
    
    trainloader = dataloaders['train']
    testloader = dataloaders['test']
    
#     dataset_sizes = [len(trainset), len(testset)]
    
#     normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                                      std=[0.229, 0.224, 0.225])
#     transform_train = transforms.Compose([
#                     transforms.RandomHorizontalFlip(),
#                     transforms.RandomCrop(32, 4),
#                     transforms.ToTensor(),
#                     normalize
#         ])
#     transform_test = transforms.Compose([
#                     transforms.ToTensor(),
#                     normalize
#         ])

#     # pin_memory=True makes transfering data from host to GPU faster
#     trainset = torchvision.datasets.CIFAR10(root='/tmp/cifar10_data', train=True,
#                                             download=True, transform=transform_train)
#     trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
#                                               shuffle=True, num_workers=2, pin_memory=True)

#     testset = torchvision.datasets.CIFAR10(root='/tmp/cifar10_data', train=False,
#                                            download=True, transform=transform_test)
#     testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
#                                              shuffle=True, num_workers=2, pin_memory=True)

    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    dataset_sizes = [len(image_datasets['train']), len(image_datasets['test'])]
    
    return trainloader, testloader, classes, dataset_sizes

In [3]:
from time import perf_counter
import matplotlib.pyplot as plt
from livelossplot import PlotLosses
import copy

def train_model(trainloader, dataset_sizes, testloader, net, device):
    if torch.cuda.device_count() > 1:
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        print("Activate multi GPU support.")
        net = nn.DataParallel(net)
    net.to(device)
    # define the loss function
    criterion = (nn.CrossEntropyLoss().cuda() 
                if torch.cuda.is_available() else nn.CrossEntropyLoss())
    # Scale the lr linearly with the batch size. 
    # Should be 0.1 when batch_size=128
    initial_lr = 0.1 * batch_size / 128
    # initialize the optimizer
    optimizer = optim.SGD(net.parameters(), 
                          lr=initial_lr, 
                          momentum=0.9,
                          weight_decay=1e-4)
    # multiply the lr by 0.1 at 100, 150, and 200 epochs
    div = num_epoch // 4
    lr_decay_milestones = [div*2, div*3]
    scheduler = optim.lr_scheduler.MultiStepLR(
                        optimizer, 
                        milestones=lr_decay_milestones, 
                        gamma=0.1,
                        last_epoch=-1)
    
    # some bookkeeping
    since = perf_counter()
    liveloss = PlotLosses()
    best_acc = 0.0
    best = 0
    
    loss_list = []
    accuracy_list = []
    val_loss_list = []
    val_acc_list = []
    max_val_acc = 0

    for epoch in range(num_epoch): # loop over the dataset multiple times

        # set printing functions
        batch_time = util.AverageMeter('Time/batch', ':.3f')
        losses = util.AverageMeter('Loss', ':6.2f')
        top1 = util.AverageMeter('Acc', ':6.2f')
        progress = util.ProgressMeter(
                        len(trainloader),
                        [losses, top1, batch_time],
                        prefix="Epoch: [{}]".format(epoch+1)
                        )

        # switch the model to the training mode
        net.train()

        print('current learning rate = {}'.format(optimizer.param_groups[0]['lr']))
        
        # each epoch
        running_loss = 0.0
        running_corrects = 0
        
        end = time.time()
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data[0].to(device), data[1].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            for name, param in net.named_parameters():
                if 'threshold' in name:
                    loss += sigma * torch.norm(param-gtarget)
            loss.backward()
            optimizer.step()

            # measure accuracy and record loss
            _, batch_predicted = torch.max(outputs.data, 1)
            batch_accu = 100.0 * (batch_predicted == labels).sum().item() / labels.size(0)
            losses.update(loss.item(), labels.size(0))
            top1.update(batch_accu, labels.size(0))
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(batch_predicted == labels.data)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 100 == 99:    
                # print statistics every 100 mini-batches each epoch
                progress.display(i) # i = batch id in the epoch
        
        epoch_loss = running_loss / dataset_sizes[0]
        epoch_acc = running_corrects.double() / dataset_sizes[0]
        loss_list.append(epoch_loss)
        accuracy_list.append(epoch_acc)
        
        # update the learning rate
        scheduler.step()

        # print test accuracy every epochs
        print('epoch {}'.format(epoch+1))
        val_loss, val_acc = test_accu(testloader, net, device)
        val_loss_list.append(epoch_loss)
        val_acc_list.append(epoch_acc)
        
        if epoch_acc > best_acc:
            best_acc = epoch_acc
            best = epoch + 1
            best_model_wts = copy.deepcopy(net.state_dict())

        liveloss.update({
                'log loss': epoch_loss,
                'val_log loss': val_loss,
                'accuracy': epoch_acc,
                'val_accuracy': val_acc
            })     
        liveloss.draw()
    
    # save the model if required
    if save:
        print("Saving the best trained model.")
        util.save_models(best_model_wts, save_folder, suffix=_ARCH)

    time_elapsed = perf_counter() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best Validation Accuracy: {}, Epoch: {}'.format(best_acc, best))
    
    return (loss_list, accuracy_list, val_loss_list, val_acc_list)

In [4]:
def test_accu(testloader, net, device):
    net.to(device)
    cnt_out = np.zeros(9) # this 9 is hardcoded for ResNet-20
    cnt_high = np.zeros(9) # this 9 is hardcoded for ResNet-20
    num_out = []
    num_high = []
    def _report_sparsity(m):
        classname = m.__class__.__name__
        if isinstance(m, q.PGConv2d):
            num_out.append(m.num_out)
            num_high.append(m.num_high)

    correct = 0
    total = 0
    running_loss = 0.0
    # switch the model to the evaluation mode
    net.eval()
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            
            criterion = (nn.CrossEntropyLoss().cuda() 
                if torch.cuda.is_available() else nn.CrossEntropyLoss())
            loss = criterion(outputs, labels)
            for name, param in net.named_parameters():
                if 'threshold' in name:
                    loss += sigma * torch.norm(param-gtarget)
            running_loss += loss.item() * images.size(0)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            """ calculate statistics per PG layer """
            if pg:
                net.apply(_report_sparsity)
                cnt_out += np.array(num_out)
                cnt_high += np.array(num_high)
                num_out = []
                num_high = []

    print('Accuracy of the network on the 10000 test images: %.1f %%' % (
        100 * correct / total))
    if pg:
        print('Sparsity of the update phase: %.1f %%' % (100-np.sum(cnt_high)*1.0/np.sum(cnt_out)*100))
    
    val_loss = running_loss / total
    val_acc = correct / total
    return val_loss, val_acc

In [5]:
def per_class_test_accu(testloader, classes, net, device):
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))
    net.eval()
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(4):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1


    for i in range(10):
        print('Accuracy of %5s : %.1f %%' % (
            classes[i], 100 * class_correct[i] / class_total[i]))

In [6]:
batch_size = 32
num_epoch = 200
_LAST_EPOCH = -1 #last_epoch arg is useful for restart
_WEIGHT_DECAY = 1e-4
_ARCH = "resnet-20"
this_file_path = os.path.abspath('./')
save_folder = os.path.join(this_file_path, 'save_CIFAR10_model')
print('Save at', save_folder)

Save at /home/aperture/Git/dnn-gating/save_CIFAR10_model


In [7]:
def generate_model(model_arch, num_classes):
    if model_arch == 'resnet-20':
        if pg:
            import model.pg_cifar10_resnet_s as m
            kwargs = {'wbits':wbits, 'abits':abits, \
                      'pred_bits':pbits, 'sparse_bp':sparse_bp, \
                      'pact':pact}
            return m.resnet20(num_classes, **kwargs)
        else:
            import model.quantized_cifar10_resnet as m
            kwargs = {'wbits':wbits, 'abits':abits, 'pact':pact}
            return m.resnet20(num_classes, **kwargs)
    else:
        raise NotImplementedError("Model architecture is not supported.")

In [8]:
path = None
save = True
test = False
wbits = 8
abits = 3
pact = True
pbits = 2
gtarget = 1
sparse_bp = True
pg = True
sigma = 0.001

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Available GPUs: {}".format(torch.cuda.device_count()))

print("Create {} model.".format(_ARCH))
net = generate_model(_ARCH, num_classes=200)
print(net)


Available GPUs: 1
Create resnet-20 model.
ResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): PGConv2d(
        16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (quantize_w): TorchQuantize(
          (quantize): TorchRoundToBits()
        )
        (quantize_a): TorchQuantize(
          (quantize): TorchRoundToBits()
        )
        (trunc_a): TorchTruncate()
      )
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): QuantizedConv2d(
        32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (quantize_w): TorchQuantize(
          (quantize): TorchRoundToBits()
        )
        (quantize_a): TorchQuantize(
          (quantize): TorchRoundToBits()
        )
      )
      (bn2): BatchN

In [None]:
if path:
    print("@ Load trained model from {}.".format(path))
    net.load_state_dict(torch.load(path))

print("Loading the data.")
trainloader, testloader, classes, dataset_sizes = load_data()
print("Dataset sizes:", dataset_sizes)
if test:
    print("Mode: Test only.")
    test_accu(testloader, net, device)
else:
    print("Start training.")
    train_model(trainloader, dataset_sizes, testloader, net, device)
    test_accu(testloader, net, device)
#     per_class_test_accu(testloader, classes, net, device)

Loading the data.
Dataset sizes: [100000, 5000]
Start training.
current learning rate = 0.025
