In [3]:
import torch
from torch import Tensor
class SoftSort(torch.nn.Module):
    def __init__(self, tau=1.0, hard=False, pow=1.0):
        super(SoftSort, self).__init__()
        self.hard = hard
        self.tau = tau
        self.pow = pow

    def forward(self, scores: Tensor):
        """
        scores: elements to be sorted. Typical shape: batch_size x n
        """
        scores = scores.unsqueeze(-1)
        sorted = scores.sort(descending=True, dim=1)[0]
        pairwise_diff = (scores.transpose(1, 2) - sorted).abs().pow(self.pow).neg() / self.tau
        P_hat = pairwise_diff.softmax(-1)

        if self.hard:
            P = torch.zeros_like(P_hat, device=P_hat.device)
            P.scatter_(-1, P_hat.topk(1, -1)[1], value=1)
            P_hat = (P - P_hat).detach() + P_hat
        return P_hat

In [2]:
import numpy
ss = SoftSort(hard=True)

In [8]:
x = torch.randn(1,3)
x

tensor([[ 1.0710,  1.0674, -0.1502]])

In [10]:
value = torch.tensor([[3.0,2.0,5.0]])
mat = ss(-value)
mat

tensor([[[0., 1., 0.],
         [1., 0., 0.],
         [0., 0., 1.]]])

In [11]:
torch.einsum('blk, bl -> bk', mat, x)

tensor([[ 1.0674,  1.0710, -0.1502]])

In [18]:
torch.einsum('blk, bl -> bk', mat, value)

tensor([[1., 2., 5.]], dtype=torch.float64)

In [19]:
import torch

# Assuming dot_prod is your input tensor with shape [128, 49]
dot_prod = torch.randn(128, 49)  # Example input; replace with your actual tensor

# Create an instance of SoftSort
soft_sort = SoftSort(tau=1.0, hard=True)

# Use SoftSort to rearrange the values based on dot_prod
rearranged_values = soft_sort(dot_prod)

# The rearranged_values will have the same shape as dot_prod
print(rearranged_values.shape)  # Sh

torch.Size([128, 49, 49])


In [22]:
torch.einsum('blk, bl -> bk', rearranged_values, dot_prod)[0]

tensor([-1.1124,  0.3074, -1.1164,  0.3205, -1.1716,  2.3287, -0.8495,  1.1453,
         0.5923, -1.5569, -0.1205, -0.8293, -0.6210,  0.5117, -0.3681,  0.2268,
         1.3512,  0.9039, -1.2890,  0.0369, -0.4759,  1.0174, -0.2304, -1.6933,
        -1.0110,  0.5619, -1.6541, -1.0059, -0.8064, -0.7143,  0.1221,  0.5179,
         1.4403, -0.7206,  1.5478,  0.7394,  0.2535,  0.7082, -2.4703,  1.3706,
        -0.3273, -0.7106, -1.3301, -2.0515, -0.0367, -1.3309,  0.0812,  1.7324,
        -1.0255])

In [23]:
torch.einsum('blk, bl -> bk', rearranged_values, dot_prod)[1]

tensor([-0.6355,  1.7180, -0.5275, -0.8325, -0.7621,  0.5230,  1.1946,  0.2854,
        -0.3085, -0.0471, -2.4254, -0.4359, -0.7054, -1.3043, -0.3422, -0.0800,
         1.1700, -1.1339, -1.0904, -1.1093,  0.3365,  0.6108, -2.7919, -1.8271,
        -0.0098, -0.3422, -0.5166,  1.1289,  1.1492, -1.1392,  2.1992, -1.3139,
         2.1645,  0.2463,  1.5276, -2.4445, -1.2331,  0.8298,  0.0398,  0.0240,
         0.3940,  1.8558,  0.0504,  0.7366, -2.2436, -2.2673, -0.2939, -0.8592,
         1.4809])

In [29]:
dot_prod = torch.randn(3, 4)  # Example input; replace with your actual tensor
dot_prod

tensor([[-0.3010, -0.0103,  0.2229,  0.9927],
        [ 1.4477, -0.4305,  0.3110, -0.3709],
        [ 0.4433, -2.0150,  0.9629, -0.3408]])

In [34]:
import torch

# Assuming dot_prod is your input tensor with shape [128, 49]


# Create an instance of SoftSort
soft_sort = SoftSort(tau=1.0, hard=True)

# Use SoftSort to rearrange the values based on dot_prod
rearranged_values = soft_sort(-1 * dot_prod)

# The rearranged_values will have the same shape as dot_prod
print(rearranged_values.shape)  # Should be [128, 49]

torch.Size([3, 4, 4])


In [35]:
rearranged_values

tensor([[[1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.]],

        [[0., 1., 0., 0.],
         [0., 0., 0., 1.],
         [0., 0., 1., 0.],
         [1., 0., 0., 0.]],

        [[0., 1., 0., 0.],
         [0., 0., 0., 1.],
         [1., 0., 0., 0.],
         [0., 0., 1., 0.]]])

In [38]:
dot_prod

tensor([[-0.3010, -0.0103,  0.2229,  0.9927],
        [ 1.4477, -0.4305,  0.3110, -0.3709],
        [ 0.4433, -2.0150,  0.9629, -0.3408]])

In [33]:
torch.einsum('blk, bl -> bk', rearranged_values, dot_prod)[1]

tensor([ 1.4477, -0.3709, -0.4305,  0.3110])

In [36]:
torch.einsum('blk, bl -> bk', rearranged_values, dot_prod)[1]

tensor([-0.3709,  1.4477,  0.3110, -0.4305])

