In [None]:
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn.init
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np

### hyper parameter

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# 랜덤 시드 고정
torch.manual_seed(777)

# GPU 사용 가능일 경우 랜덤 시드 고정
if device == "cuda":
    torch.cuda.manual_seed_all(777)

# 하이퍼 파라미터
learning_rate = 0.001
training_epochs = 40
batch_size = 128
quant_epoch = 20

### dataset and data loader

In [None]:
mnist_train = dsets.MNIST(
    root="./",  # 다운로드 경로 지정
    train=True,  # True를 지정하면 훈련 데이터로 다운로드
    transform=transforms.ToTensor(),  # 텐서로 변환
    download=True,
)

mnist_test = dsets.MNIST(
    root="./",  # 다운로드 경로 지정
    train=False,  # False를 지정하면 테스트 데이터로 다운로드
    transform=transforms.ToTensor(),  # 텐서로 변환
    download=True,
)


data_loader = torch.utils.data.DataLoader(
    dataset=mnist_train, batch_size=batch_size, shuffle=True, drop_last=True
)

### STE

In [None]:
class roundpass(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ## Define output w.r.t. input
        output = torch.round(input)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        ## Define grad_input w.r.t. grad_output
        grad_input = grad_output
        return grad_input


roundpass = roundpass.apply

### Quantization module

In [None]:
class Quantizer(nn.Module):
    def __init__(self, bits=8, always_pos=False):
        super(Quantizer, self).__init__()
        
        self.first = True
        
        self.alpha_baseline = nn.Parameter(
            torch.zeros(1, device=device), requires_grad=False)
        self.alpha_delta = nn.Parameter(
            torch.zeros(1, device=device), requires_grad=True)

        self.always_pos = always_pos
     
        self.Qp = 2**(bits-1) - 1
        self.Qn = -self.Qp
        self.num_steps = self.Qp - self.Qn

    def get_alpha(self):
        return F.softplus(self.alpha_baseline + self.alpha_delta)

    def forward(self, x):
        if self.first:
            def reverse_softplus(x):
                return np.log(np.exp(x) - 1.0)

            self.alpha_baseline.add_(reverse_softplus(x.std().item() * 3))
            self.first = False

        alpha = self.get_alpha()

        step_size_r = 0.5 * self.num_steps * torch.reciprocal(alpha)
        step_size = 2 * alpha / self.num_steps

        if self.always_pos:
            off = alpha
        else:
            off = 0

        ## define q_x given x and other components above.
        q_x = torch.clamp(roundpass((x - off) * step_size_r), self.Qn, self.Qp) * step_size + off
        return q_x

### Quantization aware modules

In [None]:
class CustomConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super(CustomConv2d, self).__init__(*args, **kwargs)
        self.q_w = Quantizer()
        self.q_a = Quantizer(always_pos=True)
        self.is_quant = False # No quantization by default

    def forward(self, x):
      if self.is_quant:        
          ## quantize the weights and inputs using the ``Quantize`` modules. 
          weight = self.q_w(self.weight)
          inputs = self.q_a(x)
      else:
          weight = self.weight
          inputs = x

      return F.conv2d(
          inputs,
          weight,
          bias=self.bias,
          stride=self.stride,
          padding=self.padding,
          dilation=self.dilation,
          groups=self.groups,
      )


class CustomLinear(nn.Linear):
    def __init__(self, *args, **kwargs):
        super(CustomLinear, self).__init__(*args, **kwargs)
        self.q_w = Quantizer()
        self.q_a = Quantizer(always_pos=True)
        self.is_quant = False # No quantization by default

    def forward(self, x):
      if self.is_quant:        
          ## quantize the weights and inputs using the ``Quantize`` modules. 
          weight = self.q_w(self.weight)
          inputs = self.q_a(x)
      else:
          weight = self.weight
          inputs = x

      return F.linear(inputs, weight, bias=self.bias)

### neural network 

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        self.conv = CustomConv2d(1, 6, kernel_size=3, stride=1, padding=0, bias=False)
        self.layer1 = nn.Sequential(
            self.conv,
            nn.BatchNorm2d(6),
            nn.ReLU(),
        )

        self.fc1 = CustomLinear(4056, 30, bias=False)
        self.fc2 = CustomLinear(30, 10, bias=False)
        self.layer2 = nn.Sequential(self.fc1, torch.nn.ReLU(), self.fc2)

    def forward(self, x):
        out = self.layer1(x)
        out = out.view(out.size(0), -1)
        out = self.layer2(out)
        return out

### custom function for evaluation

In [None]:
if not os.path.exists("weight"):
  os.mkdir('./weight')

def eval_custom(model_, num_imgs):
    with torch.no_grad():
        X_test = (
            mnist_test.data.view(len(mnist_test), 1, 28, 28).float().to(device)
        )
        Y_test = mnist_test.targets.to(device)
        prediction = model_(X_test)
        correct_prediction = torch.argmax(prediction, 1) == Y_test
        correct_prediction_100 = (
            torch.argmax(prediction[:num_imgs], 1) == Y_test[:num_imgs]
        )
        accuracy = correct_prediction.float().mean()
        accuracy_100 = correct_prediction_100.float().mean()
        print("Accuracy_all:", accuracy.item())
        print(f"Accuracy_{num_imgs}:", accuracy_100.item())

        torch.save(
            model.state_dict(),
            f"./weight/model_{str(accuracy.item()):.7}_{str(accuracy_100.item()):.7}.pth",
        )

### iteration loop

In [None]:
model = CNN().to(device)

criterion = torch.nn.CrossEntropyLoss().to(device)  # 비용 함수에 소프트맥스 함수 포함되어져 있음.

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer, start_factor=1.0, end_factor=1e-2, total_iters=training_epochs
)

