# 1. Introduction

In this project, we focus on **Test-Time Adaptation (TTA)**, which has recently gained traction due to its ability to enhance model performance without requiring access to training data.

In this project, we focus on **TTA for image classification**, particularly using **CLIP** [[2](#ref-clip2021)] with **TPT** [[3](#ref-tpt2022)]. Our approach involves adapting the model on **single-image test instances**, with the model being reset to its pre-trained state after each instance. This resembles **TTIA**, keeping the constraint of no retention of prior test-time knowledge (between batches, so between images).

<!--- visualize image using html formatting, so that i can scale it properly -->
<div align="center">

<img src="img/tpt.png" alt="Test-Time Prompt Tuning (TPT) for CLIP" title="Test-Time Prompt Tuning (TPT) for CLIP" width="600" class="center"/>

</div>

## A. TTIA

> **Definition**: "_Test-Time Instance Adaption, TTIA_ Given a classifier $f_\mathcal{S}$ learned on the source domain $\mathcal{D_s}$, and an unlabeled target instance $x_t \in \mathcal{D_T}$ under distribution shift, _test-time instance adaption_ aims to leverage the labeled knowledge implied in $\mathcal{f_S}$ to infer the label of $x_t$ adaptively" [[1](#ref-liang2025)]. In other words, TTIA aims to adapt the classifier $f_\mathcal{S}$ to the target instance $x_t$ by leveraging the knowledge of the source domain $\mathcal{D_S}$. [[1](#ref-liang2025)]

TTIA differs from TTBA in that single-instance adaption is performed, instead of batch-wise adaption, giving an example the difference is between classifying a single frame of a video and classifying a sequence of frames. In both methods no memory of the previous test-time knowledge is retained.

## B. Project Overview

We aim to reproduce TPT results on ImageNetA.

The project is structured as follows:

1. Introduction
1. Reproducing TPT + CoOp with full CoOp pretraining (on ImageNetV2)
   - Using OpenAI CLIP (both implementation and weights)
   - So that we can compare it with TPT + CoOp without pretraining (`a photo of a` initialization).)
1. Reproducing TPT
   - Using OpenAI weights and OpenCLIP implementation
     - Compare zero-shot CLIP OpenAI (weights and implementation) with OpenCLIP (weights and implementation)
   - Using `Kornia` instead of `AugMix` / `torchvision.transforms` (**Our contribution**)
     - Recreate the AugMix pipeline in Kornia
     - Kornia is faster and can directly run on the GPU
     - Benchmarking the difference
   - Reproduce TPT + simplified CoOp (without pretraining) (**Our contribution**)
1. Trying to get better at TTA (**Our contribution**)
    - A. Augment Top 1
    - B. TPT with Top 1
    - C. Self-Supervised Retrieval (Inspired by DinoV2) [[4](#ref-dinov2)]
    - C. Stupid idea with adaptive layer norm
    - D. TNT (Recreate the paper)
    - E. TNT with Top 1
    - F. TPS
1. Results and Conclusion
1. Future Work
1. Bibliography

## C. Reproducibility
The project is designed to be reproducible. Code is also available on [GitHub]() as standard `python` files.

For reproducibility seeding is done.

### D. Data

Get datasets data, create datasets and dataloader. Seeding.


Collecting torch==2.3.0
  Using cached torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl.metadata (26 kB)
Collecting notebook==7.1.3
  Using cached notebook-7.1.3-py3-none-any.whl.metadata (10 kB)
Collecting torchvision==0.18.0
  Using cached torchvision-0.18.0-cp310-cp310-manylinux1_x86_64.whl.metadata (6.6 kB)
Collecting matplotlib==3.8.4
  Using cached matplotlib-3.8.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.8 kB)
Collecting openai-clip==1.0.1
  Downloading openai-clip-1.0.1.tar.gz (1.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting kornia
  Downloading kornia-0.8.1-py2.py3-none-any.whl.metadata (17 kB)
Collecting nvidia-nccl-cu12==2.20.5 (from torch==2.3.0)
  Using cached nvidia_

In [None]:
!mkdir datasets

# Get datasets (ImageNet-A and ImageNetV2)
!gdown --fuzzy https://drive.google.com/file/d/1nfictDVptrdDwRNaxsBx0UIlP7tpDsN_/view?usp=sharing

!gdown --fuzzy https://drive.google.com/file/d/1TR1hrs9tV6rh_W-hDqRv6jH3KE-Hxsw2/view?usp=sharing

!tar -xvf imagenetv2-matched-frequency-format-val.tar.gz -C datasets
# json metadata of the datasets
!curl https://raw.githubusercontent.com/modestyachts/ImageNetV2/refs/heads/master/data/metadata/class_info.json -o datasets/imagenetv2-matched-frequency-format-val/class_info.json

!tar -xvf imagenet-a.tar -C datasets

In [None]:
import clip
import torch
# import matplotlib.pyplot as plt
import numpy as np
import os
import gc
import sys
import torch.nn as nn
import torchvision
from PIL import Image
from torch.utils.tensorboard.writer import SummaryWriter
from torchvision import transforms
from tqdm import tqdm
from torch.cuda.amp import autocast
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
import json
import copy
from copy import deepcopy
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from typing import List, Optional, Tuple, Dict
from torch.utils.data import random_split
import numpy as np
import kornia
import kornia.augmentation as K
import kornia.enhance as Ke
import torch.nn.functional as F
# import time

import random
from dataclasses import dataclass
from sklearn.cluster import KMeans


import open_clip

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Seeding and reproducibility

torch.manual_seed(456)
torch.cuda.manual_seed(456)
torch.randn(456).to("cuda")
np.random.seed(42)

g = torch.Generator()
g.manual_seed(0)
# https://docs.pytorch.org/docs/stable/notes/randomness.html


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

# 2. Reproducing Coop

With pre-training on ImageNetV2

TODO: explain why!


### Data


In [None]:
def log_values(writer, step, loss, accuracy, prefix):
    writer.add_scalar(f"{prefix}/loss", loss, step)
    writer.add_scalar(f"{prefix}/accuracy", accuracy, step)


_tokenizer = _Tokenizer()
vis_net, basic_image_transformations = clip.load("ViT-B/16", DEVICE)

In [None]:
class ImageNetA(Dataset):
    def __init__(
        self, root_dir="datasets/imagenet-a", transform=basic_image_transformations
    ):
        self.root_dir = root_dir
        self.transform = transform

        # Load class code to name mapping from README.txt
        self.class_code_to_name = self._load_class_mapping(
            os.path.join(root_dir, "README.txt")
        )

        # Map class codes to integer labels
        self.class_codes : list[str]=[]
        for d in os.listdir(root_dir):
            if os.path.isdir(os.path.join(root_dir, d)) and d in self.class_code_to_name:
                self.class_codes.append(d)


        """ self.class_codes = sorted(
            [
                d
                for d in os.listdir(root_dir)
                if os.path.isdir(os.path.join(root_dir, d))
                and d in self.class_code_to_name
            ]
        ) """
        
        self.class_code_to_idx = {
            code: idx for idx, code in enumerate(self.class_codes)
        }

        # Collect all image paths and labels
        self.samples = []
        for class_code in self.class_codes:
            class_folder = os.path.join(root_dir, class_code)
            for fname in os.listdir(class_folder):
                if fname.lower().endswith((".png", ".jpg", ".jpeg")):
                    path = os.path.join(class_folder, fname)
                    label = self.class_code_to_idx[class_code]
                    self.samples.append((path, label))

    def _load_class_mapping(self, readme_path)->Dict[int,str]:
        mapping = {}
        with open(readme_path, "r") as f:
            lines = f.readlines()[12:]  # Skip the first 12 lines
            for line in lines:
                parts = line.strip().split(" ", 1)
                if len(parts) == 2:
                    code, name = parts
                    mapping[code] = name
        return mapping

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_path, label = self.samples[idx]
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

#### Imagenet-A 🔗 ImagenetV2
For the COOP training the ImageNetV2 dataset is linked to the available classes in ImageNetA thanks to the official json metadata info for the classes

In [None]:
class ImageNetV2(Dataset):
    def __init__(
        self,
        root_dir="datasets/imagenetv2-matched-frequency-format-val",
        transform=basic_image_transformations,
        use_imagenet_a_classes=True,
        imagenet_a=None,
    ):
        self.root_dir = root_dir
        self.transform = transform
        self.use_imagenet_a = use_imagenet_a_classes

        if use_imagenet_a_classes:
            assert (
                type(imagenet_a) == ImageNetA
            ), "imagenet_a_classes set to TRUE without passing imagenet_a object"
            imagenet_a_class_code_to_idx = imagenet_a.class_code_to_idx
            self.v2id_to_info = self._load_class_mapping(
                os.path.join(root_dir, "class_info.json"),
                imagenet_a_class_code_to_idx
            )
            self.class_code_to_name = copy.deepcopy(
                imagenet_a.class_code_to_name)

        else:
            self.v2id_to_info = self._load_class_mapping(
                os.path.join(
                    root_dir, "class_info.json"), None
            )
            self.class_code_to_name = {
                idx: self.v2id_to_info["label"] for idx in self.v2id_to_info.keys()
            }

        self.samples = []
        for v2_class_code in self.v2id_to_info.keys():
            class_folder = os.path.join(root_dir, str(v2_class_code))
            for fname in os.listdir(class_folder):
                if fname.lower().endswith((".png", ".jpg", ".jpeg")):
                    path = os.path.join(class_folder, fname)
                    if use_imagenet_a_classes:
                        self.samples.append(
                            (path,
                             self.v2id_to_info[v2_class_code]["label_id"])
                        )
                    else:
                        self.samples.append((path, v2_class_code))

    def _load_class_mapping(
        self,
        infofile_path,
        imagenet_a_class_code_to_idx: Optional[dict[str, int]],
    )->Dict[str,Dict[str,str]]:
        
        mapping = {}
        with open(infofile_path) as f:
            data = json.load(f)
            for idx, item in enumerate(data):
                if imagenet_a_class_code_to_idx is not None:
                    if item["wnid"] in imagenet_a_class_code_to_idx.keys():
                        mapping[item["cid"]] = {
                            "label_id": imagenet_a_class_code_to_idx[item["wnid"]],
                            "ia_code": item["wnid"],
                            "label": item["synset"][0].lower().replace(" ", "_"),
                        }
                else:
                    mapping[item["cid"]] = {
                        "label": item["synset"][0].lower().replace(" ", "_")
                    }
        return mapping

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_path, label = self.samples[idx]
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:
def get_dataset_split(dataset: Dataset, train_percentage=0.5, validation_percentage=0.25) -> Tuple[Dataset, Dataset, Dataset]:
    # Load data

    # Create train validation and test samples
    num_samples = len(dataset)
    training_sample = int(num_samples * train_percentage + 1)
    validation_sample = int(num_samples * validation_percentage)
    test_sample = num_samples - training_sample - validation_sample

    training_dataset, validation_dataset, test_dataset = random_split(
        dataset, [training_sample, validation_sample, test_sample]
    )

    return (training_dataset, validation_dataset, test_dataset)


def get_data(
    training_dataset,
    validation_dataset,
    test_dataset,
    batch_size=64,
    transform=None,
    num_workers=8,
)->Tuple[DataLoader,DataLoader,DataLoader]:
    """
    Load the dataset, split it into train/val/test and return a DataLoader for each.
    """

    if not transform:
        transform = torchvision.transforms.Compose(
            [torchvision.transforms.ToTensor()])

    # Create a DataLoader
    train_loader = DataLoader(
        training_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        worker_init_fn=seed_worker,
        generator=g,
    )
    val_loader = DataLoader(
        validation_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        worker_init_fn=seed_worker,
        generator=g,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        worker_init_fn=seed_worker,
        generator=g,
    )

    return train_loader, val_loader, test_loader


def embed_dataset_classnames(dataset: ImageNetA, model, templates=["a photo of a {}."]):
    """
    Embed the classnames in the prompt template.
    Return the classnames and the normalized textual features.
    """
    # Create the list of descriptions and tokenize them
    classnames = dataset.class_code_to_name.values()

    texts_z_views = []
    for template in templates:
        descriptions = [template.format(c) for c in classnames]
        text_tokens = clip.tokenize(descriptions).to(DEVICE)

        # Get the normalized textual features
        with torch.no_grad():
            texts_z = model.encode_text(text_tokens).float()
            texts_z /= texts_z.norm(dim=-1, keepdim=True)
            texts_z_views.append(texts_z)

    # Evaluate the mean representation
    texts_z = torch.stack(texts_z_views).mean(dim=0)

    # Renormalise
    texts_z /= texts_z.norm(dim=-1, keepdim=True)

    return classnames, texts_z

In [None]:
dataset_a = ImageNetA()
dataset_v2 = ImageNetV2(imagenet_a=dataset_a)

### Base Model (Coop)


In [None]:
def new_model(model_class, dataset):
    classnames, _ = embed_dataset_classnames(dataset, vis_net)
    n_ctx = 4
    ctx_init = ""
    class_token_position = "end"
    csc = False
    coop_net = model_class(
        classnames=classnames,
        n_ctx=n_ctx,
        ctx_init=ctx_init,
        class_token_position=class_token_position,
        csc=csc,
    ).to(DEVICE)
    return coop_net


def load_model(model):
    model.load_state_dict(
        torch.load("./working_directory/model.pth", weights_only=True)
    )
    model.eval()
    return model

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding
        # [batch_size, n_ctx, transformer.width] -> [n_ctx, batch_size, transformer.width]
        x = x.permute(1, 0, 2)
        x = self.transformer(x)
        # [n_ctx, batch_size, transformer.width] -> [batch_size, n_ctx, transformer.width]
        x = x.permute(1, 0, 2)
        x = self.ln_final(x)

        # Take features from the eot embedding (eot_token is the highest number in each sequence)
        x = (
            x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)]
            @ self.text_projection
        )

        return x

#### Prompt learner
The prompt learner code is mostly taken from the Lab3 laboratories of the AY 24/25.
We have done few updates to fix the weights to the ones obtained after COOP, and have the possibility to keep them and reinstate them after each iteration of TPT

In [None]:
# Basic mechanics are taken from the Lab Number 3 of AY 24/25
class PromptLearner(nn.Module):
    def __init__(
        self, clip_model, classnames, n_ctx, ctx_init, class_token_position, csc=False
    ):
        super().__init__()
        n_cls = len(classnames)
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution

        # Use given words to initialize context vectors
        if ctx_init:
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init).to(
                clip_model.token_embedding.weight.device
            )
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt)
            ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            if csc:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim)

            torch.nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f"Initial context: '{prompt_prefix}'")
        print(f"Number of context words (tokens): {n_ctx}")

        self.ctx_init_state = ctx_vectors.detach().clone()
        # These are the `prompts` we want to optimize
        self.ctx = nn.Parameter(ctx_vectors)

        print(classnames)
        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(
            clip_model.token_embedding.weight.device
        )

        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer(
            "token_suffix", embedding[:, 1 + n_ctx:, :])  # CLS, EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts
        self.name_lens = name_lens
        self.class_token_position = class_token_position
        self.ctx_checkpoint = ctx_vectors.detach().clone()

    #reset context function after each TPT step  
    def reset_ctx(self):  # https://discuss.pytorch.org/t/reset-model-weights/19180
        with torch.no_grad():
            self.ctx.copy_(self.ctx_checkpoint)

        self.ctx.requires_grad = True
        
    # set context checkpoint before the TPT procedure
    def set_ctx_checkpoint(self):
        with torch.no_grad():
            self.ctx_checkpoint.copy_(self.ctx)

    def forward(self):
        prefix = self.token_prefix
        suffix = self.token_suffix
        ctx = self.ctx

        # If CoOp, expand the ctx for all classes
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim)
                    ctx,  # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)
                ],
                dim=1,
            )

        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i_half1 = ctx[i: i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i: i + 1, half_n_ctx:, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,  # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i = ctx[i: i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,  # (1, name_len, dim)
                        ctx_i,  # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError

        return prompts

In [None]:
class OurCLIP(nn.Module):
    def __init__(self, classnames, n_ctx, ctx_init, class_token_position, csc=False):
        super().__init__()
        clip_model: clip.model.CLIP
        clip_model, _ = clip.load("ViT-B/16")
        # clip_model = clip_model.cpu()
        clip_model = clip_model

        self.prompt_learner = PromptLearner(
            clip_model, classnames, n_ctx, ctx_init, class_token_position, csc=csc
        )
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale

    def reset_ctx(self):
        self.prompt_learner.reset_ctx()

    def set_ctx_checkpoint(self):
        self.prompt_learner.set_ctx_checkpoint()

    def forward(self, image):
        image_features = self.image_encoder(image)

        prompts = self.prompt_learner()
        tokenized_prompts = self.tokenized_prompts
        text_features = self.text_encoder(prompts, tokenized_prompts)

        image_features = image_features / \
            image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / \
            text_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        return logits

In [None]:
def training_step(
    net: OurCLIP,
    data_loader: torch.utils.data.DataLoader,  # type: ignore
    optimizer: torch.optim.Optimizer,
    cost_function,
    device=DEVICE,
):
    """
    Training step (for CoOp).
    """
    samples = 0.0
    cumulative_loss = 0.0
    cumulative_accuracy = 0.0

    # Set the network to training mode
    net.train()
    scaler = torch.cuda.amp.GradScaler(enabled=True)

    # Iterate over the training set
    pbar = tqdm(
        data_loader, desc="Training_Coop", position=0, leave=True, total=len(data_loader)
    )
    for batch_idx, (inputs, targets) in enumerate(data_loader):
        # Load data into GPU
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Forward pass
        with autocast():
            outputs = net(inputs)

        # Loss computation
        loss = cost_function(outputs, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        # Gradients reset
        optimizer.zero_grad()

        # Fetch prediction and loss value
        samples += inputs.shape[0]
        cumulative_loss += loss.item()
        # max() returns (maximum_value, index_of_maximum_value)
        _, predicted = outputs.max(dim=1)

        # Compute training accuracy
        cumulative_accuracy += predicted.eq(targets).sum().item()

        pbar.set_postfix(
            train_loss=loss.item(), train_acc=cumulative_accuracy / samples * 100
        )
        pbar.update(1)
        del inputs
        gc.collect()
        torch.cuda.empty_cache()

    return cumulative_loss / samples, cumulative_accuracy / samples * 100


def test_step(net: OurCLIP,
            data_loader: DataLoader,
            cost_function, device=DEVICE):
    samples = 0.0
    cumulative_loss = 0.0
    cumulative_accuracy = 0.0

    # Set the network to evaluation mode
    net.eval()

    # Disable gradient computation (we are only testing, we do not want our model to be modified in this step!)
    pbar = tqdm(
        data_loader, desc="Testing_Coop", position=0, leave=True, total=len(data_loader)
    )
    with torch.no_grad():
        # Iterate over the test set
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            # Load data into GPU
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Forward pass
            with autocast():
                outputs = net(inputs)

            # Loss computation
            loss = cost_function(outputs, targets)

            # Fetch prediction and loss value
            samples += inputs.shape[0]
            # Note: the .item() is needed to extract scalars from tensors
            cumulative_loss += loss.item()
            _, predicted = outputs.max(1)

            # Compute accuracy
            cumulative_accuracy += predicted.eq(targets).sum().item()

            pbar.set_postfix(test_acc=cumulative_accuracy / samples * 100)
            pbar.update(1)

    return cumulative_loss / samples, cumulative_accuracy / samples * 100

In [None]:
def main_coop(
    net: OurCLIP,
    dataset_splits: tuple,
    batch_size=16,
    learning_rate=0.002,
    weight_decay=0.0005,
    momentum=0.9,
    epochs=2,
    run_name="coop_training",
    skip_test=False,
):
    """
    @param: dataset_class
    @param: dataset_splits tuple that contains (training, validation, test)"""

    # Create a logger for the experiment
    writer = SummaryWriter(log_dir=f"runs/{run_name}")

    # Get dataloaders
    train_loader, val_loader, test_loader = get_data(
        dataset_splits[0],
        dataset_splits[1],
        dataset_splits[2],
        transform=basic_image_transformations,
        batch_size=batch_size,
    )

    print("Turning off gradients in both the image and the text encoder")
    for name, param in net.named_parameters():
        if "prompt_learner" not in name:
            param.requires_grad_(False)

    print(f"Total parameters: {sum(p.numel() for p in net.parameters()):,}")
    print(
        f"Total trainable parameters: {sum(p.numel() for p in net.parameters() if p.requires_grad):,}"
    )

    # Instantiate the optimizer
    optimizer = torch.optim.SGD(
        [{"params": net.parameters()}],
        lr=learning_rate,
        weight_decay=weight_decay,
        momentum=momentum,
    )

    # Define the cost function
    cost_function = torch.nn.CrossEntropyLoss()

    # Computes evaluation results before training
    if not skip_test:
        print("Before training:")
        train_loss, train_accuracy = test_step(
            net, train_loader, cost_function)
        val_loss, val_accuracy = test_step(net, val_loader, cost_function)
        test_loss, test_accuracy = test_step(net, test_loader, cost_function)

        # Log to TensorBoard
        log_values(writer, -1, train_loss, train_accuracy, "train")
        log_values(writer, -1, val_loss, val_accuracy, "validation")
        log_values(writer, -1, test_loss, test_accuracy, "test")

        print(
            f"\tTraining loss {train_loss:.5f}, Training accuracy {train_accuracy:.2f}"
        )
        print(
            f"\tValidation loss {val_loss:.5f}, Validation accuracy {val_accuracy:.2f}"
        )
        print(
            f"\tTest loss {test_loss:.5f}, Test accuracy {test_accuracy:.2f}")

    # For each epoch, train the network and then compute evaluation results
    for e in range(epochs):
        train_loss, train_accuracy = training_step(
            net, train_loader, optimizer, cost_function
        )
        val_loss, val_accuracy = test_step(net, val_loader, cost_function)

        log_values(writer, e, train_loss, train_accuracy, "train")
        log_values(writer, e, val_loss, val_accuracy, "validation")

    # Compute final evaluation results
    if not skip_test:
        print("After training:")

        train_loss, train_accuracy = test_step(
            net, train_loader, cost_function)
        val_loss, val_accuracy = test_step(net, val_loader, cost_function)
        test_loss, test_accuracy = test_step(net, test_loader, cost_function)

        log_values(writer, epochs, train_loss, train_accuracy, "train")
        log_values(writer, epochs, val_loss, val_accuracy, "validation")
        log_values(writer, epochs, test_loss, test_accuracy, "test")
        print(
            f"\tTraining loss {train_loss:.5f}, Training accuracy {train_accuracy:.2f}"
        )
        print(
            f"\tValidation loss {val_loss:.5f}, Validation accuracy {val_accuracy:.2f}"
        )
        print(
            f"\tTest loss {test_loss:.5f}, Test accuracy {test_accuracy:.2f}")

    # Closes the logger
    writer.close()
    return net

In [None]:
coop_net = new_model(OurCLIP, dataset_v2)
splitted_datasets = get_dataset_split(dataset_v2)
main_coop(coop_net, splitted_datasets, batch_size=16, skip_test=True)
torch.save(coop_net.state_dict(), "./working_directory/model_coop.pth")

# 3. Reproducing TPT

We are always using OpenAI weights.

TODO: explain why we are doing this!


### Image Augmentation

As in the paper: random crop

In [None]:
def get_augmix_transform():
    return transforms.Compose(
        [
            transforms.Resize(256),
            transforms.RandomResizedCrop(224, scale=(0.5, 1.0)),
            #transforms.RandomHorizontalFlip(),
            #transforms.ColorJitter(0.4, 0.4, 0.4),
            transforms.ToTensor(),
        ]
    )


# Basic original transform (non-augmented)
original_transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]
)


