In [2]:
from hashlib import new
import math
import torch.nn as nn
import torch
import torch.nn.functional as F
from typing import Callable
from enum import Enum

In [3]:
class QuantizationMode(Enum):
    one_bit = 1
    two_bit = 2
    
def compute_adjustment_factor(input_tensor: torch.Tensor):
    absmean_weight = torch.mean(torch.abs(input_tensor))
    adjustment_factor = 1e-4 + absmean_weight / 2 # 1e-4 to avoid zero divison error
    return adjustment_factor

quantization_mode = QuantizationMode.two_bit
def compute_2bit_quantized_tensor(input_tensor: torch.Tensor):
    twobit_matrix = torch.clip(input=torch.round(input_tensor), min=-1, max=1)
    return twobit_matrix

def compute_1bit_quantized_tensor(input_tensor: torch.Tensor):
    return torch.sign(input_tensor)

def compute_quantized_tensor(input_tensor: torch.Tensor):
    if quantization_mode == QuantizationMode.two_bit:
        return compute_2bit_quantized_tensor(input_tensor)
    else:
        return compute_1bit_quantized_tensor(input_tensor)
    

In [4]:
# generate random matrix 5x5 in pytorch
torch.manual_seed(42)
input_tensor = torch.randn(5, 5)
print(compute_adjustment_factor(input_tensor))

weight = torch.randn(5, 5)

tensor(0.4495)


In [5]:
weight_adjustment_factor = compute_adjustment_factor(weight)
adjusted_weight = weight / weight_adjustment_factor
quantized_weight = compute_quantized_tensor(adjusted_weight)


# print weight, adjusted weight and quantized weight
print(weight)
print("----")
print(quantized_weight)

tensor([[ 0.3211,  1.5736, -0.8455,  1.3123,  0.6872],
        [-1.0892, -0.3553, -0.9138,  0.8963,  2.2181],
        [ 0.5232,  0.3466, -0.1973, -1.0546,  1.2780],
        [ 0.1453,  0.2311,  0.0566,  0.4263,  0.5750],
        [-0.6417, -2.2064, -0.7508,  2.8140,  0.3598]])
----
tensor([[ 1.,  1., -1.,  1.,  1.],
        [-1., -1., -1.,  1.,  1.],
        [ 1.,  1., -0., -1.,  1.],
        [ 0.,  1.,  0.,  1.,  1.],
        [-1., -1., -1.,  1.,  1.]])


In [6]:
class BitNetLinearLayer(nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        bias=False,
        quantization_mode: QuantizationMode = QuantizationMode.two_bit,
    ):
        super(BitNetLinearLayer, self).__init__()
        self.binary_layer = True
        self.in_features = in_features
        self.out_features = out_features

        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias = (
            nn.Parameter(torch.Tensor(out_features)) if bias is not None else None
        )
        self.quantization_mode = quantization_mode

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def compute_adjustment_factor(self, input_tensor: torch.Tensor):
        absmean_weight = torch.mean(torch.abs(input_tensor))
        adjustment_factor = 1e-4 + absmean_weight * 2 + 1e-4
        return adjustment_factor

    def compute_2bit_quantized_tensor(self, input_tensor: torch.Tensor):
        twobit_matrix = torch.clip(input=torch.round(input_tensor), min=-1, max=1)
        return twobit_matrix

    def compute_1bit_quantized_tensor(self, input_tensor: torch.Tensor):
        return torch.sign(input_tensor)

    def compute_quantized_tensor(self, input_tensor: torch.Tensor):
        if self.quantization_mode == QuantizationMode.two_bit:
            return self.compute_2bit_quantized_tensor(input_tensor)
        else:
            return self.compute_1bit_quantized_tensor(input_tensor)

    def forward(self, x):
        weight_adjustment_factor = self.compute_adjustment_factor(self.weight)
        adjusted_weight = self.weight / weight_adjustment_factor

        if self.training:
            quantized_weight = (
                adjusted_weight
                + (
                    self.compute_quantized_tensor(adjusted_weight) - adjusted_weight
                ).detach()
            )
        else:
            quantized_weight = self.compute_quantized_tensor(adjusted_weight)

        return F.linear(weight_adjustment_factor * x, quantized_weight, self.bias)