In [None]:
keys = torch.randn(1,1,3)
x = torch.randn(1,3,3)
dot_prod = torch.matmul(x, keys.transpose(1,2)).squeeze(0) # [1,3]
print('dot_prod = ', dot_prod)
rearrange = torch.einsum('blk,bk->bl', soft_sort(-dot_prod), dot_prod)
print('rearrange = ', rearrange)

x_reordered = torch.gather(x, 1, rearrange.unsqueeze(-1).expand(-1, -1, 3).long())  # [B, N, C]
print('x_reordered = ', x_reordered)

In [39]:
import torch

# Generate random tensors
keys = torch.randn(1, 1, 3)
x = torch.randn(1, 3, 3)

# Calculate the dot product
dot_prod = torch.matmul(x, keys.transpose(1, 2)).squeeze(0)  # [1, 3]
print('dot_prod = ', dot_prod)

# Calculate rearrange using soft_sort
rearranged_values = soft_sort(-dot_prod)  # Assuming soft_sort is defined elsewhere
rearrange = torch.einsum('blk,bk->bl', rearranged_values, dot_prod)  # Adjusted based on the original operation
print('rearrange = ', rearrange)

# Ensure indices are valid for gathering
# Clamp to non-negative and check maximum indices


dot_prod =  tensor([[1.7146],
        [4.3728],
        [1.4453]])
rearrange =  tensor([[1.7146],
        [4.3728],
        [1.4453]])


In [43]:
torch.einsum('blk,bk->bl', rearranged_values, x)  # Adjusted based on the original operation

RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (3) for operand 1 and no ellipsis was given

In [41]:
rearrange_indices = rearrange.long()
rearrange_indices = rearrange_indices.clamp(min=0)  # Ensure non-negative indices
print(rearrange_indices)
# Ensure the maximum index does not exceed bounds of x
if rearrange_indices.max() >= x.size(1):
    raise ValueError("Some indices in rearrange are out of bounds for x.")

# Gather the input x based on the rearrangement
x_reordered = torch.gather(x, 1, rearrange_indices.unsqueeze(-1).expand(-1, -1, 3))  # [B, N, C]
print('x_reordered = ', x_reordered)


tensor([[1],
        [4],
        [1]])


ValueError: Some indices in rearrange are out of bounds for x.

In [4]:
import torch
from torch import Tensor
class SoftSort(torch.nn.Module):
    def __init__(self, tau=1.0, hard=False, pow=1.0):
        super(SoftSort, self).__init__()
        self.hard = hard
        self.tau = tau
        self.pow = pow

    def forward(self, scores: Tensor):
        """
        scores: elements to be sorted. Typical shape: batch_size x n
        """
        scores = scores.unsqueeze(-1)
        sorted = scores.sort(descending=True, dim=1)[0]
        pairwise_diff = (scores.transpose(1, 2) - sorted).abs().pow(self.pow).neg() / self.tau
        P_hat = pairwise_diff.softmax(-1)

        if self.hard:
            P = torch.zeros_like(P_hat, device=P_hat.device)
            P.scatter_(-1, P_hat.topk(1, -1)[1], value=1)
            P_hat = (P - P_hat).detach() + P_hat
        return P_hat

In [7]:
import torch
keys = torch.randn(2,1,4)
x = torch.randn(2, 5, 4)

In [8]:
keys

tensor([[[-1.1482,  0.2828, -0.4154, -1.1326]],

        [[ 1.2830,  2.1231, -0.2014, -0.9114]]])

In [9]:
x

tensor([[[ 0.4553, -0.6099, -0.6173,  1.0400],
         [ 1.4106,  0.6913, -0.3748, -0.3752],
         [-0.5062, -1.0152,  1.5643,  0.2343],
         [-0.9921,  1.2435,  1.6380,  1.2160],
         [-0.5039, -0.7539, -1.5916,  1.4486]],

        [[ 1.3329,  0.1666,  2.4799, -2.7505],
         [ 0.3777,  1.2816,  0.6190,  0.1405],
         [ 0.5517, -0.2661, -0.7200, -1.8182],
         [ 0.0213,  0.1991,  1.0101,  0.8899],
         [-0.2186, -1.7436,  1.1259,  2.1794]]])

In [11]:
ss = SoftSort(hard=True)
dot_prod = torch.matmul(x, keys.transpose(1,2)).squeeze(2)
dot_prod

tensor([[-1.6165, -0.8435, -0.6210, -0.5669, -0.6141],
        [ 4.0713,  2.9528,  1.9451, -0.5644, -6.1952]])

In [13]:
-dot_prod

tensor([[ 1.6165,  0.8435,  0.6210,  0.5669,  0.6141],
        [-4.0713, -2.9528, -1.9451,  0.5644,  6.1952]])

In [16]:
perm = ss(-dot_prod)
perm

tensor([[[1., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 1.],
         [0., 0., 0., 1., 0.],
         [0., 0., 1., 0., 0.],
         [0., 1., 0., 0., 0.],
         [1., 0., 0., 0., 0.]]])

In [20]:
torch.einsum('blk,bk->bl', perm, dot_prod)

tensor([[-1.6165, -0.8435, -0.6210, -0.6141, -0.5669],
        [-6.1952, -0.5644,  1.9451,  2.9528,  4.0713]])

In [18]:
print(x)
print(keys)

tensor([[[ 0.4553, -0.6099, -0.6173,  1.0400],
         [ 1.4106,  0.6913, -0.3748, -0.3752],
         [-0.5062, -1.0152,  1.5643,  0.2343],
         [-0.9921,  1.2435,  1.6380,  1.2160],
         [-0.5039, -0.7539, -1.5916,  1.4486]],

        [[ 1.3329,  0.1666,  2.4799, -2.7505],
         [ 0.3777,  1.2816,  0.6190,  0.1405],
         [ 0.5517, -0.2661, -0.7200, -1.8182],
         [ 0.0213,  0.1991,  1.0101,  0.8899],
         [-0.2186, -1.7436,  1.1259,  2.1794]]])