# Wrapper Dataset that takes a base dataset + index of the sample to augment
class AugmentSingleSampleDataset(Dataset):
    def __init__(self, base_dataset, sample_idx, num_augments=63):
        self.base_dataset = base_dataset
        self.sample_idx = sample_idx
        self.num_augments = num_augments
        self.augmix_transform = get_augmix_transform()
        self.original_transform = original_transform

        # Extract the image once to avoid loading it 64 times
        image, label = self.base_dataset[self.sample_idx]
        if isinstance(image, torch.Tensor):
            self.image = transforms.ToPILImage()(image)
        else:
            self.image = image
        self.label = label

    def __len__(self):
        return self.num_augments + 1  # 63 augments + 1 original

    def __getitem__(self, idx):
        if idx == 0:
            image = self.original_transform(self.image)
        else:
            image = self.augmix_transform(self.image)
        return image, self.label

### TPT Procedure


In [None]:
def select_confident_samples(logits, top_p):
    """
    Select the p-percentile of samples with lowest entropy, i.e. highest confidence.
    """
    assert 0 <= top_p < 1, "The value must be between 0 and 1"
    batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
    idx = torch.argsort(batch_entropy, descending=False)[
        : int(batch_entropy.size()[0] * top_p)
    ]
    return logits[idx], idx


