In [1]:
import torch
import torchvision

In [2]:
import torch.nn as nn

class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 1000),
            nn.ReLU(),
            nn.Linear(1000, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1000),
            nn.ReLU(),
            nn.Linear(1000, 28*28),
            nn.Sigmoid()
        )

    def forward(self, flat):
        return self.model(flat)

class AutoencoderBig(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 2048),
            nn.ReLU(),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, 28*28),
            nn.Sigmoid()
        )

    def forward(self, flat):
        return self.model(flat)

In [3]:
DEVICE = "cpu"

In [4]:
model = Autoencoder()
model.load_state_dict(torch.load('models/ae-1.pth', map_location=DEVICE))

<All keys matched successfully>

In [5]:
model

Autoencoder(
  (model): Sequential(
    (0): Linear(in_features=784, out_features=1000, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1000, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): ReLU()
    (6): Linear(in_features=256, out_features=1000, bias=True)
    (7): ReLU()
    (8): Linear(in_features=1000, out_features=784, bias=True)
    (9): Sigmoid()
  )
)

In [12]:
from sklearn.cluster import MiniBatchKMeans
import torch
import tqdm


class QuantizedParams(torch.nn.Module):
    def __init__(self, indexes, codebook):
        super().__init__()
        self.indexes = torch.nn.Parameter(indexes, requires_grad=False)
        self.codebook = torch.nn.Parameter(codebook, requires_grad=False)

    def forward(self):
        return self.codebook[self.indexes.to(torch.int32)]


class RegularParams(torch.nn.Module):
    def __init__(self, weights):
        super().__init__()
        self.weights = torch.nn.Parameter(weights, requires_grad=False)

    def forward(self):
        return self.weights


class HybridLinear(torch.nn.Module):
    def __init__(self, q_w, q_b):
        """each are callable modules that return the weights and biases"""
        super().__init__()
        self.weight = q_w
        self.bias = q_b

    def forward(self, X):
        return X @ self.weight().T + self.bias()


def k_means(X, k=2):
    kmeans = MiniBatchKMeans(n_clusters=k, random_state=0, n_init="auto").fit(X)
    return kmeans.cluster_centers_, kmeans.labels_


def quantize(weights, k=2, dtype=torch.float32):
    centroids, labels = k_means(weights.reshape(-1, 1), k)
    codebook = torch.tensor(centroids, dtype=dtype).reshape(-1)
    new_weights = torch.tensor(labels, dtype=torch.uint8).reshape(weights.shape)
    return codebook, new_weights


@torch.no_grad()
def quantize_linear_layer(layer, bits=8, dtype=torch.float32):
    weight = layer.weight
    bias = layer.bias

    num_weights = weight.view(-1, 1).shape[0]
    num_biases = 1 if bias is None else bias.view(-1, 1).shape[0]

    # apply k-means if there are enough parameters
    total_bits = int(2**bits)
    new_weight = None
    new_bias = None
    if num_weights > total_bits:
        w_codebook, w_indexes = quantize(weight, total_bits, dtype)
        new_weight = QuantizedParams(w_indexes, w_codebook)
    else:
        # no quantization :(
        new_weight = RegularParams(weight)

    if bias is not None and num_biases > total_bits:
        b_codebook, b_indexes = quantize(bias, total_bits, dtype)
        new_bias = QuantizedParams(b_indexes, b_codebook)
    else:
        # no quantization :(
        new_bias = RegularParams(
            bias if bias is not None else torch.tensor(0.0, dtype=dtype)
        )

    return HybridLinear(new_weight, new_bias)

def traverse_named_modules(m, filter="Linear"):
    for name, l in m.named_modules():
        if type(l).__name__ == filter:
            sep_name = name.split(".")
            parent = ".".join(sep_name[:-1])
            child = sep_name[-1]
            yield parent, child, l


def traverse_named_modules(m, filter="Linear"):
    for name, l in m.named_modules():
        if type(l).__name__ == filter:
            sep_name = name.split(".")
            parent = ".".join(sep_name[:-1])
            child = sep_name[-1]
            yield parent, child, l


def replace(model, filter, callback):
    modules = list(traverse_named_modules(model, filter))
    for p, c, l in tqdm.tqdm(modules, total=len(modules)):
        setattr(model.get_submodule(p), c, callback(p, c, l))


def replace_linear_with_quantized(model, bits=8, dtype=torch.float16):
    replace(
        model,
        "Linear",
        lambda p, c, l: quantize_linear_layer(l, bits=bits, dtype=dtype),
    )

In [13]:
replace_linear_with_quantized(model, bits=8, dtype=torch.float32)

100%|██████████| 5/5 [00:01<00:00,  2.99it/s]


In [6]:
model

Autoencoder(
  (model): Sequential(
    (0): Linear(in_features=784, out_features=1000, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1000, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): ReLU()
    (6): Linear(in_features=256, out_features=1000, bias=True)
    (7): ReLU()
    (8): Linear(in_features=1000, out_features=784, bias=True)
    (9): Sigmoid()
  )
)

In [17]:
torch.save(model.state_dict(), 'models/ae-1-quantized.pth')

In [8]:
def export(model, filename):
	import json
	with open(filename, "w") as f:
		new_json = {}
		for k, v in model.state_dict().items():
			new_json[k] = {"data": v.reshape(-1).tolist(), "shape": v.shape}
		json.dump(new_json, f)

In [9]:
export(model, "models/ae-1.json")