In [None]:
# Imported the necessary classes
import torch
import random
import torchvision
from torch import nn
from torchvision import transforms
from torchvision import datasets, models, transforms
import os
import matplotlib.pyplot as plt
from pathlib import Path
from typing import List, Tuple
from PIL import Image
import warnings
from typing import Optional, Tuple
from torch import Tensor
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from torchvision.models import vision_transformer
from torch.nn.modules import activation
import numpy as np
from tqdm import tqdm
import os
from torch.utils.data import Dataset
from PIL import Image
import json

device = "cuda" if torch.cuda.is_available() else "cpu"

# Get pretrained weights for ViT-Base-16 (can use any ViT model)
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT # weights which will be tempered
pretrained_vit_weights_org = pretrained_vit_weights # original weights which will not be tempered

# Setup a ViT model instance with pretrained weights
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)
pretrained_vit_org = torchvision.models.vit_b_16(weights=pretrained_vit_weights_org).to(device)

# Freeze the base parameters
for parameter in pretrained_vit.parameters():
    parameter.requires_grad = False
    
for parameter in pretrained_vit_org.parameters():
    parameter.requires_grad = False
    
# Function to find the accuracy loss
def findAccuracyLoss(
    
    # Takes model, encoders to be tempered, no of decimal place, path to image dataset as input
    model: torch.nn.Module,
    model_org : torch.nn.Module,
    encoders : List[int],
    decimals : int,
    root: str,
    image_size: Tuple[int, int] = (224, 224),
    transform: torchvision.transforms = None,
    device: torch.device = device,
):

    # Create transformation for image (if one doesn't exist)
    if transform is not None:
        image_transform = transform
    else:
        image_transform = transforms.Compose(
            [
                transforms.Resize(image_size),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    model.to(device)
    model_org.to(device)

    # Turn on model evaluation mode and inference mode
    model_org.eval()
    model.eval()
    
  
    accuracy_loss = 0 # Defined average accuracy loss
    with torch.inference_mode():
        for i in range(1):
            for img in os.listdir(root):
                img_path =  os.path.join(root , img)
                imge = Image.open(img_path).convert("RGB")
                
                # Prediction of original model without tempering weights
                transformed_image = image_transform(imge).unsqueeze(dim=0)
                target_image_pred = model_org(transformed_image.to(device)) 
                
                # Manually calling the functions as we want to access the weights in between to temper it
                out = model._process_input(transformed_image.to(device))
                n = out.shape[0]
                batch_class_token = model.class_token.expand(n, -1, -1) # Expand the class token to the full batch
                out = torch.cat([batch_class_token, out], dim=1)
                input = out
                torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
                input = input + model.encoder.pos_embedding
                input = model.encoder.dropout(input)
                
                # For all the 12 encoders of ViT-Base-16
                for i in range(12): 
                    torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
                    encoder_block = model.encoder.layers[i]
                    x = encoder_block.ln_1(input)
                    
                    # Multi head attention layer
                    mha = encoder_block.self_attention 
                    if(i in encoders):
                        x = torch.round(x , decimals = decimals)
                        mha.in_proj_weight.data = torch.round(mha.in_proj_weight.data , decimals = decimals) # Rounding off weights to a particular decimal place
                    x, _ = encoder_block.self_attention(x, x, x, need_weights=False)
                    x = encoder_block.dropout(x)
                    x = x + input
                    y = encoder_block.ln_2(x)
                    
                    # In MLP block, only rounding off the linear layers
                    if(i in encoders):
                        for m in encoder_block.mlp.modules():
                            if isinstance(m, nn.Linear): 
                                m.weight.data = torch.round(m.weight.data , decimals = decimals) # Rounding off weights to a particular decimal place
                        y = torch.round(y , decimals = decimals)
                        
                    y = encoder_block.mlp(y)
                    input = x + y 
   
                # Executing rest of layers as it is
                final_out = model.encoder.ln(input)
                final_out = final_out[:, 0]
                final_out = model.heads(final_out) # Final output
                
                target_image_pred_probs = torch.softmax(target_image_pred, dim=1) # Original accuracy of model
                target_image_pred_probs_approx = torch.softmax(final_out, dim=1) # Accuracy of model with tempered weights

                # Convert prediction probabilities to prediction labels
                target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
                target_image_pred_label_approx = torch.argmax(target_image_pred_probs_approx, dim=1)
            
                # Total accuracy loss, considering accuracy loss of each image
                accuracy_loss += (abs(target_image_pred_probs.max() - target_image_pred_probs_approx.max())*100)
                
    # Average accuracy loss, dividing by the no of images in dataset
    print("Average accuracy loss : " , accuracy_loss/100000)
    
encoders = [0,1,2,3,4,5,6,7,8,9,10,11] # Encoders to be tempered

findAccuracyLoss(model=pretrained_vit, # Model (To be tempered)
                   model_org = pretrained_vit_org, # Original model (Not to be tempered)
                    root = '/home/palashdas/ILSVRC/Data/CLS-LOC/test', # Path of image dataset
                    encoders = encoders,
                    decimals = 2) # Set the decimal place to be rounded off