def compute_avg_entropy(outputs):
    """
    Compute marginal entropy of samples and return the average.
    """
    # Calculate probabilities from logits
    probs = outputs.softmax(dim=1)
    # To avoid log(0), clamp probabilities to a minimum value
    probs = probs.clamp(min=1e-9)
    entropy = -(probs * probs.log()).sum(dim=1)
    return entropy.mean()

In [None]:
def test_step_tpt(
    net, dataset, optimizer, optimizer_state_dict, log_writer, num_aug=63, batch_size=16
):
    """
    @param net takes a OurClip model type
    """
    samples = 0.0
    cumulative_loss = 0.0
    cumulative_accuracy = 0.0

    # Set the network to evaluation mode
    net.eval()

    # Disable gradient computation (we are only testing, we do not want our model to be modified in this step!)
    pbar = tqdm(
        range(len(dataset)),
        desc="TPT_testing",
        position=0,
        leave=True,
        total=len(dataset),
    )
    # Iterate over the indices of the test set
    try:
        for sample_idx in pbar:  # Iterate through indices
            net.reset_ctx()
            optimizer.load_state_dict(optimizer_state_dict)
            # print(f"after Reset{torch.cuda.mem_get_info()}")

            # Create augmented dataset for the current sample
            aug_data = AugmentSingleSampleDataset(
                dataset, sample_idx, num_augments=num_aug
            )  # Pass the base dataset and index

            # Create a DataLoader for the augmented samples of this single image
            aug_dataloader = torch.utils.data.DataLoader(  # type: ignore
                aug_data,
                batch_size=batch_size,
                shuffle=False,
                worker_init_fn=seed_worker,
                generator=g,
            )

            # Process the augmented images for this sample
            all_outputs = []
            for images, labels in aug_dataloader:
                try:
                    with autocast():
                        # print(f"size batch {len(images)}")
                        images = images.to(DEVICE)
                        # print(torch.cuda.mem_get_info(), LINE())
                        outputs = net(images)  # Use the provided net
                        # print(torch.cuda.mem_get_info(), LINE())
                        # cpu_outputs = outputs.to("cpu")
                        all_outputs.append(outputs)
                        # print(torch.cuda.mem_get_info(), LINE())

                        """ del images
                        del outputs
                        torch.cuda.empty_cache()
                        gc.collect()
                        print(torch.cuda.mem_get_info(), LINE()) """

                except:
                    torch.cuda.memory._dump_snapshot("my_snapshot.pickle")
                    raise

            # Get the original label for this sample
            original_image, target = dataset[sample_idx]
            original_image = original_image.unsqueeze(0).to(DEVICE)
            # print(torch.cuda.mem_get_info(), LINE())

            # Make target a tensor and move to device
            target = torch.tensor([target]).to(DEVICE)

            # Concatenate outputs from all batches for this sample
            all_outputs = torch.cat(all_outputs, dim=0)

            # Select confident samples and compute average entropy
            top_outputs, _ = select_confident_samples(all_outputs, 0.2)
            loss = compute_avg_entropy(top_outputs)
            # Loss computation
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Forward pass
            with autocast():
                outputs = net(original_image)

            # Fetch prediction and loss value
            samples += original_image.shape[0]
            # Note: the .item() is needed to extract scalars from tensors
            cumulative_loss += loss.item()
            _, predicted = outputs.max(1)

            # Compute accuracy
            cumulative_accuracy += predicted.eq(target).sum().item()

            pbar.set_postfix(test_acc=cumulative_accuracy / samples * 100)
            pbar.update(1)
    except:
        raise
    finally:
        del all_outputs  # type: ignore
        del aug_data  # type: ignore

    return cumulative_loss / samples, cumulative_accuracy / samples * 100

