# "Deep Compression: Compressing Deep Neural Network with Pruning, Trained Quantization and Huffman Coding" paper implementation - https://arxiv.org/pdf/1510.00149.pdf

In [1]:
import torch
import torch.nn as nn
from torchvision.transforms import transforms
import torchvision
import torch.utils as utils
import torch.optim as optim

data_root = "./data"


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.features = nn.Sequential(nn.Conv2d(1, 6, 5),
                                      nn.ReLU(inplace=True),
                                      nn.MaxPool2d(kernel_size=2),
                                      nn.Conv2d(6, 16, 5),
                                      nn.ReLU(inplace=True),
                                      nn.MaxPool2d(kernel_size=2),
                                      )
        self.classifier = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                                        nn.Dropout(p=0.5),
                                        nn.Linear(120, 84),
                                        nn.Linear(84, 10),
                                        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


transforms = transforms.Compose([transforms.Resize(32),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),])

dataset = torchvision.datasets.MNIST(
    root=data_root, transform=transforms, download=True, train=True)
train_data = utils.data.DataLoader(
    dataset, shuffle=True, batch_size=100, num_workers=2)

test_dataset = torchvision.datasets.MNIST(
    root=data_root, transform=transforms, download=True, train=False)
test_data = utils.data.DataLoader(
    test_dataset, shuffle=False, batch_size=100, num_workers=2)


def model_test(model, test_data):
    correct = 0
    total = 0
    for (images, labels) in test_data:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()

    accuracy = 100 * correct / total
    print('Test accuracy: %d %%' % (accuracy))
    return accuracy


def train():
    net = Net()
    net.train()

    criterion = nn.CrossEntropyLoss()
    lr = 0.001
    optimizer = optim.SGD(net.parameters(), lr=lr, weight_decay=5e-4)

    for epoch in range(100):
        print("Epoch : %d" % (epoch + 1))
        running_loss = 0
        for batch_index, (inputs, target) in enumerate(train_data):
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if batch_index % 1000 == 0:
                print("[%d     , %5d] loss:%.4f" %
                      (epoch + 1, batch_index, running_loss / 1000))

    torch.save(net, "l-lenet.pth")
    model_test(net, test_data=test_data)


