In [8]:
import argparse
import os
import time
import shutil

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

from tensorboardX import SummaryWriter      

import torchvision
import torchvision.transforms as transforms

from models import *

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="0"

global best_prec
use_gpu = torch.cuda.is_available()
print('=> Building model...')
    
    
batch_size = 256
model_name = "Resnet20_quant_unstructure_pruning"
model = resnet20_quant()
print(model)

normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])


train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)


test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ]))

testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


print_freq = 100 # every 100 batches, accuracy printed. Here, each batch includes "batch_size" data points
# CIFAR10 has 50,000 training data, and 10,000 validation data.

def train(trainloader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    model.train()

    end = time.time()
    for i, (input, target) in enumerate(trainloader):
        # measure data loading time
        data_time.update(time.time() - end)

        input, target = input.cuda(), target.cuda()

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # measure accuracy and record loss
        prec = accuracy(output, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()


        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec {top1.val:.3f}% ({top1.avg:.3f}%)'.format(
                   epoch, i, len(trainloader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1))

            

def validate(val_loader, model, criterion ):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
         
            input, target = input.cuda(), target.cuda()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec = accuracy(output, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:  # This line shows how frequently print out the status. e.g., i%5 => every 5 batch, prints out
                print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec {top1.val:.3f}% ({top1.avg:.3f}%)'.format(
                   i, len(val_loader), batch_time=batch_time, loss=losses,
                   top1=top1))

    print(' * Prec {top1.avg:.3f}% '.format(top1=top1))
    return top1.avg


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

        
def save_checkpoint(state, is_best, fdir):
    filepath = os.path.join(fdir, 'checkpoint.pth')
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(fdir, 'model_best.pth.tar'))


def adjust_learning_rate(optimizer, epoch):
    """For resnet, the lr starts from 0.1, and is divided by 10 at 80 and 120 epochs"""
    adjust_list = [80, 120]
    if epoch in adjust_list:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.1        

#model = nn.DataParallel(model).cuda()
#all_params = checkpoint['state_dict']
#model.load_state_dict(all_params, strict=False)
#criterion = nn.CrossEntropyLoss().cuda()
#validate(testloader, model, criterion)

=> Building model...
ResNet_Cifar(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): QuantConv2d(
        16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (weight_quant): weight_quantize_fn()
      )
      (conv2): QuantConv2d(
        16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (weight_quant): weight_quantize_fn()
      )
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): QuantConv2d(
        16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (weight_quant): weight_quantize_fn()

In [12]:
from models import QuantConv2d

PATH = "result/{}/model_best.pth.tar"
checkpoint = torch.load(PATH.format(model_name))

# Apply pruning structure to model before loading checkpoint
import torch.nn.utils.prune as prune

for name, module in model.named_modules():
    if isinstance(module, QuantConv2d):
        print("Applying unstructured pruning to {}".format(name))
        prune.l1_unstructured(module, name='weight', amount=0.8)

model.load_state_dict(checkpoint['state_dict'])
device = torch.device("cuda") 

model.cuda()
model.eval()

model.load_state_dict(checkpoint['state_dict'])
device = torch.device("cuda") 

model.cuda()
model.eval()

test_loss = 0
correct = 0

with torch.no_grad():
    for data, target in testloader:
        data, target = data.to(device), target.to(device) # loading to GPU
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)  
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(testloader.dataset)

print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(testloader.dataset),
        100. * correct / len(testloader.dataset)))

Applying unstructured pruning to layer1.0.conv1
Applying unstructured pruning to layer1.0.conv2
Applying unstructured pruning to layer1.1.conv1
Applying unstructured pruning to layer1.1.conv2
Applying unstructured pruning to layer1.2.conv1
Applying unstructured pruning to layer1.2.conv2
Applying unstructured pruning to layer2.0.conv1
Applying unstructured pruning to layer2.0.conv2
Applying unstructured pruning to layer2.0.downsample.0
Applying unstructured pruning to layer2.1.conv1
Applying unstructured pruning to layer2.1.conv2
Applying unstructured pruning to layer2.2.conv1
Applying unstructured pruning to layer2.2.conv2
Applying unstructured pruning to layer3.0.conv1
Applying unstructured pruning to layer3.0.conv2
Applying unstructured pruning to layer3.0.downsample.0
Applying unstructured pruning to layer3.1.conv1
Applying unstructured pruning to layer3.1.conv2
Applying unstructured pruning to layer3.2.conv1
Applying unstructured pruning to layer3.2.conv2

Test set: Accuracy: 8994/

In [None]:
#### Prune all the QuantConv2D layers' 90% weights with 1) unstructured, and 2) structured manner.



In [3]:
# Unstructured pruning with 80% sparsity

import torch.nn.utils.prune as prune
for name, module in model.named_modules():
    if isinstance(module, QuantConv2d):
        print("Pruning 80% weights in {}".format(name))
        prune.l1_unstructured(module, name='weight', amount=0.8)

Pruning 80% weights in layer1.0.conv1
Pruning 80% weights in layer1.0.conv2
Pruning 80% weights in layer1.1.conv1
Pruning 80% weights in layer1.1.conv2
Pruning 80% weights in layer1.2.conv1
Pruning 80% weights in layer1.2.conv2
Pruning 80% weights in layer2.0.conv1
Pruning 80% weights in layer2.0.conv2
Pruning 80% weights in layer2.0.downsample.0
Pruning 80% weights in layer2.1.conv1
Pruning 80% weights in layer2.1.conv2
Pruning 80% weights in layer2.2.conv1
Pruning 80% weights in layer2.2.conv2
Pruning 80% weights in layer3.0.conv1
Pruning 80% weights in layer3.0.conv2
Pruning 80% weights in layer3.0.downsample.0
Pruning 80% weights in layer3.1.conv1
Pruning 80% weights in layer3.1.conv2
Pruning 80% weights in layer3.2.conv1
Pruning 80% weights in layer3.2.conv2


