In [2]:
# Utils

from matplotlib import pyplot as plt
import torchvision.transforms.functional as TF


def plot_images(images):
    cols = 8
    rows = (images.shape[0] + cols - 1) // cols

    plt.figure(figsize=(cols * 2, rows * 2))
    for i in range(images.shape[0]):
        plt.subplot(rows, cols, i + 1)

        img = images[i].permute(1, 2, 0).cpu().numpy()
        if img.dtype.kind == "f":  # float
            img = img.clip(0, 1)  # ensure in [0, 1]

        plt.imshow(img)
        plt.axis("off")

    plt.tight_layout()
    plt.show()

In [3]:
import torch
import torch.nn as nn

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)
torch.cuda.manual_seed(42)

import torchvision
import time
from tqdm import tqdm

image = torchvision.io.read_image(
    "datasets/imagenet-a/n01641577/0.038738_agama _ newt_0.7465035.jpg"
)

n_times = 100
n_augmentations = 63

In [4]:
from typing import Optional, List

import numpy as np
import kornia
import kornia.augmentation as K
import kornia.enhance as Ke


class AugMixKornia(nn.Module):
    def __init__(
        self,
        severity: int = 3,
        width: int = 3,
        depth: int = -1,
        alpha: float = 1.0,
        mixture_width: int = 3,
        chain_depth: int = 3,
        all_ops: bool = True,
        device: Optional[str] = None,
    ):
        """
        AugMix implementation using Kornia with closer fidelity to the original paper.

        Args:
            severity: Severity level of augmentations (1-10)
            width: Width of augmentation chain (not used directly, kept for compatibility)
            depth: Depth of augmentation chain (-1 for random between 1-3)
            alpha: Dirichlet distribution parameter for mixing weights
            mixture_width: Number of augmentation chains to mix
            chain_depth: Number of operations in each chain
            all_ops: Whether to use all augmentation operations
            device: Device to run on (cuda/cpu)
        """
        super().__init__()

        self.severity = severity
        self.alpha = alpha
        self.mixture_width = mixture_width
        self.chain_depth = chain_depth if depth <= 0 else depth
        self.all_ops = all_ops
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        # Define augmentation operations
        self.augmentations = self._get_augmentations()

    def _get_augmentations(self) -> List[nn.Module]:
        """Create a list of augmentation operations that will be randomly applied"""
        severity_factor = self.severity / 10.0

        if self.all_ops:
            # Full set of augmentations similar to original AugMix
            return [
                # AutoContrast
                K.ColorJitter(
                    brightness=0.1 * self.severity, contrast=0.1 * self.severity, p=1.0
                ),
                # Equalize
                Ke.equalize,
                # Posterize
                K.RandomPosterize(bits=max(1, 8 - self.severity), p=1.0),
                # Rotate
                K.RandomRotation(
                    degrees=(-30 * severity_factor, 30 * severity_factor), p=1.0
                ),
                # Solarize
                K.RandomSolarize(
                    thresholds=0.5, additions=(0.0, 0.1 * self.severity), p=1.0
                ),
                # Shear
                K.RandomAffine(
                    degrees=0,
                    shear=(-15 * severity_factor, 15 * severity_factor),
                    p=1.0,
                ),
                # Translate
                K.RandomAffine(
                    degrees=0,
                    translate=(0.1 * severity_factor, 0.1 * severity_factor),
                    p=1.0,
                ),
                # ColorJitter
                K.ColorJitter(
                    brightness=0.1 * self.severity,
                    contrast=0.1 * self.severity,
                    saturation=0.1 * self.severity,
                    hue=0.1,
                    p=1.0,
                ),
            ]
        else:
            # Simplified version
            return [
                K.ColorJitter(
                    brightness=0.1 * self.severity, contrast=0.1 * self.severity, p=1.0
                ),
                Ke.equalize,
                K.RandomAffine(
                    degrees=(-15 * severity_factor, 15 * severity_factor), p=1.0
                ),
            ]

    def _apply_augmentation_chain(self, image: torch.Tensor) -> torch.Tensor:
        """
        Apply a random sequence of augmentations to an image.

        Args:
            image: Input image tensor (C, H, W)

        Returns:
            Augmented image tensor (C, H, W)
        """
        # Randomly select augmentations for this chain
        op_indices = np.random.choice(
            len(self.augmentations), size=self.chain_depth, replace=True
        )

        augmented = image  # Don't clone immediately
        for op_idx in op_indices:
            augmented = self.augmentations[op_idx](augmented)

        return augmented.squeeze(0)

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """
        Apply AugMix to a batch of images.

        Args:
            images: Input batch of images (B, C, H, W) or (C, H, W)

        Returns:
            Augmented batch (same shape as input)
        """
        # Input validation
        if not isinstance(images, torch.Tensor):
            images = K.image_to_tensor(images)

        if images.dim() == 3:
            images = images.unsqueeze(0)

        # Move to device if needed
        if images.device != self.device:
            images = images.to(self.device)

        batch_size = images.shape[0]

        # Sample mixing weights from Dirichlet distribution
        weights = (
            torch.from_numpy(
                np.random.dirichlet([self.alpha] * self.mixture_width, size=batch_size)
            )
            .float()
            .to(self.device)
        )  # Shape (B, mixture_width)

        # Sample weights for mixing with original
        mix_weights = (
            torch.from_numpy(
                np.random.dirichlet([self.alpha, self.alpha], size=batch_size)
            )
            .float()
            .to(self.device)
        )  # Shape (B, 2)

        # Generate augmented versions for each mixture component
        # Pre-allocate memory for augmented versions
        augmented = torch.empty(
            (self.mixture_width, batch_size, *images.shape[1:]), device=self.device
        )

        for i in range(self.mixture_width):
            augmented[i] = self._apply_augmentation_chain(images)

        # Weighted sum of augmented versions
        mixed = torch.einsum("mbchw,bm->bchw", augmented, weights).to(self.device)

        # Final mix with original image
        result = (
            mix_weights[:, 0:1, None, None] * images
            + mix_weights[:, 1:2, None, None] * mixed
        )

        return result.squeeze(0) if result.shape[0] == 1 else result

