-
Notifications
You must be signed in to change notification settings - Fork 0
Example for Compressing a Classifier
myd edited this page Jun 28, 2019
·
6 revisions
This is an example of how to compress a classifier(for CIFAR-10). You will see how to use mathematical compressors for compressing a neural network and then finetune the compressed neural network by using a classifier trainer which designed for cifar-10 project.
Prepare the train/val data. PS: Using torchvision, it’s extremely easy to load the train/val data of CIFAR10:
transform_train = transforms.Compose([
transforms.RandomRotation(degrees=5),
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root=args.data_path, train=True, download=False, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
valset = torchvision.datasets.CIFAR10(root=args.data_path, train=False, download=False, transform=transform_test)
valloader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
Define a neural network, and load pre-trained parameters.
# define a network ...
net = ResNet18()
# load state_dict ...
net.load_state_dict(torch.load('./cifar10/tmp/checkpoints/run8_resnet18_epoch_150_batch_128_lr_0.01_from_run7_epoch_149/epoch_169_loss_0.02847591015841345_accuracy_0.9209.pth'))
# using graph.reconstructor to converter the net(common torch.nn.Module) to origin_net(graph.modules.ReconstructedNetwork(torch.nn.Module)) ...
reconstructor.insertCaptureBoundaryStart(net)
oup = net(torch.rand(1, 3, 32, 32))
reconstructor.insertCaptureBoundaryEnd()
origin_net, graph = reconstructor.getReconstructedNetwork(
ifDraw=True,
drawPath=os.path.join(args.checkpoints_folder, "origin_net")
)
# show the Computation(Fused-Multiply-Add) Amount of origin_net
fmlas_origin = showFMLAs(torch.rand(1, 3, 32, 32), origin_net)
print("""fmlas_origin: {} G""".format(fmlas_origin / 1e9))
# show the test precision of origin_net
test_testset(net=origin_net, testloader=valloader, device=device)
First, new a ChannelPrunning object if the method is Channel Pruning or a LowRankDecomposition object if using Low-Rank Decomposition method:
if compress_algorithm == 'CP':
# if using channel pruning algorithm ...
compressor = ChannelPrunning(
origin_net = origin_net,
trainloader = trainloader,
valloader = valloader,
trainset_ratio = args.compress_trainset_ratio,
sampled_pixels_per_img = args.compress_sampled_pixels_per_img,
compress_ratio = args.compress_ratios,
checkpointfolder = os.path.join(args.checkpoints_folder, """compress-{}""".format(compress_cnt)),
device = device,
drawgraph = args.compress_draw_graph,
verbose = args.verbose,
lars_alpha_init = args.lars_alpha_init,
accuracy_first = True if args.compress_acc_thresh > 0 else False,
accuracy_threshold = args.compress_acc_thresh,
args = args
)
elif compress_algorithm == 'LRD':
# if using low-rank decomposition algorithm ...
compressor = LowRankDecompostion(
origin_net = origin_net,
trainloader = trainloader,
valloader = valloader,
trainset_ratio = args.compress_trainset_ratio,
sampled_pixels_per_img = args.compress_sampled_pixels_per_img,
compress_ratio = args.compress_ratios,
checkpointfolder = os.path.join(args.checkpoints_folder, """compress-{}""".format(compress_cnt)),
device = device,
verbose = args.verbose,
drawgraph = args.compress_draw_graph,
accuracy_first = True if args.compress_acc_thresh > 0 else False,
accuracy_threshold = args.compress_acc_thresh,
nonlinear_case = args.compress_nonlinear_case,
args = args
)
else:
RuntimeError("""Unknow compress algorithm: {}""".format(compress_algorithm))
# endif
Then, run the compression:
# run compression
compressor.compress()
After compression completed, show the Computation(Fused-Multiply-Add) Amount and test precision of compressed network:
# show the Computation(Fused-Multiply-Add) Amount of compressor.compressed_net
fmlas_compressed = showFMLAs(torch.rand((1, 3, 32, 32)).to(device), compressor.compressed_net)
# show the test precision of compressor.compressed_net
test_testset(net=compressor.compressed_net, testloader=valloader, device=device)
- File trainer.py shows how to define a trainer designed for training cifar-10 classifier.
- File compress.py shows how to use mathematical compressor to compress a classifier and how to re-train the compressed classifier by using the trainer above. It also shows the compress-finetune-loop iteration that may continually compress the classifier until gained a classifier which striked a balance between computation amount and precision.