In [None]:
# Time to do a neural network of MNIST
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.ao.quantization as quant
from torch.autograd.function  import Function, InplaceFunction
import torch.nn.init as init
from tqdm.notebook import tqdm
import numpy as np

acc_list = []

# Hyperparameters
batch_size = 64
learning_rate = 0.1
epochs = 1

# Transform for MNIST dataset (normalization)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# MNIST Dataset
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


class Binarize(InplaceFunction):

    def forward(ctx,input,quant_mode='det',allow_scale=False,inplace=False):
        ctx.inplace = inplace
        if ctx.inplace:
            ctx.mark_dirty(input)
            output = input
        else:
            output = input.clone()

        scale= output.abs().max() if allow_scale else 1

        if quant_mode=='det':
            return output.div(scale).sign().mul(scale)
        else:
            return output.div(scale).add_(1).div_(2).add_(torch.rand(output.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1).mul(scale)

    def backward(ctx,grad_output):
        #STE
        grad_input=grad_output
        return grad_input,None,None,None

def binarized(input,quant_mode='det'):
      return Binarize.apply(input,quant_mode)

class BinarizeLinear(nn.Linear):

    def __init__(self, *kargs, **kwargs):
        super(BinarizeLinear, self).__init__(*kargs, **kwargs)
        # self.reset_parameters()

    def reset_parameters(self):
        # Initialize weights with a uniform distribution centered around 0
        init.uniform_(self.weight, a=-1, b=1)  # Uniform distribution between -1 and 1
        if self.bias is not None:
            init.constant_(self.bias, 0)  # Bias initialized to 0

    def forward(self, input):
        # if input.size(1) != 784:
        # input_b=binarized(input)
        weight_b=binarized(self.weight)
        # print(weight_b)
        out = nn.functional.linear(input,weight_b)
        out_normal = nn.functional.linear(input,self.weight)
        if not self.bias is None:
            # print(self.bias)
            self.bias.org=self.bias.data.clone()
            out += self.bias.view(1, -1).expand_as(out)
            # print(out)
        return out

list_of_acc_per_run = []
for _ in tqdm(range(5)):
  # Define the model (1 fully connected layer)
  model = nn.Sequential(
      nn.Flatten(),  # Flatten the input (28x28) to a vector (784)
      BinarizeLinear(28 * 28, 10),  # Fully connected layer with 10 output neurons (for 10 classes)
  )

  # To store the weights during forward pass
  weights_during_forward = []

  gradients_ = []

  # Hook function to store weights during forward pass
  def store_weights_hook(module, input, output):
      weights_during_forward.append(module.weight.clone().detach())

  # Attach the hook to the nn.Linear layer (model[1])
  model[1].register_forward_hook(store_weights_hook)

  # Loss and optimizer
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.SGD(model.parameters(), lr=learning_rate)

  # Training loop
  def train(model, device, train_loader, optimizer, epoch):
      model.train()
      for batch_idx, (data, target) in enumerate(train_loader):
          data, target = data.to(device), target.to(device)
          optimizer.zero_grad()  # Zero the gradients
          output = model(data)   # Forward pass
          loss = criterion(output, target)  # Calculate loss
          loss.backward()  # Backward pass (calculate gradients)
          for param in model.parameters():
              gradients_.append(param.grad)
          optimizer.step()  # Update the model parameters


  # Evaluation loop
  def test(model, device, test_loader):
      model.eval()
      test_loss = 0
      correct = 0
      with torch.no_grad():
          for data, target in test_loader:
              data, target = data.to(device), target.to(device)
              output = model(data)
              test_loss += criterion(output, target).item()  # Sum up batch loss
              pred = output.argmax(dim=1, keepdim=True)  # Get the index of the max log-probability
              correct += pred.eq(target.view_as(pred)).sum().item()

      test_loss /= len(test_loader.dataset)
      print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)')
      return 100. * correct / len(test_loader.dataset)
  # Check if CUDA is available and use GPU if possible
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model.to(device)

  # Train and test the model
  for epoch in range(8):
      train(model, device, train_loader, optimizer, epoch)
      acc_list.append(test(model, device, test_loader))
  list_of_acc_per_run.append(acc_list)