In [None]:
# Access a QuantConv2d layer in the model (e.g., layer3[2].conv2)
print(list(model.layer3[2].conv2.named_parameters())) # check whether there is mask, weight_org, ...
print(model.layer3[2].conv2.weight) # check whether there are many zeros

In [4]:
### Check sparsity ###
mask1 = model.layer3[2].conv2.weight_mask
sparsity_mask1 = (mask1 == 0).sum() / mask1.nelement()

print("Sparsity level: ", sparsity_mask1)

Sparsity level:  tensor(0.8000, device='cuda:0')


In [17]:
## check accuracy after pruning

model.eval()

test_loss = 0
correct = 0

with torch.no_grad():
    for data, target in testloader:
        data, target = data.to(device), target.to(device) # loading to GPU
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)  
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(testloader.dataset)

print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(testloader.dataset),
        100. * correct / len(testloader.dataset)))


Test set: Accuracy: 1000/10000 (10%)



In [5]:
## Start finetuning (training here), and see how much you can recover your accuracy ##
## You can change hyper parameters such as epochs or lr ##
lr = 0.1
weight_decay = 1e-4
epochs = 160
best_prec = 0

#model = nn.DataParallel(model).cuda()
model.cuda()
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
#cudnn.benchmark = True

if not os.path.exists('result'):
    os.makedirs('result')
fdir = 'result/'+str(model_name)
if not os.path.exists(fdir):
    os.makedirs(fdir)
        

for epoch in range(0, epochs):
    adjust_learning_rate(optimizer, epoch)

    train(trainloader, model, criterion, optimizer, epoch)
    
    # evaluate on test set
    print("Validation starts")
    prec = validate(testloader, model, criterion)

    # remember best precision and save checkpoint
    is_best = prec > best_prec
    best_prec = max(prec,best_prec)
    print('best acc: {:1f}'.format(best_prec))
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'best_prec': best_prec,
        'optimizer': optimizer.state_dict(),
    }, is_best, fdir)

Epoch: [0][0/196]	Time 4.744 (4.744)	Data 4.622 (4.622)	Loss 1.4065 (1.4065)	Prec 60.938% (60.938%)
Epoch: [0][100/196]	Time 0.038 (0.069)	Data 0.001 (0.046)	Loss 0.3271 (0.5255)	Prec 88.672% (82.085%)
Validation starts
Test: [0/40]	Time 3.816 (3.816)	Loss 0.4900 (0.4900)	Prec 83.594% (83.594%)
 * Prec 84.270% 
best acc: 84.270000
Epoch: [1][0/196]	Time 4.460 (4.460)	Data 4.436 (4.436)	Loss 0.3332 (0.3332)	Prec 88.672% (88.672%)
Epoch: [1][100/196]	Time 0.036 (0.066)	Data 0.001 (0.044)	Loss 0.3160 (0.3268)	Prec 88.281% (88.614%)
Validation starts
Test: [0/40]	Time 3.878 (3.878)	Loss 0.5823 (0.5823)	Prec 81.641% (81.641%)
 * Prec 84.020% 
best acc: 84.270000
Epoch: [2][0/196]	Time 4.520 (4.520)	Data 4.502 (4.502)	Loss 0.2919 (0.2919)	Prec 89.062% (89.062%)
Epoch: [2][100/196]	Time 0.038 (0.076)	Data 0.001 (0.045)	Loss 0.3009 (0.2911)	Prec 89.453% (89.720%)
Validation starts
Test: [0/40]	Time 3.843 (3.843)	Loss 0.6630 (0.6630)	Prec 82.031% (82.031%)
 * Prec 82.010% 
best acc: 84.270000
E

In [14]:
## check your accuracy again after finetuning

model.eval()

test_loss = 0
correct = 0

with torch.no_grad():
    for data, target in testloader:
        data, target = data.to(device), target.to(device) # loading to GPU
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)  
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(testloader.dataset)

print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(testloader.dataset),
        100. * correct / len(testloader.dataset)))


Test set: Accuracy: 8348/10000 (83%)



In [13]:
from models import QuantConv2d

#### check global sparsity for weight_int is near 80% #####
# Iterate through all QuantConv2d layers and compute weight_int sparsity

w_bit = 4
total_zeros = 0
total_elements = 0

for name, module in model.named_modules():
    if isinstance(module, QuantConv2d):
        weight_q = module.weight_q
        w_alpha = module.weight_quant.wgt_alpha
        w_delta = w_alpha / (2**(w_bit-1) - 1)
        
        weight_int = weight_q / w_delta
        
        zeros = (weight_int == 0).sum().item()
        elements = weight_int.nelement()
        
        total_zeros += zeros
        total_elements += elements

global_sparsity = total_zeros / total_elements
print(f"Global sparsity for weight_int: {global_sparsity:.4f} ({global_sparsity*100:.2f}%)")

Global sparsity for weight_int: 0.6982 (69.82%)


Sparsity level:  tensor(0.9204, device='cuda:0')
