In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Lenet(nn.Module):
    def __init__(self):
        super().__init__()
        self.c1 = nn.Conv2d(1, 6, kernel_size=5)
        self.s2 = nn.MaxPool2d(kernel_size=2)
        self.c3 = nn.Conv2d(6, 16, kernel_size=5)
        self.s4 = nn.MaxPool2d(kernel_size=2)
        self.f5 = nn.Linear(16 * 5 * 5, 120)
        self.f6 = nn.Linear(120, 84)
        self.f7 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.c1(x)
        x = F.relu(x)
        x = self.s2(x)
        x = self.c3(x)
        x = F.relu(x)
        x = self.s4(x)
        x = torch.flatten(x, start_dim=1)
        x = self.f5(x)
        x = F.relu(x)
        x = self.f6(x)
        x = F.relu(x)
        x = self.f7(x)

        return x


lenet = Lenet()

In [2]:
import os
import torchvision.transforms as T
import torchvision.datasets as datasets
from torch.utils.data import random_split, DataLoader

NUM_WORKERS = min(4, os.cpu_count())
BATCH_SIZE = 64

transform = T.Compose(
    [
        T.ToTensor(),
        T.Normalize((0.1307,), (0.3081,)),
        T.Pad(2),
    ]
)
train_set = datasets.MNIST("tmp/data", train=True, download=True, transform=transform)
test_set = datasets.MNIST("tmp/data", train=False, download=True, transform=transform)

train_size = int(0.8 * len(train_set))
val_size = len(train_set) - train_size

train_set, val_set = random_split(train_set, [train_size, val_size])

train_loader = DataLoader(train_set, BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_set, BATCH_SIZE, num_workers=NUM_WORKERS)
exp_loader = DataLoader(test_set, BATCH_SIZE, num_workers=NUM_WORKERS)

In [3]:
def calculate_accuracy(model: nn.Module, loader: DataLoader) -> float:
    model.eval()
    total = correct = 0

    with torch.no_grad():
        for input, target in loader:
            output = model(input)
            pred = torch.argmax(output, dim=1)
            total += target.size(0)
            correct += (pred == target).sum().item()

    return correct / total

In [4]:
import torch.optim as optim

EPOCHS = 3
LEARNING_RATE = 0.01

optimizer = optim.Adam(lenet.parameters(), lr=LEARNING_RATE)

lenet.train()
for epoch in range(EPOCHS):
    for batch_idx, (input, target) in enumerate(train_loader):
        optimizer.zero_grad()

        foutput = lenet(input)
        loss = F.cross_entropy(foutput, target)

        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(
                f"[TRAINING] epoch={epoch + 1} batch={batch_idx} loss={loss.item():.2f}"
            )


print(f"[TESTING] accuracy={(100 * calculate_accuracy(lenet, exp_loader)):.2f}%")

[TRAINING] epoch=1 batch=0 loss=2.31
[TRAINING] epoch=1 batch=100 loss=0.31
[TRAINING] epoch=1 batch=200 loss=0.22
[TRAINING] epoch=1 batch=300 loss=0.13
[TRAINING] epoch=1 batch=400 loss=0.18
[TRAINING] epoch=1 batch=500 loss=0.32
[TRAINING] epoch=1 batch=600 loss=0.06
[TRAINING] epoch=1 batch=700 loss=0.07
[TRAINING] epoch=2 batch=0 loss=0.03
[TRAINING] epoch=2 batch=100 loss=0.11
[TRAINING] epoch=2 batch=200 loss=0.06
[TRAINING] epoch=2 batch=300 loss=0.19
[TRAINING] epoch=2 batch=400 loss=0.13
[TRAINING] epoch=2 batch=500 loss=0.03
[TRAINING] epoch=2 batch=600 loss=0.26
[TRAINING] epoch=2 batch=700 loss=0.32
[TRAINING] epoch=3 batch=0 loss=0.17
[TRAINING] epoch=3 batch=100 loss=0.07
[TRAINING] epoch=3 batch=200 loss=0.12
[TRAINING] epoch=3 batch=300 loss=0.14
[TRAINING] epoch=3 batch=400 loss=0.02
[TRAINING] epoch=3 batch=500 loss=0.12
[TRAINING] epoch=3 batch=600 loss=0.14
[TRAINING] epoch=3 batch=700 loss=0.14
[TESTING] accuracy=96.63%