In [5]:
import kornia.constants


kornia_preprocess = nn.Sequential(
    K.SmallestMaxSize(
        224,
        resample=kornia.constants.Resample.BICUBIC,
    ),
    K.CenterCrop(
        size=(224, 224),
        resample=kornia.constants.Resample.BICUBIC,
    ),
    kornia.enhance.Normalize(
        mean=torch.tensor([0.48145466, 0.4578275, 0.40821073]),
        std=torch.tensor([0.26862954, 0.26130258, 0.27577711]),
    ),
)

kornia_augmix = AugMixKornia()

In [6]:
from torch.utils.data import Dataset
import os


class ImageNetADataset(Dataset):
    """
    Custom Dataset class for the ImageNet-A dataset.

    Set the `transform` parameter so that images work with your model.
    Example usage:
    ```python
        model, transform = clip.load("ViT-B/32")
        dataset = ImageNetADataset(<path>, transform=transform)
    ```
    ----

    The dataset is organized into subdirectories, each named with a class code (e.g., "n01614925").
    Each subdirectory contains images belonging to that class. The dataset also includes a README.txt file that maps class codes to human-readable names.

    The dataset is expected to be structured as follows:
    ```
    datasets/imagenet-a/
        n01440764/
            image1.jpg
            image2.jpg
            ...
        n01614925/
            image1.jpg
            image2.jpg
            ...
        ...
        README.txt
    ```

    """

    def __init__(self, root_dir="datasets/imagenet-a", transform=None):
        """
        Args:
            root_dir (str): Root directory of the ImageNet-A dataset.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform

        self.__download_if_needed()

        # Load mapping from class codes (e.g., "n01614925") to human-readable names
        readme_path = os.path.join(root_dir, "README.txt")
        self.class_code_to_label = self._load_class_mapping(readme_path)

        # Filter valid class directories that match the mapping
        self.class_codes = sorted(
            [
                d
                for d in os.listdir(root_dir)
                if os.path.isdir(os.path.join(root_dir, d))
                and d in self.class_code_to_label
            ]
        )

        # Map class codes to indices
        self.class_code_to_idx = {
            code: idx for idx, code in enumerate(self.class_codes)
        }

        # Collect all image file paths and corresponding labels
        self.samples = self._gather_samples()

        # Inverse mapping from label index to class name
        self.idx_to_label = {
            idx: self.class_code_to_label[code]
            for code, idx in self.class_code_to_idx.items()
        }

    def __download_if_needed(self):
        """
        Check if the dataset is already downloaded. If not, download it.
        """
        if not os.path.exists(self.root_dir):
            raise FileNotFoundError(
                f"Dataset not found at {self.root_dir}. Please download it first."
            )

    def _load_class_mapping(self, readme_path):
        """
        Load class code to human-readable name mapping from README.txt.
        Skips header lines and parses lines in format: 'n01440764 tench'.
        """
        mapping = {}
        with open(readme_path, "r") as file:
            lines = file.readlines()[12:]  # Skip first 12 header lines
            for line in lines:
                parts = line.strip().split(" ", 1)
                if len(parts) == 2:
                    code, name = parts
                    mapping[code] = name
        return mapping

    def _gather_samples(self):
        """
        Walk through each class directory to gather image paths and corresponding labels.
        """
        samples = []
        for class_code in self.class_codes:
            class_dir = os.path.join(self.root_dir, class_code)
            for filename in os.listdir(class_dir):
                if filename.lower().endswith((".jpg", ".jpeg", ".png")):
                    image_path = os.path.join(class_dir, filename)
                    label = self.class_code_to_idx[class_code]
                    samples.append((image_path, label))
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        """
        Load image and return dictionary containing image, label index, and class name.

        Returns:
            image (tensor)
            label (tensor)
        """
        image_path, label = self.samples[idx]

        image = torchvision.io.read_image(image_path).float() / 255.0

        if image.shape[0] == 1:  # Grayscale → RGB
            image = image.repeat(3, 1, 1)

        elif image.shape[0] == 4:  # RGBA → RGB
            image = image[:3, :, :]

        elif image.shape[0] != 3:
            raise ValueError(f"Unsupported number of channels: {image.shape[0]}")

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label)

    def get_class_name(self, idx):
        """
        Get human-readable class name for a given index.
        """
        return self.idx_to_label[idx]

In [7]:
class ImageTransform(nn.Module):
    def __init__(self, model_transform, custom_transform=None, n_views=63, device=None):
        super().__init__()
        self.model_transform = model_transform
        self.custom_transform = custom_transform
        self.n_views = n_views
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        self.eval()
        # self.model_transform.eval()
        # self.custom_transform.eval() if custom_transform is not None else None

    def __call__(self, image: torch.Tensor) -> torch.Tensor:
        """
        Apply the model transform and custom transform to the image.
        """
        with torch.no_grad():
            image = image.to(self.device)

            if self.custom_transform is not None:
                views = image.repeat(self.n_views, 1, 1, 1)
                views = self.custom_transform(views)
                views = torch.cat([views, image.unsqueeze(0)], dim=0)
                views = self.model_transform(views)

                return views
            else:
                return self.model_transform(image)

In [8]:
from torch.utils.data import DataLoader


def ResnetA(
    augmenter: ImageTransform,
    root_dir="datasets/imagenet-a",
):
    """
    Create a DataLoader for the ImageNet-A dataset. Defaults to 1 element per batch.
    Non modifiable. No shuffling.
    Args:
        augmenter (callable):
        root_dir (str): Root directory of the ImageNet-A dataset.

    Returns:
        dataloader (DataLoader): DataLoader for the ImageNet-A dataset.
        dataset (ImageNetADataset): The underlying dataset object.
    """

    def collate_fn(batch):
        """
        Custom collate function to handle the batch of images and labels.
        """

        images = batch[0][0]

        if images.ndim == 3:
            images = images.unsqueeze(0)

        labels = batch[0][1]

        return images, labels

    dataset = ImageNetADataset(root_dir=root_dir, transform=augmenter)
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        collate_fn=collate_fn,
    )

    return dataloader, dataset

# So

1. base model
2. tpt model
3. prompt learning (tpt)


In [9]:
augmenter = ImageTransform(
    model_transform=kornia_preprocess,
    custom_transform=kornia_augmix,
    n_views=63,
)

dataloader, dataset = ResnetA(augmenter)

In [10]:
from copy import deepcopy
from dataclasses import dataclass
import clip
import clip.model


@dataclass(frozen=True)
class CLIPModels:
    ViTB32: str = "ViT-B/32"
    # You can add more, but the `kornia_preprocess` should be modified accordingly
    # ViTB16: str = "ViT-B/16"
    # RN50: str = "RN50"


class TPTPromptLearner(nn.Module):
    def __init__(
        self,
        class_names: List[str],
        clip_model: clip.model.CLIP,
        base_prompt: str = "a photo of a [CLS]",
        device=None,
    ):
        super().__init__()
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        self.class_names = class_names
        self.dtype = clip_model.visual.conv1.weight.dtype
        self.token_embedding = clip_model.token_embedding
        self.token_embedding.requires_grad_(False)

        self.__init_ctx_from_prompt(base_prompt=base_prompt)

    def __init_ctx_from_prompt(self, base_prompt: str) -> None:
        """
        Initialize the context tokens from the base prompt.

        We need to make sure that the CLS token is NOT "exploded" in the prompt.

        The idea is to have prompts tuned without having to manually manage where the CLS token is.

        To do this we need to keep the CLS token position in the prompt, and update it accordingly
        when needed.

        I'm splitting the prompt into prefix and suffix, using [CLS] as a separator.
        They are trained as two different parameters, and then concatenated together.

        """

        # Split the base prompt into prefix and suffix
        promt_prefix = base_prompt.split("[CLS]")[0]
        promt_suffix = base_prompt.split("[CLS]")[1]

        # "Clean" PAD, SOT and EOT tokens
        c_token_sot = torch.tensor([[49406]]).to(self.device)  # SOT
        c_token_eot = torch.tensor([[49407]]).to(self.device)  # EOT
        c_token_pad = torch.tensor([[0]]).to(self.device)  # PAD

        # Tokenize prefix, suffix and class names
        tokenized_prefix = clip.tokenize(promt_prefix).to(self.device)
        tokenized_suffix = clip.tokenize(promt_suffix).to(self.device)

        # remove PAD, SOT and EOT tokens
        # Extract "clean" tokens
        c_tokenized_prefix = tokenized_prefix[
            (tokenized_prefix != c_token_sot)
            & (tokenized_prefix != c_token_eot)
            & (tokenized_prefix != c_token_pad)
        ].to(self.device)
        c_tokenized_suffix = tokenized_suffix[
            (tokenized_suffix != c_token_sot)
            & (tokenized_suffix != c_token_eot)
            & (tokenized_suffix != c_token_pad)
        ].to(self.device)

        tokenized_class_names = clip.tokenize(self.class_names).to(self.device)

        # self.tokenized_class_names_len = torch.argmax(
        #     (tokenized_class_names == 0).int(), dim=1, keepdim=True
        # )

        # BASE full prompt
        # [CLS] + prefix + class_name + suffix + EOT
        # pre-computed as it's used for all classes and images :)
        self.tokenized_initial_full_prompt = clip.tokenize(
            [base_prompt.replace("[CLS]", c) for c in self.class_names]
        )

        # Get base embeddings
        with torch.no_grad():
            self.embedded_sot = self.token_embedding(c_token_sot)
            self.embedded_eot = self.token_embedding(c_token_eot)
            self.embedded_pad = self.token_embedding(c_token_pad)
            self.embedded_prefix = self.token_embedding(c_tokenized_prefix)
            self.embedded_suffix = self.token_embedding(c_tokenized_suffix)
            embedded_class_names = self.token_embedding(tokenized_class_names)
            self.embedded_max_len = embedded_class_names.shape[1]

        # Setup clean embedded_class_names (list)
        # Mask to filter out SOT/EOT/PAD tokens (shape [200, 77])
        mask = (
            (tokenized_class_names != c_token_sot)
            & (tokenized_class_names != c_token_eot)
            & (tokenized_class_names != c_token_pad)
        )

        # Apply mask to embeddings (for each class)
        clean_embeddings = []
        for i in range(embedded_class_names.shape[0]):
            # masked_select would flatten, so we use boolean indexing
            clean_embed = embedded_class_names[i][mask[i]]  # [num_valid_tokens, 512]
            clean_embeddings.append(
                clean_embed.unsqueeze(0)
            )  # [1, num_valid_tokens, 512]

        self.embedded_class_names = clean_embeddings

        for i, embed in enumerate(clean_embeddings):
            self.register_buffer(f"class_embed_{i}", embed)

        # Create "init" states and set learnable parameters
        self.init_state_prefix = self.embedded_prefix.detach().clone()
        self.init_state_suffix = self.embedded_suffix.detach().clone()
        self.embedded_prefix = nn.Parameter(self.embedded_prefix)
        self.embedded_suffix = nn.Parameter(self.embedded_suffix)
        self.register_parameter("embedded_prefix", self.embedded_prefix)
        self.register_parameter("embedded_suffix", self.embedded_suffix)

    def forward(self) -> torch.Tensor:
        prompts = []
        for i in range(len(self.class_names)):

            # embeddeD_max_len: 77
            # embedded_prefix: torch.Size([4, 512])
            # embedded_class_names: torch.Size([1, 1, 512])
            # embedded_suffix: torch.Size([0, 512]

            padding_size = (
                self.embedded_max_len
                - self.embedded_prefix.shape[0]
                - getattr(self, f"class_embed_{i}").shape[1]
                - self.embedded_suffix.shape[0]
            ) - 2  # # -2 for SOT and EOT

            ## embedded sot shape: torch.Size([1, 1, 512])
            ## embedded prefix shape: torch.Size([1, 4, 512])
            ## embedded class names shape: torch.Size([1, 1, 1, 512])
            ## embedded suffix shape: torch.Size([1, 0, 512])
            ## embedded eot shape: torch.Size([1, 1, 512])
            ## effective padding shape: torch.Size([1, 70, 512])
            ## Prompt shape: torch.Size([1, 77, 512])

            prompt = torch.cat(
                (
                    self.embedded_sot,
                    self.embedded_prefix.unsqueeze(0),
                    # self.embedded_class_names[i],
                    getattr(self, f"class_embed_{i}"),
                    self.embedded_suffix.unsqueeze(0),
                    self.embedded_eot,
                    self.embedded_pad.repeat(1, padding_size, 1),
                ),
                dim=1,
            )

            prompts.append(prompt)

        prompts = torch.cat(prompts, dim=0)
        # Must have shape torch.Size([200, 77, 512]) (classes, feature1, feature2)
        return prompts

    def reset(self) -> None:
        # TODO: check, doin without `data`

        # self.embedded_prefix.data.copy_(self.init_state_prefix)
        # self.embedded_suffix.data.copy_(self.init_state_suffix)
        with torch.no_grad():
            self.embedded_prefix.copy_(self.init_state_prefix)
            self.embedded_suffix.copy_(self.init_state_suffix)


class TPTModel(nn.Module):
    def __init__(
        self, class_names: List[str], arch: CLIPModels = CLIPModels.ViTB32, device=None
    ):
        super().__init__()
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        clip_model: clip.model.CLIP
        clip_model, _ = clip.load(arch, device=self.device)

        self.dtype = clip_model.visual.conv1.weight.dtype
        # self.clip = clip_model
        self.image_encoder = clip_model.visual

        self.logit_scale = clip_model.logit_scale.data
        self.positional_embedding = clip_model.positional_embedding
        self.transformer = clip_model.transformer
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection

        for _, param in self.named_parameters():
            param.requires_grad_(False)

        self.prompt_learner = TPTPromptLearner(
            class_names=class_names, clip_model=clip_model
        )

        #

    def __encode_text(
        self, tokenized_prompt: torch.Tensor, embedded_prompt: torch.Tensor
    ) -> torch.Tensor:
        """
        Encode the text prompt using the CLIP model.
            The tokenizer is external.

        Source: CLIP source code. model.py#L343
        """
        x = embedded_prompt + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLP -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = (
            x[torch.arange(x.shape[0]), tokenized_prompt.argmax(dim=-1)]
            @ self.text_projection
        )

        return x

    def encode_image(self, image: torch.Tensor) -> torch.Tensor:
        """
        Encode the image using the CLIP model.

        Args:
            image (torch.Tensor): Input image.

        Returns:
            image_features (torch.Tensor): Normalized encoded image features.
        """
        image_features = self.image_encoder(image.type(self.dtype))
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        return image_features

    def forward(self, image: torch.Tensor, is_image: bool = True) -> torch.Tensor:
        """
        Inference function for the CLIP model.

        Args:
            images (torch.Tensor): Input images.
            is_image (bool): whether the input is an iamge or already image_features.
                If False, the input is assumed to be already image features.
        Returns:
            logits (torch.Tensor): Logits from the CLIP model.
        """
        if is_image:
            with torch.no_grad():
                image_features = self.encode_image(image)
        else:
            image_features = image

        embedded_prompt = self.prompt_learner().type(self.dtype)

        txt_features = self.__encode_text(
            self.prompt_learner.tokenized_initial_full_prompt, embedded_prompt
        )
        txt_features = txt_features / txt_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ txt_features.t()

        if is_image:
            return logits, image_features
        else:
            return logits

    def reset(self):
        """
        Reset the prompt learner to its initial state.
        """
        self.prompt_learner.reset()


class TPT(nn.Module):
    def __init__(
        self,
        class_names: List[str],
        tta_steps: int = 1,
        lr: float = 0.0001,
        arch: CLIPModels = CLIPModels.ViTB32,
        device=None,
    ):
        super().__init__()
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.tpt_steps = tta_steps

        model = TPTModel(
            class_names=class_names,
            arch=arch,
            device=self.device,
        )
        self.model = model.to(self.device)

        # Get all trainable parameters (filter by requires_grad)
        trainable_params = [p for p in self.model.parameters() if p.requires_grad]

        # Initialize optimizer with trainable parameters
        self.optim = torch.optim.AdamW(trainable_params, lr=lr)
        self.scaler = torch.cuda.amp.GradScaler()

        self.optim_init = deepcopy(self.optim.state_dict())

    def set_tta_steps(self, tta_steps: int) -> None:
        """
        Set the number of TTA steps.

        Args:
            tta_steps (int): Number of TTA steps.
        """
        self.tpt_steps = tta_steps

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # manage Prompt learner finetuning
        # so do n iterations with fine tuning
        # then return the single prediction
        # loss, etc, are 100% internal

        selected_idx = None
        for _ in range(self.tpt_steps):
            with torch.cuda.amp.autocast():
                logits, image_features = self.model(input)

                # Select the most confident samples
                if selected_idx is not None:
                    logits = logits[selected_idx]
                else:
                    logits, selected_idx = self.__select_confident_samples(logits)

                # Compute the average entropy loss
                loss = self.__avg_entropy_loss(logits)

            self.optim.zero_grad()
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optim)
            self.scaler.update()

            # print(f"Loss: {loss.item():.3f}")

        # Actual inference
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                # take only the last image of the input
                input = input[-1].unsqueeze(0)
                logits, _ = self.model(input)

        self.__reset()

        return logits

    def __select_confident_samples(
        self, logits: torch.Tensor, top: float = 0.1
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Selects the top-k samples with the lowest entropy from the logits.

        Args:
            logits (torch.Tensor): The logits from the model.
            top (float): The fraction of samples to select.
                For example, if top=0.1, it selects the top 10% of samples.
        Returns:
            torch.Tensor: The selected logits.
            torch.Tensor: The indices of the selected samples.

        [Reference](https://github.com/azshue/TPT/blob/63ecbace79694205d7884e63fdc3137a200f0b0e/tpt_classification.py#L41C5-L41C11)
        """
        batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
        idx = torch.argsort(batch_entropy, descending=False)[
            : int(batch_entropy.size()[0] * top)
        ]

        return logits[idx], idx

    def __avg_entropy_loss(self, outputs: torch.Tensor) -> torch.Tensor:
        """
        Computes the average entropy of the model's outputs.
        Args:
            outputs (torch.Tensor): The model's outputs.
        Returns:
            torch.Tensor: The average entropy.

        [Reference](https://github.com/azshue/TPT/blob/63ecbace79694205d7884e63fdc3137a200f0b0e/tpt_classification.py#L46)
        """
        logits = outputs - outputs.logsumexp(
            dim=-1, keepdim=True
        )  # logits = outputs.log_softmax(dim=1) [N, 1000]
        avg_logits = logits.logsumexp(dim=0) - np.log(
            logits.shape[0]
        )  # avg_logits = logits.mean(0) [1, 1000]
        min_real = torch.finfo(avg_logits.dtype).min
        avg_logits = torch.clamp(avg_logits, min=min_real)

        return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1)

    def __reset(self) -> None:
        """Full reset of prompt learner and optimizer state"""
        # 1. Reset prompt embeddings
        self.model.reset()

        # # 2. Reset optimizer state
        self.optim.load_state_dict(deepcopy(self.optim_init))

        # # 3. Reset gradient scaler if using AMP
        # if hasattr(self, "scaler"):
        #     self.scaler.load_state_dict(torch.cuda.amp.GradScaler().state_dict())