tensor([[[-1.1482,  0.2828, -0.4154, -1.1326]],

        [[ 1.2830,  2.1231, -0.2014, -0.9114]]])


In [17]:
print(perm.shape)

torch.Size([2, 5, 5])


In [19]:
torch.einsum('blk,bkd->bld', perm, x)

tensor([[[ 0.4553, -0.6099, -0.6173,  1.0400],
         [ 1.4106,  0.6913, -0.3748, -0.3752],
         [-0.5062, -1.0152,  1.5643,  0.2343],
         [-0.5039, -0.7539, -1.5916,  1.4486],
         [-0.9921,  1.2435,  1.6380,  1.2160]],

        [[-0.2186, -1.7436,  1.1259,  2.1794],
         [ 0.0213,  0.1991,  1.0101,  0.8899],
         [ 0.5517, -0.2661, -0.7200, -1.8182],
         [ 0.3777,  1.2816,  0.6190,  0.1405],
         [ 1.3329,  0.1666,  2.4799, -2.7505]]])

In [21]:
torch.einsum('bkl,bdl->bdk', perm, x)

RuntimeError: einsum(): subscript l has size 4 for operand 1 which does not broadcast with previously seen size 5

In [24]:
a = torch.Tensor([[[1., 2., 3., 4.],
                   [3., 4., 5., 7.]]])
importance = torch.Tensor([[2,1,3,0]])

In [27]:
ss(-importance)

tensor([[[0., 0., 0., 1.],
         [0., 1., 0., 0.],
         [1., 0., 0., 0.],
         [0., 0., 1., 0.]]])

In [31]:
print(torch.einsum('bkl,bl->bk', ss(-importance), importance))
torch.einsum('bkl,bdl->bdk', ss(-importance), a)

tensor([[0., 1., 2., 3.]])


tensor([[[4., 2., 1., 3.],
         [7., 4., 3., 5.]]])

### stage 3 reorder

In [None]:
import torch
from torchvision import transforms
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F
import seaborn as sns
import ast
import torch.nn as nn
from timm.models.registry import register_model
import math
from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
from timm.models._builder import resolve_pretrained_cfg
try:
    from timm.models._builder import _update_default_kwargs as update_args
except:
    from timm.models._builder import _update_default_model_kwargs as update_args
from timm.models.vision_transformer import Mlp, PatchEmbed
# import torchvision.transforms.functional as T
from torchvision import transforms as T
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
import torch.nn.functional as F
from einops import rearrange, repeat
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
# from .registry import register_pip_model
from pathlib import Path
from sklearn.decomposition import PCA
import os
from torch import Tensor

class SoftSort(torch.nn.Module):
    def __init__(self, tau=1.0, hard=False, pow=1.0):
        super(SoftSort, self).__init__()
        self.hard = hard
        self.tau = tau
        self.pow = pow

    def forward(self, scores: Tensor):
        """
        scores: elements to be sorted. Typical shape: batch_size x n
        """
        scores = scores.unsqueeze(-1)
        sorted = scores.sort(descending=True, dim=1)[0]
        pairwise_diff = (scores.transpose(1, 2) - sorted).abs().pow(self.pow).neg() / self.tau
        P_hat = pairwise_diff.softmax(-1)

        if self.hard:
            P = torch.zeros_like(P_hat, device=P_hat.device)
            P.scatter_(-1, P_hat.topk(1, -1)[1], value=1)
            P_hat = (P - P_hat).detach() + P_hat
        return P_hat


def _cfg(url='', **kwargs):
    return {'url': url,
            'num_classes': 1000,
            'input_size': (3, 224, 224),
            'pool_size': None,
            'crop_pct': 0.875,
            'interpolation': 'bicubic',
            'fixed_input_size': True,
            'mean': (0.485, 0.456, 0.406),
            'std': (0.229, 0.224, 0.225),
            **kwargs
            }


default_cfgs = {
    'mamba_vision_T': _cfg(url='https://huggingface.co/nvidia/MambaVision-T-1K/resolve/main/mambavision_tiny_1k.pth.tar',
                           crop_pct=1.0,
                           input_size=(3, 224, 224),
                           crop_mode='center'),
    'mamba_vision_T2': _cfg(url='https://huggingface.co/nvidia/MambaVision-T2-1K/resolve/main/mambavision_tiny2_1k.pth.tar',
                            crop_pct=0.98,
                            input_size=(3, 224, 224),
                            crop_mode='center'),
    'mamba_vision_S': _cfg(url='https://huggingface.co/nvidia/MambaVision-S-1K/resolve/main/mambavision_small_1k.pth.tar',
                           crop_pct=0.93,
                           input_size=(3, 224, 224),
                           crop_mode='center'),
    'mamba_vision_B': _cfg(url='https://huggingface.co/nvidia/MambaVision-B-1K/resolve/main/mambavision_base_1k.pth.tar',
                           crop_pct=1.0,
                           input_size=(3, 224, 224),
                           crop_mode='center'),
    'mamba_vision_L': _cfg(url='https://huggingface.co/nvidia/MambaVision-L-1K/resolve/main/mambavision_large_1k.pth.tar',
                           crop_pct=1.0,
                           input_size=(3, 224, 224),
                           crop_mode='center'),
    'mamba_vision_L2': _cfg(url='https://huggingface.co/nvidia/MambaVision-L2-1K/resolve/main/mambavision_large2_1k.pth.tar',
                            crop_pct=1.0,
                            input_size=(3, 224, 224),
                            crop_mode='center')                                
}