In [5]:
LENET_PATH = "lenet.pt"

torch.save(lenet.state_dict(), LENET_PATH)

In [6]:
from torch.quantization import QuantStub, DeQuantStub


class QLenet(Lenet):
    def __init__(self):
        super().__init__()
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.quant(x)
        x = super().forward(x)
        x = self.dequant(x)

        return x


qlenet = QLenet()

qlenet.load_state_dict(torch.load(LENET_PATH, weights_only=True))

<All keys matched successfully>

In [7]:
import torch.quantization as quantization

BACKEND = "fbgemm"

qlenet.qconfig = quantization.get_default_qconfig(BACKEND)
quantization.prepare(qlenet, inplace=True)

qlenet.eval()
with torch.no_grad():
    for batch_idx, (input, _) in enumerate(val_loader):
        qlenet(input)
        print(f"[CALIBRATING] batch={batch_idx}")

quantization.convert(qlenet, inplace=True)

quant_accuracy = 100 * calculate_accuracy(qlenet, exp_loader)
print(f"[TESTING] accuracy={quant_accuracy:.2f}%")



[CALIBRATING] batch=0
[CALIBRATING] batch=1
[CALIBRATING] batch=2
[CALIBRATING] batch=3
[CALIBRATING] batch=4
[CALIBRATING] batch=5
[CALIBRATING] batch=6
[CALIBRATING] batch=7
[CALIBRATING] batch=8
[CALIBRATING] batch=9
[CALIBRATING] batch=10
[CALIBRATING] batch=11
[CALIBRATING] batch=12
[CALIBRATING] batch=13
[CALIBRATING] batch=14
[CALIBRATING] batch=15
[CALIBRATING] batch=16
[CALIBRATING] batch=17
[CALIBRATING] batch=18
[CALIBRATING] batch=19
[CALIBRATING] batch=20
[CALIBRATING] batch=21
[CALIBRATING] batch=22
[CALIBRATING] batch=23
[CALIBRATING] batch=24
[CALIBRATING] batch=25
[CALIBRATING] batch=26
[CALIBRATING] batch=27
[CALIBRATING] batch=28
[CALIBRATING] batch=29
[CALIBRATING] batch=30
[CALIBRATING] batch=31
[CALIBRATING] batch=32
[CALIBRATING] batch=33
[CALIBRATING] batch=34
[CALIBRATING] batch=35
[CALIBRATING] batch=36
[CALIBRATING] batch=37
[CALIBRATING] batch=38
[CALIBRATING] batch=39
[CALIBRATING] batch=40
[CALIBRATING] batch=41
[CALIBRATING] batch=42
[CALIBRATING] batch=4

In [8]:
QLENET_PATH = "qlenet.pt"

torch.save(qlenet.state_dict(), QLENET_PATH)

In [9]:
class FLenet(QLenet):
    pass


flenet = FLenet()

flenet.qconfig = quantization.get_default_qconfig(BACKEND)
quantization.prepare(flenet, inplace=True)
quantization.convert(flenet, inplace=True)

flenet.load_state_dict(torch.load(QLENET_PATH, weights_only=True))

  device=storage.device,


<All keys matched successfully>

In [19]:
from typing import Dict, Any, Tuple
import random


def get_bit(byte: int, pos: int) -> int:
    return (byte >> pos) & 1


def flip_bit(byte: int, pos: int) -> int:
    flipped = byte ^ (1 << pos)

    return max(-128, min(127, flipped))


def randlayer(sd: Dict[str, Any]) -> torch.Tensor:
    layers = [
        k
        for k in sd.keys()
        if k.endswith((".weight", "._packed_params._packed_params"))
    ]

    return random.choice(layers)