In [None]:
def tpt_test(net, dataset: Dataset, run_name="tpt1", num_aug=63, batch_size=64):

    # Create a logger for the experiment
    log_writer = SummaryWriter(log_dir=f"runs/{run_name}")
    net.set_ctx_checkpoint()

    for name, param in net.named_parameters():
        if "prompt_learner" not in name:
            param.requires_grad_(False)
        # print(f"{name}is in {param.requires_grad}")
    print(torch.cuda.mem_get_info(), LINE())

    # Define the optimizer
    # optimizer = get_optimizer(net, 0.002, 0.0005, 0.9)
    # , weight_decay=wd, momentum=momentum)
    optimizer = torch.optim.AdamW([{"params": net.parameters()}], lr=0.005)
    optimizer_state_dict = deepcopy(optimizer.state_dict())
    print(torch.cuda.mem_get_info(), LINE())

    print("Test tpt:")
    test_loss, test_accuracy = test_step_tpt(
        net,
        dataset,
        optimizer,
        optimizer_state_dict,
        log_writer,
        num_aug=num_aug,
        batch_size=batch_size,
    )

    # Closes the logger
    log_writer.close()

In [None]:
coop_net = new_model(OurCLIP, dataset_a)
coop_net = load_model(coop_net)

In [None]:
tpt_test(coop_net, dataset_a, batch_size=64, num_aug=63)

In [None]:
# TODO: clear memory. (delete everything)
del dataset_a
del dataset_v2
del coop_net

### Implementation

- `TPTPromptLearner`: a simple TPT prompt learner, which has a similar amount of learnable parameters as the original TPT's prompt learner, but with a simpler and more readable implementation. The idea is to simplify how the prompt learner is structured: instead of having the `class_token_position` that can be in the `end`, `middle` or `front` as in the original TPT + CoOp model, we simply split the prompt is `pre` and `post` prompts, which are then concatenated with the current class token (embedded ofc). One thing to note is that the prompt initialization, as seen in CoOp [[4](#ref-coop2021)], is done with "a photo of a {}", without any pre-training, as performances are close.
- `TPTModel`: CLIP model with TPT+CoOp prompt learner.
- `TPT`: awesome wrapper that makes possible to manage the finetuning and reset of the model after each image, it's "invisible" to the user and painless.


In [None]:
# Note: from 5_tpt.py
@dataclass(frozen=True)
class CLIPModels:
    ViTB32: str = "ViT-B/32"


class TPTPromptLearner(nn.Module):
    def __init__(
        self,
        class_names: List[str],
        clip_model: open_clip.model.CLIP,
        arch: CLIPModels = CLIPModels.ViTB32,  # type: ignore
        base_prompt: str = "a photo of a [CLS]",
        device="cuda",
    ):
        super().__init__()
        self.device = device or (
            "cuda" if torch.cuda.is_available() else "cpu")

        self.class_names = class_names

        tokenizer = open_clip.get_tokenizer(arch)

        self.__init_ctx_from_prompt(
            tokenizer=tokenizer,
            token_embedding=clip_model.token_embedding,
            base_prompt=base_prompt,
        )

    def __init_ctx_from_prompt(
        self, tokenizer, token_embedding, base_prompt: str
    ) -> None:
        """
        Initialize the context tokens from the base prompt.

        We need to make sure that the CLS token is NOT "exploded" in the prompt.

        The idea is to have prompts tuned without having to manually manage where the CLS token is.

        To do this we need to keep the CLS token position in the prompt, and update it accordingly
        when needed.

        I'm splitting the prompt into prefix and suffix, using [CLS] as a separator.
        They are trained as two different parameters, and then concatenated together.

        """

        # Split the base prompt into prefix and suffix
        promt_prefix = base_prompt.split("[CLS]")[0]
        promt_suffix = base_prompt.split("[CLS]")[1]

        # "Clean" PAD, SOT and EOT tokens
        c_token_sot = torch.tensor([[tokenizer.sot_token_id]]).to(self.device)
        c_token_eot = torch.tensor([[tokenizer.eot_token_id]]).to(self.device)
        c_token_pad = torch.tensor([[0]]).to(self.device)  # PAD

        # Tokenize prefix, suffix and class names
        tokenized_prefix = tokenizer(promt_prefix).to(self.device)
        tokenized_suffix = tokenizer(promt_suffix).to(self.device)

        # remove PAD, SOT and EOT tokens
        # Extract "clean" tokens
        c_tokenized_prefix = tokenized_prefix[
            (tokenized_prefix != c_token_sot)
            & (tokenized_prefix != c_token_eot)
            & (tokenized_prefix != c_token_pad)
        ].to(self.device)
        c_tokenized_suffix = tokenized_suffix[
            (tokenized_suffix != c_token_sot)
            & (tokenized_suffix != c_token_eot)
            & (tokenized_suffix != c_token_pad)
        ].to(self.device)

        tokenized_class_names = tokenizer(self.class_names).to(self.device)

        # BASE full prompt
        # [CLS] + prefix + class_name + suffix + EOT
        # pre-computed as it's used for all classes and images :)
        self.tokenized_initial_full_prompt = tokenizer(
            [base_prompt.replace("[CLS]", c) for c in self.class_names]
        ).to(self.device)

        # Get base embeddings
        with torch.no_grad():
            self.embedded_sot = token_embedding(c_token_sot)
            self.embedded_eot = token_embedding(c_token_eot)
            self.embedded_pad = token_embedding(c_token_pad)
            self.embedded_prefix = token_embedding(c_tokenized_prefix)
            self.embedded_suffix = token_embedding(c_tokenized_suffix)
            embedded_class_names = token_embedding(tokenized_class_names)
            self.embedded_max_len = embedded_class_names.shape[1]

        # Setup clean embedded_class_names (list)
        # Mask to filter out SOT/EOT/PAD tokens (shape [200, 77])
        mask = (
            (tokenized_class_names != c_token_sot)
            & (tokenized_class_names != c_token_eot)
            & (tokenized_class_names != c_token_pad)
        )

        # Apply mask to embeddings (for each class)
        clean_embeddings = []
        for i in range(embedded_class_names.shape[0]):
            # masked_select would flatten, so we use boolean indexing
            # [num_valid_tokens, 512]
            clean_embed = embedded_class_names[i][mask[i]]
            clean_embeddings.append(
                clean_embed.unsqueeze(0)
            )  # [1, num_valid_tokens, 512]

        self.embedded_class_names = clean_embeddings

        for i, embed in enumerate(clean_embeddings):
            self.register_buffer(f"class_embed_{i}", embed)

        # Create "init" states and set learnable parameters
        self.init_state_prefix = self.embedded_prefix.detach().clone()
        self.init_state_suffix = self.embedded_suffix.detach().clone()
        self.embedded_prefix = nn.Parameter(self.embedded_prefix)
        self.embedded_suffix = nn.Parameter(self.embedded_suffix)
        self.register_parameter("embedded_prefix", self.embedded_prefix)  # type: ignore
        self.register_parameter("embedded_suffix", self.embedded_suffix)  # type: ignore

    def forward(self) -> torch.Tensor:
        prompts = []
        for i in range(len(self.class_names)):

            # embeddeD_max_len: 77
            # embedded_prefix: torch.Size([4, 512])
            # embedded_class_names: torch.Size([1, 1, 512])
            # embedded_suffix: torch.Size([0, 512]

            padding_size = (
                self.embedded_max_len
                - self.embedded_prefix.shape[0]
                - getattr(self, f"class_embed_{i}").shape[1]
                - self.embedded_suffix.shape[0]
            ) - 2  # # -2 for SOT and EOT

            # embedded sot shape: torch.Size([1, 1, 512])
            # embedded prefix shape: torch.Size([1, 4, 512])
            # embedded class names shape: torch.Size([1, 1, 1, 512])
            # embedded suffix shape: torch.Size([1, 0, 512])
            # embedded eot shape: torch.Size([1, 1, 512])
            # effective padding shape: torch.Size([1, 70, 512])
            # Prompt shape: torch.Size([1, 77, 512])

            prompt = torch.cat(
                (
                    self.embedded_sot,
                    self.embedded_prefix.unsqueeze(0),
                    # self.embedded_class_names[i],
                    getattr(self, f"class_embed_{i}"),
                    self.embedded_suffix.unsqueeze(0),
                    self.embedded_eot,
                    self.embedded_pad.repeat(1, padding_size, 1),
                ),
                dim=1,
            )

            prompts.append(prompt)

        prompts = torch.cat(prompts, dim=0)
        # Must have shape torch.Size([200, 77, 512]) (classes, feature1, feature2)
        return prompts

    def reset(self) -> None:
        # TODO: check, doin without `data`

        # self.embedded_prefix.data.copy_(self.init_state_prefix)
        # self.embedded_suffix.data.copy_(self.init_state_suffix)
        with torch.no_grad():
            self.embedded_prefix.copy_(self.init_state_prefix)
            self.embedded_suffix.copy_(self.init_state_suffix)


