# Import & Parameter

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import os
import time
from enum import Enum
import math
import torch.ao.quantization as tq
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.nn.quantized as nnq
import pyverilator                     # pip install pyverilator

# 參數
num_classes = 100
input_shape = (32, 32, 3)
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
image_size = 72
patch_size = 6
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [projection_dim * 2, projection_dim]
transformer_layers = 8
mlp_head_units = [2048, 1024]

MODEL_PATH = "./weights/vit.pt"
QUANT_MODEL_PATH = "./weights/quant_vit.pt"

DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Is cuda avaliable: " , torch.cuda.is_available())

: 

In [2]:
def get_cifar_loaders(batch_size, root="data/cifar100", split_ratio=0.1, image_size=72):
    # train 資料增強
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(2),
        transforms.RandomAffine(0, translate=(0.2, 0.2)),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),  # CIFAR-100 平均值
                             (0.2675, 0.2565, 0.2761)), # CIFAR-100 標準差
    ])
    # val/test 無增強
    eval_transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),
                             (0.2675, 0.2565, 0.2761)),
    ])
    trainset = datasets.CIFAR10(root=root, train=True,  download=True, transform=transform)
    testset  = datasets.CIFAR10(root=root, train=False, download=True, transform=eval_transform)
    # 切分出小比例當 val
    val_len   = int(split_ratio * len(trainset))
    train_len = len(trainset) - val_len
    trainset, valset = random_split(trainset, [train_len, val_len])
    return (
        DataLoader(trainset, batch_size=batch_size, shuffle=True,  num_workers=2),
        DataLoader(valset,   batch_size=batch_size, shuffle=False, num_workers=2),
        DataLoader(testset,  batch_size=batch_size, shuffle=False, num_workers=2),
    )

In [3]:
class MLP(nn.Module):
    def __init__(self, hidden_units, dropout_rate):
        super().__init__()
        layers = []
        for units in hidden_units:
            layers.append(nn.Linear(units[0], units[1]))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


In [None]:
def extract_mlp_params(pt_path,
                       w1_key="mlp.0.weight", b1_key="mlp.0.bias",
                       w2_key="mlp.3.weight", b2_key="mlp.3.bias"):
    """
    從 .pt 檔撈兩層 Linear 的權重/偏差，回傳 tuple:
    (w1, b1, w2, b2)，皆為 numpy array (float32)
    """
    sd = torch.load(pt_path, map_location="cpu")        # state_dict 或整個 nn.Module
    if isinstance(sd, torch.nn.Module):
        sd = sd.state_dict()

    w1 = sd[w1_key].cpu().numpy().astype(np.float32)    # shape [out1, in1]
    b1 = sd[b1_key].cpu().numpy().astype(np.float32)    # shape [out1]
    w2 = sd[w2_key].cpu().numpy().astype(np.float32)    # shape [out2, in2]
    b2 = sd[b2_key].cpu().numpy().astype(np.float32)    # shape [out2]

    return w1, b1, w2, b2

def quantize_w_int8(W):
    s = np.max(np.abs(W))
    scale = s / 127.0 if s != 0 else 1.0
    q = np.round(W / scale).astype(np.int8)
    return q, scale

def quantize_b_int32(b, in_scale, w_scale):
    scale = in_scale * w_scale
    q = np.round(b / scale).astype(np.int32)
    return q

def pack_int8_to_words(arr_int8):
    """
    arr_int8: 1D np.int8
    回傳 List[int32]，每 4 個 byte 拼 1 word, 低位在前 (little-endian)。
    """
    assert arr_int8.ndim == 1 and len(arr_int8) % 4 == 0
    words = []
    it = arr_int8.astype(np.uint8)        # 先轉成 0~255
    for i in range(0, len(it), 4):
        w =  int(it[i+0])        \
          | (int(it[i+1]) <<  8) \
          | (int(it[i+2]) << 16) \
          | (int(it[i+3]) << 24)
        words.append(w)
    return words


## 硬體MLP模擬

In [None]:
# --------------------------------- hardware_mlp.py ----------------------------
import numpy as np
import torch
import torch.nn as nn
import pyverilator                       # pip install pyverilator

