In [1]:
import torch
print("Version PyTorch :", torch.__version__)
print("Version CUDA utilisée par PyTorch :", torch.version.cuda)

print(torch.cuda.is_available())
print(torch.backends.cudnn.version())

Version PyTorch : 2.1.2
Version CUDA utilisée par PyTorch : 11.8
True
8700


In [2]:
##########################################
############ TEST MAMBA LAYER  ###########
##########################################

import torch
from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
print("input", x.shape)
print("output", y.shape)

input torch.Size([2, 64, 16])
output torch.Size([2, 64, 16])


In [3]:
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence

import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import trunc_normal_

from mmpretrain.registry import MODELS
from mmpretrain.models.utils import (MultiheadAttention, SwiGLUFFNFused, build_norm_layer,
                     resize_pos_embed, to_2tuple)
from mmpretrain.models.backbones.base_backbone import BaseBackbone
from mamba_ssm import Mamba
from torch import Tensor
from typing import Optional
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
from functools import partial
#from .ssm2d import Block2D,Mamba2D,SplitHead2D
from einops import rearrange
# from .mamband import MambaND
from mmpretrain.models.utils.embed import PatchMerging
from prettytable import PrettyTable

Block2D=SplitHead2D=nn.Identity # TODO: Clan implementation and release
MambaND = nn.Identity

from mmcv.cnn.bricks.drop import build_dropout
torch.set_printoptions(precision=4)