my_tpt = TPT(class_names=dataset.class_code_to_label.values(), tta_steps=0, lr=0.005)

print(f"Total parameters: {sum(p.numel() for p in my_tpt.parameters())}")
print(
    f"Trainable parameters: {sum(p.numel() for p in my_tpt.parameters() if p.requires_grad)}"
)

# for image, label in tqdm(dataloader):
#     l = my_tpt(image)
#     break

Total parameters: 151279360
Trainable parameters: 2048


For TPT, we initialize the prompt as the default hand-crafted one “a photo
of a", and optimize the corresponding 4 tokens in the text input embedding space based on a single
test image. We augment a single test image 63 times using random resized crops and construct a
batch of 64 images, including the original one. Among the 64 predictions, we select the top 10%
(ρ=0.1) confident samples (lowest 10% in self-entropy) and compute the entropy of the averaged
probability of the selected predictions (i.e., marginal entropy). We optimize the prompt to minimize
the marginal entropy for 1 step, using the AdamW optimizer with a learning rate of 0.005.


In [None]:
import torch.nn.functional as F


def bench(
    model: nn.Module,
    dataloader: DataLoader,
    device: str,
    reduce: int | None = None,
):
    """Benchmark the model on the dataset.

    The model must return logits.
    """

    total = 0
    correct = 0

    start = time.time()

    for image, label in tqdm(dataloader):
        image = image.to(device)
        label = label.to(device)

        # with torch.no_grad():
        logits = model(image)

        # pred_class = logits.argmax(dim=-1)
        marginal_prob = F.softmax(logits, dim=1).mean(0)
        pred_class = marginal_prob.argmax().item()

        total += 1
        correct += int((pred_class == label).max().item())

        if reduce:
            if total > reduce:
                break

    end = time.time()

    accuracy = correct / total
    latency = (end - start) / total  # ms

    print(f"Accuracy: {accuracy * 100:.2f}%")
    print(f"Latency: {latency * 1000:.2f} ms")

    return accuracy, latency

my_tpt.set_tta_steps(1)

accuracy, latency = bench(my_tpt, dataloader, device, reduce=256)

  0%|          | 15/7500 [00:14<2:04:21,  1.00it/s]

In [26]:
accuracy

NameError: name 'accuracy' is not defined