# -------- 量化工具 ------------------------------------------------------------
def quantize_w_int8(w: np.ndarray):
    s = np.max(np.abs(w))
    scale = s / 127. if s != 0 else 1.
    return np.round(w / scale).astype(np.int8), scale

def quantize_b_int32(b: np.ndarray, in_scale: float, w_scale: float):
    scale = in_scale * w_scale
    return np.round(b / scale).astype(np.int32)

def pack_int8_to_words(vec: np.ndarray):
    """每 4 個 int8 → 1 個 little-endian int32"""
    assert vec.ndim == 1 and len(vec) % 4 == 0
    words = []
    u = vec.astype(np.uint8)
    for i in range(0, len(u), 4):
        words.append(int(u[i]) | int(u[i+1]) << 8 |
                     int(u[i+2]) << 16 | int(u[i+3]) << 24)
    return words

# -------- 權重/偏差抽取 --------------------------------------------------------
def extract_mlp_params(pt_path, keys):
    """keys = (w1, b1, w2, b2)"""
    sd = torch.load(pt_path, map_location="cpu")
    if isinstance(sd, torch.nn.Module):
        sd = sd.state_dict()

    w1 = sd[keys[0]].cpu().numpy().astype(np.float32)
    b1 = sd[keys[1]].cpu().numpy().astype(np.float32)
    w2 = sd[keys[2]].cpu().numpy().astype(np.float32)
    b2 = sd[keys[3]].cpu().numpy().astype(np.float32)
    return w1, b1, w2, b2

# -------- PyVerilator Wrapper  -------------------------------------------------
class HardwareMLP(nn.Module):
    """
    Two-layer int8 MLP implemented in Top.sv.
    * feature_dim 必須和 ViT projection_dim 相同 (預設 64)
    """
    def __init__(self,
                 top_rtl="Top.sv",
                 feature_dim=64,
                 input_scale=0.008,      # 根據校準結果填
                 output_scale=0.008,
                 scaling_factor=32):     # Top.sv PPU 右移 5bit (= 32)
        super().__init__()
        self.feature_dim   = feature_dim
        self.input_scale   = input_scale
        self.output_scale  = output_scale
        self.scaling_factor = scaling_factor

        self.sim = pyverilator.PyVerilator.build(top_rtl, trace_en=False)
        self._hw_reset()

    # ---------------- low-level ----------------
    def _tick(self, n=1):
        for _ in range(n):
            self.sim.clock.tick()

    def _hw_reset(self):
        self.sim.io.rst = 1; self._tick(); self.sim.io.rst = 0

    def _write_stream_to_sram(self, mode: int, stream):
        """依 tb.sv handshake：ready↑→送 data_in+i_en"""
        self.sim.io.mode  = mode
        self.sim.io.ready = 1
        self._tick(); self.sim.io.ready = 0
        for wd in stream:
            self.sim.io.data_in = wd
            self.sim.io.i_en    = 1
            self._tick()
        self.sim.io.i_en = 0

    # ---------------- Public API ----------------
    def load_from_pt(self, pt_path, pt_keys):
        """
        1. 從 .pt 擷取 Linear1/2 權重 2. 量化 3. 寫入 SRAM
        pt_keys = (w1, b1, w2, b2)
        """
        w1, b1, w2, b2 = extract_mlp_params(pt_path, pt_keys)

        w1_q, s_w1 = quantize_w_int8(w1)
        w2_q, s_w2 = quantize_w_int8(w2)
        b1_q = quantize_b_int32(b1, self.input_scale, s_w1)
        b2_q = quantize_b_int32(b2, s_w1,          s_w2)

        # 打包成 32-bit stream
        w1_stream = pack_int8_to_words(w1_q.flatten())
        w2_stream = pack_int8_to_words(w2_q.flatten())
        b1_stream = [int(v & 0xFFFFFFFF) for v in b1_q]
        b2_stream = [int(v & 0xFFFFFFFF) for v in b2_q]

        # 依 RTL mode 寫入：2=W1 3=B1 4=W2 5=B2
        self._hw_reset()
        self._write_stream_to_sram(2, w1_stream)
        self._write_stream_to_sram(3, b1_stream)
        self._write_stream_to_sram(4, w2_stream)
        self._write_stream_to_sram(5, b2_stream)
        self.sim.io.mode = 0

    # ---------------- forward ----------------
    def forward(self, x: torch.Tensor):
        """
        x : [batch, feature_dim] float32
        回傳同形狀 float32 tensor
        """
        if x.dim() == 1:
            x = x.unsqueeze(0)
        batch = x.shape[0]
        outs = []

        for i in range(batch):
            vec = x[i]
            # --- quantize ---
            xi8 = np.round(vec.numpy() / self.input_scale) \
                                                       .clip(-128, 127).astype(np.int8)
            # --- handshake: ready↑ (1clk) → 送 ifmap 16 words → wait done ---
            self.sim.io.mode           = 0
            self.sim.io.scaling_factor = self.scaling_factor
            self.sim.io.ready = 1; self._tick(); self.sim.io.ready = 0

            for j in range(0, self.feature_dim, 4):
                wd = (int(xi8[j  ]) & 0xFF)       | (int(xi8[j+1]) & 0xFF) << 8 | \
                     (int(xi8[j+2]) & 0xFF) << 16 | (int(xi8[j+3]) & 0xFF) << 24
                self.sim.io.data_in = wd
                self._tick()

            yo = []
            while len(yo) < self.feature_dim:
                self._tick()
                if int(self.sim.io.valid):
                    yo.append(self.sim.io.ofmap & 0xFF)
            y_f = torch.tensor(yo, dtype=torch.float32) * self.output_scale
            outs.append(y_f)
        return torch.stack(outs, 0)