total_batch = len(data_loader)
print("총 배치의 수 : {}".format(total_batch))

for epoch in range(training_epochs):
    avg_cost = 0

    if epoch >= quant_epoch:
        for m in model.modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
                m.is_quant=True

    for X, Y in data_loader:  # 미니 배치 단위로 꺼내온다. X는 미니 배치, Y는 레이블.
        # image is already size of (28x28), no reshape
        # label is not one-hot encoded
        X = X.to(device)
        Y = Y.to(device)

        optimizer.zero_grad()
        hypothesis = model(X)
        cost = criterion(hypothesis, Y)
        cost.backward()
        optimizer.step()
        scheduler.step()

        avg_cost += cost / total_batch
    eval_custom(model, 100)

    print("[Epoch: {:>4}] cost = {:>.9}".format(epoch + 1, avg_cost))

총 배치의 수 : 468
Accuracy_all: 0.8931999802589417
Accuracy_100: 0.8899999856948853
[Epoch:    1] cost = 0.530423105
Accuracy_all: 0.9013999700546265
Accuracy_100: 0.8899999856948853
[Epoch:    2] cost = 0.432589144
Accuracy_all: 0.9099999666213989
Accuracy_100: 0.9099999666213989
[Epoch:    3] cost = 0.389850825
Accuracy_all: 0.9146999716758728
Accuracy_100: 0.9399999976158142
[Epoch:    4] cost = 0.358769208
Accuracy_all: 0.9185000061988831
Accuracy_100: 0.9300000071525574
[Epoch:    5] cost = 0.335046858
Accuracy_all: 0.9236999750137329
Accuracy_100: 0.9300000071525574
[Epoch:    6] cost = 0.31562832
Accuracy_all: 0.9258999824523926
Accuracy_100: 0.949999988079071
[Epoch:    7] cost = 0.299797207
Accuracy_all: 0.9282999634742737
Accuracy_100: 0.949999988079071
[Epoch:    8] cost = 0.285898119
Accuracy_all: 0.9299999475479126
Accuracy_100: 0.949999988079071
[Epoch:    9] cost = 0.273763508
Accuracy_all: 0.9315999746322632
Accuracy_100: 0.949999988079071
[Epoch:   10] cost = 0.2632038
Acc