Skip to content

Example for Compressing a Classifier

myd edited this page Jul 1, 2019 · 6 revisions

This is an example of how to compress a classifier(for CIFAR-10). You will see how to use mathematical compressors for compressing a neural network and then finetune the compressed neural network by using a classifier trainer which designed for cifar-10 project.

1. Preparations

Prepare the train/val data. PS: Using torchvision, it’s extremely easy to load the train/val data of CIFAR10:

transform_train = transforms.Compose([
    transforms.RandomRotation(degrees=5),
    transforms.RandomCrop(32, padding=4),  
    transforms.RandomHorizontalFlip(),  
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root=args.data_path, train=True, download=False, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
valset = torchvision.datasets.CIFAR10(root=args.data_path, train=False, download=False, transform=transform_test)
valloader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

Define a neural network, and load pre-trained parameters.

# define a network ...
net = ResNet18() 

# load state_dict ...
net.load_state_dict(torch.load('./cifar10/tmp/checkpoints/run8_resnet18_epoch_150_batch_128_lr_0.01_from_run7_epoch_149/epoch_169_loss_0.02847591015841345_accuracy_0.9209.pth'))

# using graph.reconstructor to converter the net(common torch.nn.Module) to origin_net(graph.modules.ReconstructedNetwork(torch.nn.Module)) ...
reconstructor.insertCaptureBoundaryStart(net)
oup = net(torch.rand(1, 3, 32, 32))
reconstructor.insertCaptureBoundaryEnd()
origin_net, graph = reconstructor.getReconstructedNetwork(
    ifDraw=True, 
    drawPath=os.path.join(args.checkpoints_folder, "origin_net")
    )

# show the Computation(Fused-Multiply-Add) Amount of origin_net
fmlas_origin = showFMLAs(torch.rand(1, 3, 32, 32), origin_net)
print("""fmlas_origin: {} G""".format(fmlas_origin / 1e9))

# show the test precision of origin_net
test_testset(net=origin_net, testloader=valloader, device=device)

2. Mathematical Compression

First, new a ChannelPrunning object if the method is Channel Pruning or a LowRankDecomposition object if using Low-Rank Decomposition method:

if compress_algorithm == 'CP':
    # if using channel pruning algorithm ...
    compressor = ChannelPrunning(
        origin_net              = origin_net,                                 
        trainloader             = trainloader,
        valloader               = valloader,
        trainset_ratio          = args.compress_trainset_ratio, 
        sampled_pixels_per_img  = args.compress_sampled_pixels_per_img,
        compress_ratio          = args.compress_ratios, 
        checkpointfolder        = os.path.join(args.checkpoints_folder, """compress-{}""".format(compress_cnt)), 
        device                  = device, 
        drawgraph               = args.compress_draw_graph,
        verbose                 = args.verbose,
        lars_alpha_init         = args.lars_alpha_init,
        accuracy_first          = True if args.compress_acc_thresh > 0 else False,
        accuracy_threshold      = args.compress_acc_thresh,
        args                    = args
        )
elif compress_algorithm == 'LRD':
    # if using low-rank decomposition algorithm ...
    compressor = LowRankDecompostion(
        origin_net              = origin_net, 
        trainloader             = trainloader, 
        valloader               = valloader,
        trainset_ratio          = args.compress_trainset_ratio, 
        sampled_pixels_per_img  = args.compress_sampled_pixels_per_img,
        compress_ratio          = args.compress_ratios, 
        checkpointfolder        = os.path.join(args.checkpoints_folder, """compress-{}""".format(compress_cnt)), 
        device                  = device, 
        verbose                 = args.verbose,
        drawgraph               = args.compress_draw_graph,
        accuracy_first          = True if args.compress_acc_thresh > 0 else False,
        accuracy_threshold      = args.compress_acc_thresh,
        nonlinear_case          = args.compress_nonlinear_case,
        args                    = args
        )
else:
    RuntimeError("""Unknow compress algorithm: {}""".format(compress_algorithm))
# endif

Then, run the compression:

# run compression
compressor.compress()

After compression completed, show the Computation(Fused-Multiply-Add) Amount and test precision of compressed network:

# show the Computation(Fused-Multiply-Add) Amount of compressor.compressed_net
fmlas_compressed = showFMLAs(torch.rand((1, 3, 32, 32)).to(device), compressor.compressed_net) 

# show the test precision of compressor.compressed_net
test_testset(net=compressor.compressed_net, testloader=valloader, device=device)

3. Finetune

Define a Trainer for CIFAR10:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from trainer.basetrainer import *
from trainer.basedistiller import *
from utils.compressmethod import showFMLAs

# <class CIFAR10Trainer>
class CIFAR10Trainer(BaseTrainer):
    
    # <method: __init__>
    def __init__(self, 
        train_loader,  val_loader,  eval_loader, 
        network, 
        criterion, optimizer, scheduler, 
        epochs, 
        device, tbx_writer, checkpoints_folder, additional_args):
        super(CIFAR10Trainer, self).__init__(
            train_loader = train_loader, 
            val_loader = val_loader, 
            eval_loader = eval_loader, 
            network = network, 
            criterion = criterion, 
            optimizer = optimizer, 
            scheduler = scheduler, 
            epochs = epochs, 
            device = device, 
            tbx_writer = tbx_writer, 
            checkpoints_folder = checkpoints_folder,
            additional_args = additional_args
            )
        self._best_precision = 0
        pass
    # <method: __init__>
    
    # <method: __get_val_batch_preds__>
    def __get_val_batch_preds__(self, batch_val_data, *args, **kwargs):
        """ This is method must be overload. User overload it to get val predicts of network"""
        inputs = batch_val_data[0]
        if self._device:
            inputs = inputs.to(self._device)
        return self._network(inputs)
    # <method: __get_val_batch_preds__>
    
    # <method: __get_val_losses__>
    def __get_val_losses__(self, batch_preds, batch_val_data, *args, **kwargs):
        """ This is method must be overload. User overload it to get val loss using loss function of network"""
        batch_val_lables = batch_val_data[1]
        if self._device:
            batch_val_lables = batch_val_lables.to(self._device)
        #   endif
        loss = self._criterion(batch_preds, batch_val_lables)
        return {'total_loss': loss}
    # <method: __get_val_losses__>              
    
    # <method: __get_train_batch_preds__>
    def __get_train_batch_preds__(self, batch_train_data, *args, **kwargs):
        """ This is method must be overload. User overload it to get train predicts of network"""
        inputs = batch_train_data[0]
        if self._device:
            inputs = inputs.to(self._device)
        return self._network(inputs)
    # <method: __get_train_batch_preds__>
    
    # <method: __get_train_losses__>
    def __get_train_losses__(self, batch_preds, batch_train_data, *args, **kwargs):
        """ This is method must be overload. User overload it to get training loss using loss function of network"""
        batch_train_lables = batch_train_data[1]
        if self._device:
            batch_train_lables = batch_train_lables.to(self._device)
        #   endif
        loss = self._criterion(batch_preds, batch_train_lables)                
        return {'total_loss': loss}
    # <method: __get_train_losses__>

    # <method: __eval__>
    def __eval__(self, epoch, *args, **kwargs):
        # if True:
        # import pdb; pdb.set_trace()
        try:
            fmlas = showFMLAs(torch.rand(1, 3, 32, 32).to(self._device), self._network, self._device) / 1e6                    
            precision = kwargs['eval'](net=self._network, testloader=self._eval_loader, device=self._device)
            print("""Eval_{}_Precision: {}""".format(epoch, precision))
            self._tbx_writer.add_scalar('Eval_Precision', precision, epoch - 1)
            self._tbx_writer.add_scalar('Eval_fmlas / M', fmlas, epoch - 1)
            # save params ...
            if precision > self._best_precision:
                self._best_precision = precision
                net_sd = self._network.state_dict()
                for key in list(net_sd.keys()):
                    net_sd[key] = net_sd[key].to('cpu')
                #   endfor
                state_dict = {
                    'device': 'cpu',
                    'net': net_sd,
                    'precision': self._best_precision,
                    'fmlas' : fmlas,
                    'optimizer':self._optimizer
                }
                torch.save(state_dict, os.path.join( self._checkpoints_folder, "best_eval_checkpoint.pth.tar" ))
        except:
            RuntimeWarning("""error in eval at epoch {}""".format(epoch))
    # <method: __eval__>

    # <method: __if_stop_trainning__>
    def __if_stop_trainning__(self, *args, **kwargs):
        for param in self._optimizer.param_groups:
            lr=param["lr"]
        #   endfor
        if lr < kwargs['stop_lr']:
            return True
        #   endif
        return False
    # <method: __if_stop_trainning__>

# <class CIFAR10Trainer>

Set Criterion:

criterion = nn.CrossEntropyLoss()

Set Optimizer and Learning-Rate Scheduler:

optimizer = optim.SGD(compressor.compressed_net.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, threshold=1e-5, patience=args.patience, min_lr=args.stop_lr)

Set a tensorboardx writer:

tbx_writer = SummaryWriter(log_dir=finetune_checkpointfolder)

New a CIFAR10Trainer object:

trainer = CIFAR10Trainer(
    train_loader = trainloader,
    val_loader = valloader,
    eval_loader = valloader,
    network = compressor.compressed_net,
    criterion = criterion,
    optimizer = optimizer,
    scheduler = scheduler,
    epochs = args.epochs,
    device = device,
    tbx_writer = tbx_writer,
    checkpoints_folder = finetune_checkpointfolder,
    additional_args = args
    )

Begin training:

trainer.__run__(stop_lr = args.stop_lr, eval = test_testset)

Get the re-trained network with best validation resoult:

retrain_net_best = trainer.__get_best_trained_network__()

After compression completed, show test precision of retrained network:

test_testset(net=retrain_net_best, testloader=valloader, device=device)

Save the best re-trained network into .pkl file:

retrain_net_best.save(os.path.join(finetune_checkpointfolder, "retrain_net_best.pkl"))

4. Code Downloads

  • File trainer.py shows how to define a trainer designed for training cifar-10 classifier.
  • File compress.py shows how to use mathematical compressor to compress a classifier and how to re-train the compressed classifier by using the trainer above. It also shows the compress-finetune-loop iteration that may continually compress the classifier until gained a classifier which striked a balance between computation amount and precision.

Clone this wiki locally