def window_partition(x, window_size):
    """
    Args:
        x: (B, C, H, W)
        window_size: window size
        h_w: Height of window
        w_w: Width of window
    Returns:
        local window features (num_windows*B, window_size*window_size, C)
    """
    B, C, H, W = x.shape
    x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
    windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: local window features (num_windows*B, window_size, window_size, C)
        window_size: Window size
        H: Height of image
        W: Width of image
    Returns:
        x: (B, C, H, W)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.reshape(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], H, W)
    return x


def _load_state_dict(module, state_dict, strict=False, logger=None):
    """Load state_dict to a module.

    This method is modified from :meth:`torch.nn.Module.load_state_dict`.
    Default value for ``strict`` is set to ``False`` and the message for
    param mismatch will be shown even if strict is False.

    Args:
        module (Module): Module that receives the state_dict.
        state_dict (OrderedDict): Weights.
        strict (bool): whether to strictly enforce that the keys
            in :attr:`state_dict` match the keys returned by this module's
            :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
        logger (:obj:`logging.Logger`, optional): Logger to log the error
            message. If not specified, print function will be used.
    """
    unexpected_keys = []
    all_missing_keys = []
    err_msg = []

    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata
    
    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(
            prefix[:-1], {})
        module._load_from_state_dict(state_dict, prefix, local_metadata, True,
                                     all_missing_keys, unexpected_keys,
                                     err_msg)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')

    load(module)
    load = None
    missing_keys = [
        key for key in all_missing_keys if 'num_batches_tracked' not in key
    ]

    if unexpected_keys:
        err_msg.append('unexpected key in source '
                       f'state_dict: {", ".join(unexpected_keys)}\n')
    if missing_keys:
        err_msg.append(
            f'missing keys in source state_dict: {", ".join(missing_keys)}\n')

    
    if len(err_msg) > 0:
        err_msg.insert(
            0, 'The model and loaded state dict do not match exactly\n')
        err_msg = '\n'.join(err_msg)
        if strict:
            raise RuntimeError(err_msg)
        elif logger is not None:
            logger.warning(err_msg)
        else:
            print(err_msg)


def _load_checkpoint(model,
                    filename,
                    map_location='cpu',
                    strict=False,
                    logger=None):
    """Load checkpoint from a file or URI.

    Args:
        model (Module): Module to load checkpoint.
        filename (str): Accept local filepath, URL, ``torchvision://xxx``,
            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
            details.
        map_location (str): Same as :func:`torch.load`.
        strict (bool): Whether to allow different params for the model and
            checkpoint.
        logger (:mod:`logging.Logger` or None): The logger for error message.

    Returns:
        dict or OrderedDict: The loaded checkpoint.
    """
    checkpoint = torch.load(filename, map_location=map_location)
    if not isinstance(checkpoint, dict):
        raise RuntimeError(
            f'No state_dict found in checkpoint file {filename}')
    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    elif 'model' in checkpoint:
        state_dict = checkpoint['model']
    else:
        state_dict = checkpoint
    if list(state_dict.keys())[0].startswith('module.'):
        state_dict = {k[7:]: v for k, v in state_dict.items()}

    if sorted(list(state_dict.keys()))[0].startswith('encoder'):
        state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}

    _load_state_dict(model, state_dict, strict, logger)
    return checkpoint


class Downsample(nn.Module):
    """
    Down-sampling block"
    """

    def __init__(self,
                 dim,
                 keep_dim=False,
                 ):
        """
        Args:
            dim: feature size dimension.
            norm_layer: normalization layer.
            keep_dim: bool argument for maintaining the resolution.
        """

        super().__init__()
        if keep_dim:
            dim_out = dim
        else:
            dim_out = 2 * dim
        self.reduction = nn.Sequential(
            nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False),
        )

    def forward(self, x):
        x = self.reduction(x)
        return x


class PatchEmbed(nn.Module):
    """
    Patch embedding block"
    """

    def __init__(self, in_chans=3, in_dim=64, dim=96):
        """
        Args:
            in_chans: number of input channels.
            dim: feature size dimension.
        """
        # in_dim = 1
        super().__init__()
        self.proj = nn.Identity()
        self.conv_down = nn.Sequential(
            nn.Conv2d(in_chans, in_dim, 3, 2, 1, bias=False),
            nn.BatchNorm2d(in_dim, eps=1e-4),
            nn.ReLU(),
            nn.Conv2d(in_dim, dim, 3, 2, 1, bias=False),
            nn.BatchNorm2d(dim, eps=1e-4),
            nn.ReLU()
            )

    def forward(self, x):
        x = self.proj(x)
        x = self.conv_down(x)
        return x

# [128, 80, 56, 56]
class ConvBlock(nn.Module):

    def __init__(self, dim,
                 drop_path=0.,
                 layer_scale=None,
                 kernel_size=3):
        super().__init__()

        self.conv1 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
        self.norm1 = nn.BatchNorm2d(dim, eps=1e-5)
        self.act1 = nn.GELU(approximate= 'tanh')
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
        self.norm2 = nn.BatchNorm2d(dim, eps=1e-5)
        self.layer_scale = layer_scale
        if layer_scale is not None and type(layer_scale) in [int, float]:
            self.gamma = nn.Parameter(layer_scale * torch.ones(dim))
            self.layer_scale = True
        else:
            self.layer_scale = False
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x # torch.Size([128, 80, 56, 56])
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.norm2(x)
        if self.layer_scale:
            x = x * self.gamma.view(1, -1, 1, 1)
        x = input + self.drop_path(x)
        return x # torch.Size([128, 80, 56, 56])