class Block(nn.Module):
    def __init__(
        self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False,reverse=False,
        transpose=False,split_head=False,
        drop_path_rate=0.0,drop_rate=0.0,use_mlp=False,downsample=False,
    ):
        """
        Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"

        This Block has a slightly different structure compared to a regular
        prenorm Transformer block.
        The standard block is: LN -> MHA/MLP -> Add.
        [Ref: https://arxiv.org/abs/2002.04745]
        Here we have: Add -> LN -> Mixer, returning both
        the hidden_states (output of the mixer) and the residual.
        This is purely for performance reasons, as we can fuse add and LayerNorm.
        The residual needs to be provided (except for the very first block).
        """
        super().__init__()
        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm
        self.mixer = mixer_cls(dim)
        self.norm = norm_cls(dim)
        self.split_head = split_head
        self.reverse = reverse
        self.transpose = transpose
        self.drop_path = build_dropout(
            dict(type='DropPath', drop_prob=drop_path_rate)
        )
        self.dropout = build_dropout(
            dict(type='Dropout', drop_prob=drop_rate)
        )
        self.downsample = downsample
        if downsample:
            self.down_sample_layer = PatchMerging(
                dim,dim if self.split_head else dim,
            )
        if use_mlp:
            self.ffn = SwiGLUFFNFused(
                    embed_dims=dim,
                    feedforward_channels=int(dim*4),
                    layer_scale_init_value=0.0)
            self.ln2 = build_norm_layer(dict(type='LN'), dim)
        else:
            self.ffn = None
        if self.fused_add_norm:
            assert RMSNorm is not None, "RMSNorm import fails"
            assert isinstance(
                self.norm, (nn.LayerNorm, RMSNorm)
            ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"

    def forward(
        self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None,skip=True,**kwargs
    ):
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states: the sequence to the encoder layer (required).
            residual: hidden_states = Mixer(LN(residual))
        """
        
        # print("Input Tensor :", hidden_states)
        # print("shape : ", hidden_states.shape)
        
        h = w = 0
        if self.transpose:
            l = hidden_states.shape[1]
            h = w = int(np.sqrt(l))
            # assert h * w == l
            hidden_states = rearrange(hidden_states,'n (h w) c -> n (w h) c',h=h,w=w)
            #print("hidden_states transpose:", hidden_states)
            if residual is not None:
                residual = rearrange(residual,'n (h w) c -> n (w h) c',h=h,w=w)
        if self.reverse:
            hidden_states = hidden_states.flip(1)
            #print("hidden_states reverse:", hidden_states)
            if residual is not None:
                residual = residual.flip(1)
        if not self.fused_add_norm:
            hidden_states = self.norm(hidden_states)
            # residual = (hidden_states + residual) if residual is not None else hidden_states
            # hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
            # if self.residual_in_fp32:
            #     residual = residual.to(torch.float32)
            if self.split_head:
                l = hidden_states.shape[1]
                h = w = int(np.sqrt(l))
                hidden_states = SplitHead2D.apply(hidden_states,4,h,w)
            if skip:
                hidden_states = hidden_states + self.drop_path(self.mixer(hidden_states, inference_params=inference_params,**(kwargs if isinstance(self.mixer,MambaND) else {})))
            else:
                hidden_states = self.drop_path(self.dropout(self.mixer(hidden_states, inference_params=inference_params)))
            if self.split_head:
                hidden_states = SplitHead2D.apply(hidden_states,4,h,w)
        else:
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
            hidden_states, residual = fused_add_norm_fn(
                hidden_states,
                self.norm.weight,
                self.norm.bias,
                residual=residual,
                prenorm=True,
                residual_in_fp32=self.residual_in_fp32,
                eps=self.norm.eps,
            )
            hidden_states = self.drop_path(self.mixer(hidden_states, inference_params=inference_params,**kwargs))
        if self.ffn is not None:
            hidden_states = self.ffn(self.ln2(hidden_states),identity=hidden_states)
        if self.reverse:
            hidden_states = hidden_states.flip(1)
            if residual is not None:
                residual = residual.flip(1)
        if self.transpose:
            hidden_states = rearrange(hidden_states,'n (w h) c -> n (h w) c',h=h,w=w)
            #print("Rearrange parcours 1")
            if residual is not None:
                residual = rearrange(residual,'n (w h) c -> n (h w) c',h=h,w=w)
        if self.downsample:
            if 'h' in kwargs:
                h,w = kwargs['h'],kwargs['w']
            hidden_states,(h,w) = self.down_sample_layer(
                hidden_states,(h,w)
            )
            assert residual is None
            residual = (h,w)

        # print(f"Output Tensor : {hidden_states}")
        # print("shape : ", hidden_states.shape)
        return hidden_states, residual

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

import torch.nn.functional as F
from mamba_ssm.ops.selective_scan_interface import selective_scan_ref
def  causal_conv1d_fn_ref(x, weight, bias=None,seq_idx=None, activation=None):
    """
    x: (batch, dim, seqlen)
    weight: (dim, width)
    bias: (dim,)
    activation: either None or "silu" or "swish"

    out: (batch, dim, seqlen)
    """
    seqlen=x.shape[-1]
    dim,width = weight.shape
    assert activation =='silu'
    x = F.conv1d(x, weight.unsqueeze(1), bias, padding=width.item() - 1, groups=dim.item())[..., :seqlen]
    return F.silu(x)

def mamba_inner_ref(
    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
    out_proj_weight, out_proj_bias,
    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
    C_proj_bias=None, delta_softplus=True
):
    L = xz.shape[-1]
    delta_rank = delta_proj_weight.shape[1]
    d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
    x, z = xz.chunk(2, dim=1)
    x = causal_conv1d_fn_ref(x, rearrange(conv1d_weight, "d 1 w -> d w"), bias=conv1d_bias,seq_idx=None, activation="silu")
    # We're being very careful here about the layout, to avoid extra transposes.
    # We want delta to have d as the slowest moving dimension
    # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
    x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
    delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
    delta = rearrange(delta, "d (b l) -> b d l", l=L)
    if B is None:  # variable B
        B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl d)
        if B_proj_bias is not None:
            B = B + B_proj_bias.to(dtype=B.dtype)
        if not A.is_complex():
            B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
        else:
            B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
    if C is None:  # variable B
        C = x_dbl[:, -d_state:]  # (bl d)
        if C_proj_bias is not None:
            C = C + C_proj_bias.to(dtype=C.dtype)
        if not A.is_complex():
            C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
        else:
            C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
    y = selective_scan_ref(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
    return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)

# for FLOPS calc only
class Mamba_Ref(Mamba):

    def forward(self, hidden_states, inference_params=None):
        """
        hidden_states: (B, L, D)
        Returns: same shape as hidden_states
        """
        batch, seqlen, dim = hidden_states.shape

        conv_state, ssm_state = None, None
        if inference_params is not None:
            conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
            if inference_params.seqlen_offset > 0:
                # The states are updated inplace
                out, _, _ = self.step(hidden_states, conv_state, ssm_state)
                return out

        # We do matmul and transpose BLH -> HBL at the same time
        xz = rearrange(
            self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
            "d (b l) -> b d l",
            l=seqlen,
        )
        if self.in_proj.bias is not None:
            xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")

        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)
        # In the backward pass we write dx and dz next to each other to avoid torch.cat
        if self.use_fast_path and inference_params is None:  # Doesn't support outputting the states
            out = mamba_inner_ref(
                xz,
                self.conv1d.weight,
                self.conv1d.bias,
                self.x_proj.weight,
                self.dt_proj.weight,
                self.out_proj.weight,
                self.out_proj.bias,
                A,
                None,  # input-dependent B
                None,  # input-dependent C
                self.D.float(),
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
            )
        else:
            x, z = xz.chunk(2, dim=1)
            # Compute short convolution
            if conv_state is not None:
                # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
                # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
                conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))  # Update state (B D W)
            if causal_conv1d_fn is None:
                x = self.act(self.conv1d(x)[..., :seqlen])
            else:
                assert self.activation in ["silu", "swish"]
                x = causal_conv1d_fn(
                    x=x,
                    weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
                    bias=self.conv1d.bias,
                    activation=self.activation,
                )

            # We're careful here about the layout, to avoid extra transposes.
            # We want dt to have d as the slowest moving dimension
            # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
            x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # (bl d)
            dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
            dt = self.dt_proj.weight @ dt.t()
            dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
            B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
            C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
            assert self.activation in ["silu", "swish"]
            y = selective_scan_ref(
                x,
                dt,
                A,
                B,
                C,
                self.D.float(),
                z=z,
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
                return_last_state=ssm_state is not None,
            )
            if ssm_state is not None:
                y, last_state = y
                ssm_state.copy_(last_state)
            y = rearrange(y, "b d l -> b l d")
            out = self.out_proj(y)
        return out


def create_block(
    d_model,
    ssm_cfg=None,
    norm_epsilon=1e-5,
    rms_norm=False,
    residual_in_fp32=False,
    fused_add_norm=False,
    layer_idx=None,
    device=None,
    dtype=None,
    reverse=None,
    is_2d=False,
    drop_rate=0.1,
    drop_path_rate=0.1,
    use_mlp=False,
    transpose=False,
    split_head=False,
    use_nd=False,
    downsample=False,
    use_ref=False,
    n_dim=2,
):
    if ssm_cfg is None:
        ssm_cfg = {}
    factory_kwargs = {"device": device, "dtype": dtype}
    
    if use_nd: 
        #ssm_cfg['d_state'] *= 4
        transpose = False
        reverse = False
        mixer_cls = partial(MambaND , layer_idx=layer_idx, n_dim=n_dim,**ssm_cfg, **factory_kwargs)
    elif use_ref:
        mixer_cls = partial(Mamba_Ref, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
    else:
        mixer_cls = partial(Mamba2D if is_2d else Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
    norm_cls = partial(
        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
    )
    if is_2d:
        block = Block2D(
            d_model,
            mixer_cls,
            norm_cls=norm_cls,
            fused_add_norm=fused_add_norm,
            residual_in_fp32=residual_in_fp32,
            reverse=reverse,
            drop_rate=drop_rate,
            transpose=transpose,
            drop_path_rate=drop_path_rate,
        )
    else:
        block = Block(
            d_model,
            mixer_cls,
            norm_cls=norm_cls,
            fused_add_norm=fused_add_norm,
            residual_in_fp32=residual_in_fp32,
            reverse=reverse,
            transpose=transpose,
            drop_rate=drop_rate,
            use_mlp=use_mlp,
            drop_path_rate=drop_path_rate,
            split_head=split_head,
            downsample=downsample,
        )
    block.layer_idx = layer_idx
    return block 
    


In [4]:
@MODELS.register_module()
class Mamba2DModel(BaseBackbone):
    """Vision Transformer.

    A PyTorch implement of : `An Image is Worth 16x16 Words: Transformers
    for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_

    Args:
        arch (str | dict): Vision Transformer architecture. If use string,
            choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
            and 'deit-base'. If use dict, it should have below keys:

            - **embed_dims** (int): The dimensions of embedding.
            - **num_layers** (int): The number of transformer encoder layers.
            - **num_heads** (int): The number of heads in attention modules.
            - **feedforward_channels** (int): The hidden dimensions in
              feedforward modules.

            Defaults to 'base'.
        img_size (int | tuple): The expected input image shape. Because we
            support dynamic input shape, just set the argument to the most
            common input image shape. Defaults to 224.
        patch_size (int | tuple): The patch size in patch embedding.
            Defaults to 16.
        in_channels (int): The num of input channels. Defaults to 3.
        out_indices (Sequence | int): Output from which stages.
            Defaults to -1, means the last stage.
        drop_rate (float): Probability of an element to be zeroed.
            Defaults to 0.
        drop_path_rate (float): stochastic depth rate. Defaults to 0.
        qkv_bias (bool): Whether to add bias for qkv in attention modules.
            Defaults to True.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='LN')``.
        final_norm (bool): Whether to add a additional layer to normalize
            final feature map. Defaults to True.
        out_type (str): The type of output features. Please choose from

            - ``"cls_token"``: The class token tensor with shape (B, C).
            - ``"featmap"``: The feature map tensor from the patch tokens
              with shape (B, C, H, W).
            - ``"avg_featmap"``: The global averaged feature map tensor
              with shape (B, C).
            - ``"raw"``: The raw feature tensor includes patch tokens and
              class tokens with shape (B, L, C).

            Defaults to ``"cls_token"``.
        with_cls_token (bool): Whether concatenating class token into image
            tokens as transformer input. Defaults to True.
        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
            -1 means not freezing any parameters. Defaults to -1.
        interpolate_mode (str): Select the interpolate mode for position
            embeding vector resize. Defaults to "bicubic".
        layer_scale_init_value (float or torch.Tensor): Init value of layer
            scale. Defaults to 0.
        patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
        layer_cfgs (Sequence | dict): Configs of each transformer layer in
            encoder. Defaults to an empty dict.
        init_cfg (dict, optional): Initialization config dict.
            Defaults to None.
    """
    arch_zoo = {
        **dict.fromkeys(
            ['small'], {
                'embed_dims': 384,
                'num_layers': 8,
                'num_heads': 6,
                'feedforward_channels': 384 * 4
            }),
        **dict.fromkeys(
            ['base'], {
                'embed_dims': 768,
                'num_layers': 12,
                'num_heads': 12,
                'feedforward_channels': 768 * 4
            }),
    }
    num_extra_tokens = 1  # class token
    OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'}

    def __init__(self,
                 arch='small',
                 img_size=32,
                 patch_size=8,
                 in_channels=3,
                 out_indices=-1,
                 drop_rate=0.,
                 drop_path_rate=0.,
                 qkv_bias=True,
                 norm_cfg=dict(type='LN', eps=1e-6),
                 norm_cfg_2=dict(type='LN', eps=1e-6),
                 final_norm=True,
                 out_type='cls_token',
                 with_cls_token=True,
                 frozen_stages=-1,
                 interpolate_mode='bicubic',
                 layer_scale_init_value=0.,
                 patch_cfg=dict(),
                 layer_cfgs=dict(),
                 pre_norm=False,
                 init_cfg=None,
                 is_2d=True,
                 use_v2=False,
                 force_a2=False,
                 embed_dims=None,
                 has_transpose=True,
                 fused_add_norm=True,
                 use_mlp=False,
                 split_head=False,
                 use_nd=False,
                 downsample=None,
                 expand=None,
                 constant_dim=False,
                 has_reverse=True,
                 update_interval=None,
                 duplicate=None,
                 n_dim=2,
                 num_layers=None,
                 use_ref=False,
                 d_state=16):
        super(Mamba2DModel, self).__init__(init_cfg)

        if isinstance(arch, str):
            arch = arch.lower()
            assert arch in set(self.arch_zoo), \
                f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
            self.arch_settings = self.arch_zoo[arch]
        else:
            essential_keys = {
                'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
            }
            assert isinstance(arch, dict) and essential_keys <= set(arch), \
                f'Custom arch needs a dict with keys {essential_keys}'
            self.arch_settings = arch
        self.update_interval = update_interval
        self.embed_dims = self.arch_settings['embed_dims'] 
        if embed_dims is not None:
            self.embed_dims = embed_dims
        if num_layers is None:
            num_layers = self.arch_settings['num_layers']
        self.num_layers = num_layers * (2 if not use_mlp else 1)
        self.downsample = self.arch_settings.get('downsample',downsample)
        if self.downsample is None:
            self.downsample = []
        #self.downsample = list((x+1)*2-1 for x in self.downsample)
        self.img_size = to_2tuple(img_size)
        self.is_2d = is_2d
        self.use_nd=use_nd

        # Set patch embedding
        _patch_cfg = dict(
            in_channels=in_channels,
            input_size=img_size,
            embed_dims=self.embed_dims,
            conv_type='Conv2d',
            kernel_size=patch_size,
            stride=patch_size,
            bias=not pre_norm,  # disable bias if pre_norm is used(e.g., CLIP)
        )
        _patch_cfg.update(patch_cfg)
        self.patch_embed = PatchEmbed(**_patch_cfg)
        self.patch_resolution = self.patch_embed.init_out_size
        num_patches = self.patch_resolution[0] * self.patch_resolution[1]

        # Set out type
        if out_type not in self.OUT_TYPES:
            raise ValueError(f'Unsupported `out_type` {out_type}, please '
                             f'choose from {self.OUT_TYPES}')
        self.out_type = out_type

        # Set cls token
        self.with_cls_token = with_cls_token
        if with_cls_token:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
        elif out_type != 'cls_token':
            self.cls_token = None
            self.num_extra_tokens = 0
        else:
            raise ValueError(
                'with_cls_token must be True when `out_type="cls_token"`.')

        # Set position embedding
        self.interpolate_mode = interpolate_mode
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + self.num_extra_tokens,
                        self.embed_dims))
        self._register_load_state_dict_pre_hook(self._prepare_pos_embed)

        self.drop_after_pos = nn.Dropout(p=drop_rate)

        if isinstance(out_indices, int):
            out_indices = [out_indices]
        assert isinstance(out_indices, Sequence), \
            f'"out_indices" must by a sequence or int, ' \
            f'get {type(out_indices)} instead.'
        for i, index in enumerate(out_indices):
            if index < 0:
                out_indices[i] = self.num_layers + index
            assert 0 <= out_indices[i] <= self.num_layers, \
                f'Invalid out_indices {index}'
        self.out_indices = out_indices

        # stochastic depth decay rule
        dpr = np.linspace(0, drop_path_rate, self.num_layers)

        self.layers = ModuleList()
        if isinstance(layer_cfgs, dict):
            layer_cfgs = [layer_cfgs] * self.num_layers
        ssm_cfg={"d_state":d_state}
        if expand is not None:
            ssm_cfg['expand'] = expand
        if duplicate:
            ssm_cfg['duplicate'] = duplicate
        if use_v2 and is_2d:
            ssm_cfg['use_v2'] = use_v2
        if force_a2:
            ssm_cfg['force_a2'] = force_a2
        dim = self.embed_dims
        for i in range(self.num_layers):
            #self.layers.append(TransformerEncoderLayer(**_layer_cfg))
            do_downsample = i in self.downsample
            self.layers.append(
                create_block(
                d_model=dim,
                ssm_cfg=ssm_cfg,
                fused_add_norm=fused_add_norm,
                residual_in_fp32=True,
                drop_rate=drop_rate,
                drop_path_rate=dpr[i],
                reverse= (not split_head ) and (i % 2) > 0 and has_reverse,
                transpose = (not split_head ) and has_transpose and ( i % 4) >=2,
                use_mlp=use_mlp,
                is_2d=is_2d,
                rms_norm=False,
                split_head=split_head,
                use_nd=use_nd,
                downsample=do_downsample,
                n_dim=n_dim,
                use_ref=use_ref
                )
            )
            if do_downsample and not constant_dim:
                dim *= 2
                
        self.frozen_stages = frozen_stages
        if pre_norm:
            self.pre_norm = build_norm_layer(norm_cfg, dim)
        else:
            self.pre_norm = nn.Identity()

        self.final_norm = final_norm
        if final_norm:
            self.ln1 = build_norm_layer(norm_cfg, dim)
        if self.out_type == 'avg_featmap':
            self.ln2 = build_norm_layer(norm_cfg_2, dim)

        # freeze stages only when self.frozen_stages > 0
        if self.frozen_stages > 0:
            self._freeze_stages()
        self.count_parameters()
    @property
    def norm1(self):
        return self.ln1

    @property
    def norm2(self):
        return self.ln2

    def init_weights(self):
        super(Mamba2DModel, self).init_weights()

        if not (isinstance(self.init_cfg, dict)
                and self.init_cfg['type'] == 'Pretrained'):
            if self.pos_embed is not None:
                trunc_normal_(self.pos_embed, std=0.02)
        else:
            ckpt = self.init_cfg['checkpoint']
            ckpt = torch.load(ckpt,map_location='cpu')['state_dict']
            ckpt = {k.replace('backbone.','',1):v for k,v in ckpt.items() if  k.startswith('backbone.') }
            curr_dict = self.state_dict()
            for k,v in ckpt.items():
                if k in curr_dict and (v_new:=curr_dict[k]).shape != v.shape:
                    if 'patch_embed' in k:
                        # n c H W
                        v_resized = torch.nn.functional.interpolate(v,v_new.shape[-2:])
                        assert v_resized.shape == v_new.shape
                        ckpt[k] = v_resized
                    elif 'pos_embed' in k:
                        b,old_len,dim = v.shape
                        old_d = int(np.sqrt(old_len))
                        new_len = v_new.shape[1]
                        new_d = int(np.sqrt(new_len))
                        v_resized = v.reshape(b,old_d,old_d,dim).permute(0,3,1,2)
                        v_resized = torch.nn.functional.interpolate(v_resized,(new_d,new_d)).flatten(2).permute(0,2,1)
                        assert v_resized.shape == v_new.shape
                        ckpt[k] = v_resized
                    else:
                        print(k,v_new.shape,v.shape)
            res = self.load_state_dict(ckpt,strict=False)
            print('----------init-------------------')
            print(res)

    def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
        name = prefix + 'pos_embed'
        if name not in state_dict.keys():
            return

        ckpt_pos_embed_shape = state_dict[name].shape
        if (not self.with_cls_token
                and ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1):
            # Remove cls token from state dict if it's not used.
            state_dict[name] = state_dict[name][:, 1:]
            ckpt_pos_embed_shape = state_dict[name].shape

        if self.pos_embed.shape != ckpt_pos_embed_shape:
            from mmengine.logging import MMLogger
            logger = MMLogger.get_current_instance()
            logger.info(
                f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
                f'to {self.pos_embed.shape}.')

            ckpt_pos_embed_shape = to_2tuple(
                int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
            pos_embed_shape = self.patch_embed.init_out_size

            state_dict[name] = resize_pos_embed(state_dict[name],
                                                ckpt_pos_embed_shape,
                                                pos_embed_shape,
                                                self.interpolate_mode,
                                                self.num_extra_tokens)

    @staticmethod
    def resize_pos_embed(*args, **kwargs):
        """Interface for backward-compatibility."""
        return resize_pos_embed(*args, **kwargs)

    def _freeze_stages(self):
        # freeze position embedding
        if self.pos_embed is not None:
            self.pos_embed.requires_grad = False
        # set dropout to eval model
        self.drop_after_pos.eval()
        # freeze patch embedding
        self.patch_embed.eval()
        for param in self.patch_embed.parameters():
            param.requires_grad = False
        # freeze pre-norm
        for param in self.pre_norm.parameters():
            param.requires_grad = False
        # freeze cls_token
        if self.cls_token is not None:
            self.cls_token.requires_grad = False
        # freeze layers
        for i in range(1, self.frozen_stages + 1):
            m = self.layers[i - 1]
            m.eval()
            for param in m.parameters():
                param.requires_grad = False
        # freeze the last layer norm
        if self.frozen_stages == len(self.layers):
            if self.final_norm:
                self.ln1.eval()
                for param in self.ln1.parameters():
                    param.requires_grad = False

            if self.out_type == 'avg_featmap':
                self.ln2.eval()
                for param in self.ln2.parameters():
                    param.requires_grad = False

    def forward(self, x):
        B = x.shape[0]
        x, patch_resolution = self.patch_embed(x)
        x = x + resize_pos_embed(
            self.pos_embed,
            self.patch_resolution,
            patch_resolution,
            mode=self.interpolate_mode,
            num_extra_tokens=self.num_extra_tokens)[:,self.num_extra_tokens:]
        if self.is_2d:
            assert self.cls_token is  None
            x = rearrange(x,'n (h w) c-> n c h w',h=patch_resolution[0],w=patch_resolution[1])
        if self.cls_token is not None:
            # stole cls_tokens impl from Phil Wang, thanks
            cls_token = self.cls_token.expand(B, -1, -1)
            x = torch.cat((x,cls_token), dim=1) # append last
        x = self.drop_after_pos(x)

        x = self.pre_norm(x)
        h = patch_resolution[0]
        w = patch_resolution[1]
        outs = []
        residual = None
        if self.update_interval:
            raw_x = x
            for i,layer in enumerate(self.layers):
                z = i // 2
                x_l,residual = layer(raw_x,residual,skip=False,h=h,w=w)
                if layer.downsample:
                        h,w = residual
                        residual = None
                x = x + x_l
                if (i+1) % self.update_interval == 0 or i == len(self.layers) - 1:
                    raw_x = x
                #x = raw_x
                if i == len(self.layers) - 1:
                    x = (x + residual) if residual is not None else x
                if i == len(self.layers) - 1 and self.final_norm:
                    x = self.ln1(x)

                if i in self.out_indices:
                    outs.append(self._format_output(x, patch_resolution))
        else:
            for i, layer in enumerate(self.layers):
                if self.use_nd:
                    x,residual = layer(x,residual,h=h,w=w)
                    if layer.downsample:
                        h,w = residual
                        residual = None
                else:
                    x,residual = layer(x,residual,h=h,w=w)
                    if layer.downsample:
                        h,w = residual
                        residual = None

                if i == len(self.layers) - 1:
                    x = (x + residual) if residual is not None else x
                if i == len(self.layers) - 1 and self.final_norm:
                    x = self.ln1(x)

                if i in self.out_indices:
                    outs.append(self._format_output(x, patch_resolution))

        return tuple(outs)

    def count_parameters(self,model=None):
        if model is None:
            model = self
        table = PrettyTable(["Modules", "Parameters"])
        total_params = 0
        for name, parameter in model.named_parameters():
            if not parameter.requires_grad:
                continue
            params = parameter.numel()
            table.add_row([name, params])
            total_params += params
        self.total_parms = total_params
        print(table)
        print(f"Total Trainable Params: {total_params}")
        return total_params
    
    def _format_output(self, x, hw):
        if self.out_type == 'raw':
            return x
        if self.out_type == 'cls_token':
            return x[:, -1]
        if not self.is_2d:
            patch_token = x[:, self.num_extra_tokens:]
        else:
            patch_token = x
        if self.out_type == 'featmap':
            B = x.size(0)
            # (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
            if self.is_2d:
                return patch_token
            else:
                return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
        if self.out_type == 'avg_featmap':
            if self.is_2d:
                return self.ln2(patch_token.flatten(2).mean(dim=-1))
            else:
                return self.ln2(patch_token.mean(dim=1))

    def get_layer_depth(self, param_name: str, prefix: str = ''):
        """Get the layer-wise depth of a parameter.

        Args:
            param_name (str): The name of the parameter.
            prefix (str): The prefix for the parameter.
                Defaults to an empty string.

        Returns:
            Tuple[int, int]: The layer-wise depth and the num of layers.

        Note:
            The first depth is the stem module (``layer_depth=0``), and the
            last depth is the subsequent module (``layer_depth=num_layers-1``)
        """
        num_layers = self.num_layers + 2

        if not param_name.startswith(prefix):
            # For subsequent module like head
            return num_layers - 1, num_layers

        param_name = param_name[len(prefix):]

        if param_name in ('cls_token', 'pos_embed'):
            layer_depth = 0
        elif param_name.startswith('patch_embed'):
            layer_depth = 0
        elif param_name.startswith('layers'):
            layer_id = int(param_name.split('.')[1])
            layer_depth = layer_id + 1
        else:
            layer_depth = num_layers - 1

        return layer_depth, num_layers


In [5]:
##  """""""""""""""""""""""      ##
##              CONFIG           ##   
##    """"""""""""""""""""""""   ##

## TEST FOR TinyIMAGENET TRAINING

In [9]:
import torch
from mmengine.hooks import Hook
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import pandas as pd
from typing import Any, Dict, List, Optional, Sequence, Tuple

from mmengine.evaluator import BaseMetric
from mmengine.registry import METRICS

def train_step(self, data_batch, optim_wrapper):
    # Forward pass
    outputs = self(**data_batch)

    # Calculate loss
    loss_fn = torch.nn.CrossEntropyLoss()  # Change this to your loss function
    loss = loss_fn(outputs, data_batch['labels'])

    # Backward pass and optimize
    optim_wrapper.zero_grad()
    loss.backward()
    optim_wrapper.step()

    # Collect predictions and loss
    metrics = {
        'loss': loss,
        'preds': outputs.detach()  # Ensure predictions are detached from the computation graph
    }

    return metrics


def val_step(self, data_batch):
    # Forward pass
    outputs = self(**data_batch)

    # Calculate loss
    loss_fn = torch.nn.CrossEntropyLoss()  # Change this to your loss function
    loss = loss_fn(outputs, data_batch['labels'])

    # Collect other metrics if needed
    metrics = {
        'loss': loss,
        'preds': outputs
    }

    return metrics



# Ajout d'une metric pour le calcul de la loss de validation
@METRICS.register_module()
class LossValMetric(BaseMetric):
    """Loss evaluation metric."""
    default_prefix: Optional[str] = 'loss'

    def __init__(self, collect_device: str = 'cpu', prefix: Optional[str] = None) -> None:
        super().__init__(collect_device=collect_device, prefix=prefix)
        self.loss_fn = torch.nn.CrossEntropyLoss()  # loss function 

    def process(self, data_batch: Sequence[Tuple[Any, Dict]], data_samples: Sequence[Dict]) -> None:
        """Process one batch of data samples and data_samples. The processed
        results should be stored in ``self.results``, which will be used to
        compute the metrics when all batches have been processed.
        """
        outputs = [data_sample['pred_score'].to('cuda') for data_sample in data_samples]
        labels = [data_sample['gt_label'].to('cuda') for data_sample in data_samples]
   
        # Stack outputs and labels to form tensors
        outputs = torch.stack(outputs)
        labels = torch.cat(labels)

        # Calculate loss
        loss = self.loss_fn(outputs, labels).cpu().item()
        self.results.append({'loss': loss})

    def compute_metrics(self, results: List[Dict]) -> Dict:
        """Compute the metrics from processed results."""
        losses = [x['loss'] for x in results]
        avg_loss = sum(losses) / len(losses)
        return {'loss': avg_loss}



class MetricLoggerHook(Hook):
    def __init__(self):
        self.train_loss = []
        self.val_loss = []
        self.train_top1_acc = []
        self.val_top1_acc = []
        self.batch_losses = []  # To store batch losses for averaging
        self.batch_top1_acc = []  # To store batch accuracies for averaging
        self.all_labels = []  # To store true labels for confusion matrix
        self.all_preds = []  # To store predicted labels for confusion matrix


    def after_train_epoch(self, runner):
        if self.batch_losses:
            avg_loss = sum(self.batch_losses) / len(self.batch_losses)
            self.train_loss.append(avg_loss)
            self.batch_losses = []  # Reset batch losses for the next epoch
        else:
            self.train_loss.append(None)


        if self.batch_top1_acc:
            max_acc = max(self.batch_top1_acc)  # Get the maximum accuracy
            self.train_top1_acc.append(max_acc)
            self.batch_top1_acc = []  # Reset batch accuracies for the next epoch
        else:
            self.train_top1_acc.append(None)
    

    def after_val_epoch(self, runner, metrics=None):
        if metrics and 'accuracy/top1' in metrics:
            self.val_top1_acc.append(metrics['accuracy/top1'])
        else:
            self.val_top1_acc.append(None)

        if metrics and 'loss/loss' in metrics:
            self.val_loss.append(metrics['loss/loss'])
        else:
            self.val_loss.append(None)

        self.save_to_csv('metrics.csv')
            
        self.plot_metrics()  # Plot metrics after each validation epoch
        self.plot_confusion_matrix()  # Plot confusion matrix after each validation epoch
        self.all_labels = []  # Reset labels for the next epoch
        self.all_preds = []  # Reset predictions for the next epoch

    def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
        self.batch_losses.append(outputs['loss'].item())
        runner.message_hub.update_scalar('loss', outputs['loss'].item())
    
        # Forward pass to get predictions
        with torch.no_grad():
            preds = runner.model(**data_batch)
            preds = preds.argmax(dim=1)
    
        
        # Assuming labels are in data_batch['data_samples']
        labels = [sample.gt_label.item() for sample in data_batch['data_samples']]
        labels = torch.tensor(labels).to(preds.device)  # Ensure labels are on the same device as preds
        correct = preds.eq(labels).sum().item()
        accuracy = correct / len(labels)
        self.batch_top1_acc.append(accuracy)
        runner.message_hub.update_scalar('top1_acc', accuracy)

        
    
    def after_val_iter(self, runner, batch_idx, data_batch=None, outputs=None):
        true_labels = [sample.gt_label.item() for sample in data_batch['data_samples']]
        pred_labels = [output.pred_label.item() for output in outputs]

        self.all_labels.extend(true_labels)
        self.all_preds.extend(pred_labels)

    def plot_metrics(self):
        epochs = range(1, len(self.train_loss) + 1)

        plt.figure(figsize=(12, 5))

        # Courbe de perte d'entraînement
        plt.subplot(1, 2, 1)
        if len(self.train_loss) > 0:
            plt.plot(epochs, self.train_loss, label='Training loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title('Training Loss over Epochs')
        plt.legend()

        # Courbe d'accuracy de validation
        plt.subplot(1, 2, 2)
        if len(self.val_top1_acc) > 0:
            plt.plot(epochs, self.val_top1_acc, label='Validation Top-1 Accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.title('Validation Accuracy over Epochs')
        plt.legend()

        plt.show()


    def plot_confusion_matrix(self):
        cm = confusion_matrix(self.all_labels, self.all_preds, normalize='true')
        disp = ConfusionMatrixDisplay(confusion_matrix=cm)
        
        fig, ax = plt.subplots(figsize=(25, 25))  # Adjust the figure size as needed
        disp.plot(ax=ax, cmap=plt.cm.Blues, xticks_rotation=90, values_format='.1f')
        
        plt.title('Confusion Matrix')
        plt.xlabel('Predicted Labels')
        plt.ylabel('True Labels')
        
        # Ajuster la taille de la police
        plt.xticks(fontsize=10) 
        plt.yticks(fontsize=10)  
        
        plt.show()

    def save_to_csv(self, filename):
        epochs = range(1, len(self.train_loss) + 1)

        #print(f"epoch : {epochs}, train_loss : {self.train_loss}, val_acc : {self.val_top1_acc}")
        data = {
            'Epoch': epochs,
            'Training Loss': self.train_loss,
            'Validation Loss': self.val_loss,
            'Training Top-1 Accuracy': self.train_top1_acc,
            'Validation Top-1 Accuracy': self.val_top1_acc
        }
        df = pd.DataFrame(data)
        df.to_csv(filename, index=False)

In [1]:
import os
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
import torch
from mmengine.config import Config, ConfigDict
from mmengine.registry import RUNNERS
from mmengine.runner import Runner
import logging
from PIL import Image



logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class DataSample:
  def __init__(self, gt_label):
   self.gt_label = torch.tensor([gt_label]).to('cuda')

  def set_pred_score(self, score):
   self.pred_score = score
   return self

  def set_pred_label(self, label):
   self.pred_label = label
   return self

  @property
  def gt_score(self):
   # Cette méthode peut être modifiée selon vos besoins spécifiques
   return None

  def __contains__(self, key):
   return key in self.__dict__

  def __getitem__(self, key):
   return self.__dict__[key]



class CustomDataset(Dataset):
  def __init__(self, dataset):
   self.dataset = dataset

  def __len__(self):
   return len(self.dataset)

  def __getitem__(self, idx):
   data = self.dataset[idx]
   return {'inputs': data[0], 'labels': data[1]}
      

# Fonction custom_collate_fn
def custom_collate_fn(batch):
    inputs = torch.stack([item['inputs'] for item in batch]).to('cuda')
    labels = torch.tensor([item['labels'] for item in batch]).to('cuda')
    data_samples = [DataSample(label.item()) for label in labels]
    return {'inputs': inputs, 'data_samples': data_samples}

# Fonction pour lire les annotations de validation
def read_val_annotations(val_dir):
    val_annotations_path = os.path.join(val_dir, 'val_annotations.txt')
    with open(val_annotations_path, 'r') as f:
        val_annotations = f.readlines()
    
    val_labels = {}
    for line in val_annotations:
        parts = line.split('\t')
        val_labels[parts[0]] = parts[1]
    
    return val_labels

# Fonction pour créer le dataset de validation
class TinyImageNetValidationDataset(Dataset):
    def __init__(self, val_dir, transform=None):
        self.val_dir = val_dir
        self.transform = transform
        self.val_labels = read_val_annotations(val_dir)
        self.img_dir = os.path.join(val_dir, 'images')
        self.img_names = list(self.val_labels.keys())
        
        self.class_to_idx = {cls: idx for idx, cls in enumerate(sorted(set(self.val_labels.values())))}
        
    def __len__(self):
        return len(self.img_names)
    
    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        label = self.class_to_idx[self.val_labels[img_name]]
        img_path = os.path.join(self.img_dir, img_name)
        
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        
        return img, label



# Define the configuration dictionary directly in the notebook
cfg = ConfigDict(
  default_scope='mmpretrain',
  env_cfg=dict(
   cudnn_benchmark=False,
   mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
   dist_cfg=dict(backend='nccl'),
  ),
  vis_backends=[dict(type='LocalVisBackend')],
  visualizer=dict(type='UniversalVisualizer', vis_backends=[
               dict(type='LocalVisBackend')]),
  log_level='INFO',
  load_from=None, #'/home/simon/Documents/MambaND/MambaNDoff/image_classification/work_dirs/cifar10_experiment/best_accuracy_top1_epoch_30.pth',
  resume=False,
  randomness=dict(seed=77, diff_rank_seed=True),
  train_cfg=dict(by_epoch=True, max_epochs=60),
  val_cfg=dict(),
  test_cfg=dict(),
  auto_scale_lr=dict(base_batch_size=32),
  model=dict(
   type='ImageClassifier',
   backbone=dict(
       type='Mamba2DModel',
       arch='small',
       img_size=64,
       patch_size=8,
       out_type='avg_featmap',
       drop_path_rate=0.1,
       drop_rate=0.1,
       with_cls_token=False,
       final_norm=True,
       fused_add_norm=False,
       d_state=16,
       is_2d=False,
       use_v2=False,
       constant_dim=True,
       downsample=(9,),
       force_a2=False,
       use_mlp=False,
   ),
   neck=None,
   head=dict(
       type='LinearClsHead',
       num_classes=200,
       in_channels=384,
       loss=dict(
           type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
       init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)]),
   train_cfg=dict(augments=[
       dict(type='Mixup', alpha=0.8),
       dict(type='CutMix', alpha=1.0)
   ])),
  optim_wrapper=dict(
   optimizer=dict(
       type='AdamW', lr=1e-3, weight_decay=0.1, betas=(0.9, 0.999)),
   constructor='LearningRateDecayOptimWrapperConstructor',
   clip_grad=dict(max_norm=2.0),
   paramwise_cfg=dict(
       norm_decay_mult=0.1,
       layer_decay_rate=0.95,
       custom_keys={
           '.ln': dict(decay_mult=0.0),
           '.bias': dict(decay_mult=0.0),
           '.cls_token': dict(decay_mult=0.0),
           '.pos_embed': dict(decay_mult=0.0),
           '.A_log': dict(decay_mult=0.1),
           '.A2_log': dict(decay_mult=0.1),
           '.absolute_pos_embed': dict(decay_mult=0.0),
       })),
  param_scheduler=[
   dict(
     type='LinearLR',
     start_factor=1e-4,
     by_epoch=True,
     begin=0,
     end=10,
     convert_to_iter_based=True),
   dict(
       type='CosineAnnealingLR',
       by_epoch=True,
       begin=10,
       eta_min=1e-5,
       convert_to_iter_based=True)
  ],
  default_hooks=dict(
   timer=dict(type='IterTimerHook'),
   logger=dict(type='LoggerHook', interval=100),
   param_scheduler=dict(type='ParamSchedulerHook'),
   checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3, save_best='auto'),
   sampler_seed=dict(type='DistSamplerSeedHook'),
   visualization=dict(type='VisualizationHook', enable=True),
  ),
  data_preprocessor=dict(
   num_classes=200,
   mean=[0.5, 0.5, 0.5],
   std=[0.5, 0.5, 0.5],
   to_rgb=True,
  )
)



# Chemins pour les données tiny Imagenet
data_root = './'

transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.3),
    transforms.ToTensor(),
    transforms.Normalize(cfg.data_preprocessor.mean,cfg.data_preprocessor.std),
    transforms.RandomErasing(p=0.2, scale=(0.03, 0.03))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[x / 255.0 for x in cfg.data_preprocessor['mean']],
                     std=[x / 255.0 for x in cfg.data_preprocessor['std']])
])

train_dataset = datasets.ImageFolder(root=os.path.join(data_root, 'train'), transform=transform_train)
val_dataset = TinyImageNetValidationDataset(val_dir=os.path.join(data_root, 'val'), transform=transform_test)

# DataLoader pour l'entraînement
train_dataloader = DataLoader(
    CustomDataset(train_dataset),
    batch_size=32,
    shuffle=True,
    num_workers=0,
    collate_fn=custom_collate_fn
)

# DataLoader pour la validation
val_dataloader = DataLoader(
    CustomDataset(val_dataset),
    batch_size=32,
    shuffle=False,
    num_workers=0,
    collate_fn=custom_collate_fn
)

# Évaluateur
cfg.val_evaluator = [dict(type='Accuracy', topk=(1, 5)), dict(type='LossValMetric')]
cfg.test_evaluator = cfg.val_evaluator

# Affichage des configurations
print("Train dataloader:", train_dataloader)
print("Validation dataloader:", val_dataloader)
print("Validation evaluator:", cfg.val_evaluator)

# ####################   DEBOGAGE VISUALISATION DONNEE   ###############################

# def visualize_samples(dataset):
#     for i in range(0,5):
#         img, label = dataset[i]
#         plt.imshow(img.permute(1, 2, 0))  # Change the format for displaying
#         plt.title(f'Label: {label}')
#         plt.show()


# # Visualiser quelques échantillons du dataset
# print("Visualisation des échantillons de l'ensemble d'entraînement :")
# visualize_samples(train_dataset)

# print("Visualisation des échantillons de l'ensemble de validation :")
# visualize_samples(val_dataset)
# ####################################################################################

# Simuler les arguments d'entrée
class Args:
    work_dir = None
    resume = None
    amp = False
    no_validate = False
    auto_scale_lr = False
    no_pin_memory = False
    no_persistent_workers = False
    cfg_options = None
    launcher = 'none'
    local_rank = 0

args = Args()

def merge_args(cfg, args):
    """Merge CLI arguments to config."""
    if args.no_validate:
        cfg.val_cfg = None
        cfg.val_dataloader = None
        cfg.val_evaluator = None

    cfg.launcher = args.launcher

    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        cfg.work_dir = os.path.join('./work_dirs', 'tinyimagenet_experiment')

    if args.amp:
        cfg.optim_wrapper.type = 'AmpOptimWrapper'
        cfg.optim_wrapper.setdefault('loss_scale', 'dynamic')

    if args.resume == 'auto':
        cfg.resume = True
        cfg.load_from = None
    elif args.resume is not None:
        cfg.resume = True
        cfg.load_from = args.resume

    if args.auto_scale_lr:
        cfg.auto_scale_lr.enable = True

    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    return cfg

# Merging args with the configuration
cfg = merge_args(cfg, args)

# Initialisation du hook personnalisé
metric_logger_hook = MetricLoggerHook()

# Initialize runner and start training
runner = Runner(
    model=cfg.model,
    work_dir=cfg.work_dir,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    test_dataloader=val_dataloader,  # Assuming test and val dataloaders are the same
    train_cfg=cfg.train_cfg,
    val_cfg=cfg.val_cfg,
    test_cfg=cfg.test_cfg,
    auto_scale_lr=cfg.auto_scale_lr,
    optim_wrapper=cfg.optim_wrapper,
    param_scheduler=cfg.param_scheduler,
    val_evaluator=cfg.val_evaluator,
    test_evaluator=cfg.test_evaluator,
    default_hooks=cfg.default_hooks,
    custom_hooks=[metric_logger_hook],
    data_preprocessor=cfg.data_preprocessor,
    load_from=cfg.load_from,
    resume=cfg.resume,
    launcher=cfg.launcher,
    env_cfg=cfg.env_cfg,
    log_level=cfg.log_level,
    visualizer=cfg.visualizer,
    default_scope=cfg.default_scope,
    randomness=cfg.randomness,
    cfg=cfg
)

# Déplacer le modèle et les autres composants sur l'appareil approprié
runner.model.to('cuda')

# Start training
logger.info("Starting training...")
# Sample batch display for debugging
sample_batch = next(iter(train_dataloader))
print("Sample batch inputs shape:", sample_batch['inputs'].shape)
print("Sample batch labels shape:", sample_batch['data_samples'][0].gt_label.shape)

runner.train()

# Après l'entraînement, tracer les courbes
metric_logger_hook.plot_metrics()



KeyboardInterrupt