In [4]:
class Patches(nn.Module):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def forward(self, images):
        # images: [B, C, H, W]
        B, C, H, W = images.shape
        ph, pw = self.patch_size, self.patch_size
        assert H % ph == 0 and W % pw == 0
        # 轉成 [B, num_patches, patch_dim]
        patches = images.unfold(2, ph, ph).unfold(3, pw, pw)  # [B, C, nph, npw, ph, pw]
        patches = patches.permute(0, 2, 3, 1, 4, 5)  # [B, nph, npw, C, ph, pw]
        patches = patches.reshape(B, -1, C * ph * pw)  # [B, num_patches, patch_dim]
        return patches

class PatchEncoder(nn.Module):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.proj = nn.Linear(patch_size * patch_size * 3, projection_dim)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, projection_dim))
        self.adder   = nnq.FloatFunctional()
        self.quant   = tq.QuantStub()
        self.dequant = tq.DeQuantStub()
        
    def forward(self, x):
        # 1) 如果是量化張量，反量化到 float
        if x.is_quantized:
            x = self.dequant(x)
        # 2) 在 float 做投影 + 位置加法
        x = self.proj(x) + self.pos_embed
        # 3) 再一次量化
        return self.quant(x)


In [5]:
class TransformerBlock(nn.Module):
    def __init__(self, projection_dim, num_heads, transformer_units, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(projection_dim)
        self.attn = nn.MultiheadAttention(projection_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(projection_dim)
        self.mlp = nn.Sequential(
            nn.Linear(projection_dim, transformer_units[0]),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(transformer_units[0], transformer_units[1]),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # x: [B, num_patches, projection_dim]
        x1 = self.norm1(x)
        attn_output, _ = self.attn(x1, x1, x1)
        x2 = attn_output + x
        x3 = self.norm2(x2)
        x3 = self.mlp(x3)
        return x3 + x2



In [6]:
class ViT(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, projection_dim, num_patches, num_heads, transformer_units, transformer_layers, mlp_head_units):
        super().__init__()
        self.patches = Patches(patch_size)
        self.encoder = PatchEncoder(num_patches, projection_dim)
        self.transformer_layers = nn.ModuleList([
            TransformerBlock(projection_dim, num_heads, transformer_units)
            for _ in range(transformer_layers)
        ])
        self.norm = nn.LayerNorm(projection_dim)
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(0.5)
        mlp_layers = []
        in_dim = num_patches * projection_dim
        for out_dim in mlp_head_units:
            mlp_layers.append(nn.Linear(in_dim, out_dim))
            mlp_layers.append(nn.GELU())
            mlp_layers.append(nn.Dropout(0.5))
            in_dim = out_dim
        self.mlp_head = nn.Sequential(*mlp_layers)
        self.classifier = nn.Linear(in_dim, num_classes)

    def forward(self, x):
        # x: [B, 3, H, W]
        x = self.patches(x)  # [B, num_patches, patch_dim]
        x = self.encoder(x)  # [B, num_patches, projection_dim]
        for block in self.transformer_layers:
            x = block(x)
        x = self.norm(x)
        x = self.flatten(x)
        x = self.dropout(x)
        x = self.mlp_head(x)
        logits = self.classifier(x)
        return logits


## Utility

In [7]:
def evaluate(model, loader, criterion, device=DEFAULT_DEVICE):
    model.eval()
    running_loss, total, correct = 0.0, 0, 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images) if device == "cuda" else model(images.cpu())
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            preds = outputs.argmax(dim=1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    avg_loss = running_loss / len(loader)
    acc = correct / total
    cm  = confusion_matrix(all_labels, all_preds)
    return avg_loss, acc, cm
    
def plot_loss_accuracy(
    train_loss, val_loss,
    train_top1, val_top1,
    train_top5, val_top5,
    filename="loss_accuracy.png"
):
    # 建立 1×3 的子圖：Loss / Top-1 / Top-5
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 4))

    # 1) Loss
    ax1.plot(train_loss, label="Train", color="C0")
    ax1.plot(val_loss,   label="Val",   color="C1")
    ax1.set_title("Loss")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")
    ax1.legend()
    ax1.grid(True)

    # 2) Top-1 Accuracy
    ax2.plot(train_top1, label="Train Top-1", color="C0")
    ax2.plot(val_top1,   label="Val Top-1",   color="C1")
    ax2.set_title("Top-1 Accuracy")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Accuracy")
    ax2.legend()
    ax2.grid(True)

    # 3) Top-5 Accuracy
    ax3.plot(train_top5, label="Train Top-5", color="C0")
    ax3.plot(val_top5,   label="Val Top-5",   color="C1")
    ax3.set_title("Top-5 Accuracy")
    ax3.set_xlabel("Epoch")
    ax3.set_ylabel("Accuracy")
    ax3.legend()
    ax3.grid(True)

    # 儲存與顯示
    os.makedirs(os.path.dirname(filename) or ".", exist_ok=True)
    plt.tight_layout()
    plt.savefig(filename)
    plt.show()

def preprocess_filename(filename: str = MODEL_PATH, existed: str = "keep_both") -> str:
    if existed == "overwrite":
        pass
    elif existed == "keep_both":
        base, ext = os.path.splitext(filename)
        cnt = 1
        while os.path.exists(filename):
            filename = f"{base}-{cnt}{ext}"
            cnt += 1
    elif existed == "raise" and os.path.exists(filename):
        raise FileExistsError(f"{filename} already exists.")
    else:
        raise ValueError(f"Unknown value for 'existed': {existed}")
    return filename
    
def save_model(
    model, filename: str = MODEL_PATH, verbose: bool = True, existed: str = "keep_both"
) -> None:
    filename = preprocess_filename(filename, existed)

    os.makedirs(os.path.dirname(filename), exist_ok=True)
    torch.save(model.state_dict(), filename)
    if verbose:
        print(f"Model saved at {filename} ({os.path.getsize(filename) / 1e6} MB)")
    else:
        print(f"Model saved at {filename}")


## Main

In [None]:
device = torch.device(DEFAULT_DEVICE)
model = ViT(
    image_size=image_size,
    patch_size=patch_size,
    num_classes=num_classes,
    projection_dim=projection_dim,
    num_patches=num_patches,
    num_heads=num_heads,
    transformer_units=[projection_dim*2, projection_dim],
    transformer_layers=transformer_layers,
    mlp_head_units=mlp_head_units
).to(device)
trainloader, valloader, testloader = get_cifar_loaders(batch_size, image_size=image_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
train_loss, train_acc1 = [], []
train_acc5 = []
val_loss, val_acc1 = [], []
val_acc5 = []

t0 = time.time()
best_epoch_train_acc = 0.0
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    total, correct_1, correct_5, val_1, val_5 = 0, 0, 0, 0, 0

    loop = tqdm(trainloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds_1 = logits.topk(1, dim=1)  
        _, preds_5 = logits.topk(5, dim=1)  
        total += labels.size(0)
        correct_1 += (preds_1.squeeze(1) == labels).sum().item()
        correct_5 += (preds_5 == labels.view(-1,1)).any(dim=1).sum().item()
        # 動態顯示當前 loss 和 accuracy
        loop.set_postfix(loss=running_loss/(total/images.shape[0]), top1_acc=correct_1/total, top5_acc=correct_5/total)
        
    epoch_train_loss = running_loss / len(trainloader)
    epoch_train_acc = correct_1 / total
    train_loss.append(epoch_train_loss)
    train_acc1.append(epoch_train_acc)
    train_acc5.append(correct_5 / total)

    # 每次訓練完做一次validation
    with torch.no_grad():
        val_running_loss, val_correct, val_total = 0.0, 0, 0
        loop2 = tqdm(valloader, desc="Validation", leave=True)
        for images, labels in loop2:
            images, labels = images.to(device), labels.to(device)
            logits = model(images)
            loss = criterion(logits, labels)
            val_running_loss += loss.item()
            _, preds_1 = logits.topk(1, dim=1)  
            _, preds_5 = logits.topk(5, dim=1)  
            val_total += labels.size(0)
            val_1 += (preds_1.squeeze(1) == labels).sum().item()
            val_5 += (preds_5 == labels.view(-1,1)).any(dim=1).sum().item()

        epoch_val_loss = val_running_loss / len(valloader)
        epoch_val_acc = val_1 / val_total
        val_loss.append(epoch_val_loss)
        val_acc1.append(epoch_val_acc)
        val_acc5.append(val_5 / val_total)
    # save model if better
    if epoch_train_acc > best_epoch_train_acc:
        best_epoch_train_acc = epoch_train_acc
        save_model(model, MODEL_PATH, existed="overwrite")

    
# 測試
model.eval()
with torch.no_grad():
    val_running_loss, correct_1, correct_5, val_total = 0.0, 0, 0, 0
    for images, labels in valloader:
        images, labels = images.to(device), labels.to(device)
        logits = model(images)
        loss = criterion(logits, labels)
        val_running_loss += loss.item()
        _, preds_1 = logits.topk(1, dim=1)  
        _, preds_5 = logits.topk(5, dim=1)  
        val_total += labels.size(0)
        correct_1 += (preds_1.squeeze(1) == labels).sum().item()
        correct_5 += (preds_5 == labels.view(-1,1)).any(dim=1).sum().item()

    epoch_val_loss = val_running_loss / len(valloader)
    epoch_val_acc_1 = correct_1 / val_total
    epoch_val_acc_5 = correct_5 / val_total

# 5. 繪製並儲存訓練曲線
os.makedirs("figure", exist_ok=True)

plot_loss_accuracy(
    train_loss=train_loss, val_loss=val_loss,
    train_top1=train_acc1, val_top1=val_acc1,
    train_top5=train_acc1, val_top5=val_acc5,
    filename="figure/vit_loss_accuracy.png"
);
# 6. 測試並顯示最終結果
#test_loss, test_accuracy, _ = evaluate(model, testloader, criterion)
print(f"Test loss={epoch_val_loss:.4f}, top1 accuracy={epoch_val_acc_1:.4f}, top1 accuracy={epoch_val_acc_5:.4f}")
print(f"Total training time: {time.time() - t0:.1f}s")


## ViT + 硬體 MLP

In [None]:
# -------------------------------------- vit_hw.py -----------------------------
import torch
import torch.nn as nn
from vit_original import ViT, TransformerBlock   # 假設兩個 class 在原 ViT.py

# ---- 修改 TransformerBlock：加可插拔 mlp ----------------
class TransformerBlockHW(TransformerBlock):
    def __init__(self, projection_dim, num_heads, transformer_units,
                 dropout=0.1, hw_mlp=None):
        super().__init__(projection_dim, num_heads, transformer_units, dropout)
        if hw_mlp is not None:
            self.mlp = hw_mlp  # 把 nn.Sequential 換掉

# ---- ViT 包裝：在 __init__ 傳進同一顆 hw_mlp ----
class ViT_Hardware(ViT):
    def __init__(self, hw_mlp=None, **kwargs):
        super().__init__(**kwargs)
        if hw_mlp is not None:
            # 用新的 block 取代
            self.transformer_layers = nn.ModuleList([
                TransformerBlockHW(kwargs['projection_dim'],
                                   kwargs['num_heads'],
                                   kwargs['transformer_units'],
                                   hw_mlp=hw_mlp)
                for _ in range(kwargs['transformer_layers'])
            ])


## 硬體測試

In [None]:
# --------------------------------- train_and_eval.py --------------------------
import torch, torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from hardware_mlp import HardwareMLP
from vit_hw import ViT_Hardware

# ---- Dummy dataset -----------------------------------------------------------
def make_dummy_loader(batch=4, n_batch=20):
    x = torch.rand(batch * n_batch, 3, 224, 224)
    y = torch.randint(0, 10, (batch * n_batch,))
    ds = TensorDataset(x, y)
    return DataLoader(ds, batch_size=batch)

# ---- 1. 建 ViT (訓練階段先用軟體 MLP) ---------------------------------------
model = ViT_Hardware(           # hw_mlp=None → 全軟體
    hw_mlp=None,
    image_size=224, patch_size=16, num_classes=10,
    projection_dim=64, num_patches=(224//16)**2,
    num_heads=4, transformer_units=[128, 64],
    transformer_layers=1,       # demo 用 1 層
    mlp_head_units=[128]
).train()

optimizer = torch.optim.Adam(model.parameters(), 1e-3)
criterion = nn.CrossEntropyLoss()
loader = make_dummy_loader()

# ---- 2. 簡易 train 1 epoch ----------------------------------------------------
for xb, yb in loader:
    optimizer.zero_grad()
    logits = model(xb)
    loss = criterion(logits, yb)
    loss.backward(); optimizer.step()
print("train done, loss =", loss.item())

torch.save(model.state_dict(), "latest.pt")

# ---- 3. 構建硬體 MLP + 灌入剛存的權重 ----------------------------------------
hw_mlp = HardwareMLP(feature_dim=64,
                     input_scale=0.008, output_scale=0.008,
                     scaling_factor=32)
hw_mlp.load_from_pt(
    "latest.pt",
    pt_keys=("transformer_layers.0.mlp.0.weight",
             "transformer_layers.0.mlp.0.bias",
             "transformer_layers.0.mlp.3.weight",
             "transformer_layers.0.mlp.3.bias")
)

# ---- 4. 重新建立 “部署版” ViT：MLP → 硬體 -----------------------------------
vit_deploy = ViT_Hardware(
    hw_mlp=hw_mlp,
    image_size=224, patch_size=16, num_classes=10,
    projection_dim=64, num_patches=(224//16)**2,
    num_heads=4, transformer_units=[128, 64],
    transformer_layers=1, mlp_head_units=[128]
).eval()

# ---- 5. 驗證硬體輸出 vs 軟體輸出 --------------------------------------------
x_test = torch.randn(2, 3, 224, 224)
with torch.no_grad():
    y_hw = vit_deploy(x_test)             # MLP 經 Top.sv
    # 為對比，build 另一個軟體-MLP 版本
    model.eval()
    y_sw = model(x_test)                  # 純 PyTorch

diff = (y_hw - y_sw).abs().max().item()
print(f"Max |HW-SW| = {diff:.6f}")


# QConfig

In [12]:
class PowerOfTwoObserver(tq.MinMaxObserver):
    """
    Observer module for power-of-two quantization (dyadic quantization with b = 1).
    """

    
    def scale_approximate(self, scale: float, max_shift_amount=8) -> float:
        #########Implement your code here##########
        if scale <= 0:  # Handle non-positive scale values
            return 2 ** -max_shift_amount  # Return a safe default value

        # Calculate the power of 2 closest to the scale
        exponent = math.log2(scale)

        # Clamp the exponent to the allowed range
        exponent = max(-max_shift_amount, min(exponent, max_shift_amount))

        # Return the power of 2
        return (2.0 ** exponent)
        ##########################################

    def calculate_qparams(self):
        """Calculates the quantization parameters with scale as power of two."""
        min_val, max_val = self.min_val.item(), self.max_val.item()

        """ Calculate zero_point as in the base class """
        #########Implement your code here##########
        max_num = 2**7 - 1
        max_range = max(abs(min_val), abs(max_val))
        scale = max_range / max_num
        
        zero_point = 0 if self.dtype == torch.qint8 else max_num
        
        ##########################################
        scale = self.scale_approximate(scale)
        scale = torch.tensor(scale, dtype=torch.float32)
        zero_point = torch.tensor(zero_point, dtype=torch.int64)
        return scale, zero_point

class CustomQConfig(Enum):
    POWER2 = tq.QConfig(
        activation=PowerOfTwoObserver.with_args(
            dtype=torch.quint8, qscheme=torch.per_tensor_symmetric
        ),
        weight=PowerOfTwoObserver.with_args(
            dtype=torch.qint8, qscheme=torch.per_tensor_symmetric
        ),
    )
    DEFAULT = None

In [None]:
'''checkpoint = torch.load(MODEL_PATH, map_location='cpu')  # [2]

# 如果你存的是 state_dict（dict of name→Tensor）
if isinstance(checkpoint, dict):
    print("Keys and tensor shapes in state_dict:")
    total_params = 0
    for name, tensor in checkpoint.items():
        shape = tuple(tensor.shape)
        numel = tensor.numel()
        total_params += numel
        print(f"  {name}: shape={shape}, params={numel}")
    print(f"Total parameters: {total_params}")
else:
    # 如果你存的是整個模型物件
    print("Checkpoint is a model instance:")
    print(checkpoint)
with open('model_parameters.txt', 'w') as f:
    for name, tensor in checkpoint.items():
        f.write(f"{name}: {tuple(tensor.shape)}\n")'''

In [18]:
# reset_seed(0)
DEVICE = "cpu"
torch.backends.quantized.engine = 'fbgemm'
# 1. 先從 DataLoader 拿到一筆範例，讀出通道數與圖像大小
in_channels, in_size = trainloader.dataset[0][0].shape[:2]

# 2. 建立與訓練時完全相同的 ViT 架構，並切到 CPU、eval 模式
model_q = ViT(
    image_size=image_size,
    patch_size=patch_size,
    num_classes=num_classes,
    projection_dim=projection_dim,
    num_patches=num_patches,
    num_heads=num_heads,
    transformer_units=transformer_units,
    transformer_layers=transformer_layers,
    mlp_head_units=mlp_head_units
).eval().cpu()

# 3. 載入你先前儲存的 state_dict
model_q.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
# --- 2. 包裝模型並指定 QConfig ---
# ViT 是非常規 CNN，無 fuse 需求，直接包 QuantWrapper
model_q = tq.QuantWrapper(model_q)  
model_q.qconfig = CustomQConfig.POWER2.value
# 放回CPU
model_q.cpu()
trainloader_cpu = DataLoader(trainloader.dataset, batch_size=batch_size, shuffle=False)

# --- 3. 準備量化 (插入 observer) ---
tq.prepare(model_q, inplace=True)

# --- 4. 校準 (用 trainloader 跑 forward) ---
model_q.eval()
with torch.no_grad():
    #for i, (images, _) in enumerate(trainloader):
    #    if i >= 10: break            # 取前 10 批做校準
    #    model_q(images.to(DEVICE))
    loop = tqdm(trainloader, desc="Calibrating", leave=True)
    for data, _ in loop:
        data = data.to(DEVICE)
        model_q(data.cpu())
# --- 5. 轉換為量化模型 ---
tq.convert(model_q, inplace=True)
'''
for name, module in model_q.named_modules():
    if isinstance(module, nnq.Linear):
        print(f"量化成功層: {name}")
    elif isinstance(module, nn.Linear):
        print(f"未量化層: {name}")
'''
# --- 7. 儲存量化模型 ---
os.makedirs("weights", exist_ok=True)
torch.save(model_q.state_dict(), QUANT_MODEL_PATH)
print("Quantized model saved to " + QUANT_MODEL_PATH)

# --- 6. 評估量化後模型效能 ---
#quant_loss, quant_acc, _ = evaluate(model_q, testloader, criterion, DEVICE)
model_q.eval()
with torch.no_grad():
    val_running_loss, correct_1, correct_5, val_total = 0.0, 0, 0, 0
    for images, labels in valloader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        logits = model_q(images.cpu())
        loss = criterion(logits, labels.cpu())
        val_running_loss += loss.item()
        _, preds_1 = logits.topk(1, dim=1)  
        _, preds_5 = logits.topk(5, dim=1)  
        val_total += labels.size(0)
        correct_1 += (preds_1.squeeze(1) == labels).sum().item()
        correct_5 += (preds_5 == labels.view(-1,1)).any(dim=1).sum().item()

    epoch_val_loss = val_running_loss / len(valloader)
    epoch_val_acc_1 = correct_1 / val_total
    epoch_val_acc_5 = correct_5 / val_total

print(f"Quantized top1 accurcay: {epoch_val_acc_1:.4f} Quantized top5 accurcay: {epoch_val_acc_5:.4f}")



#original_size = os.path.getsize('./weights/cifar10/vgg.pt') / 1e6
#quantized_size = os.path.getsize(QUANT_MODEL_PATH) / 1e6
#print(f"Original model size: {original_size:.2f} MB")
#print(f"Quantized model size: {quantized_size:.2f} MB")
#print(f"Size reduction: {(1 - quantized_size/original_size) * 100:.2f}%")

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  tq.prepare(model_q, inplace=True)
Calibrating: 100%|███████████████████████████████████████████████████████████████████| 176/176 [01:01<00:00,  2.87it/s]
For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.pre

Quantized model saved to ./weights/quant_vit.pt


NotImplementedError: Could not run 'quantized::linear' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::linear' is only available for these backends: [Meta, QuantizedCPU, QuantizedCUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMTIA, AutogradMAIA, AutogradMeta, Tracer, AutocastCPU, AutocastMTIA, AutocastMAIA, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

Meta: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\MetaFallbackKernel.cpp:23 [backend fallback]
QuantizedCPU: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\native\quantized\cpu\qlinear.cpp:1436 [kernel]
QuantizedCUDA: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\native\quantized\cudnn\Linear.cpp:359 [kernel]
BackendSelect: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\PythonFallbackKernel.cpp:194 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\functorch\DynamicLayer.cpp:479 [backend fallback]
Functionalize: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\FunctionalizeFallbackKernel.cpp:375 [backend fallback]
Named: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\native\NegateFallback.cpp:18 [backend fallback]
ZeroTensor: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:104 [backend fallback]
AutogradOther: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:63 [backend fallback]
AutogradCPU: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:67 [backend fallback]
AutogradCUDA: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:75 [backend fallback]
AutogradXLA: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:87 [backend fallback]
AutogradMPS: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:95 [backend fallback]
AutogradXPU: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:71 [backend fallback]
AutogradHPU: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:108 [backend fallback]
AutogradLazy: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:91 [backend fallback]
AutogradMTIA: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:79 [backend fallback]
AutogradMAIA: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:83 [backend fallback]
AutogradMeta: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:99 [backend fallback]
Tracer: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\torch\csrc\autograd\TraceTypeManual.cpp:294 [backend fallback]
AutocastCPU: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\autocast_mode.cpp:322 [backend fallback]
AutocastMTIA: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\autocast_mode.cpp:466 [backend fallback]
AutocastMAIA: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\autocast_mode.cpp:504 [backend fallback]
AutocastXPU: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\autocast_mode.cpp:542 [backend fallback]
AutocastMPS: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\autocast_mode.cpp:209 [backend fallback]
AutocastCUDA: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\autocast_mode.cpp:165 [backend fallback]
FuncTorchBatched: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\functorch\LegacyBatchingRegistrations.cpp:731 [backend fallback]
BatchedNestedTensor: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\functorch\LegacyBatchingRegistrations.cpp:758 [backend fallback]
FuncTorchVmapMode: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\functorch\VmapModeRegistrations.cpp:27 [backend fallback]
Batched: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\functorch\TensorWrapper.cpp:210 [backend fallback]
PythonTLSSnapshot: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\PythonFallbackKernel.cpp:202 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\functorch\DynamicLayer.cpp:475 [backend fallback]
PreDispatch: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\PythonFallbackKernel.cpp:206 [backend fallback]
PythonDispatcher: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\PythonFallbackKernel.cpp:198 [backend fallback]