class MambaVisionMixer(nn.Module):
    def __init__(
        self,
        d_model,
        d_state=16,
        d_conv=4,
        expand=2,
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        conv_bias=True,
        bias=False,
        use_fast_path=True, 
        layer_idx=None,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
        self.use_fast_path = use_fast_path
        self.layer_idx = layer_idx
        self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs)    
        self.x_proj = nn.Linear(
            self.d_inner//2, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
        )
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner//2, bias=True, **factory_kwargs)
        dt_init_std = self.dt_rank**-0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(self.dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError
        dt = torch.exp(
            torch.rand(self.d_inner//2, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)
        self.dt_proj.bias._no_reinit = True
        A = repeat(
            torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=self.d_inner//2,
        ).contiguous()
        A_log = torch.log(A)
        self.A_log = nn.Parameter(A_log)
        self.A_log._no_weight_decay = True
        self.D = nn.Parameter(torch.ones(self.d_inner//2, device=device))
        self.D._no_weight_decay = True
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
        self.conv1d_x = nn.Conv1d(
            in_channels=self.d_inner//2,
            out_channels=self.d_inner//2,
            bias=conv_bias//2,
            kernel_size=d_conv,
            groups=self.d_inner//2,
            **factory_kwargs,
        )
        self.conv1d_z = nn.Conv1d(
            in_channels=self.d_inner//2,
            out_channels=self.d_inner//2,
            bias=conv_bias//2,
            kernel_size=d_conv,
            groups=self.d_inner//2,
            **factory_kwargs,
        )

    def forward(self, hidden_states):
        """
        hidden_states: (B, L, D)
        Returns: same shape as hidden_states
        """
        # import ipdb; ipdb.set_trace()
        _, seqlen, _ = hidden_states.shape
        xz = self.in_proj(hidden_states)
        xz = rearrange(xz, "b l d -> b d l")
        x, z = xz.chunk(2, dim=1)
        A = -torch.exp(self.A_log.float())
        x = F.silu(F.conv1d(input=x, weight=self.conv1d_x.weight, bias=self.conv1d_x.bias, padding='same', groups=self.d_inner//2))
        z = F.silu(F.conv1d(input=z, weight=self.conv1d_z.weight, bias=self.conv1d_z.bias, padding='same', groups=self.d_inner//2))
        x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))
        dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
        dt = rearrange(self.dt_proj(dt), "(b l) d -> b d l", l=seqlen)
        B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
        C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
        y = selective_scan_fn(x, 
                              dt, 
                              A, 
                              B, 
                              C, 
                              self.D.float(), 
                              z=None, 
                              delta_bias=self.dt_proj.bias.float(), 
                              delta_softplus=True, 
                              return_last_state=None)
        
        y = torch.cat([y, z], dim=1)
        y = rearrange(y, "b d l -> b l d")
        out = self.out_proj(y)
        return out
    

class Attention(nn.Module):

    def __init__(
            self,
            dim,
            num_heads=8,
            qkv_bias=False,
            qk_norm=False,
            attn_drop=0.,
            proj_drop=0.,
            norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = True

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        if self.fused_attn:
            x = F.scaled_dot_product_attention(
             q, k, v,
                dropout_p=self.attn_drop.p,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):
    def __init__(self, 
                 dim, 
                 num_heads, 
                 counter, 
                 transformer_blocks, 
                 mlp_ratio=4., 
                 qkv_bias=False, 
                 qk_scale=False, 
                 drop=0., 
                 attn_drop=0.,
                 drop_path=0., 
                 act_layer=nn.GELU, 
                 norm_layer=nn.LayerNorm, 
                 Mlp_block=Mlp,
                 layer_scale=None,
                 ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        if counter in transformer_blocks:
            self.mixer = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
            norm_layer=norm_layer,
        )
        else:
            self.mixer = MambaVisionMixer(d_model=dim, 
                                          d_state=8,  
                                          d_conv=3,    
                                          expand=1
                                          )

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
        self.gamma_1 = nn.Parameter(layer_scale * torch.ones(dim))  if use_layer_scale else 1
        self.gamma_2 = nn.Parameter(layer_scale * torch.ones(dim))  if use_layer_scale else 1

    def forward(self, x):
        x = x + self.drop_path(self.gamma_1 * self.mixer(self.norm1(x)))
        x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x


class MambaVisionLayer(nn.Module):
    """
    MambaVision layer
    """

    def __init__(self,
                 dim,
                 depth,
                 num_heads,
                 window_size,
                 conv=False,
                 downsample=True,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 layer_scale=None,
                 layer_scale_conv=None,
                 transformer_blocks=[],
                 heatmap=[]
    ):
        """
        Args:
            dim: feature size dimension.
            depth: number of layers in each stage.
            window_size: window size in each stage.
            conv: bool argument for conv stage flag.
            downsample: bool argument for down-sampling.
            mlp_ratio: MLP ratio.
            num_heads: number of heads in each stage.
            qkv_bias: bool argument for query, key, value learnable bias.
            qk_scale: bool argument to scaling query, key.
            drop: dropout rate.
            attn_drop: attention dropout rate.
            drop_path: drop path rate.
            norm_layer: normalization layer.
            layer_scale: layer scaling coefficient.
            layer_scale_conv: conv layer scaling coefficient.
            transformer_blocks: list of transformer blocks.
        """

        super().__init__()
        self.heatmap = heatmap
        self.conv = conv
        self.transformer_block = False
        if conv:
            self.blocks = nn.ModuleList([ConvBlock(dim=dim,
                                                   drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                                   layer_scale=layer_scale_conv)
                                         for i in range(depth)])
            self.transformer_block = False
        else:
            self.blocks = nn.ModuleList([Block(dim=dim,
                                               counter=i,
                                               transformer_blocks=transformer_blocks,
                                               num_heads=num_heads,
                                               mlp_ratio=mlp_ratio,
                                               qkv_bias=qkv_bias,
                                               qk_scale=qk_scale,
                                               drop=drop,
                                               attn_drop=attn_drop,
                                               drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                               layer_scale=layer_scale)
                                         for i in range(depth)])
            self.transformer_block = True

        self.downsample = None if not downsample else Downsample(dim=dim)
        self.window_size = window_size

    def forward(self, x):
        _, _, H, W = x.shape  # torch.Size([128, 80, 56, 56])

        if self.transformer_block:
            pad_r = (self.window_size - W % self.window_size) % self.window_size
            pad_b = (self.window_size - H % self.window_size) % self.window_size
            if pad_r > 0 or pad_b > 0:
                x = torch.nn.functional.pad(x, (0, pad_r, 0, pad_b))
                _, _, Hp, Wp = x.shape
            else:
                Hp, Wp = H, W
            x = window_partition(x, self.window_size)  # torch.Size([128, 196, 320])

        heatmaps = []  # Initialize heatmaps list to store activations after each block
        for _, blk in enumerate(self.blocks):
            x = blk(x)
            heatmaps.append(x)  # Store the output of each block

        if self.transformer_block:
            x = window_reverse(x, self.window_size, Hp, Wp)
            if pad_r > 0 or pad_b > 0:
                x = x[:, :, :H, :W].contiguous()

        if self.downsample is None:
            return x, heatmaps  # Return final output and heatmaps
        return self.downsample(x), heatmaps  # Return downsampled output and heatmaps

class MambaVisionLayer_reorder(nn.Module):
    """
    MambaVision layer"
    """

    def __init__(self,
                 dim,
                 depth,
                 num_heads,
                 window_size,
                 conv=False,
                 downsample=True,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 layer_scale=None,
                 layer_scale_conv=None,
                 transformer_blocks = [],
    ):
        """
        Args:
            dim: feature size dimension.
            depth: number of layers in each stage.
            window_size: window size in each stage.
            conv: bool argument for conv stage flag.
            downsample: bool argument for down-sampling.
            mlp_ratio: MLP ratio.
            num_heads: number of heads in each stage.
            qkv_bias: bool argument for query, key, value learnable bias.
            qk_scale: bool argument to scaling query, key.
            drop: dropout rate.
            attn_drop: attention dropout rate.
            drop_path: drop path rate.
            norm_layer: normalization layer.
            layer_scale: layer scaling coefficient.
            layer_scale_conv: conv layer scaling coefficient.
            transformer_blocks: list of transformer blocks.
        """

        super().__init__()
        self.conv = conv
        self.transformer_block = False
        self.heatmap = heatmap
        self.learnable_keys = nn.Parameter(torch.randn(1, 1, dim)) # ensure dimension = 320
        if conv:
            self.blocks = nn.ModuleList([ConvBlock(dim=dim,
                                                   drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                                   layer_scale=layer_scale_conv)
                                                   for i in range(depth)])
            self.transformer_block = False
        else:
            self.blocks = nn.ModuleList([Block(dim=dim,
                                               counter=i, 
                                               transformer_blocks=transformer_blocks,
                                               num_heads=num_heads,
                                               mlp_ratio=mlp_ratio,
                                               qkv_bias=qkv_bias,
                                               qk_scale=qk_scale,
                                               drop=drop,
                                               attn_drop=attn_drop,
                                               drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                               layer_scale=layer_scale)
                                               for i in range(depth)])
            self.transformer_block = True

        self.downsample = None if not downsample else Downsample(dim=dim)
        self.do_gt = False
        self.window_size = window_size
        self.soft_sort = SoftSort(hard=True)

    def forward(self, x):
        B, _, H, W = x.shape

        if self.transformer_block:
            pad_r = (self.window_size - W % self.window_size) % self.window_size
            pad_b = (self.window_size - H % self.window_size) % self.window_size
            if pad_r > 0 or pad_b > 0:
                x = torch.nn.functional.pad(x, (0,pad_r,0,pad_b))
                _, _, Hp, Wp = x.shape
            else:
                Hp, Wp = H, W
            x = window_partition(x, self.window_size)

        heatmaps = []  # Initialize heatmaps list to store activations after each block
        # perm_matrix = None
        for idx, blk in enumerate(self.blocks):
            heatmaps.append(x) # Store the output of each block   
            x = blk(x)
            if idx == 1:
                learn_key = self.learnable_keys.expand(B, -1, -1) # [B, 1, C], x [B, N, C]
                dot_prod = torch.matmul(x, learn_key.transpose(1,2)).squeeze(2) # [B, N]
                # import ipdb; ipdb.set_trace()
                perm_matrix = self.soft_sort(-dot_prod) # [B, N, N]
                x = torch.einsum('blk, bkn -> bln', perm_matrix, x)    
            # if idx == 3 and perm_matrix is not None:
            #     # Apply reverse permutation to restore the original order
            #     perm_matrix_inv = perm_matrix.transpose(1, 2)  # [B, N, N] inverse of the permutation matrix
            #     x = torch.einsum('blk, bkn -> bln', perm_matrix_inv, x)       

            
        if self.transformer_block:
            x = window_reverse(x, self.window_size, Hp, Wp)
            if pad_r > 0 or pad_b > 0:
                x = x[:, :, :H, :W].contiguous()
        if self.downsample is None:
            return x, heatmaps
        return self.downsample(x), heatmaps


class MambaVision(nn.Module):
    """
    MambaVision,
    """

    def __init__(self,
                 dim,
                 in_dim,
                 depths,
                 window_size,
                 mlp_ratio,
                 num_heads,
                 drop_path_rate=0.2,
                 in_chans=3,
                 num_classes=1000,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 layer_scale=None,
                 layer_scale_conv=None,
                 **kwargs):
        """
        Args:
            dim: feature size dimension.
            depths: number of layers in each stage.
            window_size: window size in each stage.
            mlp_ratio: MLP ratio.
            num_heads: number of heads in each stage.
            drop_path_rate: drop path rate.
            in_chans: number of input channels.
            num_classes: number of classes.
            qkv_bias: bool argument for query, key, value learnable bias.
            qk_scale: bool argument to scaling query, key.
            drop_rate: dropout rate.
            attn_drop_rate: attention dropout rate.
            norm_layer: normalization layer.
            layer_scale: layer scaling coefficient.
            layer_scale_conv: conv layer scaling coefficient.
        """
        super().__init__()
        num_features = int(dim * 2 ** (len(depths) - 1))
        self.num_classes = num_classes
        self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim)
        self.drop_path_rate = drop_path_rate
        dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(depths))]
        self.levels = nn.ModuleList()
        self.heatmap = []
        
        for i in range(len(depths)):
            conv = True if (i == 0 or i == 1) else False
            if i == 2:
                level = MambaVisionLayer_reorder(dim=int(dim * 2 ** i),
                                        depth=depths[i],
                                        num_heads=num_heads[i],
                                        window_size=window_size[i],
                                        mlp_ratio=mlp_ratio,
                                        qkv_bias=qkv_bias,
                                        qk_scale=qk_scale,
                                        conv=conv,
                                        drop=drop_rate,
                                        attn_drop=attn_drop_rate,
                                        drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
                                        downsample=(i < 3),
                                        layer_scale=layer_scale,
                                        layer_scale_conv=layer_scale_conv,
                                        transformer_blocks=list(range(depths[i]//2+1, depths[i])) if depths[i]%2!=0 else list(range(depths[i]//2, depths[i])),
                                        )
            else:
                level = MambaVisionLayer(dim=int(dim * 2 ** i),
                                        depth=depths[i],
                                        num_heads=num_heads[i],
                                        window_size=window_size[i],
                                        mlp_ratio=mlp_ratio,
                                        qkv_bias=qkv_bias,
                                        qk_scale=qk_scale,
                                        conv=conv,
                                        drop=drop_rate,
                                        attn_drop=attn_drop_rate,
                                        drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
                                        downsample=(i < 3),
                                        layer_scale=layer_scale,
                                        layer_scale_conv=layer_scale_conv,
                                        transformer_blocks=list(range(depths[i]//2+1, depths[i])) if depths[i]%2!=0 else list(range(depths[i]//2, depths[i])),
                                        )
            # self.heatmap.append(heatmap)
            self.levels.append(level)
        self.norm = nn.BatchNorm2d(num_features)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, LayerNorm2d):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'rpb'}

    def forward_features(self, x):
        all_heatmaps = []
        # print('x_shape = ', x.shape)
        x = self.patch_embed(x) # torch.Size([128, 3, 224, 224]) -> torch.Size([128, 160, 28, 28])
        for level in self.levels:
            # print(x)
            x, heatmaps = level(x)
            all_heatmaps.extend(heatmaps)  # Aggregate heatmaps from each layer
        # torch.Size([128, 640, 7, 7])
        x = self.norm(x)
        x = self.avgpool(x) # torch.Size([128, 640, 1, 1])
        x = torch.flatten(x, 1) # torch.Size([128, 640])
        return x, all_heatmaps

    def forward(self, x):
        x, headmap = self.forward_features(x)
        x = self.head(x)
        # import ipdb; ipdb.set_trace()
        return x, headmap

    def _load_state_dict(self, 
                         pretrained, 
                         strict: bool = False):
        _load_checkpoint(self, 
                         pretrained, 
                         strict=strict)

def run_PCA(X):
    pca = PCA(n_components=3)
    res = pca.fit_transform(X)
    return res

# Function to generate a unique file name
def generate_filename(directory, base_name, extension):
    counter = 0
    while True:
        if counter == 0:
            file_name = f"{base_name}.{extension}"
        else:
            file_name = f"{base_name}_{counter}.{extension}"
        file_path = os.path.join(directory, file_name)
        if not os.path.exists(file_path):
            return file_path
        counter += 1


if __name__ == "__main__":

    # Model initialization settings (same as your code)
    kwargs = {'pretrained_cfg': None, 'pretrained_cfg_overlay': None, 
              'in_chans': 3, 'num_classes': 1000, 'img_size': (224, 224)}
    
    # checkpoint_path = "/home/ubuntu/workspace/mambavision_1/mambavision/model_weights/mambavision_tiny_1k.pth.tar"
    # checkpoint_path = "/home/ubuntu/workspace/mambavision_1/mambavision/model_weights/checkpoint-4.pth.tar"
    checkpoint_path = "/home/ubuntu/workspace/mambavision_1/test/weights/checkpoint-0.pth.tar"
    image_path = "/home/ubuntu/workspace/mambavision_1/test/bear.jpg"
    heatmap = []
    # Directory where the file will be saved
    directory = './bear'
    base_name = 'pca_interpolate_order2'
    extension = 'jpg'
    # Generate a unique file name
    file_path = generate_filename(directory, base_name, extension)
    
    # Model initialization (same as your code)
    model = MambaVision(depths=[1, 3, 8, 4],
                        num_heads=[2, 4, 8, 16],
                        window_size=[8, 8, 14, 7],
                        dim=80,
                        in_dim=32,
                        mlp_ratio=4,
                        resolution=224,
                        drop_path_rate=0.2, 
                        **kwargs)

    # Load the checkpoint and move the model to the correct device (same as your code)
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint['state_dict'])
    print("Model loaded successfully.")

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Preprocess the input image (same as your code)
    image = Image.open(image_path)
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    input_tensor = preprocess(image).unsqueeze(0).to(device)

    # Capture model activations (same as your code)
    activations = []
    outputs = {}
    def hook_fn(module, input, output):
        outputs[module] = output.detach()

    for layer in model.children():
        layer.register_forward_hook(hook_fn)

    model.eval()
    with torch.no_grad():
        output, heatmaps = model(input_tensor)

    # Reshape heatmaps
    heatmap_reshape = []
    for heatmap in heatmaps:
        # print(heatmap.shape)
        # Check if the heatmap has 3 dimensions (batch, height*width, feature_dim)
        if len(heatmap.shape) == 4:
            batch_size, feature_dim, side_length, side_length = heatmap.shape
            # Try reshaping the feature maps into square-like dimensions
            heatmap_res = heatmap.view(batch_size, side_length * side_length, feature_dim)
            heatmap_reshape.append(heatmap_res)
        else:
            heatmap_reshape.append(heatmap)

    # PCA for each heatmap and resize using T.interpolate
    pca_images = []
    for i, heatmap in enumerate(heatmap_reshape):
        # print(f"Heatmap {i}: Shape = {heatmap.shape}")
        v = heatmap.squeeze(0)  # Remove batch dimension

        # Run PCA along the feature dimension (the last dimension)
        img = run_PCA(v.cpu())
        # Reshape to square for interpolation
        h = w = int(math.sqrt(img.shape[0])) #  Adjust this line
        img = img.reshape(h, w, -1)
        min = img.min(axis=(0, 1))
        max = img.max(axis=(0, 1))
        img = (img - min) / (max - min)
        img = Image.fromarray((img * 255).astype(np.uint8))
        img = T.Resize((224,224), interpolation=T.InterpolationMode.NEAREST)(img)
        pca_images.append(img)

    # Convert the initial image to [224, 224] format for consistency
    initial_image_resized = image.resize((224, 224))
    initial_image_np = np.array(initial_image_resized)

    # Calculate the total number of images (1 initial + number of heatmaps)
    total_images = len(pca_images) + 1

    # Calculate the grid size for subplots
    cols = 5
    rows = (total_images + cols - 1) // cols  # Ensure enough rows to fit all images

    # Plot the initial image and resized PCA-reduced images in a single figure
    fig, axs = plt.subplots(rows, cols, figsize=(12, 12))  # Adjust depending on the number of heatmaps
    axs = axs.ravel()

    # Show the original image in the first subplot
    init = axs[0].imshow(initial_image_np)
    axs[0].axis('off')
    axs[0].set_title("Initial Image")
    fig.colorbar(init, ax=axs[0], orientation='vertical')
    # Show the PCA heatmaps in the remaining subplots
    for i in range(1, total_images):
        # pca_images[i-1] = (pca_images[i-1] - pca_images[i-1].min())/(pca_images[i-1].max() - pca_images[i-1].min()).astype(int)
        im = axs[i].imshow(pca_images[i - 1], cmap='jet')
        axs[i].axis('off')
        axs[i].set_title(f"Heatmap {i-1}")
        # Add a color bar next to each heatmap
        

    # Hide unused subplots if any
    for j in range(total_images, len(axs)):
        axs[j].axis('off')

    plt.tight_layout()
    plt.savefig(file_path)
    plt.show()
        


### stage 3 + 4 reorder

In [1]:
import torch
x = torch.randn(2,3,4)
keys = torch.randn(2,3,1)

In [4]:
print(x)

tensor([[[ 0.6738,  0.5989,  0.7617, -0.7375],
         [-1.1518, -0.8145, -1.1547,  0.2358],
         [-1.5809, -0.2918, -0.8664,  0.3996]],

        [[ 0.2251, -1.5289, -0.4852,  0.8026],
         [ 0.2098, -0.5337,  0.2852, -0.6550],
         [-0.5836,  1.1107,  0.0136,  0.0351]]])


In [5]:
print(keys)

tensor([[[ 0.5900],
         [-0.1405],
         [-0.5588]],

        [[-1.9241],
         [ 1.8041],
         [ 0.2175]]])


In [12]:
x_params = x*keys
-x_params

tensor([[[-0.3975, -0.3533, -0.4494,  0.4351],
         [-0.1619, -0.1145, -0.1623,  0.0331],
         [-0.8835, -0.1631, -0.4842,  0.2233]],

        [[ 0.4332, -2.9418, -0.9336,  1.5442],
         [-0.3784,  0.9628, -0.5146,  1.1817],
         [ 0.1269, -0.2416, -0.0030, -0.0076]]])

In [11]:
torch.topk(-x_params, k=3, dim=1)

torch.return_types.topk(
values=tensor([[[-0.1619, -0.1145, -0.1623,  0.4351],
         [-0.3975, -0.1631, -0.4494,  0.2233],
         [-0.8835, -0.3533, -0.4842,  0.0331]],

        [[ 0.4332,  0.9628, -0.0030,  1.5442],
         [ 0.1269, -0.2416, -0.5146,  1.1817],
         [-0.3784, -2.9418, -0.9336, -0.0076]]]),
indices=tensor([[[1, 1, 1, 0],
         [0, 2, 0, 2],
         [2, 0, 2, 1]],

        [[0, 1, 2, 0],
         [2, 2, 1, 1],
         [1, 0, 0, 2]]]))

In [15]:
x,_ = torch.tensor([3,2,1])
torch.topk(-x, k=3)

torch.return_types.topk(
values=tensor([-1, -2, -3]),
indices=tensor([2, 1, 0]))