# Deep Learning course Project Assignment
### Test Time Adaptation

Yesun-Erdene Jargalsaikhan [247523], y.jargalsaikhan@studenti.unitn.it

18 June, 2025

Quick Access:
1. [Introduction](#introduction)  
   1.1 [Test Time Adaptation](#memo)  
   1.2 [MEMO](#augmentation-techniques)  
   1.3 [Marginal Entropy Minimization with One test point](#marginal-entropy-minimization-with-one-test-point)  
   1.4 [My modifications](#my-modifications)  
   1.5 [Instruction for running experiments ⬇️](#instruction-for-running-experiments)

2. [Implementation](#implementation)     
   2.1 [MEMO re-implementation](#memo-re-implementation)  
   2.2 [Entropy Loss](#entroty-loss).  
      2.2.1  [Marginal Entropy Loss with Sharpened Softmax](#marginal-entropy-loss-with-sharpened-softmax)   
      2.2.2  [Augmentation-Weighted Entropy](#augmentation-weighted-entropy)   
   2.3 [Dataset for the domain shift challenge](#dataset-for-the-domain-shift-challenge)  

3. [Preparation for the experiments](#preparation-for-experiments)  
    3.1 [Pre-trained models](#pre-trained-models)   
    3.3 [Test functions](#test-functions)

4. [Experiments](#experiments)  
   4.1 [Baselines](#baselines)  
   4.2 [TTA applied](#tta-applied)

5. [Results](#results)   
   5.1 [Discussion](#discussion)  
   5.2 [Conclusion](#Conclusion)

## Introduction


In this project I re-implemented MEMO (Marginal Entropy Minimization with One test point) to improve the performance of the pre-trained model on image classification task at a test time. The purpose of the project is to adapt the model to perform better when the test data distribution differs from the distribution of the dataset it was originally trained on.


#### Test Time Adaptation

For domain shift challenges in image classification, this adaptation technique improves performance of the pre-trained model on test dataset that are visually different by without accessing to the labels of the samples and one at a time.

#### MEMO

**Test Time Robustness via Adaptation and Augmentation** [Zhang et al. (2021)](https://arxiv.org/abs/2110.09506)

The approach proposed a simple method that modifies how the inference is done without assumption on training process or test time data availabbility, and can be used by model that is probabilistic and adaptable.

MEMO performs set of different augmentations on the test sample independently, and takes the conditional output distribution from the pre-trained model as the output for each augmentented version of the image, then computes the marginal distribution by averaging the conditional output distributions, and computes marginal entropy of the marginal distribution and minimize it. And this adapts the model parameters for each test sample to ensure predict same label accross augmented images (invariant to augmentations) and increase the confidence of the model. Finally the adapted model can then make its final prediction on the clean test point rather than the augmented copies.

Only minimizing the entropy of the conditional probability distribions is asking just the model to be confident in it's prediction regardless of the  output correctness.

#### Marginal Entropy Minimization with One test point

**Image classification**

The task of image classification — predicts category of an image, prediction $\widehat{y}$  is label the image belongs to, computes probability distribution over the set of classes which gives conditional probability distribution over the set of classes given the input and weights, and the label is given the argmax over the probability distirbution

$$
\widehat{y} = M(x \mid W) = \arg\max_{y \in Y} \, p(y \mid x, W)
$$


<div align="center">where $M$ trained model, $W$ weight space, $X$ input space, $Y$ output space, $w \in W, x \in X, y \in Y$ corresponding variables</div>



**Augmentation**

- $A = \{ a_1, a_2, \ldots, a_K \}$: set of $K$ augmentations. <br>
- $x_k = a_k(\mathbf{x})$: the $k$-th augmented input. <br>
- $p_W(y \mid x_k) = M_W(x_k)$: model's conditional output distribution for augmented input.


**Marginal Distribution**

Marginal distribution is computed by averaging the conditional output distributions over the augmented versions of single text input.

$$
\bar{p}_W(y \mid \mathbf{x}) = \frac{1}{K} \sum_{k=1}^{K} p_W(y \mid x_k)
$$


**Entropy**

We need to measure the confidence of the classifier in the prediction, considering there is no access to the label. Entropy measures uncertainty in a probability distribution: the higher the entropy, the less confidence there is; the lower the entropy, the higher the confidence in the classifier.

$$
H(p) = - \sum_{i=1}^{C} p_i \log p_i
$$


 >  &nbsp;&nbsp;&nbsp;&nbsp;**Marginal Entropy** <br>
  &nbsp;&nbsp;&nbsp;&nbsp;Marginal entropy is computed as the entropy of the marginal distribution, representing the model's overall uncertainty.

&nbsp;&nbsp;&nbsp;&nbsp;$$
H(\bar{p}_W(y \mid x)) = - \sum_{c=1}^{C} \bar{p}_W(y = c \mid x) \log \bar{p}_W(y = c \mid x)
$$


<div align="center">where $C$ is the number of output classes.</div>


**Minimization of the entropy**

Through training iteration and objective function, the confidence of the model will get higher towards a given class, and the probabilities associated with other classes will go down. So that the model gets more confident. One issue is that misprediction will be amplified.

### My modifications

The experiment results of each modification are discussed in the [discussion part](#discussion).

#### Adaptive batch normalization

As the MEMO paper suggested that adaptive batch normalization on the model improved the performance, I also brought it as a modification to the model, and applied it also for all the subsequent modifications

#### Different ways to compute the marginal Entropy

Since one of the objective is to increase the confidence of the model in it's prediction, I decided to modify the method of computing the entropy to make it more robust to different version of augmentations. With this reason I implemented following methods:

- **Sharpened softmax marginal entropy**

The motivation is inspired by the work [Veličković et al. (2025)](https://arxiv.org/abs/2410.01104), which proposed an adaptive temperature technique as an ad-hoc technique for improving the sharpness of softmax at inference. The detail can be found in the dedicated [cell here](#marginal-entropy-loss-with-sharpened-softmax).

- **Augmentation weighted marginal entropy**

The motivation has come from the idea that not all augmentations are useful, so weighting the entropy contribution from each augmentation. Weighted marginal entropy loss, where the weights are based on the model’s confidence (max softmax probability) for each augmentation. The detail can be found in the dedicated [cell here](#marginal-entropy-loss-with-weighted-augmentation).


#### Different choices of augmentation methods

Regarding the condition that he distribution of the dataset is different than the original, also having only one test example at a time, it's obliged to use data augmentation method to improve the model's performace. Therefore, for choosig the augmentation method, I tried to following the MEMO experiments, and chose another combination of augmentation methods to improve the model's performace.

- **RandomResizedCrop**

As the MEMO paper suggested, the RandomResizedCrop method was used in their experiments, I conducted the experiments mostly with this augmentation method.

- **RandomResizedCrop + RandomHorizontalFlip**

I chose the combination of the two methods because the MEMO paper has witnessed the improvement using the horizontal flip standard augmentation method.

- **RandomResizedCrop + RandomAffine + RandomPerspective + RandomHorizontalFlip**

### Instructions for running experiments
‼️

*   If the reader must run the experiment on ***Google Colab***, please set the variable $IS\_COLAB$ to $True$ in the first Preparation cell
    *   The dataset needs to be present in the reader's Google Drive, and please change the image root variables as well
*   If the reader must run the experiment on ***AWS***, please set the variable $IS\_COLAB$ to $False$ in the first Preparation cell
    *   The dataset is uploaded to usergroup-49, and there is no need to change directories; only the notebook needs to be in one directory. If necessary, please adjust the image root variable to suit your environment.
    *   Uncomment lines in the 2nd cell if needed to place the dataset in the desired directory

*   Each experiment is allocated in on single cell, so that it's convinient for the reader if need to test experiments ([here](#experiments)).

If anything needs clarification, please send an email to me.

## Implementation

### Preparation

In [None]:
# --- Change it into True if need to run in colab (but need to have the dataset in colab too) -----#
IS_COLAB = False

In [None]:
# --- Uncomment following lines if need to place the dataset in desired directory -----#

# extract dataset if needed (in aws)
# !tar -xf datasets/ImageNet-A/imagenet-a.tar
# !tar -xf datasets/ImageNet-V2/imagenetv2-matched-frequency.tar.gz

In [None]:
# --- Uncomment and run the following lines if error related to torch happens -----#

# !pip uninstall -y torch torchvision
# %pip install torch torchvision
# %pip install ftfy regex tqdm

In [None]:
import os
import sys
import json
import types
import logging

import numpy as np
import pandas as pd

from PIL import Image
from io import BytesIO

from copy import deepcopy
from datetime import datetime
from tqdm.notebook import tqdm

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import Dataset
import torchvision.datasets as datasets

from torchvision.transforms import v2
from torchvision.transforms.functional import to_tensor

from torch.utils.tensorboard import SummaryWriter
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models import vit_b_16, ViT_B_16_Weights

import matplotlib.pyplot as plt

# Considering TTA, one input prediction at a time
DEFAULT_BATCH_SIZE = 1

# chosen from the MEMO paper
NUM_AUGMENTATION = 32

# prior strength number for adaptive batch normalization
PRIOR_STRENGTH = 16

# prior strength number for adaptive batch normalization
TEMPERATURE = 0.5

# device where the computation should take place
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# if running in COLAB, please change the path according to your dataset path
if IS_COLAB:
    from google.colab import drive
    drive.mount("/content/drive/")
    path = '/content/drive/MyDrive/Colab Notebooks/DL' # change the path if needed
    os.chdir(path)

# defining image root for datasets
if IS_COLAB:
    im_groot_imagenet_a = "/content/drive/MyDrive/Colab Notebooks/DL/datasets/ImageNet-A/imagenet-a"
    im_groot_imagenet_v2 = "/content/drive/MyDrive/Colab Notebooks/DL/datasets/ImageNet-V2/imagenetv2-matched-frequency-format-val"
else: # in aws, I added the datasets in aws
    im_groot_imagenet_a = "../datasets/ImageNet-A/imagenet-a/"
    im_groot_imagenet_v2 = "../datasets/ImageNet-V2/imagenetv2-matched-frequency-format-val/"


# Logger for saving the experiment result
logger = logging.getLogger('my_test_logger')
logger.setLevel(logging.INFO)

2025-06-18 10:10:16.565174: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-18 10:10:16.713426: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750241416.735421    1305 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750241416.744249    1305 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-18 10:10:16.838445: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

### Memo re-implementation

#### MEMO

In [None]:
# https://github.com/bethgelab/robustness/blob/main/robusta/batchnorm/bn.py#L175
def _modified_bn_forward(self, input):
    est_mean = torch.zeros(self.running_mean.shape, device=self.running_mean.device)
    est_var = torch.ones(self.running_var.shape, device=self.running_var.device)
    nn.functional.batch_norm(input, est_mean, est_var, None, None, True, 1.0, self.eps)
    running_mean = self.prior * self.running_mean + (1 - self.prior) * est_mean
    running_var = self.prior * self.running_var + (1 - self.prior) * est_var
    return nn.functional.batch_norm(input, running_mean, running_var, self.weight, self.bias, False, 0, self.eps)

class MEMO(torch.nn.Module):
    """
    A class for neurol network for the reimplementation of MEMO for test time adaptation method
    Args:
        model:                model
        learning_rate:        learning rate which used for adapting the model during test time
        optimizer:            optimizer used during test time
        loss_function:        loss function
        label_mask:           label_mask
        apply_tta:            whether to apply TTA during test time
        augmentations:        sef of augmentations
        prior_strenght_bn:    prior_strenght_bn
        apply_adaptive_bn:    wether to apply adaptive BN
        temperature:          temperature
        apply_transform:      apply_transform
        preprocess:           preprocess
        device:               device
        number_of_augmentation: number of augmentation to generate (will be the intermediate batch size for test time adapting)
    Other
        initial_configuration : initial configuration of the model, before analysing each new smaple, the state of the network should be reset to initial
    """
    def __init__(self,
                 model,
                 learning_rate,
                 optimizer,
                 loss_function,
                 label_mask,
                 apply_tta = True,
                 augmentations = None,
                 prior_strenght_bn = None,
                 apply_adaptive_bn = False,
                 temperature = None,
                 apply_transform = False,
                 preprocess = None,
                 device=device,
                 number_of_augmentation=NUM_AUGMENTATION):
        super(MEMO, self).__init__()

        self.device = device
        self.optimizer = optimizer
        self.apply_tta = apply_tta
        self.label_mask = label_mask
        self.model = model.to(device)
        self.loss_function = loss_function
        self.temperature = temperature
        self.augmentations = augmentations
        self.learning_rate = learning_rate
        self.apply_transform = apply_transform
        self.prior_strenght_bn = prior_strenght_bn
        self.apply_adaptive_bn = apply_adaptive_bn
        self.preprocess = preprocess
        self.number_of_augmentation = number_of_augmentation
        self.initial_configuration = deepcopy(self.model.state_dict())

        # apply the batch normalization if switch parameters is True
        if apply_adaptive_bn == True:
            nn.BatchNorm2d.prior = 1.0
            nn.BatchNorm2d.forward = _modified_bn_forward
            for module in self.model.modules():
                # That globally overrides BatchNorm everywhere in PyTorch,
                # which may affect other models or modules unintentionally.
                if isinstance(module, nn.BatchNorm2d):
                    module.forward = types.MethodType(_modified_bn_forward, module)

    def augment_image(self, x):
        """
        Applies augmentation choices on image the number of augmentation times

        Returns: list of augmented versions of input image including the original input image

        """
        with torch.no_grad():
            augmented_images = []
            # augmented_images.append(x)

            for _ in range(self.number_of_augmentation):
                aug_img = self.augmentations(x)  # always apply transforms
                if not isinstance(aug_img, torch.Tensor):
                    aug_img = to_tensor(aug_img)
                augmented_images.append(aug_img)

            inputs = torch.stack(augmented_images).to(self.device)

            return inputs

    def test_time_adaptation(self, x):
        """
        This function implements the MEMO approach:
            1. Augment the image
            2. Compute the marginal distribution
            3. Compute the entropy and minimize the entropy
            4. Predict the original input with the updated model

        Returns:
        """

        # Set the network to evaluation mode
        self.model.eval()

        # Load the model to device to do computation (GPU)
        self.model.to(self.device)

        # reset prediction variable
        predictions = None

        if self.prior_strenght_bn is None:
            nn.BatchNorm2d.prior = 1
        else:
            nn.BatchNorm2d.prior = float(self.prior_strenght_bn) / float(self.prior_strenght_bn + 1)

        for image in x:
            # Reset the gradient
            self.optimizer.zero_grad()

            # Get the augmented variants of the image
            inputs = self.augment_image(image)

            # First forward pass and take the logits
            outputs = self.model(inputs)

            # Compute the entropy loss
            if self.temperature is None:
                loss = self.loss_function(outputs)
            else:
                # if temperatutre is given, pass to the loss function with the temperature parameter
                loss = self.loss_function(outputs, self.temperature)

            # Compute the gradient
            loss.backward()

            # Parameters update
            self.optimizer.step()

            # Predict the original input with the updated model
            if self.apply_transform == False:
                prediction = self.model(x)
                prediction = prediction[:, self.label_mask]
            else:
                # applies prepocess, as preprocess of the base model has not done on the image
                original_input = self.preprocess(image).unsqueeze(0).to(self.device)
                prediction = self.model(original_input)
                prediction = prediction[:, self.label_mask]

            # Reset the model parameters to initial
            self.model.load_state_dict(deepcopy(self.initial_configuration))

        # reset the prior back to 1
        nn.BatchNorm2d.prior = 1

        return prediction


    def forward(self, x):
        """
        Forward pass
        """
        if self.apply_tta:
            return self.test_time_adaptation(x)
        else:
            output = self.model(x)
            output = output[:, self.label_mask]

        return output

#### Entropy Loss


##### Marginal Entropy Loss

As the MEMO paper implemented, marginal entropy loss is computed on the marginal distribution which is the average of conditional output distribution over the augmented versions of the single input.
y is already a probability as softmax is applied

$$
\bar{p}_W(y \mid \mathbf{x}) = \frac{1}{K} \sum_{k=1}^{K} p_W(y \mid x_k)
$$

$$
H(\bar{p}_W(y \mid x)) = - \sum_{c=1}^{C} \bar{p}_W(y = c \mid x) \log \bar{p}_W(y = c \mid x)
$$



where:
- $K$ is the number of augmentations
- $C$ is the number of classes
- $\bar{p}_W(y = c \mid x)$ is marginal prediction after averaging softmax over augmentations
- $H(⋅)$ is entropy of the marginal prediction

In [None]:
def marginal_entropy_loss(logits):
    # Apply softmax to get probability distribution
    probabilities = F.softmax(logits, dim=1)

    # Compute marginal distribution averaged over the outputs of augmentations
    marginal_outputs = probabilities.mean(dim=0)

    # Compute the entropy of the marginal output
    marginal_entropy = -torch.sum(marginal_outputs * torch.log(marginal_outputs))

    return marginal_entropy

##### Marginal Entropy Loss with Sharpened Softmax

The motivation is inspired by the work [Veličković et al. (2025)](https://arxiv.org/abs/2410.01104), which proposed an adaptive temperature technique as an ad-hoc technique for improving the sharpness of softmax at inference time for overcoming the fundamental limitation of softmax that arises as the number of items grows at test time. The main suggestion from the work is that softmax loses sharpness (low-entropy decision boundary) on out-of-distribution (OOD) inputs, which, in my case, can be grounded to distribution variation. Therefore, need to dynamically lower the softmax temperature at inference time (“adaptive temperature”), by making $T$ smaller for larger or more difficult inputs, sharpen the softmax, leading towards a lower-entropy, however in my case I decided to keep the $T$ as fixed rather than dynamically degrading the value:

$$
\bar{p}_W(y = c \mid x) = \frac{1}{K} \sum_{k=1}^{K} \text{Softmax}\left( \frac{z^{(k)}}{T} \right)_c
$$

$$
H(\bar{p}_W(y \mid x)) = - \sum_{c=1}^{C} \bar{p}_W(y = c \mid x) \log \bar{p}_W(y = c \mid x)
$$

where:
- $K$ is the number of augmentations
- $C$ is the number of classes
- $z^{(k)}(x)$: logits of the model for the $k^{\text{th}}$ augmentation
- $T$: temperature parameter to sharpen the softmax (in my case, it's fixed at $0.5$)
- $\bar{p}_W(y = c \mid x)$ is the marginal prediction after averaging softmax over augmentations with temperature scaling
- $H(⋅)$ is the entropy of the marginal prediction

In [None]:
def sharpened_softmax_entropy(logits, temperature=0.5):
    """
    Compute the entropy of the marginal prediction using temperature-scaled softmax.
    """
    # Apply temperature-scaled softmax to each augmentation
    probabilities = F.softmax(logits / temperature, dim=1)

    # Average across K augmentations to get marginal distribution
    marginal_outputs = probabilities.mean(dim=0)

    # Compute entropy of the marginal distribution with prevention of numerical error
    marginal_entropy = -torch.sum(marginal_outputs * torch.log(marginal_outputs + 1e-8))

    return marginal_entropy

##### Marginal Entropy Loss with Weighted Augmentation

The motication has came from the idea of not all augmentation is useful, so weighting the entropy contribtution from each augmentation Weighted marginal entropy loss where the weights are based on the model’s confidence (max softmax probability) for each augmentation.
Takes logits for $K$ augmentations (shape $K×C$) where $C$ is number of classes, converts logits to probabilities with softmax. For each augmentation, takes the max probability (the model’s confidence). Normalizes these confidences so they sum to 1 — these become the weights. Computes a weighted average (marginal) of the probability distributions. Finally, calculates entropy of this weighted marginal distribution.

$$
\quad w_k = \frac{\max_c p^{(k)}(y = c \mid x)}{\sum_{j=1}^{K} \max_c p^{(j)}(y = c \mid x)}
$$

$$
\bar{p}_w(y = c \mid x) = \sum_{k=1}^{K} w_k \cdot p^{(k)}(y = c \mid x)
$$


$$
H(\bar{p}_w(y \mid x)) = - \sum_{c=1}^{C} \bar{p}_w(y = c \mid x) \log \bar{p}_w(y = c \mid x)
$$



Where:

*   $K$ is the number of augmentations
*   $C$ is the number of classes
*   $p^{(k)}(y = c \mid x)$ is the softmax output for class $c$  on the $k^{\text{th}}$ augmented input
*   $w_k$  is the confidence weight of the $k^{\text{th}}$ augmentation (based on max probability)
*   $\bar{p}_w(y = c \mid x)$ is the weighted marginal probability distribution over classes
* $H(⋅)$ is entropy of the marginal prediction



This approach emphasizes augmentations where the model is more confident, weighting their contributions more heavily in the entropy calculation. This can help focus training or test-time adaptation on more reliable augmentations.

In [None]:
def augmentation_weighted_entropy(logits):
    """
    Computes augmentation-weighted marginal entropy loss.
    """
    # Convert logits to probabilities
    probabilities = F.softmax(logits, dim=1)

    # Use max probability (confidence) as weight per augmentation
    confidences, _ = probabilities.max(dim=1)

    # Normalize weights to sum to 1
    weights = confidences / confidences.sum()

    # Weighted marginal distribution
    marginal = torch.sum(probabilities * weights.unsqueeze(1), dim=0)

    # Compute entropy of the marginal distribution with prevention of numerical error
    entropy = -torch.sum(marginal * torch.log(marginal + 1e-8))

    return entropy

#### Optimizer

For optimizer, used AdamW with default $weight\_decay=0.01$ as suggested by MEMO paper

In [None]:
def get_optimizer(model, learning_rate):
    optimizer = torch.optim.AdamW(model.parameters(), learning_rate)

    return optimizer

### Dataset for the domain shift challenge

For the test dataset, following benchmarks are used:

*   [ImageNet-Adversial](https://arxiv.org/abs/1907.07174) - set of images that contains natural adversarial examples and a classifier trained on ImageNet misclassifies or performs very poorly [ore info](https://arxiv.org/abs/1902.10811).
*   [ImageNet-V2](https://github.com/modestyachts/ImageNetV2) - set of image that are recent re-collection of ImageNet with same categoriess, therefore contains distribution shift with respect to original dataset. [more info](https://github.com/hendrycks/natural-adv-examples).

#### Data label mapping

Due to the difference of number of classes in ImageNet-A which contains only **200** out of 1000 classes in original dataset ImageNet, I needed to take into account this difference and handle with custom way. Following class names and indices are taken from [Natural Adversarial Examples Repository](https://github.com/hendrycks/natural-adv-examples/blob/master/eval.py) and [TPT repository](https://github.com/azshue/TPT/blob/main/data/imagenet_variants.py)

In [None]:
imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]

# ImageNet-A indexes to ImageNet
# https://github.com/hendrycks/natural-adv-examples/blob/master/eval.py
thousand_k_to_200 = {0: -1, 1: -1, 2: -1, 3: -1, 4: -1, 5: -1, 6: 0, 7: -1, 8: -1, 9: -1, 10: -1, 11: 1, 12: -1, 13: 2, 14: -1, 15: 3, 16: -1, 17: 4, 18: -1, 19: -1, 20: -1, 21: -1, 22: 5, 23: 6, 24: -1, 25: -1, 26: -1, 27: 7, 28: -1, 29: -1, 30: 8, 31: -1, 32: -1, 33: -1, 34: -1, 35: -1, 36: -1, 37: 9, 38: -1, 39: 10, 40: -1, 41: -1, 42: 11, 43: -1, 44: -1, 45: -1, 46: -1, 47: 12, 48: -1, 49: -1, 50: 13, 51: -1, 52: -1, 53: -1, 54: -1, 55: -1, 56: -1, 57: 14, 58: -1, 59: -1, 60: -1, 61: -1, 62: -1, 63: -1, 64: -1, 65: -1, 66: -1, 67: -1, 68: -1, 69: -1, 70: 15, 71: 16, 72: -1, 73: -1, 74: -1, 75: -1, 76: 17, 77: -1, 78: -1, 79: 18, 80: -1, 81: -1, 82: -1, 83: -1, 84: -1, 85: -1, 86: -1, 87: -1, 88: -1, 89: 19, 90: 20, 91: -1, 92: -1, 93: -1, 94: 21, 95: -1, 96: 22, 97: 23, 98: -1, 99: 24, 100: -1, 101: -1, 102: -1, 103: -1, 104: -1, 105: 25, 106: -1, 107: 26, 108: 27, 109: -1, 110: 28, 111: -1, 112: -1, 113: 29, 114: -1, 115: -1, 116: -1, 117: -1, 118: -1, 119: -1, 120: -1, 121: -1, 122: -1, 123: -1, 124: 30, 125: 31, 126: -1, 127: -1, 128: -1, 129: -1, 130: 32, 131: -1, 132: 33, 133: -1, 134: -1, 135: -1, 136: -1, 137: -1, 138: -1, 139: -1, 140: -1, 141: -1, 142: -1, 143: 34, 144: 35, 145: -1, 146: -1, 147: -1, 148: -1, 149: -1, 150: 36, 151: 37, 152: -1, 153: -1, 154: -1, 155: -1, 156: -1, 157: -1, 158: -1, 159: -1, 160: -1, 161: -1, 162: -1, 163: -1, 164: -1, 165: -1, 166: -1, 167: -1, 168: -1, 169: -1, 170: -1, 171: -1, 172: -1, 173: -1, 174: -1, 175: -1, 176: -1, 177: -1, 178: -1, 179: -1, 180: -1, 181: -1, 182: -1, 183: -1, 184: -1, 185: -1, 186: -1, 187: -1, 188: -1, 189: -1, 190: -1, 191: -1, 192: -1, 193: -1, 194: -1, 195: -1, 196: -1, 197: -1, 198: -1, 199: -1, 200: -1, 201: -1, 202: -1, 203: -1, 204: -1, 205: -1, 206: -1, 207: 38, 208: -1, 209: -1, 210: -1, 211: -1, 212: -1, 213: -1, 214: -1, 215: -1, 216: -1, 217: -1, 218: -1, 219: -1, 220: -1, 221: -1, 222: -1, 223: -1, 224: -1, 225: -1, 226: -1, 227: -1, 228: -1, 229: -1, 230: -1, 231: -1, 232: -1, 233: -1, 234: 39, 235: 40, 236: -1, 237: -1, 238: -1, 239: -1, 240: -1, 241: -1, 242: -1, 243: -1, 244: -1, 245: -1, 246: -1, 247: -1, 248: -1, 249: -1, 250: -1, 251: -1, 252: -1, 253: -1, 254: 41, 255: -1, 256: -1, 257: -1, 258: -1, 259: -1, 260: -1, 261: -1, 262: -1, 263: -1, 264: -1, 265: -1, 266: -1, 267: -1, 268: -1, 269: -1, 270: -1, 271: -1, 272: -1, 273: -1, 274: -1, 275: -1, 276: -1, 277: 42, 278: -1, 279: -1, 280: -1, 281: -1, 282: -1, 283: 43, 284: -1, 285: -1, 286: -1, 287: 44, 288: -1, 289: -1, 290: -1, 291: 45, 292: -1, 293: -1, 294: -1, 295: 46, 296: -1, 297: -1, 298: 47, 299: -1, 300: -1, 301: 48, 302: -1, 303: -1, 304: -1, 305: -1, 306: 49, 307: 50, 308: 51, 309: 52, 310: 53, 311: 54, 312: -1, 313: 55, 314: 56, 315: 57, 316: -1, 317: 58, 318: -1, 319: 59, 320: -1, 321: -1, 322: -1, 323: 60, 324: 61, 325: -1, 326: 62, 327: 63, 328: -1, 329: -1, 330: 64, 331: -1, 332: -1, 333: -1, 334: 65, 335: 66, 336: 67, 337: -1, 338: -1, 339: -1, 340: -1, 341: -1, 342: -1, 343: -1, 344: -1, 345: -1, 346: -1, 347: 68, 348: -1, 349: -1, 350: -1, 351: -1, 352: -1, 353: -1, 354: -1, 355: -1, 356: -1, 357: -1, 358: -1, 359: -1, 360: -1, 361: 69, 362: -1, 363: 70, 364: -1, 365: -1, 366: -1, 367: -1, 368: -1, 369: -1, 370: -1, 371: -1, 372: 71, 373: -1, 374: -1, 375: -1, 376: -1, 377: -1, 378: 72, 379: -1, 380: -1, 381: -1, 382: -1, 383: -1, 384: -1, 385: -1, 386: 73, 387: -1, 388: -1, 389: -1, 390: -1, 391: -1, 392: -1, 393: -1, 394: -1, 395: -1, 396: -1, 397: 74, 398: -1, 399: -1, 400: 75, 401: 76, 402: 77, 403: -1, 404: 78, 405: -1, 406: -1, 407: 79, 408: -1, 409: -1, 410: -1, 411: 80, 412: -1, 413: -1, 414: -1, 415: -1, 416: 81, 417: 82, 418: -1, 419: -1, 420: 83, 421: -1, 422: -1, 423: -1, 424: -1, 425: 84, 426: -1, 427: -1, 428: 85, 429: -1, 430: 86, 431: -1, 432: -1, 433: -1, 434: -1, 435: -1, 436: -1, 437: 87, 438: 88, 439: -1, 440: -1, 441: -1, 442: -1, 443: -1, 444: -1, 445: 89, 446: -1, 447: -1, 448: -1, 449: -1, 450: -1, 451: -1, 452: -1, 453: -1, 454: -1, 455: -1, 456: 90, 457: 91, 458: -1, 459: -1, 460: -1, 461: 92, 462: 93, 463: -1, 464: -1, 465: -1, 466: -1, 467: -1, 468: -1, 469: -1, 470: 94, 471: -1, 472: 95, 473: -1, 474: -1, 475: -1, 476: -1, 477: -1, 478: -1, 479: -1, 480: -1, 481: -1, 482: -1, 483: 96, 484: -1, 485: -1, 486: 97, 487: -1, 488: 98, 489: -1, 490: -1, 491: -1, 492: 99, 493: -1, 494: -1, 495: -1, 496: 100, 497: -1, 498: -1, 499: -1, 500: -1, 501: -1, 502: -1, 503: -1, 504: -1, 505: -1, 506: -1, 507: -1, 508: -1, 509: -1, 510: -1, 511: -1, 512: -1, 513: -1, 514: 101, 515: -1, 516: 102, 517: -1, 518: -1, 519: -1, 520: -1, 521: -1, 522: -1, 523: -1, 524: -1, 525: -1, 526: -1, 527: -1, 528: 103, 529: -1, 530: 104, 531: -1, 532: -1, 533: -1, 534: -1, 535: -1, 536: -1, 537: -1, 538: -1, 539: 105, 540: -1, 541: -1, 542: 106, 543: 107, 544: -1, 545: -1, 546: -1, 547: -1, 548: -1, 549: 108, 550: -1, 551: -1, 552: 109, 553: -1, 554: -1, 555: -1, 556: -1, 557: 110, 558: -1, 559: -1, 560: -1, 561: 111, 562: 112, 563: -1, 564: -1, 565: -1, 566: -1, 567: -1, 568: -1, 569: 113, 570: -1, 571: -1, 572: 114, 573: 115, 574: -1, 575: 116, 576: -1, 577: -1, 578: -1, 579: 117, 580: -1, 581: -1, 582: -1, 583: -1, 584: -1, 585: -1, 586: -1, 587: -1, 588: -1, 589: 118, 590: -1, 591: -1, 592: -1, 593: -1, 594: -1, 595: -1, 596: -1, 597: -1, 598: -1, 599: -1, 600: -1, 601: -1, 602: -1, 603: -1, 604: -1, 605: -1, 606: 119, 607: 120, 608: -1, 609: 121, 610: -1, 611: -1, 612: -1, 613: -1, 614: 122, 615: -1, 616: -1, 617: -1, 618: -1, 619: -1, 620: -1, 621: -1, 622: -1, 623: -1, 624: -1, 625: -1, 626: 123, 627: 124, 628: -1, 629: -1, 630: -1, 631: -1, 632: -1, 633: -1, 634: -1, 635: -1, 636: -1, 637: -1, 638: -1, 639: -1, 640: 125, 641: 126, 642: 127, 643: 128, 644: -1, 645: -1, 646: -1, 647: -1, 648: -1, 649: -1, 650: -1, 651: -1, 652: -1, 653: -1, 654: -1, 655: -1, 656: -1, 657: -1, 658: 129, 659: -1, 660: -1, 661: -1, 662: -1, 663: -1, 664: -1, 665: -1, 666: -1, 667: -1, 668: 130, 669: -1, 670: -1, 671: -1, 672: -1, 673: -1, 674: -1, 675: -1, 676: -1, 677: 131, 678: -1, 679: -1, 680: -1, 681: -1, 682: 132, 683: -1, 684: 133, 685: -1, 686: -1, 687: 134, 688: -1, 689: -1, 690: -1, 691: -1, 692: -1, 693: -1, 694: -1, 695: -1, 696: -1, 697: -1, 698: -1, 699: -1, 700: -1, 701: 135, 702: -1, 703: -1, 704: 136, 705: -1, 706: -1, 707: -1, 708: -1, 709: -1, 710: -1, 711: -1, 712: -1, 713: -1, 714: -1, 715: -1, 716: -1, 717: -1, 718: -1, 719: 137, 720: -1, 721: -1, 722: -1, 723: -1, 724: -1, 725: -1, 726: -1, 727: -1, 728: -1, 729: -1, 730: -1, 731: -1, 732: -1, 733: -1, 734: -1, 735: -1, 736: 138, 737: -1, 738: -1, 739: -1, 740: -1, 741: -1, 742: -1, 743: -1, 744: -1, 745: -1, 746: 139, 747: -1, 748: -1, 749: 140, 750: -1, 751: -1, 752: 141, 753: -1, 754: -1, 755: -1, 756: -1, 757: -1, 758: 142, 759: -1, 760: -1, 761: -1, 762: -1, 763: 143, 764: -1, 765: 144, 766: -1, 767: -1, 768: 145, 769: -1, 770: -1, 771: -1, 772: -1, 773: 146, 774: 147, 775: -1, 776: 148, 777: -1, 778: -1, 779: 149, 780: 150, 781: -1, 782: -1, 783: -1, 784: -1, 785: -1, 786: 151, 787: -1, 788: -1, 789: -1, 790: -1, 791: -1, 792: 152, 793: -1, 794: -1, 795: -1, 796: -1, 797: 153, 798: -1, 799: -1, 800: -1, 801: -1, 802: 154, 803: 155, 804: 156, 805: -1, 806: -1, 807: -1, 808: -1, 809: -1, 810: -1, 811: -1, 812: -1, 813: 157, 814: -1, 815: 158, 816: -1, 817: -1, 818: -1, 819: -1, 820: 159, 821: -1, 822: -1, 823: 160, 824: -1, 825: -1, 826: -1, 827: -1, 828: -1, 829: -1, 830: -1, 831: 161, 832: -1, 833: 162, 834: -1, 835: 163, 836: -1, 837: -1, 838: -1, 839: 164, 840: -1, 841: -1, 842: -1, 843: -1, 844: -1, 845: 165, 846: -1, 847: 166, 848: -1, 849: -1, 850: 167, 851: -1, 852: -1, 853: -1, 854: -1, 855: -1, 856: -1, 857: -1, 858: -1, 859: 168, 860: -1, 861: -1, 862: 169, 863: -1, 864: -1, 865: -1, 866: -1, 867: -1, 868: -1, 869: -1, 870: 170, 871: -1, 872: -1, 873: -1, 874: -1, 875: -1, 876: -1, 877: -1, 878: -1, 879: 171, 880: 172, 881: -1, 882: -1, 883: -1, 884: -1, 885: -1, 886: -1, 887: -1, 888: 173, 889: -1, 890: 174, 891: -1, 892: -1, 893: -1, 894: -1, 895: -1, 896: -1, 897: 175, 898: -1, 899: -1, 900: 176, 901: -1, 902: -1, 903: -1, 904: -1, 905: -1, 906: -1, 907: 177, 908: -1, 909: -1, 910: -1, 911: -1, 912: -1, 913: 178, 914: -1, 915: -1, 916: -1, 917: -1, 918: -1, 919: -1, 920: -1, 921: -1, 922: -1, 923: -1, 924: 179, 925: -1, 926: -1, 927: -1, 928: -1, 929: -1, 930: -1, 931: -1, 932: 180, 933: 181, 934: 182, 935: -1, 936: -1, 937: 183, 938: -1, 939: -1, 940: -1, 941: -1, 942: -1, 943: 184, 944: -1, 945: 185, 946: -1, 947: 186, 948: -1, 949: -1, 950: -1, 951: 187, 952: -1, 953: -1, 954: 188, 955: -1, 956: 189, 957: 190, 958: -1, 959: 191, 960: -1, 961: -1, 962: -1, 963: -1, 964: -1, 965: -1, 966: -1, 967: -1, 968: -1, 969: -1, 970: -1, 971: 192, 972: 193, 973: -1, 974: -1, 975: -1, 976: -1, 977: -1, 978: -1, 979: -1, 980: 194, 981: 195, 982: -1, 983: -1, 984: 196, 985: -1, 986: 197, 987: 198, 988: 199, 989: -1, 990: -1, 991: -1, 992: -1, 993: -1, 994: -1, 995: -1, 996: -1, 997: -1, 998: -1, 999: -1}
indices_in_1k = [k for k in thousand_k_to_200 if thousand_k_to_200[k] != -1]

# ImageNet-V2 indexes to ImageNet
#https://github.com/azshue/TPT/blob/main/data/imagenet_variants.py
imagenet_v_mask = [0, 1, 10, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 11,110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 12, 120, 121, 122,123, 124, 125, 126, 127, 128, 129, 13, 130, 131, 132, 133, 134, 135,136, 137, 138, 139, 14, 140, 141, 142, 143, 144, 145, 146, 147, 148,149, 15, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 16, 160,161, 162, 163, 164, 165, 166, 167, 168, 169, 17, 170, 171, 172, 173,174, 175, 176, 177, 178, 179, 18, 180, 181, 182, 183, 184, 185, 186,187, 188, 189, 19, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 2, 20, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 21, 210,211, 212, 213, 214, 215, 216, 217, 218, 219, 22, 220, 221, 222, 223,224, 225, 226, 227, 228, 229, 23, 230, 231, 232, 233, 234, 235, 236,237, 238, 239, 24, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 25, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 26, 260, 261,262, 263, 264, 265, 266, 267, 268, 269, 27, 270, 271, 272, 273, 274,275, 276, 277, 278, 279, 28, 280, 281, 282, 283, 284, 285, 286, 287,288, 289, 29, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 3, 30, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 31, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 32, 320, 321, 322, 323, 324,325, 326, 327, 328, 329, 33, 330, 331, 332, 333, 334, 335, 336, 337,338, 339, 34, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 35,350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 36, 360, 361, 362,363, 364, 365, 366, 367, 368, 369, 37, 370, 371, 372, 373, 374, 375,376, 377, 378, 379, 38, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 39, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 4, 40,400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 41, 410, 411, 412,413, 414, 415, 416, 417, 418, 419, 42, 420, 421, 422, 423, 424, 425,426, 427, 428, 429, 43, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 44, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 45, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 46, 460, 461, 462, 463,464, 465, 466, 467, 468, 469, 47, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 48, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 49, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 5, 50, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 51, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 52, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 53, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 54, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 55, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 56, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 57, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 58, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 59, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 6, 60, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 61, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 62, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 63, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 64, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 65, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 66, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 67, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 68, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 69, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 7, 70, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 71, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 72, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 73, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 74, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 75, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 76, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 77, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 78, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 79, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, 8, 80, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 81, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 82, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829, 83, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 84, 840, 841, 842, 843, 844, 845, 846, 847, 848, 849, 85, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 86, 860, 861, 862, 863, 864, 865, 866, 867, 868, 869, 87, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 88, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 89, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 9, 90, 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 91, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 92, 920, 921, 922, 923, 924, 925, 926, 927, 928, 929, 93, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 94, 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, 95, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 96, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 97, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 98, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 99, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999]

#### Data loader

In [None]:
def get_data(im_groot, transform):
    """
      Loads image dataset from the given root directory.

      Args:
          im_groot: path to root image dataset directory
          transform: composition of transform functions to apply on each image
      Returns:
          A dataset object that can be used by DataLoader further purpose
    """
    dataset = torchvision.datasets.ImageFolder(root=im_groot, transform=transform)

    return dataset

def pil_collate_fn(batch):
    """
      Custom function for DataLoader, which keeps PIL image, labels pairs in batch,
      for in order to apply my transformation during the TTA, rather than in
      data pre-processing

      Args:
          batch (list of tuples): Each element is a tuple (PIL image, label).
      Returns:
          list of images, tensor of labels
    """
    images, labels = zip(*batch)

    return list(images), torch.tensor(labels)

## Preparation for experiments

### Pre-trained models

For pre-trained model I used two different following models as backbone to the implementation:

1.   First,  I used [**ResNet-50**](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html) which is trained on ImageNet as this model is trained on original ImageNet datasset which consists of 1000 classes of images. The model was introduced in [He. et. al](https://arxiv.org/abs/1512.03385), is a CNN that has residual connections ($y=f(x, θ) + x$, where $x$ is original input and $f(x, θ)$ is output of the convolution layer), which helps training deeper models with gradient flow without vanishing. The model consists of 50 layers, processes the input image through a the convolutional layers with residual connections, progressively extracting high-level features, then applies global average pooling and a fully connected layer to predict the class label. The architecture is following:

  *   1 convolution layer (7x7),
  *   1 maxpool layer (7x7),
  *   16 Residual blocks, consisting of:
      *   Conv layer (1x1)
      *   Conv layer (3x3)
      *   Conv layer (1x1)
      *   Identity Bypass
      *   ReLU output
  *   Global avg-pooling layer
  *   FC layer
  *   Softmax output layer

Resnet50 model has BatchNorm layer, therefore, I also added adaptive batch norm when using Resnet50 as backbone. The weight difference source can be found [here](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights)
  
2.   For second, as the MEMO paper used vision transformer model in their experiment, I used [**Vit-B/16**](https://docs.pytorch.org/vision/main/models/generated/torchvision.models.vit_b_16.html), the base Vision Transformer model that used Transformer architechture for image classification task. Splits an input image intopatches, embeds them into vectors, adds a special [CLS] token and positional encodings, processes the sequence through a 12 Transformer encoder, and uses the final [CLS] token's representation to predict the class label via a linear classifier. The [CLS] token accumulates information from all the patches by the attention. The final classifier learns to use that token’s representation to output a prediction. The architecture steps of the model is following:
  *   Split into patches (16×16 ~ vit-b/16)
  *   Linear embedding
  *   Add CLS token
  *   Positional Embedding
  *   12 Transformer blocks, consisting of:
      * LayerNorm
      * Multi head Self-Attention
      * LayerNorm
      * MLP
      * Residual connection
  *   CLS Token Output
  *   Classification Head (Linear layer)
  *   Softmax output

The Vit-B/16 model has only layer norm is implemented, therefore when using Vit-B/16 as a backbone, adaptive batch norm is not applied at all.

***Instruction***: For each pre-trained model, I defined cells for preparing the model, data loader, learning rate, optimizer and different choice of augmentations. Therefore, to run an experiment, only need to run cell for desired pre-trained model.

##### ResNet50

In [None]:
# Pre-trained model ResNet50
def get_resnet(is_default_weight=True):
    """
      The function returns ResNet50 pre-trained model, depending on the weight version specification

      Arg:
          is_default_weight: specifies weights version of the model that need to take
              true: model with weights `IMAGENET1K_V2` with higher top-1 accuracy 80.85
              false: model with weights `IMAGENET1K_V1` with top-1 accuracy 76.13
    """
    if is_default_weight:
        preprocess = ResNet50_Weights.DEFAULT.transforms()
        return preprocess, resnet50(weights=ResNet50_Weights.DEFAULT)#.to(device)
    else:
        preprocess = ResNet50_Weights.IMAGENET1K_V1.transforms()
        return preprocess, resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)#.to(device)

def get_resnet_loader(is_default_weight=True, augmentation_option="option1", is_imgnet_a=True):
    preprocess, ResNet = get_resnet(is_default_weight)

    # Hyperparameters
    learning_rate = 0.00025

    # optimizer
    optimizer_resnet = get_optimizer(ResNet, learning_rate)

    match augmentation_option:
        case 'option1':
            # For baseline
            # Preprocess first, and then put augmentation RandomResizedCrop
            if is_imgnet_a == True:
                dataset_imagenet_a =  get_data(im_groot_imagenet_a, preprocess)
                data_loader_imagenet_a = torch.utils.data.DataLoader(dataset_imagenet_a, DEFAULT_BATCH_SIZE)
            else:
                dataset_imagenet_v2 =  get_data(im_groot_imagenet_v2, preprocess)
                data_loader_imagenet_v2 = torch.utils.data.DataLoader(dataset_imagenet_v2, DEFAULT_BATCH_SIZE)

            # augmentations = None
            augmentations = T.Compose([
                v2.ToImage(),
                v2.ToDtype(torch.uint8, scale=True),
                v2.Resize(preprocess.resize_size, interpolation=preprocess.interpolation,
                            antialias=preprocess.antialias),
                v2.CenterCrop(preprocess.crop_size),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(preprocess.mean, preprocess.std)
            ])

        case 'option2':
            # For testing MEMO
            # Augmentation - RandomResizedCrop
            if is_imgnet_a == True:
                dataset_imagenet_a =  get_data(im_groot_imagenet_a, None)
                data_loader_imagenet_a = torch.utils.data.DataLoader(dataset_imagenet_a, DEFAULT_BATCH_SIZE, collate_fn=pil_collate_fn)
            else:
                dataset_imagenet_v2 =  get_data(im_groot_imagenet_v2, None)
                data_loader_imagenet_v2 = torch.utils.data.DataLoader(dataset_imagenet_v2, DEFAULT_BATCH_SIZE, collate_fn=pil_collate_fn)

            augmentations = T.Compose([
                v2.ToImage(),
                v2.ToDtype(torch.uint8, scale=True),
                v2.RandomResizedCrop(224, antialias=True),
                v2.Resize(preprocess.resize_size, interpolation=preprocess.interpolation,
                            antialias=preprocess.antialias),
                v2.CenterCrop(preprocess.crop_size),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(preprocess.mean, preprocess.std)
            ])
        case 'option3':
            # For testing MEMO
            # Augmentation - RandomResizedCrop + RandomPerspective
            if is_imgnet_a == True:
                dataset_imagenet_a =  get_data(im_groot_imagenet_a, None)
                data_loader_imagenet_a = torch.utils.data.DataLoader(dataset_imagenet_a, DEFAULT_BATCH_SIZE, collate_fn=pil_collate_fn)
            else:
                dataset_imagenet_v2 =  get_data(im_groot_imagenet_v2, None)
                data_loader_imagenet_v2 = torch.utils.data.DataLoader(dataset_imagenet_v2, DEFAULT_BATCH_SIZE, collate_fn=pil_collate_fn)

            augmentations = T.Compose([
                v2.ToImage(),
                v2.ToDtype(torch.uint8, scale=True),
                v2.RandomResizedCrop(224, antialias=True),
                v2.RandomHorizontalFlip(),
                v2.Resize(preprocess.resize_size, interpolation=preprocess.interpolation,
                            antialias=preprocess.antialias),
                v2.CenterCrop(preprocess.crop_size),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(preprocess.mean, preprocess.std)
            ])
        case 'option4':
            # For testing MEMO
            # Augmentation - RandomResizedCrop + RandomPerspective + RandomAffine + RandomHorizontalFlip
            if is_imgnet_a == True:
                dataset_imagenet_a =  get_data(im_groot_imagenet_a, None)
                data_loader_imagenet_a = torch.utils.data.DataLoader(dataset_imagenet_a, DEFAULT_BATCH_SIZE, collate_fn=pil_collate_fn)
            else:
                dataset_imagenet_v2 =  get_data(im_groot_imagenet_v2, None)
                data_loader_imagenet_v2 = torch.utils.data.DataLoader(dataset_imagenet_v2, DEFAULT_BATCH_SIZE, collate_fn=pil_collate_fn)

            augmentations = T.Compose([
                v2.ToImage(),
                v2.ToDtype(torch.uint8, scale=True),
                v2.RandomResizedCrop(224, antialias=True),
                v2.RandomAffine(degrees=0, scale=(0.9, 1.2)),
                v2.RandomPerspective(),
                v2.RandomHorizontalFlip(),
                v2.Resize(preprocess.resize_size, interpolation=preprocess.interpolation,
                            antialias=preprocess.antialias),
                v2.CenterCrop(preprocess.crop_size),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(preprocess.mean, preprocess.std)
            ])

    if is_imgnet_a == True:
        return ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_a, augmentations, preprocess

    return ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_v2, augmentations, preprocess

##### VIT-B/16

In [None]:
# Pre-trained model VIT-B/16
def get_vit():
    """
      The function returns VIT-B/16 pre-trained model, depending on the weight version specification
                           the model with weights `IMAGENET1K_V1` with higher top-1 accuracy 81.07
    """
    preprocess_vit = ViT_B_16_Weights.DEFAULT.transforms()
    return preprocess_vit, vit_b_16(weights=ViT_B_16_Weights.DEFAULT)#.to(device)

def get_vit_loader(augmentation_option, is_imgnet_a):
    preprocess_vit, VIT_V1 = get_vit()

    # Hyperparameters
    learning_rate = 0.00005

    # optimizer
    optimizer_vit = get_optimizer(VIT_V1, learning_rate)

    match augmentation_option:
        case 'option1':
            # For baseline testing
            # Preprocess first, and then put augmentation RandomResizedCrop
            if is_imgnet_a == True:
                dataset_imagenet_a_vit =  get_data(im_groot_imagenet_a, preprocess_vit)
                data_loader_imagenet_a_vit = torch.utils.data.DataLoader(dataset_imagenet_a_vit, DEFAULT_BATCH_SIZE)
            else:
                dataset_imagenet_v2_vit =  get_data(im_groot_imagenet_v2, preprocess_vit)
                data_loader_imagenet_v2_vit = torch.utils.data.DataLoader(dataset_imagenet_v2_vit, DEFAULT_BATCH_SIZE)

            # augmentations_vit = None
            augmentations_vit = T.Compose([
                v2.ToImage(),
                v2.ToDtype(torch.uint8, scale=True),
                v2.Resize(preprocess_vit.resize_size, interpolation=preprocess_vit.interpolation,
                            antialias=preprocess_vit.antialias),
                v2.CenterCrop(preprocess_vit.crop_size),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(preprocess_vit.mean, preprocess_vit.std)
            ])

        case 'option2':
            # For testing MEMO
            # Augmentation - RandomResizedCrop
            if is_imgnet_a == True:
                dataset_imagenet_a_vit =  get_data(im_groot_imagenet_a, transform=None)
                data_loader_imagenet_a_vit = torch.utils.data.DataLoader(dataset_imagenet_a_vit, DEFAULT_BATCH_SIZE, collate_fn=pil_collate_fn)
            else:
                dataset_imagenet_v2_vit =  get_data(im_groot_imagenet_v2, transform=None)
                data_loader_imagenet_v2_vit = torch.utils.data.DataLoader(dataset_imagenet_v2_vit, DEFAULT_BATCH_SIZE, collate_fn=pil_collate_fn)

            augmentations_vit = T.Compose([
                v2.ToImage(),
                v2.ToDtype(torch.uint8, scale=True),
                v2.RandomResizedCrop(224, antialias=True),
                v2.Resize(preprocess_vit.resize_size, interpolation=preprocess_vit.interpolation,
                          antialias=preprocess_vit.antialias),
                v2.CenterCrop(preprocess_vit.crop_size),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(preprocess_vit.mean, preprocess_vit.std)
            ])
        case 'option3':
            # For testing MEMO
            # Augmentation - RandomResizedCrop + RandomPerspective
            if is_imgnet_a == True:
                dataset_imagenet_a_vit =  get_data(im_groot_imagenet_a, transform=None)
                data_loader_imagenet_a_vit = torch.utils.data.DataLoader(dataset_imagenet_a_vit, DEFAULT_BATCH_SIZE, collate_fn=pil_collate_fn)
            else:
                dataset_imagenet_v2_vit =  get_data(im_groot_imagenet_v2, transform=None)
                data_loader_imagenet_v2_vit = torch.utils.data.DataLoader(dataset_imagenet_v2_vit, DEFAULT_BATCH_SIZE, collate_fn=pil_collate_fn)

            augmentations_vit = T.Compose([
                v2.ToImage(),
                v2.ToDtype(torch.uint8, scale=True),
                v2.RandomResizedCrop(224, antialias=True),
                v2.RandomHorizontalFlip(),
                v2.Resize(preprocess_vit.resize_size, interpolation=preprocess_vit.interpolation,
                          antialias=preprocess_vit.antialias),
                v2.CenterCrop(preprocess_vit.crop_size),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(preprocess_vit.mean, preprocess_vit.std)
            ])
        case 'option4':
            # For testing MEMO
            # Augmentation - RandomResizedCrop + RandomPerspective + RandomAffine + RandomHorizontalFlip
            if is_imgnet_a == True:
                dataset_imagenet_a_vit =  get_data(im_groot_imagenet_a, transform=None)
                data_loader_imagenet_a_vit = torch.utils.data.DataLoader(dataset_imagenet_a_vit, DEFAULT_BATCH_SIZE, collate_fn=pil_collate_fn)
            else:
                dataset_imagenet_v2_vit =  get_data(im_groot_imagenet_v2, transform=None)
                data_loader_imagenet_v2_vit = torch.utils.data.DataLoader(dataset_imagenet_v2_vit, DEFAULT_BATCH_SIZE, collate_fn=pil_collate_fn)

            augmentations_vit = T.Compose([
                v2.ToImage(),
                v2.ToDtype(torch.uint8, scale=True),
                v2.RandomResizedCrop(224, antialias=True),
                v2.RandomAffine(degrees=0, scale=(0.9, 1.2)),
                v2.RandomPerspective(),
                v2.RandomHorizontalFlip(),
                v2.Resize(preprocess_vit.resize_size, interpolation=preprocess_vit.interpolation,
                          antialias=preprocess_vit.antialias),
                v2.CenterCrop(preprocess_vit.crop_size),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(preprocess_vit.mean, preprocess_vit.std)
            ])

    if is_imgnet_a == True:
        return VIT_V1, learning_rate, optimizer_vit, data_loader_imagenet_a_vit, augmentations_vit, preprocess_vit

    return VIT_V1, learning_rate, optimizer_vit, data_loader_imagenet_v2_vit, augmentations_vit, preprocess_vit

### Test functions

In [None]:
"""
  Evaluates the performance of a trained model on a given test dataset
  for testing baseline without applying the TTA method

  Args:
      model:         The model to be evaluated.
      data_loader:   DataLoader for the test dataset.
      device:        Device to perform computation on
  Returns:
      The test accuracy computed over the dataset.
"""
def test_model(model, data_loader, device=device):
    # for computing accuracy
    count = 0.0
    accuracy = 0.0

    # Load model into device (GPU)
    model.to(device)

    # Set the network to evaluation mode
    model.eval()

    # Disable gradient computation (it's only a test time, model update should not happen)
    with torch.no_grad():
        for _, (input, target) in enumerate(tqdm(data_loader)):
            # Load data into device (GPU)
            input = input.to(device)
            target = target.to(device)

            # Forward pass
            output = model(input)

            # Get the predicted class by taking the index of the max logit across classes
            _, predicted = output.max(1)

            # Increase total sample count by batch size (which is 1)
            count += input.shape[0]

            # Compare prediction with ground truth and count correct predictions
            accuracy += predicted.eq(target).sum().item()

            del output, input, target, predicted

    # Compute final accuracy
    return accuracy / count * 100

In [None]:
"""
  Evaluates the performance of a model that applies TTA method

  Args:
      model:         The model to be evaluated.
      data_loader:   DataLoader for the test dataset.
      device:        Device to perform computation on
  Returns:
      The test accuracy computed over the dataset.
  """
def test_model_tta_applied(model, data_loader, device=device):
    # for computing accuracy
    count = 0.0
    accuracy = 0.0

    # Load model into device (GPU)
    model.to(device)

    # Set the network to evaluation mode
    model.eval()

    for _, (input, target) in enumerate(tqdm(data_loader)):
        # Load data into device (GPU)
        input = input[0]
        target = target.to(device)

        # Forward pass - only one example at a time
        output = model([input])

        # Get the predicted class by taking the index of the max logit across classes
        _, predicted = output.max(1)

          # Increase total sample count by batch size (which is 1)
        count += 1

        # Compare prediction with ground truth and count correct predictions
        accuracy += predicted.eq(target).sum().item()
        del output, input, target, predicted

    # Compute final accuracy
    return accuracy / count * 100

## Experiments

##### Logger util

In [None]:
def setup_experiment_logger(log_folder: str, experiment_name: str) -> logging.Logger:
    """
        For the purpose of loggin my result
        As there is a risk of cannot see my result if aws account signed out or something happens
    """
    os.makedirs(log_folder, exist_ok=True)

    log_path = os.path.join(log_folder, f"{experiment_name}.log")

    logger = logging.getLogger(experiment_name)
    logger.setLevel(logging.INFO)

    if not logger.handlers:
        file_handler = logging.FileHandler(log_path)
        file_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
        logger.addHandler(file_handler)

    return logger

### Baselines

#### Baseline ~ ResNet50_V1

In [None]:
# ImageNet-A
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_a, _, _ = get_resnet_loader(is_default_weight=False, augmentation_option="option1", is_imgnet_a=True)

model = MEMO(ResNet, learning_rate, optimizer_resnet, marginal_entropy_loss, indices_in_1k, apply_tta=False, augmentations=None)
test_result = test_model(model, data_loader_imagenet_a)

print("\nResult: Baseline - ResNet50_V1 ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "Baseline_ResNet50_V1_ImageNet-A")
logger.info("Result: Baseline - ResNet50_V1 ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: Baseline - ResNet50_V1 ~ ImageNet-A
0.02666666666666667




In [None]:
# ImageNet-V2
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_v2, _, _ = get_resnet_loader(is_default_weight=False, augmentation_option="option1", is_imgnet_a=False)

model = MEMO(ResNet, learning_rate, optimizer_resnet, marginal_entropy_loss, imagenet_v_mask, apply_tta=False, augmentations=None)
test_result = test_model(model, data_loader_imagenet_v2)

print("\nResult: Baseline - ResNet50_V1 ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "Baseline_ResNet50_V1_ImageNet-V2")
logger.info("Result: Baseline - ResNet50_V1 ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: Baseline - ResNet50_V1 ~ ImageNet-V2
63.14999999999999




#### Baseline ~ ResNet50_V2

In [None]:
# ImageNet-A
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_a, _, _ = get_resnet_loader(is_default_weight=True, augmentation_option="option1", is_imgnet_a=True)

model = MEMO(ResNet, learning_rate, optimizer_resnet, marginal_entropy_loss, indices_in_1k, apply_tta=False, augmentations=None)
test_result = test_model(model, data_loader_imagenet_a)

print("\nResult: Baseline - ResNet50_V2 ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "Baseline_ResNet50_V2_ImageNet-A")
logger.info("Result: Baseline - ResNet50_V2 ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: Baseline - ResNet50_V2 ~ ImageNet-A
14.266666666666666




In [None]:
# ImageNet-V2
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_v2, _, _ = get_resnet_loader(is_default_weight=True, augmentation_option="option1", is_imgnet_a=False)

model = MEMO(ResNet, learning_rate, optimizer_resnet, marginal_entropy_loss, imagenet_v_mask, apply_tta=False, augmentations=None)
test_result = test_model(model, data_loader_imagenet_v2)

print("\nResult: Baseline - ResNet50_V2 ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "Baseline_ResNet50_V2_ImageNet-V2")
logger.info("Result: Baseline - ResNet50_V2 ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: Baseline - ResNet50_V2 ~ ImageNet-V2
69.89




#### Baseline ~ VIT-B/16

In [None]:
# ImageNet-A
VIT_V1, learning_rate, optimizer_vit, data_loader_imagenet_a_vit, _, _ = get_vit_loader(augmentation_option="option1", is_imgnet_a=True)

model = MEMO(VIT_V1, learning_rate, optimizer_vit, marginal_entropy_loss, indices_in_1k, apply_tta=False, augmentations=None)
test_result = test_model(model, data_loader_imagenet_a_vit)

print("\nResult: Baseline - ViT_B_16_V1 ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "Baseline_ViT_B_16_V1_ImageNet-A")
logger.info("Result: Baseline - ViT_B_16_V1 ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: Baseline - ViT_B_16_V1 ~ ImageNet-A
20.746666666666666




In [None]:
# ImageNet-V2
VIT_V1, learning_rate, optimizer_vit, data_loader_imagenet_v2_vit, _, _ = get_vit_loader(augmentation_option="option1", is_imgnet_a=False)

model = MEMO(VIT_V1, learning_rate, optimizer_vit, marginal_entropy_loss, imagenet_v_mask, apply_tta=False, augmentations=None)
test_result = test_model(model, data_loader_imagenet_v2_vit)

print("\nResult: Baseline - ViT_B_16_V1 ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "Baseline_ViT_B_16_V1_ImageNet-V2")
logger.info("Result: Baseline - ViT_B_16_V1 ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: Baseline - ViT_B_16_V1 ~ ImageNet-V2
69.57




### Test Time Adaptation applied

#### RestNet50_V2 + MEMO ~ applied TTA ~ RandomResizedCrop

In [None]:
# ImageNet-A
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_a, augmentations, preprocess = get_resnet_loader(is_default_weight=True, augmentation_option="option2", is_imgnet_a=True)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             marginal_entropy_loss,
             indices_in_1k,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=None,
             apply_adaptive_bn=False,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_a)

print("\nResult: MEMO ~ TTA - ResNet50_V2 + RandomResizedCrop ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V2_Memo_TTA_RandomResizedCrop_ImageNet-A")
logger.info("Result: MEMO ~ TTA - ResNet50_V2 + RandomResizedCrop ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop ~ ImageNet-A
18.413333333333334




In [None]:
# ImageNet-V2
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_v2, augmentations, preprocess = get_resnet_loader(is_default_weight=True, augmentation_option="option2", is_imgnet_a=False)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             marginal_entropy_loss,
             imagenet_v_mask,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=None,
             apply_adaptive_bn=False,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_v2)

print("\nResult: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V2_Memo_TTA_RandomResizedCrop_ImageNet-V2")
logger.info("Result: MEMO ~ TTA - ResNet50_V2 + RandomResizedCrop ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop ~ ImageNet-V2
76.03




#### RestNet50_V2 + MEMO ~ applied TTA ~ RandomResizedCrop + BN

In [None]:
# ImageNet-A
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_a, augmentations, preprocess = get_resnet_loader(is_default_weight=True, augmentation_option="option2", is_imgnet_a=True)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             marginal_entropy_loss,
             indices_in_1k,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_a)

print("\nResult: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V2_Memo_TTA_BN_RandomResizedCrop_ImageNet-A")
logger.info("Result: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop ~ ImageNet-A
24.240000000000002




In [None]:
# ImageNet-V2
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_v2, augmentations, preprocess = get_resnet_loader(is_default_weight=True, augmentation_option="option2", is_imgnet_a=False)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             marginal_entropy_loss,
             imagenet_v_mask,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_v2)

print("\nResult: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V2_Memo_TTA_BN_RandomResizedCrop_ImageNet-V2")
logger.info("Result: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop ~ ImageNet-V2
78.69




#### RestNet50_V2 + MEMO ~ applied TTA + BN + Multi Augmentation

In [None]:
# ImageNet-A
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_a, augmentations, preprocess = get_resnet_loader(is_default_weight=True, augmentation_option="option3", is_imgnet_a=True)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             marginal_entropy_loss,
             indices_in_1k,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_a)

####  RandomResizedCrop + RandomHorizontalFlip
print("\nResult: MEMO ~ TTA - ResNet50_V2 + BN + Augmentation Random Order 2 ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V2_Memo_TTA_BN_AugRandOrd_ImageNet-A")
logger.info("Result: MEMO ~ TTA - ResNet50_V2 + BN + Augmentation Random Order ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V2 + BN + Augmentation Random Order 2 ~ ImageNet-A
24.306666666666665




In [None]:
# ImageNet-V2
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_v2, augmentations, preprocess = get_resnet_loader(is_default_weight=True, augmentation_option="option3", is_imgnet_a=False)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             marginal_entropy_loss,
             imagenet_v_mask,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_v2)

print("\nResult: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop + RandomHorizontalFlip ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V2_Memo_TTA_BN_RResizedCrop_HorFlip__ImageNet-V2")
logger.info("Result: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop + RandomHorizontalFlip ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop + RandomHorizontalFlip ~ ImageNet-V2
78.9




#### RestNet50_V2 + MEMO ~ applied TTA + BN + Mixture of Augmentations

In [None]:
# ImageNet-A
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_a, augmentations, preprocess = get_resnet_loader(is_default_weight=True, augmentation_option="option4", is_imgnet_a=True)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             marginal_entropy_loss,
             indices_in_1k,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_a)

####  RandomResizedCrop + RandomHorizontalFlip
print("\nResult: MEMO ~ TTA - ResNet50_V2 + BN + Mixture of Augmentations ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V2_Memo_TTA_BN_MixtureAugs_ImageNet-A")
logger.info("Result: MEMO ~ TTA - ResNet50_V2 + BN + Mixture of Augmentations ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V2 + BN + Mixture of Augmentations ~ ImageNet-A
24.133333333333333




In [None]:
# ImageNet-V2
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_v2, augmentations, preprocess = get_resnet_loader(is_default_weight=True, augmentation_option="option4", is_imgnet_a=False)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             marginal_entropy_loss,
             imagenet_v_mask,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_v2)

print("\nResult: MEMO ~ TTA - ResNet50_V2 + BN + Mixture of Augmentations ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V2_Memo_TTA_BN_MixtureAugs_ImageNet-V2")
logger.info("Result: MEMO ~ TTA - ResNet50_V2 + BN + Mixture of Augmentations ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V2 + BN + Mixture of Augmentations ~ ImageNet-V2
78.5




#### RestNet50_V2 + MEMO ~ applied TTA + BN + Sharpened Softmax Marginal Entropy

In [None]:
# ImageNet-A
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_a, augmentations, preprocess = get_resnet_loader(is_default_weight=True, augmentation_option="option2", is_imgnet_a=True)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             sharpened_softmax_entropy,
             indices_in_1k,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=TEMPERATURE,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_a)

print("\nResult: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop + Sharpened Softmax Entropy ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V2_Memo_TTA_BN_RandomResizedCrop_SharpSoftMax_ImageNet-A")
logger.info("Result: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop + Sharpened Softmax Entropy ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop + Sharpened Softmax Entropy ~ ImageNet-A
29.86666666666667




In [None]:
# ImageNet-V2
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_v2, augmentations, preprocess = get_resnet_loader(is_default_weight=True, augmentation_option="option2", is_imgnet_a=False)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             sharpened_softmax_entropy,
             imagenet_v_mask,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=TEMPERATURE,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_v2)

print("\nResult: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop + Sharpened Softmax Entropy ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V2_Memo_TTA_BN_RandomResizedCrop_SharpSoftMax_ImageNet-V2")
logger.info("Result: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop + Sharpened Softmax Entropy ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop + Sharpened Softmax Entropy ~ ImageNet-V2
75.49




#### RestNet50_V2 + MEMO ~ applied TTA + BN + Augmentation Weigthed Marginal Entropy

In [None]:
# ImageNet-A
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_a, augmentations, preprocess = get_resnet_loader(is_default_weight=True, augmentation_option="option2", is_imgnet_a=True)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             augmentation_weighted_entropy,
             indices_in_1k,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_a)

print("\nResult: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop + Augmentation Weighted Entropy ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V2_Memo_TTA_BN_RandomResizedCrop_WeightedEntropy_ImageNet-A")
logger.info("Result: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop + Augmentation Weighted Entropy ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop + Augmentation Weighted Entropy ~ ImageNet-A
24.226666666666667




In [None]:
# ImageNet-V2
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_v2, augmentations, preprocess = get_resnet_loader(is_default_weight=True, augmentation_option="option2", is_imgnet_a=False)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             augmentation_weighted_entropy,
             imagenet_v_mask,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_v2)

print("\nResult: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop + Augmentation Weighted Entropy ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V2_Memo_TTA_BN_RandomResizedCrop_WeightedEntropy_ImageNet-V2")
logger.info("Result: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop + Augmentation Weighted Entropy ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V2 + BN + RandomResizedCrop + Augmentation Weighted Entropy ~ ImageNet-V2
76.78




NameError: name 'setup_experiment_logger' is not defined

#### RestNet50_V1 + MEMO ~ applied TTA ~ RandomResizedCrop

In [None]:
# ImageNet-A
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_a, augmentations, preprocess = get_resnet_loader(is_default_weight=False, augmentation_option="option2", is_imgnet_a=True)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             marginal_entropy_loss,
             indices_in_1k,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=None,
             apply_adaptive_bn=False,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_a)

print("\nResult: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V1_Memo_TTA_RandomResizedCrop_ImageNet-A")
logger.info("Result: MEMO ~ TTA - ResNet50_V1 + RandomResizedCrop ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop ~ ImageNet-A
3.6533333333333333




In [None]:
# ImageNet-V2
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_v2, augmentations, preprocess = get_resnet_loader(is_default_weight=False, augmentation_option="option2", is_imgnet_a=False)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             marginal_entropy_loss,
             imagenet_v_mask,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=None,
             apply_adaptive_bn=False,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_v2)

print("\nResult: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V1_Memo_TTA_BN_RandomResizedCrop_ImageNet-V2")
logger.info("Result: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop ~ ImageNet-V2
69.03




#### RestNet50_V1 + MEMO ~ applied TTA ~ RandomResizedCrop + BN

In [None]:
# ImageNet-A ###
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_a, augmentations, preprocess = get_resnet_loader(is_default_weight=False, augmentation_option="option2", is_imgnet_a=True)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             marginal_entropy_loss,
             indices_in_1k,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_a)

print("\nResult: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V1_Memo_TTA_BN_RandomResizedCrop_ImageNet-A")
logger.info("Result: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop ~ ImageNet-A
8.426666666666668




In [None]:
# ImageNet-V2
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_v2, augmentations, preprocess = get_resnet_loader(is_default_weight=False, augmentation_option="option2", is_imgnet_a=False)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             marginal_entropy_loss,
             imagenet_v_mask,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_v2)

print("\nResult: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V1_Memo_TTA_BN_RandomResizedCrop_ImageNet-V2")
logger.info("Result: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop ~ ImageNet-V2
68.2




#### RestNet50_v1 + MEMO ~ applied TTA + BN + Multi Augmentation

In [None]:
# ImageNet-A
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_a, augmentations, preprocess = get_resnet_loader(is_default_weight=False, augmentation_option="option3", is_imgnet_a=True)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             marginal_entropy_loss,
             indices_in_1k,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_a)

# RandomResizedCrop + RandomHorizontalFlip
print("\nResult: MEMO ~ TTA - ResNet50_V1 + BN + Augmentation Random Order 2 ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V1_Memo_TTA_BN_AugRandOrd_ImageNet-A")
logger.info("Result: MEMO ~ TTA - ResNet50_V1 + BN + Augmentation Random Order ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V1 + BN + Augmentation Random Order 2 ~ ImageNet-A
8.173333333333334




In [None]:
# ImageNet-V2
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_v2, augmentations, preprocess = get_resnet_loader(is_default_weight=False, augmentation_option="option3", is_imgnet_a=False)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             marginal_entropy_loss,
             imagenet_v_mask,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_v2)

print("\nResult: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop + RandomHorizontalFlip ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V1_Memo_TTA_BN_RResizedCrop_HorFlip__ImageNet-V2")
logger.info("Result: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop + RandomHorizontalFlip ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop + RandomHorizontalFlip ~ ImageNet-V2
67.78999999999999




#### RestNet50_V1 + MEMO ~ applied TTA + BN + Mixture of Augmentations

In [None]:
# ImageNet-A
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_a, augmentations, preprocess = get_resnet_loader(is_default_weight=False, augmentation_option="option4", is_imgnet_a=True)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             marginal_entropy_loss,
             indices_in_1k,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_a)

# RandomResizedCrop + RandomHorizontalFlip
print("\nResult: MEMO ~ TTA - ResNet50_V1 + BN + Mixture of Augmentations ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V1_Memo_TTA_BN_MixtureAugs_ImageNet-A")
logger.info("Result: MEMO ~ TTA - ResNet50_V1 + BN + Mixture of Augmentations ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V1 + BN + Mixture of Augmentations ~ ImageNet-A
7.506666666666667




In [None]:
# ImageNet-V2
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_v2, augmentations, preprocess = get_resnet_loader(is_default_weight=False, augmentation_option="option4", is_imgnet_a=False)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             marginal_entropy_loss,
             imagenet_v_mask,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_v2)

print("\nResult: MEMO ~ TTA - ResNet50_V1 + BN + Mixture of Augmentations ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V1_Memo_TTA_BN_MixtureAugs_ImageNet-V2")
logger.info("Result: MEMO ~ TTA - ResNet50_V1 + BN + Mixture of Augmentations ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V1 + BN + Mixture of Augmentations ~ ImageNet-V2
66.13




#### RestNet50_V1 + MEMO ~ applied TTA + BN + Sharpened Softmax Marginal Entropy

In [None]:
# ImageNet-A
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_a, augmentations, preprocess = get_resnet_loader(is_default_weight=False, augmentation_option="option2", is_imgnet_a=True)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             sharpened_softmax_entropy,
             indices_in_1k,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=TEMPERATURE,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_a)

print("\nResult: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop + Sharpened Softmax Entropy ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V1_Memo_TTA_BN_RandomResizedCrop_SharpSoftMax_ImageNet-A")
logger.info("Result: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop + Sharpened Softmax Entropy ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop + Sharpened Softmax Entropy ~ ImageNet-A
4.84




In [None]:
# ImageNet-V2
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_v2, augmentations, preprocess = get_resnet_loader(is_default_weight=False, augmentation_option="option2", is_imgnet_a=False)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             sharpened_softmax_entropy,
             imagenet_v_mask,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=TEMPERATURE,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_v2)

print("\nResult: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop + Sharpened Softmax Entropy ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V1_Memo_TTA_BN_RandomResizedCrop_SharpSoftMax_ImageNet-V2")
logger.info("Result: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop + Sharpened Softmax Entropy ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop + Sharpened Softmax Entropy ~ ImageNet-V2
68.03




#### RestNet50_V1 + MEMO ~ applied TTA + BN + Augmentation Weigthed Marginal Entropy

In [None]:
#### ImageNet-A
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_a, augmentations, preprocess = get_resnet_loader(is_default_weight=False, augmentation_option="option2", is_imgnet_a=True)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             augmentation_weighted_entropy,
             indices_in_1k,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_a)

print("\nResult: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop + Augmentation Weighted Entropy ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V1_Memo_TTA_BN_RandomResizedCrop_WeightedEntropy_ImageNet-A")
logger.info("Result: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop + Augmentation Weighted Entropy ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop + Augmentation Weighted Entropy ~ ImageNet-A
9.173333333333334




In [None]:
# ImageNet-V2
ResNet, learning_rate, optimizer_resnet, data_loader_imagenet_v2, augmentations, preprocess = get_resnet_loader(is_default_weight=False, augmentation_option="option2", is_imgnet_a=False)

model = MEMO(ResNet,
             learning_rate,
             optimizer_resnet,
             augmentation_weighted_entropy,
             imagenet_v_mask,
             apply_tta=True,
             augmentations=augmentations,
             prior_strenght_bn=PRIOR_STRENGTH,
             apply_adaptive_bn=True,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess)

test_result = test_model_tta_applied(model, data_loader_imagenet_v2)

print("\nResult: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop + Augmentation Weighted Entropy ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "ResNet50_V1_Memo_TTA_BN_RandomResizedCrop_WeightedEntropy_ImageNet-V2")
logger.info("Result: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop + Augmentation Weighted Entropy ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: MEMO ~ TTA - ResNet50_V1 + BN + RandomResizedCrop + Augmentation Weighted Entropy ~ ImageNet-V2
68.8




#### VIT-B/16 + MEMO ~ RandomResizedCrop

In [None]:
# ImageNet-A
VIT_V1, learning_rate, optimizer_vit, data_loader_imagenet_a_vit, augmentations_vit, prepocess_vit = get_vit_loader(augmentation_option="option2", is_imgnet_a=True)
model = MEMO(VIT_V1,
             learning_rate,
             optimizer_vit,
             marginal_entropy_loss,
             indices_in_1k,
             apply_tta=True,
             augmentations=augmentations_vit,
             prior_strenght_bn=None,
             apply_adaptive_bn=False,
             temperature=None,
             apply_transform=True,
             preprocess=prepocess_vit)

test_result = test_model_tta_applied(model, data_loader_imagenet_a_vit)

print("\nResult: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "VIT-B16_V1_Memo_TTA_RandomResizedCrop_ImageNet-A")
logger.info("Result: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop ~ ImageNet-A
25.426666666666662




In [None]:
# ImageNet-V2
VIT_V1, learning_rate, optimizer_vit, data_loader_imagenet_v2_vit, augmentations_vit, preprocess_vit = get_vit_loader(augmentation_option="option2", is_imgnet_a=False)

model = MEMO(VIT_V1,
             learning_rate,
             optimizer_vit,
             marginal_entropy_loss,
             imagenet_v_mask,
             apply_tta=True,
             augmentations=augmentations_vit,
             prior_strenght_bn=None,
             apply_adaptive_bn=False,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess_vit)

test_result = test_model_tta_applied(model, data_loader_imagenet_v2_vit)

print("\nResult: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "VIT-B16_V1_Memo_TTA_RandomResizedCrop_ImageNet-V2")
logger.info("Result: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop ~ ImageNet-V2
72.89999999999999




#### VIT-B/16 + MEMO ~ Multi Augmentation

In [None]:
# ImageNet-A
VIT_V1, learning_rate, optimizer_vit, data_loader_imagenet_a_vit, augmentations_vit, preprocess_vit = get_vit_loader(augmentation_option="option3", is_imgnet_a=True)

model = MEMO(VIT_V1,
             learning_rate,
             optimizer_vit,
             marginal_entropy_loss,
             indices_in_1k,
             apply_tta=True,
             augmentations=augmentations_vit,
             prior_strenght_bn=None,
             apply_adaptive_bn=False,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess_vit)

test_result = test_model_tta_applied(model, data_loader_imagenet_a_vit)

print("\nResult: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop + RandomHorizontalFlip ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "VIT-B16_V1_Memo_TTA_RResizedCrop_HorFlip_ImageNet-A")
logger.info("Result: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop + RandomHorizontalFlip ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop + RandomHorizontalFlip ~ ImageNet-A
24.973333333333333




NameError: name 'setup_experiment_logger' is not defined

In [None]:
# ImageNet-V2
VIT_V1, learning_rate, optimizer_vit, data_loader_imagenet_v2_vit, augmentations_vit, preprocess_vit = get_vit_loader(augmentation_option="option3", is_imgnet_a=False)

model = MEMO(VIT_V1,
             learning_rate,
             optimizer_vit,
             marginal_entropy_loss,
             imagenet_v_mask,
             apply_tta=True,
             augmentations=augmentations_vit,
             prior_strenght_bn=None,
             apply_adaptive_bn=False,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess_vit)

test_result = test_model_tta_applied(model, data_loader_imagenet_v2_vit)

print("\nResult: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop + RandomHorizontalFlip ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "VIT-B16_V1_Memo_TTA_RResizedCrop_HorFlip_ImageNet-V2")
logger.info("Result: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop + RandomHorizontalFlip ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop + RandomHorizontalFlip ~ ImageNet-V2
72.84




#### VIT-B/16 + MEMO ~ Mixture of Augmentation

In [None]:
# ImageNet-A
VIT_V1, learning_rate, optimizer_vit, data_loader_imagenet_a_vit, augmentations_vit, preprocess_vit = get_vit_loader(augmentation_option="option4", is_imgnet_a=True)

model = MEMO(VIT_V1,
             learning_rate,
             optimizer_vit,
             marginal_entropy_loss,
             indices_in_1k,
             apply_tta=True,
             augmentations=augmentations_vit,
             prior_strenght_bn=None,
             apply_adaptive_bn=False,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess_vit)

test_result = test_model_tta_applied(model, data_loader_imagenet_a_vit)

print("\nResult: MEMO ~ TTA - VIT-B/16_V1 + Mixture of Augmentations ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "VIT-B16_V1_Memo_TTA_MixtureAugs_ImageNet-A")
logger.info("Result: MEMO ~ TTA - VIT-B/16_V1 + Mixture of Augmentations ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: MEMO ~ TTA - VIT-B/16_V1 + Mixture of Augmentations ~ ImageNet-A
24.493333333333332




In [None]:
# ImageNet-V2
VIT_V1, learning_rate, optimizer_vit, data_loader_imagenet_v2_vit, augmentations_vit, preprocess_vit = get_vit_loader(augmentation_option="option4", is_imgnet_a=False)

model = MEMO(VIT_V1,
             learning_rate,
             optimizer_vit,
             marginal_entropy_loss,
             imagenet_v_mask,
             apply_tta=True,
             augmentations=augmentations_vit,
             prior_strenght_bn=None,
             apply_adaptive_bn=False,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess_vit)

test_result = test_model_tta_applied(model, data_loader_imagenet_v2_vit)

print("\nResult: MEMO ~ TTA - VIT-B/16_V1 + Mixture of Augmentations ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "VIT-B16_V1_Memo_TTA_MixtureAugs_ImageNet-V2")
logger.info("Result: MEMO ~ TTA - VIT-B/16_V1 + Mixture of Augmentations ~ ImageNet-V2")
logger.info(f"{test_result}")

#### VIT-B/16 + MEMO ~ applied TTA + BN + Sharpened Softmax Marginal Entropy

In [None]:
# ImageNet-A
VIT_V1, learning_rate, optimizer_vit, data_loader_imagenet_a_vit, augmentations_vit, preprocess_vit = get_vit_loader(augmentation_option="option2", is_imgnet_a=True)
model = MEMO(VIT_V1,
             learning_rate,
             optimizer_vit,
             sharpened_softmax_entropy,
             indices_in_1k,
             apply_tta=True,
             augmentations=augmentations_vit,
             prior_strenght_bn=None,
             apply_adaptive_bn=False,
             temperature=0.5,
             apply_transform=True,
             preprocess=preprocess_vit)

test_result = test_model_tta_applied(model, data_loader_imagenet_a_vit)

print("\nResult: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop + Sharpened Softmax ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "VIT-B16_V1_Memo_TTA_RandomResizedCrop_SharpSoftMax_ImageNet-A")
logger.info("Result: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop + Sharpened Softmax Entropy ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop + Sharpened Softmax ~ ImageNet-A
23.42666666666667




In [None]:
# ImageNet-V2
VIT_V1, learning_rate, optimizer_vit, data_loader_imagenet_v2_vit, augmentations_vit, preprocess_vit = get_vit_loader(augmentation_option="option2", is_imgnet_a=False)

model = MEMO(VIT_V1,
             learning_rate,
             optimizer_vit,
             sharpened_softmax_entropy,
             imagenet_v_mask,
             apply_tta=True,
             augmentations=augmentations_vit,
             prior_strenght_bn=None,
             apply_adaptive_bn=False,
             temperature=TEMPERATURE,
             apply_transform=True,
             preprocess=preprocess_vit)

test_result = test_model_tta_applied(model, data_loader_imagenet_v2_vit)

print("\nResult: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop + Sharpened Softmax Entropy ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "VIT-B16_V1_Memo_TTA_RandomResizedCrop_SharpSoftMax_ImageNet-V2")
logger.info("Result: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop + Sharpened Softmax Entropy ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop + Sharpened Softmax Entropy ~ ImageNet-V2
72.3




#### VIT-B/16 + MEMO ~ applied TTA + BN + Augmentation Weigthed Marginal Entropy

In [None]:
# ImageNet-A
VIT_V1, learning_rate, optimizer_vit, data_loader_imagenet_a_vit, augmentations_vit, preprocess_vit = get_vit_loader(augmentation_option="option2", is_imgnet_a=True)
model = MEMO(VIT_V1,
             learning_rate,
             optimizer_vit,
             augmentation_weighted_entropy,
             indices_in_1k,
             apply_tta=True,
             augmentations=augmentations_vit,
             prior_strenght_bn=None,
             apply_adaptive_bn=False,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess_vit)

test_result = test_model_tta_applied(model, data_loader_imagenet_a_vit)

print("\nResult: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop + Augmentation Weigthed Entropy ~ ImageNet-A")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "VIT-B16_V1_Memo_TTA_RandomResizedCrop_WeigthedEntropy_ImageNet-A")
logger.info("Result: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop + Augmentation Weigthed Entropy ~ ImageNet-A")
logger.info(f"{test_result}")

  0%|          | 0/7500 [00:00<?, ?it/s]


Result: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop + Augmentation Weigthed Entropy ~ ImageNet-A
25.173333333333332




In [None]:
# ImageNet-V2
VIT_V1, learning_rate, optimizer_vit, data_loader_imagenet_v2_vit, augmentations_vit, preprocess_vit = get_vit_loader(augmentation_option="option2", is_imgnet_a=False)

model = MEMO(VIT_V1,
             learning_rate,
             optimizer_vit,
             augmentation_weighted_entropy,
             imagenet_v_mask,
             apply_tta=True,
             augmentations=augmentations_vit,
             prior_strenght_bn=None,
             apply_adaptive_bn=False,
             temperature=None,
             apply_transform=True,
             preprocess=preprocess_vit)

test_result = test_model_tta_applied(model, data_loader_imagenet_v2_vit)

print("\nResult: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop + Augmentation Weigthed Entropy ~ ImageNet-V2")
print(test_result)
print('\n')

logger = setup_experiment_logger("experiment_logs", "VIT-B16_V1_Memo_TTA_RandomResizedCrop_WeigthedEntropy_ImageNet-V2")
logger.info("Result: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop + Augmentation Weigthed Entropy ~ ImageNet-V2")
logger.info(f"{test_result}")

  0%|          | 0/10000 [00:00<?, ?it/s]


Result: MEMO ~ TTA - VIT-B/16_V1 + RandomResizedCrop + Augmentation Weigthed Entropy ~ ImageNet-V2
72.82




## Results

The following tables show the top-1 accuracy (%) of the model on ImageNet-A and ImageNet-V2 datasets using various combinations of MEMO with Batch Normalization,  different choices of augmentation techniques, and entropy configuration approaches where the test time adaptation method is applied, ResNet50 and Vit-B/16 as a backbone model.

|  Test Result with ResNet50 with weight V1  | ImageNet-A | ImageNet-V2 |
|-----|:------------:|:-------------:|
|                            Baseline                       | 0.03 | 63.15 |
| +MEMO (RandomResizedCrop) | 3.65  | **69.03** |
| +MEMO + BN (RandomResizedCrop) | 8.43  | 68.2 |
| +MEMO + BN (RandomResizedCrop) + Sharpened SoftMax Marginal Entropy | 4.84| 68.03 |
| +MEMO + BN (RandomResizedCrop) + Augmentation Weighted Marginal Entropy | **9.17** | 68.80 |
| +MEMO + BN (RandomResizedCrop + RandomHorizontalFlip) | 8.17 | 67.79 |
| +MEMO + BN (RandomResizedCrop + RandomAffine + RandomPerspective + RandomHorizontalFlip) | 7.51 | 66.13 |

| Test Result with ResNet50 with weight V2| ImageNet-A | ImageNet-V2 |
|-----|:------------:|:-------------:|
| Baseline | 14.27 | 69.89 |
| +MEMO (RandomResizedCrop) | 18.41  | 76.03 |
| +MEMO + BN (RandomResizedCrop) | 24.24  | 78.69 |
| +MEMO + BN (RandomResizedCrop) + Sharpened SoftMax Marginal Entropy | **29.87** | 75.49 |
| +MEMO + BN (RandomResizedCrop) + Augmentation Weighted Marginal Entropy | 24.22 | 76.78 |
| +MEMO + BN (RandomResizedCrop + RandomHorizontalFlip) | 24.31 | **78.90** |
| +MEMO + BN (RandomResizedCrop + RandomAffine + RandomPerspective + RandomHorizontalFlip) | 24.13 | 78.5 |

| Test Result with Vit-B/16 | ImageNet-A | ImageNet-V2|
|-----|:------------:|:------------:|
| Baseline| 20.75 | 69.57 |
| +MEMO (RandomResizedCrop) | **25.43** | **72.90** |
| +MEMO (RandomResizedCrop) + Sharpened SoftMax Marginal Entropy| 23.43 | 72.30 |
| +MEMO (RandomResizedCrop) + Augmentation Weighted Marginal Entropy| 25.17 | 72.82 |
| +MEMO (RandomResizedCrop + RandomHorizontalFlip)| 24.97 | 72.84 |
| +MEMO (RandomResizedCrop + RandomAffine + RandomPerspective + RandomHorizontalFlip) | 24.49 | __ |

### Discussion

Due to the large amount of time required to keep the experiment minimal, I limited the augmentation number to 32, although it was ideal to set it to 64, as suggested in the MEMO paper. Nevertheless, my modicaitons did improve the baseline performance, which achieved without applying any TTA method. Also for the same reason, I haven't included AugMix as an option to apply as an augmentation; however, one experiment was done with AugMix, but it was taking more than 4 hours, and the result was not competitive with others; therefore, I avoided including AugMix augmentation choice in the modification.

From the results of ResNet50 with weight differences, the new weight V2 has improved performance on both ImageNet-A and ImageNet-V2.

For the Vit.B/16 model, the performance outperforms the baseline, yet it does not surpass that of ResNet50 as a backbone. Unfortunately, none of the other modifications to the model could outperform the model using only RandomResizedCrop as an augmentation method, even though all outperformed the baseline with noticeably higher accuracy.

**Adaptive Batch Normalization**

As the MEMO paper suggested, I inserted the technique into the model that has ResNet50 as a backbone, with the same expectation that the tweaked model's performance would be improved on both datasets; therefore, the subsequent modifications were made all with a BN layer, except for models based on Vit-B/16.

**Sharpened softmax marginal entropy**

This tweaking approach improved the ResNet-50 (weight V1) as the backbone of the model on ImageNet-A, yielding the best accuracy result **29.87 (+15.6)**.

**Augmentation weighted marginal entropy**

Instead, this modification to the entropy computation approach improved the ResNet-50 (weight V1) as the backbone of the model on ImageNet-A, yielding the best accuracy result **9.17 (+9.14)**. Also, for ImageNet-V2, the modification has given the $2^{\text{nd}}$ best result **68.80 (+6.65)**.


**RandomResizedCrop**

This augmentation choice was used for the previous two modifications as it has given slightly less but relatively stable accuracy results. Additionally, for Vit-B/16 as the backbone of the model in both ImageNet-A and ImageNet-V2, it has yielded the best results, **25.43 (+4.68)** and **72.90 (+3.33)**, respectively.


**RandomResizedCrop + RandomHorizontalFlip**

As the MEMO paper has witnessed the improvement using the horizontal flip standard augmentation method, I implemented it for ResNet50 (weight V2) as the backbone of the model for ImageNet-V2, which gave the best result, **78.90 (+9.01)** accuracy.


**RandomResizedCrop + RandomAffine + RandomPerspective + RandomHorizontalFlip**

As a trial of alternating the augmentation methods, I used these combination of augmentation methods to improve the performance. This method outperformemed all the baselines on both datasets, however the results were slighly less than using the RandomResizedCrop + RandomHorizontalFlip combination.


### Conclusion

This project shows that MEMO improved the pre-trained model through marginal entropy minimization with a one-test-point approach. Additionally, slight modifications to how softmax and entropy are computed also enhanced the model’s performance on the classification task. Using different augmentation techniques had limited potential to improve classification accuracy. However, due to the sharp distribution shift in the test dataset—especially in ImageNet-A—neither augmentation nor model adjustments led to a significant performance improvement even though they outperformed all the baselines. Therefore, it is reasonable to conclude that the MEMO method is effective for test-time adaptation, especially in scenarios with distribution shift, even when only a single test sample is available at a time, although further tuning and adaptation steps, or model components may be necessary to achieve optimal performance,