class TPTModel(nn.Module):
    def __init__(
        self,
        class_names: List[str],
        arch: CLIPModels,
        pretrained: str,
        device="cuda",
    ):
        super().__init__()
        self.device = device or (
            "cuda" if torch.cuda.is_available() else "cpu")

        clip_model: open_clip.model.CLIP
        clip_model, _, _ = open_clip.create_model_and_transforms(  # type:ignore
            model_name=arch,  # type:ignore
            pretrained=pretrained,
            device=device,
            force_quick_gelu=True,
        )

        self.model = clip_model
        self.model.eval()

        self.tokenizer = open_clip.get_tokenizer(arch)  # type:ignore
        self.class_names = class_names

        self.visual: open_clip.transformer.VisionTransformer = (  # type:ignore
            clip_model.visual
        )  # type:ignore
        self.visual.eval()

        self.token_embedding = clip_model.token_embedding

        self.positional_embedding = clip_model.positional_embedding
        self.transformer = clip_model.transformer
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.attn_mask = clip_model.attn_mask
        self.text_pool_type = clip_model.text_pool_type

        for _, param in self.named_parameters():
            param.requires_grad_(False)

        self.prompt_learner = TPTPromptLearner(
            arch=arch, class_names=class_names, clip_model=clip_model
        )

    def _pool(self, x: torch.Tensor):
        if self.visual.attn_pool is not None:
            if self.visual.attn_pool_contrastive is not None:
                # This is untested, WIP pooling that should match paper
                x = self.visual.ln_post(
                    x
                )  # TBD LN first or separate one after each pool?
                tokens = self.visual.attn_pool(x)
                if self.visual.attn_pool_type == "parallel":
                    pooled = self.visual.attn_pool_contrastive(x)
                else:
                    assert self.visual.attn_pool_type == "cascade"
                    pooled = self.visual.attn_pool_contrastive(tokens)
            else:
                # this is the original OpenCLIP CoCa setup, does not match paper
                x = self.visual.attn_pool(x)
                x = self.visual.ln_post(x)
                pooled, tokens = self.visual._global_pool(x)
        elif self.visual.final_ln_after_pool:
            pooled, tokens = self.visual._global_pool(x)
            pooled = self.visual.ln_post(pooled)
        else:
            x = self.visual.ln_post(x)
            pooled, tokens = self.visual._global_pool(x)

        return pooled, tokens, x

    def _forward_image(self, x: torch.Tensor) -> torch.Tensor:
        x = self.visual._embeds(x)
        x = self.visual.transformer(x)

        pooled, tokens, x = self._pool(x)

        if self.visual.proj is not None:
            pooled = pooled @ self.visual.proj
        if self.visual.output_tokens:
            return pooled, tokens, x  # type:ignore

        return pooled

    def __encode_image(
        self, image, normalize: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        pooled_pre_norm = self._forward_image(image)
        return (
            F.normalize(pooled_pre_norm, dim=-1) if normalize else pooled_pre_norm
        )  # type:ignore

    def __encode_text(self, text=None, normalize: bool = False):
        cast_dtype = self.transformer.get_cast_dtype()

        x = self.prompt_learner().to(cast_dtype)

        text = self.prompt_learner.tokenized_initial_full_prompt

        x = x + self.positional_embedding.to(cast_dtype)
        x = self.transformer(x, attn_mask=self.attn_mask)
        x = self.ln_final(x)  # [batch_size, n_ctx, transformer.width]
        x = text_global_pool(x, text, self.text_pool_type)  # type:ignore
        if self.text_projection is not None:
            if isinstance(self.text_projection, nn.Linear):
                x = self.text_projection(x)
            else:
                x = x @ self.text_projection

        return F.normalize(x, dim=-1) if normalize else x

    def forward(self, image: torch.Tensor):
        """
        Inference function for the CLIP model.

        Args:
            images (torch.Tensor): Input images.
        Returns:
            logits (torch.Tensor): Logits from the CLIP model.
        """

        with torch.no_grad():
            image_features = self.__encode_image(image, normalize=True)
        text_features = self.__encode_text(normalize=True)

        logit_scale = self.model.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        return logits, image_features

    def reset(self):
        """
        Reset the prompt learner to its initial state.
        """
        self.prompt_learner.reset()


class TPT(nn.Module):
    def __init__(
        self,
        pretrained: str,
        arch: CLIPModels,
        class_names: List[str],
        tta_steps: int = 1,
        lr: float = 0.0001,
        device="cuda",
    ):
        super().__init__()
        self.device = device or (
            "cuda" if torch.cuda.is_available() else "cpu")
        self.tpt_steps = tta_steps

        model = TPTModel(
            class_names=class_names,
            arch=arch,
            pretrained=pretrained,
            device=self.device,
        )
        self.model = model.to(self.device)
        self.model.eval()

        # # # TEST - learnable layer norm
        # self.model.visual.ln_post.requires_grad_(True)
        # self.model.ln_final.requires_grad_(True)

        # Get all trainable parameters (filter by requires_grad)
        trainable_params = [
            p for p in self.model.parameters() if p.requires_grad]

        # Initialize optimizer with trainable parameters
        self.optim = torch.optim.AdamW(trainable_params, lr=lr)
        self.scaler = torch.cuda.amp.GradScaler()

        self.optim_init = deepcopy(self.optim.state_dict())

        # Initialize backup lists
        self.ln_backup = {
            "weights": [],  # For gamma (scale)
            "biases": [],  # For beta (shift)
        }

        # Backup all LN params in text encoder
        for block in self.model.transformer.resblocks:  # type:ignore
            self.ln_backup["weights"].append(
                block.ln_1.weight.data.detach().clone()
            )  # gamma for ln_1
            self.ln_backup["biases"].append(
                block.ln_1.bias.data.detach().clone()
            )  # beta for ln_1
            self.ln_backup["weights"].append(
                block.ln_2.weight.data.detach().clone()
            )  # gamma for ln_2
            self.ln_backup["biases"].append(
                block.ln_2.bias.data.detach().clone()
            )  # beta for ln_2

        # Backup final LN
        self.ln_backup["weights"].append(
            self.model.ln_final.weight.data.detach().clone()
        )
        self.ln_backup["biases"].append(
            self.model.ln_final.bias.data.detach().clone())

    def set_tta_steps(self, tta_steps: int) -> None:
        """
        Set the number of TTA steps.

        Args:
            tta_steps (int): Number of TTA steps.
        """
        self.tpt_steps = tta_steps

    def forward(self, input: torch.Tensor):
        selected_idx = None

        for step in range(self.tpt_steps):
            with torch.cuda.amp.autocast():

                logits, image_features = self.model(input)

                # Select the most confident samples
                if selected_idx is not None:
                    logits = logits[selected_idx]
                else:
                    logits, selected_idx = self.__select_confident_samples(
                        logits)

                # Compute the average entropy loss
                loss = self.__avg_entropy_loss(logits)

            self.optim.zero_grad()
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optim)
            self.scaler.update()

        # Actual inference
        with torch.no_grad(), torch.autocast("cuda"):
            # take only the last image of the input
            # input = input[-1].unsqueeze(0)
            logits, _ = self.model(input)

            marginal_prob = F.softmax(logits, dim=1).mean(0)
            pred_class = marginal_prob.argmax().item()

        self.__reset()

        return pred_class

    def __select_confident_samples(
        self, logits: torch.Tensor, top: float = 0.1
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Selects the top-k samples with the lowest entropy from the logits.

        Args:
            logits (torch.Tensor): The logits from the model.
            top (float): The fraction of samples to select.
                For example, if top=0.1, it selects the top 10% of samples.
        Returns:
            torch.Tensor: The selected logits.
            torch.Tensor: The indices of the selected samples.

        [Reference](https://github.com/azshue/TPT/blob/63ecbace79694205d7884e63fdc3137a200f0b0e/tpt_classification.py#L41C5-L41C11)
        """
        batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
        idx = torch.argsort(batch_entropy, descending=False)[
            : int(batch_entropy.size()[0] * top)
        ]

        return logits[idx], idx

    def __avg_entropy_loss(self, outputs: torch.Tensor) -> torch.Tensor:
        """
        Computes the average entropy of the model's outputs.
        Args:
            outputs (torch.Tensor): The model's outputs.
        Returns:
            torch.Tensor: The average entropy.

        [Reference](https://github.com/azshue/TPT/blob/63ecbace79694205d7884e63fdc3137a200f0b0e/tpt_classification.py#L46)
        """
        logits = outputs - outputs.logsumexp(
            dim=-1, keepdim=True
        )  # logits = outputs.log_softmax(dim=1) [N, 1000]
        avg_logits = logits.logsumexp(dim=0) - np.log(
            logits.shape[0]
        )  # avg_logits = logits.mean(0) [1, 1000]
        min_real = torch.finfo(avg_logits.dtype).min
        avg_logits = torch.clamp(avg_logits, min=min_real)

        return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1)

    def __reset(self) -> None:
        """Full reset of prompt learner and optimizer state"""
        # 1. Reset prompt embeddings
        for p in self.model.parameters():
            p.grad = None

        self.model.reset()

        # with torch.no_grad():
        #     self.embedded_prefix.copy_(self.init_state_prefix)
        #     self.embedded_suffix.copy_(self.init_state_suffix)

        with torch.no_grad():
            idx = 0
            # Reset LN params in text encoder
            for block in self.model.transformer.resblocks:  # type:ignore
                block.ln_1.weight.data.copy_(self.ln_backup["weights"][idx].clone())
                block.ln_1.bias.data.copy_(self.ln_backup["biases"][idx].clone())
                idx += 1
                block.ln_2.weight.data.copy_(
                    self.ln_backup["weights"][idx].clone())
                block.ln_2.bias.data.copy_(
                    self.ln_backup["biases"][idx].clone())
                idx += 1

            # # Reset final LN
            self.model.ln_final.weight.data.copy_(
                self.ln_backup["weights"][-1].clone())
            self.model.ln_final.bias.data.copy_(
                self.ln_backup["biases"][-1].clone())

        # # 2. Reset optimizer state
        self.optim.load_state_dict(deepcopy(self.optim_init))

### Running


In [None]:
augmenter = ImageTransform(
    model_transform=kornia_preprocess,
    custom_transform=kornia_random_crop,
    n_views=63,
    device="cpu",
)

dataloader, dataset = ImagenetA(augmenter, num_workers=5)

clip_model, _, _ = open_clip.create_model_and_transforms(
    model_name="ViT-B-16",
    pretrained="openai",
    device=DEVICE,
    force_quick_gelu=True,
)
clip_model.eval()  # type:ignore

wrapper_clip = TPT(
    arch="ViT-B-16",  # type:ignore
    pretrained="openai",
    class_names=dataset.class_code_to_label.values(),  # type:ignore
    tta_steps=1,
    lr=5e-3,
)

bench(wrapper_clip, dataloader, DEVICE,
      reduce=None, comment="tpt", visualize=False)

# 4. Trying to get a better at TTA (our contribution)

- augmix (as a note criticizing TPT's paper.)
- augment top 1


TPT with CoOp is quite slow due to the finetuning of the prompt. The idea is to try to get better or similar performances getting inspiration form TPT and other TTA methods, but, possibly, without any finetuning, or if needed, with a possibly faster finetuning.


$\delta = \Alpha + 3* \gamma$


## A. Augment Top 1 🚀

The idea is to remove the prompt learner (CoOp style) from the TPT model and use the most confident samples logits (top 1%) and the original image ones, average them and use the average logits as the final prediction.

We want to keep the most confident samples logits, as they are the ones that are more likely to be correct, and average them with the original image logits, to avoid getting a too biased prediction and "losing context".

<blockquote>

```python
final_logits = torch.cat((selected_logits, initial_logits), dim=0)

marginal_prob = F.softmax(final_logits, dim=1).mean(0)
pred_class = int(marginal_prob.argmax().item())
```

</blockquote>

We expect this to be way faster than TPT (as no finetuning is done) and to have slightly better performances, as we are using the most confident samples logits, which are more likely to be correct, and keeping the original image logits to avoid getting a too biased prediction.

- Why the "biased prediction"? Because the augmentations are random crop, so the model might be biased towards the augmented images, which might not be representative of the original image. By averaging the logits, we can mitigate this bias and get a more reliable prediction.
- Is this the best method to do this? No, but it's a good starting point, it's simple and it works well in practice.


### Model


In [None]:
class ClipWrapper(nn.Module):
    def __init__(
        self,
        model: nn.Module,
        class_labels: dict,
        prompt: str = "a photo of a {}",
        device: str = "cuda",
    ):
        super().__init__()
        self.device = device

        self.tokenizer = open_clip.get_tokenizer("ViT-B-16")
        self.model: open_clip.model.CLIP = model
        self.logit_scale = model.logit_scale.data.exp()
        # self.logit_scale = model.log

        with torch.no_grad():
            prompts = torch.cat(
                [self.tokenizer(prompt.format(c)) for c in class_labels.values()]
            ).to(device)
            self.text_features = model.encode_text(prompts, normalize=True)

    def select_confident_samples(
        self, logits: torch.Tensor, top: float = 0.1
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Selects the top-k samples with the lowest entropy from the logits.

        Args:
            logits (torch.Tensor): The logits from the model.
            top (float): The fraction of samples to select.
                For example, if top=0.1, it selects the top 10% of samples.
        Returns:
            torch.Tensor: The selected logits.
            torch.Tensor: The indices of the selected samples.

        [Reference](https://github.com/azshue/TPT/blob/63ecbace79694205d7884e63fdc3137a200f0b0e/tpt_classification.py#L41C5-L41C11)
        """
        batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
        idx = torch.argsort(batch_entropy, descending=False)[
            : int(batch_entropy.size()[0] * top)
        ]

        return logits[idx], idx

    def forward(self, x: torch.Tensor) -> int:
        with torch.no_grad(), torch.autocast("cuda"):
            image_features = self.model.encode_image(x, normalize=True)

            initial_image_features = image_features[-1:]
            filtered_image_features = image_features[:-1:]

            # filter logits
            initial_logits = (
                self.logit_scale * initial_image_features @ self.text_features.t()
            )
            filtered_logits = (
                self.logit_scale * filtered_image_features @ self.text_features.t()
            )

            # Get top k logits
            selected_logits, _ = self.select_confident_samples(
                # filtered_logits, top=1 / filtered_logits.shape[0]
                filtered_logits,
                top=0.1,
            )

            # selected_logits = selected_logits.mean(0, keepdim=True)

            # final_logits = selected_logits
            final_logits = torch.cat((selected_logits, initial_logits), dim=0)

            marginal_prob = F.softmax(final_logits, dim=1).mean(0)
            pred_class = int(marginal_prob.argmax().item())

        return pred_class

### Running


In [None]:
clip_model, _, _ = open_clip.create_model_and_transforms(
    # model_name="ViT-B-32", pretrained="datacomp_xl_s13b_b90k", device=device#, force_quick_gelu=True
    model_name="ViT-B-16",
    pretrained="openai",
    device=DEVICE,
    force_quick_gelu=True,
)
clip_model.eval()  # type:ignore

# Create a ClipSkeleton instance
wrapper_clip = ClipWrapper(
    clip_model, class_labels=dataset.class_code_to_label, device=DEVICE  # type:ignore
).to(DEVICE)

bench(wrapper_clip, dataloader, DEVICE, reduce=None, comment="top1", visualize=False)

## B. TPT with Top 1

The idea is pretty simple: use the top 1 + original image logits as specified above, but on TPT w/ CoOp.

The implementation is straightforward: override the `forward` method of the `TPT` class (the one which manages the finetuning of the `TPTModel`), so that it uses the top 1 + original image logits instead of the prompt learner. Note that the finetuning of the model is kept as is, only the final prediction is changed.

**Diff**:

<blockquote>

```python
with torch.no_grad(), torch.autocast("cuda"):
            # take only the last image of the input
            logits, _ = self.model(input)

            original_logits = logits[-1:]
            filtered_logits = logits[:-1:]

            # Get top k logits
            selected_logits, _ = self.__select_confident_samples(
                filtered_logits,
                top=0.1,
            )

            final_logits = torch.cat((selected_logits, original_logits), dim=0)

            marginal_prob = F.softmax(final_logits, dim=1).mean(0)
            pred_class = marginal_prob.argmax().item()
```

</blockquote>


### Model


In [None]:
class TPTTop1(TPT):
    def __init__(
        self,
        pretrained: str,
        arch: CLIPModels,
        class_names: List[str],
        tta_steps: int = 1,
        lr: float = 0.0001,
        device="cuda",
    ):
        super().__init__(
            pretrained=pretrained,
            arch=arch,
            class_names=class_names,
            tta_steps=tta_steps,
            lr=lr,
            device=device,
        )

    def forward(self, input: torch.Tensor):
        selected_idx = None

        for step in range(self.tpt_steps):
            with torch.cuda.amp.autocast():

                logits, image_features = self.model(input)

                # Select the most confident samples
                if selected_idx is not None:
                    logits = logits[selected_idx]
                else:
                    logits, selected_idx = self.__select_confident_samples(logits)

                # Compute the average entropy loss
                loss = self.__avg_entropy_loss(logits)

            self.optim.zero_grad()
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optim)
            self.scaler.update()

        # Actual inference
        with torch.no_grad(), torch.autocast("cuda"):
            # take only the last image of the input
            logits, _ = self.model(input)

            original_logits = logits[-1:]
            filtered_logits = logits[:-1:]

            # Get top k logits
            selected_logits, _ = self.__select_confident_samples(
                filtered_logits,
                top=0.1,
            )

            final_logits = torch.cat((selected_logits, original_logits), dim=0)

            marginal_prob = F.softmax(final_logits, dim=1).mean(0)
            pred_class = marginal_prob.argmax().item()

        self.__reset()

        return pred_class

### Running


In [None]:
wrapper_clip = TPT(
    arch="ViT-B-16",  # type:ignore
    pretrained="openai",
    class_names=dataset.class_code_to_label.values(),  # type:ignore
    tta_steps=1,
    lr=5e-3,
)

bench(
    wrapper_clip, dataloader, DEVICE, reduce=None, comment="tpt-top1", visualize=False
)

## C. Self-Supervised Retrieval: RENAME THIS

Here we get inspiration from DinoV2's [[5](#ref-dinov2)] self-supervised retrieval method. Using 63+1 augmentations, as in TPT, 6 clusters (kmeans) are populated, then cluster confidences are computed (mean cluster confidence, metric: cosine similarity). The most confident cluster is selected and it's logits are averages togheter with the original image logits.

We don't expect much from this methods as CLIP is hasn't been trained, compared to DinoV2, for extracting saliency maps, so the clusters are not expected to be very meaningful. Still, this is interesting as it can be used to visualize the clusters to try to interpret, a little bit, what the model is doing. Of course this is unreliable.

TODO: add plotting examples: cherry pick one that works well and one that doesn't work well at all, to show the difference in the clusters and how they are not much meaningful.

TODO: it could be interesting to try to use the clusters to get a better prediction when the model is failing to classify an image, e.g. when the confidence is low.


### Model


In [None]:
class ClipWrapper(nn.Module):
    def __init__(
        self,
        model: nn.Module,
        class_labels: dict,
        prompt: str = "a photo of a {}",
        device: str = "cuda",
    ):
        super().__init__()
        self.device = device
        self.class_labels = class_labels

        self.tokenizer = open_clip.get_tokenizer("ViT-B-16")
        self.model = model
        self.logit_scale = model.logit_scale.exp()

        # Precompute text features
        with torch.no_grad():
            prompts = torch.cat(
                [self.tokenizer(prompt.format(c)) for c in class_labels.values()]
            ).to(device)
            self.text_features = model.encode_text(prompts, normalize=True)

        self.kmeans = KMeans(n_clusters=4, random_state=42)

    def forward(self, x: torch.Tensor) -> int:
        with torch.no_grad(), torch.autocast("cuda"):
            # x: (B, 3, 224, 224)
            image_features = self.model.encode_image(x, normalize=True)

            # Move to CPU and convert to numpy for sklearn
            features_np = image_features.cpu().numpy()

            # Standardize features
            from sklearn.preprocessing import StandardScaler

            scaler = StandardScaler()
            X_scaled = scaler.fit_transform(features_np)

            # Cluster features
            from sklearn.cluster import KMeans

            kmeans = KMeans(n_clusters=6, random_state=42)
            cluster_labels = kmeans.fit_predict(X_scaled)

            ###################################################

            # # get the cluster with higher confidence
            # cluster_confidences = []
            # for cluster_idx in range(6):
            #     cluster_features = image_features[cluster_labels == cluster_idx]
            #     logits = self.logit_scale * cluster_features @ self.text_features.t()
            #     cluster_confidences.append(logits.mean().item())

            # # Get the cluster with the highest confidence
            # best_cluster_idx = cluster_confidences.index(max(cluster_confidences))

            # image_features_r = image_features[cluster_labels == best_cluster_idx]

            # image_features_r = torch.cat((image_features_r, image_features[-1:]), dim=0)

            # logits = self.logit_scale * image_features_r @ self.text_features.t()

            # marginal_prob = F.softmax(logits, dim=1).mean(0)

            # pred_class = marginal_prob.argmax().item()

            ###################################################

            # Cluster closer to the original image
            cluster_confidences = []
            for cluster_idx in range(6):
                cosine_sim = (
                    image_features[cluster_labels == cluster_idx]
                    @ image_features[-1:].t()
                )
                cluster_confidences.append(cosine_sim.mean().item())

            # Get the cluster with the highest confidence
            best_cluster_idx = cluster_confidences.index(max(cluster_confidences))

            image_features_r = image_features[cluster_labels == best_cluster_idx]

            image_features_r = torch.cat((image_features_r, image_features[-1:]), dim=0)

            logits = self.logit_scale * image_features_r @ self.text_features.t()

            marginal_prob = F.softmax(logits, dim=1).mean(0)

            pred_class = marginal_prob.argmax().item()

            # # Accuracy: 55.72%
            # # Latency: 234.35 ms

            ####################################################
            # # # ll = []

            # # # for cluster_idx in range(6):
            # # #     cluster_indices = (cluster_labels == cluster_idx)
            # # #     if len(cluster_indices) == 0:
            # # #         continue

            # # #     cluster_features = image_features[cluster_indices]
            # # #     logits = self.logit_scale * cluster_features @ self.text_features.t()
            # # #     logits = logits.mean(dim=0)

            # # #     ll.append(logits)
            # # #     print(logits.shape)
            # # # # ll
            # # # # print(ll.shape)
            # # # ll = torch.stack(ll, dim=0)
            # # # ll = ll.mean(dim=0, keepdim=True)

            # # # marginal_prob = F.softmax(ll, dim=1).mean(0)
            # # # pred_class = marginal_prob.argmax().item()
            ####################################################

            # exit()

            # image_features = torch.cat((image_features_r, image_features[-1:]), dim=0)

            # logits = self.logit_scale * image_features @ self.text_features.t()
            # marginal_prob = F.softmax(logits, dim=1).mean(0)
            # pred_class = marginal_prob.argmax().item()

            # # Visualize each cluster's images
            # import matplotlib.pyplot as plt
            # from torchvision.transforms.functional import to_pil_image

            # for cluster_idx in range(6):
            #     # Get indices of images in this cluster
            #     cluster_indices = (cluster_labels == cluster_idx).nonzero()[0]

            #     if len(cluster_indices) == 0:
            #         continue

            #     print(f"Cluster {cluster_idx} has {len(cluster_indices)} images")

            #     # Setup plot
            #     cols = min(8, len(cluster_indices))
            #     rows = (len(cluster_indices) + cols - 1) // cols
            #     plt.figure(figsize=(cols * 2, rows * 2))
            #     plt.suptitle(f"Cluster {cluster_idx} - {len(cluster_indices)} images")

            #     for plot_idx, img_idx in enumerate(cluster_indices, start=1):
            #         if plot_idx > cols * rows:
            #             break

            #         img = x[img_idx].permute(1, 2, 0).cpu().numpy()
            #         img = to_pil_image(img)
            #         plt.subplot(rows, cols, plot_idx)
            #         plt.imshow(img)
            #         plt.axis('off')

            #     plt.tight_layout()
            #     plt.show()

            ####################################################################

            # # # Visualize cluster_labels
            # import matplotlib.pyplot as plt
            # plt.figure(figsize=(10, 8))
            # plt.scatter(X_scaled[:, 0], X_scaled[:, 1], c=cluster_labels, cmap='viridis', alpha=0.6)
            # plt.title('KMeans Clustering of Image Features')
            # plt.xlabel('Feature 1')
            # plt.ylabel('Feature 2')
            # plt.colorbar()
            # plt.show()

            # from sklearn.manifold import TSNE

            # # Apply t-SNE
            # tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
            # X_tsne = tsne.fit_transform(X_scaled)

            # import matplotlib.pyplot as plt

            # plt.figure(figsize=(10, 8))

            # # If you did clustering, color by cluster
            # scatter = plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=clusters, cmap='viridis', alpha=0.6)

            # # If you have true labels, you could color by those instead
            # # scatter = plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=true_labels, cmap='viridis', alpha=0.6)

            # plt.colorbar(scatter)
            # plt.title('t-SNE Visualization with Clusters')
            # plt.xlabel('t-SNE dimension 1')
            # plt.ylabel('t-SNE dimension 2')
            # plt.show()

            # cosine_sim = torch.mm(image_features, image_features.t())
            # cosine_sim = image_features @ image_features.t()

            # # kmeans clusters
            # kmeans = KMeans(n_clusters=8, random_state=42)
            # kmeans.fit(cosine_sim.cpu().numpy())
            # print(kmeans.labels_)

            # show clusters

            # #  2. Get similarities of the last image ([-1]) with all others
            # last_img_similarities = cosine_sim[-1, :]  # (B,)

            # # 3. Sort indices (descending order, excluding the last image itself)
            # sorted_indices = torch.argsort(last_img_similarities, descending=True).cpu().numpy()
            # # sorted_indices = sorted_indices[sorted_indices != len(x)-1]  # Remove self-comparison

            # # 4. Visualize the last image + top-k most similar images
            # k = 63  # Number of similar images to display
            # cols = 8
            # rows = (image_features.shape[0] + cols - 1) // cols
            # plt.figure(figsize=(cols * 2, rows * 2))
            # for i, idx in enumerate(sorted_indices[:k], start=2):
            #     img = x[idx].permute(1, 2, 0).cpu().numpy()
            #     img = TF.to_pil_image(img)

            #     plt.subplot(rows, cols, i)
            #     plt.imshow(img)
            #     plt.title(f"Sim: {last_img_similarities[idx]:.3f}")
            #     plt.axis('off')

            # plt.tight_layout()
            # plt.show()

            # Perform KMeans clustering
            # kmeans = KMeans(n_clusters=4, random_state=42)
            # kmeans.fit(cosine_sim.cpu().numpy())
            # labels = kmeans.labels_
            # print(labels)

            # exit()

        return int(pred_class)

### Running


In [None]:
# Load the CLIP model
clip_model, _, _ = open_clip.create_model_and_transforms(
    model_name="ViT-B-16",
    pretrained="openai",
    device=DEVICE,
    force_quick_gelu=True,
)
clip_model.eval()  # type:ignore

# Create a ClipSkeleton instance
wrapper_clip = ClipWrapper(
    clip_model, class_labels=dataset.class_code_to_label, device=DEVICE  # type:ignore
).to(DEVICE)

bench(wrapper_clip, dataloader, DEVICE, reduce=200, comment="", visualize=False)

## C. (stupid) Adaptive Layer Norm 😥

dire che ho provato anche a fare layernorm learnable e effettivamente funziona meglio, ma che rottura di coglioni, lo volevo senza backprop.

TODO: add what we expect


## D. TNT

TODO: add amount of learnable parameters wrt TPT

Implementation of TNT [[6](#ref-tnt2023)]. It's pretty straightforward, it's CLIP with learnable random noise on the augmented images (noise shape: `(3, height, width)`).

We expect this to be slightly faster than TPT, and have slightly better performances, as in the paper.


### Model


In [None]:
class TNT(nn.Module):
    def __init__(
        self,
        model: nn.Module,
        class_labels: dict,
        prompt: str = "a photo of a {}",
        device: str = "cuda",
        tnt_steps: int = 3,
        top_k: float = 0.1,
        epsilon: float = 1 / 255,
        lr: float = 1e-3,
        alpha: float = 1.0,
        beta: float = 1.0,
        temperature: float = 7e-3,
    ):
        super().__init__()
        self.device = device
        self.model: open_clip.model.CLIP = model
        self.logit_scale = model.logit_scale.data.exp()
        self.tokenizer = open_clip.get_tokenizer("ViT-B-16")
        self.tnt_steps = tnt_steps
        self.top_k = top_k
        self.eps = epsilon
        self.lr = lr
        self.alpha = alpha
        self.beta = beta
        self.temperature = temperature

        with torch.no_grad():
            prompts = torch.cat(
                [self.tokenizer(prompt.format(c)) for c in class_labels.values()]
            ).to(device)
            self.text_features = model.encode_text(prompts, normalize=True)

        self.noise = None

    def reset(self):
        self.noise = None

    def select_confident_samples(
        self, logits: torch.Tensor, top: float = 0.1
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Selects the top-k samples with the lowest entropy from the logits.

        Args:
            logits (torch.Tensor): The logits from the model.
            top (float): The fraction of samples to select.
                For example, if top=0.1, it selects the top 10% of samples.
        Returns:
            torch.Tensor: The selected logits.
            torch.Tensor: The indices of the selected samples.

        [Reference](https://github.com/azshue/TPT/blob/63ecbace79694205d7884e63fdc3137a200f0b0e/tpt_classification.py#L41C5-L41C11)
        """
        batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
        idx = torch.argsort(batch_entropy, descending=False)[
            : int(batch_entropy.size()[0] * top)
        ]

        return logits[idx], idx

    def forward(self, x: torch.Tensor) -> int:
        x = x.to(self.device)

        if self.noise is None:
            self.noise = torch.randn_like(
                x[0], requires_grad=True, device=self.device
            )  # , dtype=torch.float16)
            self.noise.data = self.noise.clamp(-self.eps, self.eps)

        self.noise.requires_grad = True

        with torch.autocast(self.device):  # , dtype=torch.float16):
            for _ in range(self.tnt_steps):
                # x_aug = x + torch.clamp(self.noise, 0, 1)[None, ...]
                x_aug = x + self.noise[None, ...].clamp(0, 1)

                image_features = self.model.encode_image(x_aug, normalize=True)
                logits = self.logit_scale * image_features @ self.text_features.t()

                # Select top-k logits
                top_logits, top_idx = self.select_confident_samples(
                    logits, top=self.top_k
                )
                top_features = image_features[top_idx]

                # Entropy loss
                prob = F.softmax(top_logits, dim=1).mean(dim=0)
                entropy_loss = -(prob * prob.log()).sum()

                # Inter-view consistency loss
                pairwise_dist = torch.cdist(top_features, top_features, p=2)
                inter_view_loss = pairwise_dist.sum()

                # Total loss
                loss = self.alpha * entropy_loss + self.beta * inter_view_loss
                loss.backward()

                # Update noise
                with torch.no_grad():
                    grad = self.noise.grad
                    self.noise -= self.lr * grad.sign()
                    self.noise.clamp_(-self.eps, self.eps)
                    self.noise.requires_grad = True
                    self.noise.grad = None

        with torch.no_grad(), torch.autocast(self.device):
            x_aug = x + self.noise[None, ...].clamp(0, 1)[-1:]
            image_features = self.model.encode_image(x_aug, normalize=True)
            logits = self.logit_scale * image_features @ self.text_features.t()
            probs = F.softmax(logits / self.temperature, dim=1).mean(dim=0)
            pred_class = int(probs.argmax().item())

        return pred_class

### Running


In [None]:
clip_model, _, _ = open_clip.create_model_and_transforms(
    # model_name="ViT-B-32", pretrained="datacomp_xl_s13b_b90k", device=device#, force_quick_gelu=True
    model_name="ViT-B-16",
    pretrained="openai",
    device=DEVICE,
    force_quick_gelu=True,
)
clip_model.eval()  # type:ignore

# Set the model to evaluation mode
for param in clip_model.parameters():  # type:ignore
    param.requires_grad = False

# Create a ClipSkeleton instance
wrapper_clip = TNT(
    clip_model,  # type:ignore
    class_labels=dataset.class_code_to_label,
    device=DEVICE,
    tnt_steps=1,  # type:ignore
).to(DEVICE)


bench(wrapper_clip, dataloader, DEVICE, reduce=None, comment="tnt", visualize=False)

## E. TNT + Top 1

It's the same as TNT, but with the top 1 + original image logits as final prediction.

We expect this to be slightly more accurate than TNT.

**Diff**:

<blockquote>

```python
with torch.no_grad(), torch.autocast(self.device):
            x_aug = x + self.noise[None, ...].clamp(0, 1)
            image_features = self.model.encode_image(x_aug, normalize=True)
            logits = self.logit_scale * image_features @ self.text_features.t()

            selected_logits, _ = self.select_confident_samples(
                logits[:-1], top=self.top_k
            )
            final_logits = torch.cat((selected_logits, logits[-1:]), dim=0)
            probs = F.softmax(final_logits / self.temperature, dim=1).mean(dim=0)
            pred_class = int(probs.argmax().item())
```

</blockquote>


### Model


In [None]:
class TNTTop1(TNT):
    def __init__(
        self,
        model: nn.Module,
        class_labels: dict,
        prompt: str = "a photo of a {}",
        device: str = "cuda",
        tnt_steps: int = 3,
        top_k: float = 0.1,
        epsilon: float = 1 / 255,
        lr: float = 1e-3,
        alpha: float = 1.0,
        beta: float = 1.0,
        temperature: float = 7e-3,
    ):
        super().__init__(
            model=model,
            class_labels=class_labels,
            prompt=prompt,
            device=device,
            tnt_steps=tnt_steps,
            top_k=top_k,
            epsilon=epsilon,
            lr=lr,
            alpha=alpha,
            beta=beta,
            temperature=temperature,
        )

    def forward(self, x: torch.Tensor) -> int:
        x = x.to(self.device)

        if self.noise is None:
            self.noise = torch.randn_like(
                x[0], requires_grad=True, device=self.device
            )  # , dtype=torch.float16)
            self.noise.data = self.noise.clamp(-self.eps, self.eps)

        self.noise.requires_grad = True

        with torch.autocast(self.device):  # , dtype=torch.float16):
            for _ in range(self.tnt_steps):
                # x_aug = x + torch.clamp(self.noise, 0, 1)[None, ...]
                x_aug = x + self.noise[None, ...].clamp(0, 1)

                image_features = self.model.encode_image(x_aug, normalize=True)
                logits = self.logit_scale * image_features @ self.text_features.t()

                # Select top-k logits
                top_logits, top_idx = self.select_confident_samples(
                    logits, top=self.top_k
                )
                top_features = image_features[top_idx]

                # Entropy loss
                prob = F.softmax(top_logits, dim=1).mean(dim=0)
                entropy_loss = -(prob * prob.log()).sum()

                # Inter-view consistency loss
                pairwise_dist = torch.cdist(top_features, top_features, p=2)
                inter_view_loss = pairwise_dist.sum()

                # Total loss
                loss = self.alpha * entropy_loss + self.beta * inter_view_loss
                loss.backward()

                # Update noise
                with torch.no_grad():
                    grad = self.noise.grad
                    self.noise -= self.lr * grad.sign()
                    self.noise.clamp_(-self.eps, self.eps)
                    self.noise.requires_grad = True
                    self.noise.grad = None

        with torch.no_grad(), torch.autocast(self.device):
            x_aug = x + self.noise[None, ...].clamp(0, 1)
            image_features = self.model.encode_image(x_aug, normalize=True)
            logits = self.logit_scale * image_features @ self.text_features.t()

            selected_logits, _ = self.select_confident_samples(
                logits[:-1], top=self.top_k
            )
            final_logits = torch.cat((selected_logits, logits[-1:]), dim=0)
            probs = F.softmax(final_logits / self.temperature, dim=1).mean(dim=0)
            pred_class = int(probs.argmax().item())

        return pred_class

### Running


In [None]:
clip_model, _, _ = open_clip.create_model_and_transforms(
    model_name="ViT-B-16",
    pretrained="openai",
    device=DEVICE,
    force_quick_gelu=True,
)
clip_model.eval()  # type:ignore

# Set the model to evaluation mode
for param in clip_model.parameters():  # type:ignore
    param.requires_grad = False

# Create a ClipSkeleton instance
wrapper_clip = TNT(
    clip_model,
    class_labels=dataset.class_code_to_label,
    device=DEVICE,
    tnt_steps=1,  # type:ignore
).to(DEVICE)


bench(
    wrapper_clip, dataloader, DEVICE, reduce=None, comment="tnt-top1", visualize=False
)

## F. TPS

TODO: add what we expect


## G. FILM

TODO: add what we expect


# 5. Thoughts and Conclusion

Pro e contro di ogni uno

TODO: plot accuracy/latency.

quindi che usare backprop non conviene (citare frustaingly easy)


# 6. Future Work

- ai based augmentation (trying to optimize e.g. random crops)
- do something stupid like merge TPT + TNT + frustatingly easy


---
# References
---