train()

  warn(


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz


1.3%

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz


100.0%

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz



6.0%

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100.0%


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz


100.0%

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch : 1





[1     ,     0] loss:0.0023
Epoch : 2
[2     ,     0] loss:0.0023
Epoch : 3
[3     ,     0] loss:0.0023
Epoch : 4
[4     ,     0] loss:0.0023
Epoch : 5
[5     ,     0] loss:0.0023
Epoch : 6
[6     ,     0] loss:0.0023
Epoch : 7
[7     ,     0] loss:0.0023
Epoch : 8
[8     ,     0] loss:0.0023
Epoch : 9
[9     ,     0] loss:0.0023
Epoch : 10
[10     ,     0] loss:0.0023
Epoch : 11
[11     ,     0] loss:0.0022
Epoch : 12
[12     ,     0] loss:0.0021
Epoch : 13
[13     ,     0] loss:0.0019
Epoch : 14
[14     ,     0] loss:0.0015
Epoch : 15
[15     ,     0] loss:0.0011
Epoch : 16
[16     ,     0] loss:0.0010
Epoch : 17
[17     ,     0] loss:0.0012
Epoch : 18
[18     ,     0] loss:0.0008
Epoch : 19
[19     ,     0] loss:0.0008
Epoch : 20
[20     ,     0] loss:0.0008
Epoch : 21
[21     ,     0] loss:0.0007
Epoch : 22
[22     ,     0] loss:0.0007
Epoch : 23
[23     ,     0] loss:0.0009
Epoch : 24
[24     ,     0] loss:0.0006
Epoch : 25
[25     ,     0] loss:0.0005
Epoch : 26
[26     ,     0] 

In [4]:
import torch
import torch.optim as optim
from torch.autograd import Variable
from sklearn.cluster import KMeans
import numpy as np
import torch.nn as nn


class DeepCompressor():
    def __init__(self, model_path, test_data, train_data, k, lr):
        self.test_data = test_data
        self.train_data = train_data
        self.model = torch.load(model_path)
        self.criterion = torch.nn.CrossEntropyLoss()
        self.k = k
        self.lr = lr

    def train(self, optimizer=None, epoches=10):
        if optimizer is None:
            optimizer = \
                optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9)

        for i in range(epoches):
            print("Epoch: ", i)
            self.train_epoch(optimizer, weight_share=True)
        print("Finished fine tuning.")
        return self.update_weights()

    def update_weights(self):
        model = self.model
        for layer, (name, module) in enumerate(model.classifier._modules.items()):
            module.register_backward_hook(self.scalar_quantization)
            weight = module.weight.data.cpu().numpy()
            weight_shape = weight.shape
            centroids = module.centroids
            labels = module.labeled_weight
            new_weight = self.get_finilized_weight(
                weight=weight, centroids=centroids, labels=labels)
            new_weight = new_weight.reshape(
                weight_shape[0], weight_shape[1], dtype=np.int8)
            module.weight = torch.from_numpy(new_weight).int()
            del module.labeled_weight
        return model

    def get_finilized_weight(self, weight, centroids, labels):
        for index, label in enumerate(labels):
            weight[index] = centroids[label][0]
        return weight

    def train_batch(self, optimizer, batch, label, weight_share):
        self.model.zero_grad()
        input = Variable(batch)
        if weight_share:
            output = self.forward(input)
            self.criterion(output, Variable(label)).backward()
        else:
            self.criterion(self.model(input), Variable(label)).backward()
            optimizer.step()

    def train_epoch(self, optimizer=None, weight_share=False):
        index = 1
        for batch_index, (batch, label) in enumerate(self.train_data, 0):
            self.train_batch(optimizer, batch,
                             label, weight_share)
            if batch_index % 100 == 0:
                print(batch_index)

    def forward(self, x):
        x = self.model.features(x)
        x = x.view(x.size(0), -1)
        for layer, (name, module) in enumerate(self.model.classifier._modules.items()):
            if isinstance(module, torch.nn.modules.Linear):
                module.register_backward_hook(self.scalar_quantization)
                weight = module.weight.data.cpu().numpy()
                weight_shape = weight.shape
                sorted_centroids, centroids, labeled_weight = self.find_centroids(
                    weight, self.k)
                new_weight = self.get_converted_weight(
                    labeled_weight=labeled_weight, centroids=centroids)
                new_weight = new_weight.reshape(
                    weight_shape[0], weight_shape[1])
                module.labeled_weight = labeled_weight
                module.centroids = centroids
                module.weight.data = torch.from_numpy(
                    new_weight).float()
            x = module(x)
        return x

    def find_centroids(self, weight, num_class):
        a = weight.reshape(-1, 1)
        kmeans = KMeans(n_clusters=num_class, random_state=0).fit(a)
        centroids = kmeans.cluster_centers_
        labels = kmeans.labels_
        sorted_centroids = -np.sort(-centroids, axis=0)
        return sorted_centroids, centroids, labels

    def get_converted_weight(self, labeled_weight, centroids):
        new_weight = np.zeros(shape=labeled_weight.shape, dtype=np.float32)
        for index, label in enumerate(labeled_weight):
            new_weight[index] = centroids[label][0]
        return labeled_weight

    def get_centroids_gradients(self, grad_input, labeled_weight, dw, grad_output):
        w_grad = grad_input[2].t().data.cpu().numpy()
        grad_w = w_grad.reshape(-1, 1)
        for index, label in enumerate(labeled_weight):
            dw[label][0] += grad_w[index]
        return dw

    def scalar_quantization(self, module, grad_input, grad_output):
        if isinstance(module, nn.Linear):
            labeled_weight = module.labeled_weight
            centroids = module.centroids
            dw = np.zeros(shape=centroids.shape, dtype=np.float32)
            dw = self.get_centroids_gradients(
                grad_input, labeled_weight, dw, grad_output)
            module.centroids = centroids - (self.lr * dw)


class NET(nn.Module):
    def __init__(self):
        super(NET, self).__init__()

    @classmethod
    def copy(cls, source, **kw):
        instance = cls(**kw)
        for name in dir(source):
            if not (name is 'forward' or name.startswith("__")):
                value = getattr(source, name)
                setattr(instance, name, value)
        return instance

    def forward(self, x):
        x = self.model.features(x)
        x = x.view(x.size(0), -1)
        for layer, (name, module) in enumerate(self.model.classifier._modules.items()):
            if isinstance(module, torch.nn.modules.Linear):
                weight_shape = module.weight.shape
                new_weight = self.get_converted_weight(
                    labeled_weight=module.weight, centroids=module.centroids)
                new_weight = new_weight.reshape(
                    weight_shape[0], weight_shape[1])
                new_weight = Variable(
                    torch.from_numpy(new_weight).float())
                if x.dim() == 2 and module.bias is not None:
                    return torch.addmm(module.bias, x, new_weight.t())
                output = x.matmul(new_weight.t())
                if module.bias is not None:
                    output += module.bias
                x = output
        return x


  if not (name is 'forward' or name.startswith("__")):


In [None]:
import torch
from torchvision.transforms import transforms
import torchvision
import torch.utils as utils

compressor = DeepCompressor("l-lenet.pth", test_data=test_data, train_data=train_data, k=32, lr=0.001)
model = compressor.train(epoches=10)
print(model)
torch.save(model, "DC-lenet.pth")
net = NET.copy(model)
torch.save(net, "DC-new-forward-lenet.pth")