In [35]:
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms, datasets, models
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.cluster import MiniBatchKMeans
import numpy as np
import copy
import json
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import subprocess
import os
import warnings
warnings.filterwarnings("ignore")

In [36]:
def test(model):
    model.eval()
    correct = 0
    for x, y in ds_test:
        x, y = x.cuda(), y.cuda()
        output = model(x)
        predictions = output.argmax(1)
        correct += (predictions == y).sum().item()
    return correct / len(ds_test.dataset)

# CIFAR10

In [2]:
def get_transforms(nx=224):
	def _convert_image_to_rgb(image):
		return image.convert("RGB")
	train_transform = torchvision.transforms.Compose([
		torchvision.transforms.Resize(nx, antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
		torchvision.transforms.CenterCrop(nx),
		_convert_image_to_rgb,
		torchvision.transforms.ToTensor(),
		torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
	])
	test_transform = torchvision.transforms.Compose([
		torchvision.transforms.Resize(nx, antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
		torchvision.transforms.CenterCrop(nx),
		_convert_image_to_rgb,
		torchvision.transforms.ToTensor(),
		torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
	])
	return train_transform, test_transform

In [3]:
train_transform, test_transform = get_transforms(nx=32)
ds_train = torchvision.datasets.CIFAR10(root='~/datasets', train=False, download=True, transform=train_transform)
ds_train = torch.utils.data.DataLoader(ds_train, batch_size=128, shuffle=True)
ds_test = torchvision.datasets.CIFAR10(root='~/datasets', train=False, download=True, transform=test_transform)
ds_test = torch.utils.data.DataLoader(ds_test, batch_size=128, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_vgg16_bn", pretrained=True).cuda()


## git clone https://github.com/huyvnphan/PyTorch_CIFAR10.git
# from PyTorch_CIFAR10.cifar10_models.resnet import resnet18
# model = resnet18(pretrained=True).cuda()


print(f'Baseline accuracy: {test(model)}')

# ImageNet

In [8]:
batch_size=64

transform_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_dataset = torchvision.datasets.ImageFolder(root='/data/ImageNet/ILSVRC2012_img_val/', transform=transform_test)
ds_test = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

In [38]:
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True).cuda()
model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg16', pretrained=True).cuda()

print(f'Baseline accuracy: {test(model)}')

# Weight Clustering

In [29]:
def storage_size(model):
    net2 = copy.deepcopy(model)
    file_path = "/tmp/model.h5"
    torch.save(net2.state_dict(), file_path)
    size_before = os.path.getsize(file_path) / (1024 * 1024)
    subprocess.run(["gzip", "-qf", file_path])
    size_after = os.path.getsize(file_path + ".gz") / (1024 * 1024)
    return size_before,size_after , size_before/size_after


In [31]:
# Hyperparameters
BATCH_SIZE = 32
BLOCK_SIZE = 4
N_CENTROIDS =50 
DELTA = abs(2.170973157686992 - 0.003227713898589287) / 2
MODULE_FILTER = lambda name, module: isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) and name != 'classifier.6'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [32]:
def blockify(x, block_width, block_height):
    return x.reshape(-1, block_width * block_height)
def deblockify(x, *original_dimensions):
    return x.reshape(*original_dimensions)
def legoifier(model, block_size=BLOCK_SIZE, n_centroids=N_CENTROIDS, module_filter=MODULE_FILTER):
    modules_to_update = [
        module for name, module in model.named_modules()
        if module_filter(name, module)
    ]

    # Group into blocks

    blocks = torch.cat([
        blockify(module.weight, block_size, block_size)
        for module in modules_to_update
    ])

    # Cluster blocks

    kmeans = MiniBatchKMeans(n_init='auto', n_clusters=n_centroids, init='k-means++', random_state=0)
    distances = kmeans.fit_transform(blocks.detach().cpu())
    distances = distances[np.arange(distances.shape[0]), kmeans.labels_]
    distances = torch.from_numpy(distances).to(device)


    def legoify(delta=DELTA):
        lego_model = copy.deepcopy(model)
        modules_to_update = [
            module for name, module in lego_model.named_modules()
            if module_filter(name, module)
        ]
        blocks = torch.cat([
            blockify(module.weight, block_size, block_size)
            for module in modules_to_update
        ])

        # Replace blocks closest to centers with those centers
        legos = torch.where(
            (distances < delta).view(-1, 1),
            torch.from_numpy(kmeans.cluster_centers_[kmeans.labels_]).to(device).type_as(blocks),
            blocks
        )
        blocks_replaced = (blocks != legos).all(axis=1).sum()
        block_count = len(legos)

        # Convert legos back to original shape and update model
        legos_used = 0

        for i, module in enumerate(modules_to_update):
            n_legos = module.weight.numel() // block_size ** 2
            module_legos = legos[legos_used:legos_used + n_legos]
            legos_used += n_legos
            updated_weight = deblockify(module_legos, *module.weight.shape)
            module.weight.data = torch.as_tensor(updated_weight).to(device)

        return lego_model, blocks_replaced, block_count

    return legoify

In [41]:
module_filters = {
    'all': MODULE_FILTER,
    'conv': lambda name, module: isinstance(module, nn.Conv2d),
    'linear': lambda name, module: isinstance(module, nn.Linear) and name != 'classifier.6'
}
deltas = np.linspace(0.001, 0.1, num=50)
deltas = np.append(deltas, DELTA)
model.cuda()
results = {}

print('size before compression:', storage_size(model)[1] )

for filter_name, module_filter in module_filters.items():
    results[filter_name] = {}
    legoify = legoifier(model, module_filter=module_filter)
    for delta in deltas:
        lego_model, blocks_replaced, block_count = legoify(delta=delta)
        accuracy = test(lego_model)
        print(f'delta:{round(delta, 3)}, accuracy: {accuracy}, model_size: {round(storage_size(lego_model)[1], 3)}, replaced: {(blocks_replaced / block_count).item()*100}')
        results[filter_name][delta] = {'accuracy': accuracy, 'replaced': (blocks_replaced / block_count).item()}
#del lego_model  # Takes up 0.5 GB of memory!

In [40]:
import pprint
pprint.pprint(results)

In [39]:
plt.title('Lego-ified Model Accuracy')
plt.xlabel('Delta')
plt.ylabel('% Accuracy')
for filter_name in module_filters.keys():
    plt.plot(list(results[filter_name].keys()), [value['accuracy'] for value in results[filter_name].values()], label=filter_name)
plt.legend()
plt.show()

plt.title('Blocks Replaced by Cluster Centers')
plt.xlabel('Delta')
plt.ylabel('% Replaced')
for filter_name in module_filters.keys():
    plt.plot(list(results[filter_name].keys()), [value['replaced'] for value in results[filter_name].values()], label=filter_name)
plt.legend()
plt.show()