def get_tensor_shape(param):
    if isinstance(param, torch.Tensor):
        return param.shape
    elif isinstance(param, tuple) and len(param) > 0:
        # Handle packed parameters
        return param[0].shape
    raise ValueError(f"Unsupported parameter type: {type(param)}")

def randidx(param) -> Tuple[int, ...]:
    shape = get_tensor_shape(param)
    return tuple(random.randint(0, s - 1) for s in shape)

def get_tensor_at_idx(param, idx):
    if isinstance(param, torch.Tensor):
        return param[idx]
    elif isinstance(param, tuple) and len(param) > 0:
        # Handle packed parameters
        return param[0][idx]
    raise ValueError(f"Unsupported parameter type: {type(param)}")

def set_tensor_at_idx(param, idx, value):
    if isinstance(param, torch.Tensor):
        param[idx] = value
    elif isinstance(param, tuple) and len(param) > 0:
        # Handle packed parameters
        param[0][idx] = value
    else:
        raise ValueError(f"Unsupported parameter type: {type(param)}")


def randbitpos() -> int:
    return random.randint(0, 8)


def dequant_int8(val: int, scale: float, zero_point: int) -> float:
    return float(val - zero_point) * scale

In [11]:
exp_loader = DataLoader(test_set, batch_size=1, num_workers=NUM_WORKERS)

In [None]:
import copy
from torch import quantize_per_tensor
import pandas as pd

results = []
orig_sd = flenet.state_dict()

for batch_idx, (input, target) in enumerate(exp_loader):
    sd = copy.deepcopy(orig_sd)
    output = torch.argmax(flenet(input), dim=1)
    layer = randlayer(sd)
    param = sd[layer]
    idx = randidx(param)
    tnsr = get_tensor_at_idx(param, idx)
    intval = tnsr.int_repr().item()
    pos = randbitpos()
    orig_bit = get_bit(intval, pos)
    fintval = flip_bit(intval, pos)
    fbit = get_bit(fintval, pos)
    ffval = dequant_int8(fintval, tnsr.q_scale(), tnsr.q_zero_point())
    fftnsr = torch.tensor(ffval, dtype=torch.float32)
    qfftnsr = quantize_per_tensor(
        fftnsr, tnsr.q_scale(), tnsr.q_zero_point(), torch.qint8
    )
    set_tensor_at_idx(param, idx, qfftnsr)
    flenet.load_state_dict(sd)
    foutput = torch.argmax(flenet(input), dim=1)

    results.append(
        {
            "image_id": batch_idx,
            "layer": layer,
            "bit_position": pos,
            "original_bit": orig_bit,
            "flipped_bit": fbit,
            "original_value": intval,
            "flipped_value": fintval,
            "original_output": output.item(),
            "flipped_output": foutput.item(),
            "target": target.item(),
        }
    )

    flenet.load_state_dict(orig_sd)

df = pd.DataFrame(results)
df.to_csv("results.csv", index=False)

In [24]:
df[df["bit_position"] == 8]


# errors_df = df[df["original_output"] != df["flipped_output"]]

# errors_df

Unnamed: 0,image_id,layer,bit_position,original_bit,flipped_bit,original_value,flipped_value,original_output,flipped_output,target
0,0,c3.weight,8,1,1,-5,-128,7,7,7
9,9,f7._packed_params._packed_params,8,1,1,-4,-128,9,9,9
17,17,f5._packed_params._packed_params,8,1,1,-22,-128,7,7,7
22,22,f7._packed_params._packed_params,8,1,1,-1,-128,6,6,6
40,40,f7._packed_params._packed_params,8,1,1,-35,-128,1,1,1
...,...,...,...,...,...,...,...,...,...,...
9937,9937,f7._packed_params._packed_params,8,1,1,-24,-128,2,2,2
9955,9955,c1.weight,8,0,0,17,127,1,1,1
9962,9962,f6._packed_params._packed_params,8,1,1,-19,-128,0,0,0
9970,9970,f6._packed_params._packed_params,8,0,0,41,127,5,5,5