In [7]:
import copy

def create_quantized_copy_of_model(
    input_model: nn.Module, quantization_mode: QuantizationMode
):
    model_copy = copy.deepcopy(input_model)
    hash_table = {n: m for n, m in model_copy.named_modules()}

    for key in list(hash_table.keys()):
        if isinstance(hash_table[key], nn.Linear):
            new_module = BitNetLinearLayer(
                in_features=hash_table[key].in_features,
                out_features=hash_table[key].out_features,
                bias=hash_table[key].bias is not None,
                quantization_mode=quantization_mode,
            )
            name_chain = key.split(".")
            parent_module_attr_name = ".".join(name_chain[:-1])
            parent_module = hash_table[parent_module_attr_name]
            setattr(parent_module, name_chain[-1], new_module)
    for n, m in model_copy.named_modules():
        assert not isinstance(m, nn.Linear)
    return model_copy

In [None]:
from datasets import load_dataset
import lightning as L
from transformers.models.vit.configuration_vit import ViTConfig
from transformers.models.vit.modeling_vit import ViTModel, ViTForImageClassification


import torch
from torchvision import transforms
from torch.utils.data import DataLoader

config = ViTConfig(
    hidden_size=128,
    num_hidden_layers=6,
    num_attention_heads=4,
    intermediate_size=256,
    hidden_act="gelu",
    image_size=28,
    patch_size=4,
    num_labels=10,
    num_channels=1,
)


class ViTImageClassifier(L.LightningModule):
    def __init__(self, config: ViTConfig, lr=1e-3):
        super().__init__()
        self.model = ViTForImageClassification(config)
        self.config = config
        self.lr = lr

    def forward(self, batch):
        return self.model(**batch)

    def training_step(self, batch, batch_idx):
        output = self(batch)
        loss = output.loss
        argmax = output.logits.argmax(dim=1)
        accuracy = (argmax == batch["labels"]).float().mean()
        self.log_dict(
            {
                "tl": loss.item(),
                "ta": accuracy.item(),
            },
            prog_bar=True,
            on_step=True,
            on_epoch=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            output = self(batch)
            loss = output.loss
            argmax = output.logits.argmax(dim=1)
            accuracy = (argmax == batch["labels"]).float().mean()

        self.log_dict(
            {
                "vl": loss.item(),
                "va": accuracy.item(),
            },
            prog_bar=True,
            on_step=True,
            on_epoch=True,
        )
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)


dataset = load_dataset("fashion_mnist")

image_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5)),
    ]
)

processed_dataset = dataset.map(
    lambda x: {"pixel_values": image_transforms(x["image"]), "labels": x["label"]}
)
processed_dataset = processed_dataset.remove_columns(["label", "image"])
processed_dataset.set_format("torch", columns=["pixel_values", "labels"])


train_dataloader = DataLoader(processed_dataset["train"], batch_size=128)
eval_dataloader = DataLoader(processed_dataset["test"], batch_size=128)

normal_model = ViTImageClassifier(config)
one_bit_quantized_model = create_quantized_copy_of_model(
    normal_model, quantization_mode=QuantizationMode.one_bit
)
two_bit_quantized_model = create_quantized_copy_of_model(
    normal_model, quantization_mode=QuantizationMode.two_bit
)

from lightning.pytorch.loggers import WandbLogger



In [None]:
from lightning.pytorch.loggers import WandbLogger


# normal_f_mnist
normal_logger = WandbLogger(project="BitNet", name="normal_f_mnist")
normal_trainer = L.Trainer(
    max_epochs=10,
    logger=normal_logger,
)
normal_trainer.fit(
    normal_model,
    train_dataloaders=train_dataloader,
    val_dataloaders=eval_dataloader,
)


# one_bit_f_mnist

one_bit_logger = WandbLogger(project="BitNet", name="one_bit_f_mnist")
one_bit_trainer = L.Trainer(
    max_epochs=10,
    logger=one_bit_logger,
)
one_bit_quantized_model.lr = 1e-4
one_bit_trainer.fit(
    one_bit_quantized_model,
    train_dataloaders=train_dataloader,
    val_dataloaders=eval_dataloader,
)
