<a href="https://colab.research.google.com/github/sinagho/W-Mamba/blob/main/model_modified.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Libs

In [1]:
# !pip install torch torchvision torchaudio --upgrade
# !pip install causal-conv1d && pip install mamba-ssm

In [2]:
# !pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121
# !pip uninstall mamba-ssm causal-conv1d
# !pip install causal-conv1d && pip install mamba-ssm

In [3]:
# !pip install triton

In [4]:
# !pip install packaging
# !pip install timm==0.4.12
# !pip install pytest
# !pip install chardet
# !pip install yacs
# !pip install termcolor

In [5]:
import time
import math
from functools import partial
from typing import Optional, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange, repeat
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

# Base Codes

In [6]:
import io
from contextlib import redirect_stderr
from fvcore.nn import FlopCountAnalysis


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
  
def count_flops(model, x):
  with redirect_stderr(io.StringIO()):
    flops = FlopCountAnalysis(model, (x,))
    # flops.count(ignore_modules=[torch.nn.ReLU, torch.nn.ReLU6, torch.nn.SiLU, torch.nn.Identity])
    flops_amount = flops.total()
  return flops_amount

def count_parameters_and_flops(model, x):
    params = count_parameters(model)
    flops_amount = count_flops(model, x)
    return params, flops_amount
  
def print_parameters_and_flops(model, x, inout=False):
    params, flops_amount = count_parameters_and_flops(model, x)
    if inout:
      output = model(x)
      print(f"Input: {x.shape},\tOutput: {output.shape}\n{80*'-'}")
    print(f"Parameters: {params/1e6:.6f} M,\tFLOPs: {flops_amount/1e9:.6f} G")

In [7]:
DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"

## Check the details

In [8]:
# my_model = VSSBlock(hidden_dim = 64).cuda()
# print(my_model)
# x = torch.randn(1,64,128,128).cuda()
# print_parameters_and_flops(my_model, x, inout=True)

# Make the Models (V1 & V2)

In [9]:
from blocks.ssm import SS2D

class VSSBlock(nn.Module):
    def __init__(
        self,
        hidden_dim: int = 0,
        drop_path: float = 0,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
        attn_drop_rate: float = 0,
        d_state: int = 16,
        **kwargs,
    ):
        super().__init__()
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = SS2D(d_model=hidden_dim, dropout=attn_drop_rate, d_state=d_state, **kwargs)
        self.drop_path = DropPath(drop_path)

    def forward(self, input: torch.Tensor):
        input = input.permute(0, 2, 3, 1).contiguous() # B, C, H, W -> B, H, W, C
        x = self.ln_1(input)
        x = self.self_attention(x)
        x = input + self.drop_path(x)
        x = x.permute(0, 3, 1, 2).contiguous() # B, H, W, C -> B, C, H, W
        return x


In [10]:
x = torch.randn(1, 64, 128, 128).cuda() # (B, N, C)
block = VSSBlock(hidden_dim=64, drop_path=0.1, attn_drop_rate=0, d_state=16).cuda()
print_parameters_and_flops(block, x, inout=True)

Input: torch.Size([1, 64, 128, 128]),	Output: torch.Size([1, 64, 128, 128])
--------------------------------------------------------------------------------
Parameters: 0.055936 M,	FLOPs: 0.772811 G


## V1 & V2

In [11]:
# class VSSModule(nn.Module):
#     def __init__(
#         self,
#         hidden_dim: int = 0,
#         drop_path: float = 0,
#         norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
#         attn_drop_rate: float = 0,
#         d_state: int = 16,
#         init_value: float = 1.0,
#         vss_block: Callable[..., torch.nn.Module] = VSSBlock,
#         **kwargs,
#     ):
#         super().__init__()
#         self.vssm = vss_block(hidden_dim, drop_path, norm_layer, attn_drop_rate, d_state, **kwargs)
#         self.lambda_ = nn.Parameter(init_value*torch.ones(hidden_dim,1,1), requires_grad=True)
#         self.conv_bbone = nn.Sequential(nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, groups=hidden_dim),
#                                         nn.BatchNorm2d(num_features=hidden_dim),
#                                         nn.LeakyReLU(),
#                                         nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1),
#                                         nn.BatchNorm2d(num_features=hidden_dim),
#                                         nn.SiLU())
#         self.beta =  nn.Parameter(init_value*torch.ones(hidden_dim,1,1), requires_grad=True)
#         self.mlp = nn.Sequential(nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1),
#                                  nn.BatchNorm2d(num_features=hidden_dim),
#                                  nn.SiLU())
#         def forward(self, x):
#             raise NotImplementedError

# from models.vmaco import VSSModule

# class VSSModuleV1(VSSModule):
#     def __init__(self, *args, **kwargs):
#         super().__init__(*args, **kwargs)

#     def forward(self, x):
#         y_ssm = self.lambda_ * self.vssm(x)
#         y_cnn = self.conv_bbone(y_ssm + x)
#         x = y_cnn + self.beta*x
#         x = self.mlp(x)
#         return x
    

# class VSSModuleV2(VSSModule):
#     def __init__(self, *args, **kwargs):
#         super().__init__(*args, **kwargs)

#     def forward(self, x):
#         x_vss = self.lambda_ * self.vssm(x)
#         x_cnn = self.beta * self.conv_bbone(x_vss + x)
#         x = self.mlp(x_vss + x_cnn)
#         return x
    

# from blocks.cnn import CBAM
# from blocks.vit import IBIBlock

# class VSSModuleV3(VSSModule):
#     def __init__(self, *args, **kwargs):
#         kwargs['init_value'] = None
#         fmap_size = kwargs['fmap_size']        
        
#         super().__init__(*args, **kwargs)
        
#         self.cbbam = CBAM(self.hidden_dim, ratio=16, kernel_size=5)
#         self.ibi = IBIBlock(fmap_size, 
#                             window_size=7, 
#                             dim_in=self.hidden_dim, 
#                             dim_embed=self.hidden_dim, 
#                             depths=2, stage_spec='LS', heads=4, 
#                             attn_drop=0.0, proj_drop=0.0, expansion_mlp=1,
#                             drop=0.0, drop_path_rate=0.0, use_dwc_mlp=False,
#         )

#     def forward(self, x):
#         x_cbb = self.conv_bbone(x)
#         x_cnn = self.cbbam(x_cbb) + x
#         x_ibi, _, _ = self.ibi(x_cnn)
#         x = x_ibi + x_cnn
#         x_vss = self.vssm(x)
#         x = self.mlp(x_vss + x)
#         return x


In [12]:
from models.vmaco import VSSModuleV1, VSSModuleV2, VSSModuleV3

x = torch.randn(1, 64, 112, 112).cuda()

print("VSSModuleV1")
model1 = VSSModuleV1(hidden_dim=64, vss_block=VSSBlock).cuda()
print_parameters_and_flops(model1, x, inout=True)

print("\nVSSModuleV2")
model2 = VSSModuleV2(hidden_dim=64, vss_block=VSSBlock).cuda()
print_parameters_and_flops(model2, x, inout=True)

print("\nVSSModuleV3")
model3 = VSSModuleV3(fmap_size=112, hidden_dim=64, vss_block=VSSBlock).cuda()
print_parameters_and_flops(model3, x, inout=True)

VSSModuleV1
Input: torch.Size([1, 64, 112, 112]),	Output: torch.Size([1, 64, 112, 112])
--------------------------------------------------------------------------------
Parameters: 0.065408 M,	FLOPs: 0.713692 G

VSSModuleV2
Input: torch.Size([1, 64, 112, 112]),	Output: torch.Size([1, 64, 112, 112])
--------------------------------------------------------------------------------
Parameters: 0.065408 M,	FLOPs: 0.713692 G

VSSModuleV3


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Input: torch.Size([1, 64, 112, 112]),	Output: torch.Size([1, 64, 112, 112])
--------------------------------------------------------------------------------
Parameters: 0.117626 M,	FLOPs: 1.505102 G


In [13]:
class BaseModule(nn.Module):
    def _init_weights(self, m: nn.Module):
        for name, p in m.named_parameters():
            if name in ["out_proj.weight"]:
                p = p.clone().detach_() # fake init, just to keep the seed ....
                nn.init.kaiming_uniform_(p, a=math.sqrt(5))

        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
            

class VSSLayer(BaseModule):
    """ A layer for one stage.
    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """
    def __init__(
        self,
        dim,
        depth,
        attn_drop=0.,
        drop_path=0.,
        norm_layer=nn.LayerNorm,
        upsample=None,
        downsample=None,
        use_checkpoint=False,
        d_state=16,
        init_value: float =1.0,
        vss_module: Callable[..., torch.nn.Module] = partial(VSSModuleV1, vss_block=VSSBlock),
        **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.use_checkpoint = use_checkpoint
        self.blocks = nn.ModuleList([vss_module(
            hidden_dim = dim,
            drop_path = drop_path[i] if isinstance(drop_path, list) else drop_path,
            norm_layer = norm_layer,
            attn_drop_rate = attn_drop,
            d_state = d_state,
            init_value = init_value,
            **kwargs,
        ) for i in range(depth)])

        self.upsample = upsample(dim, norm_layer=norm_layer) if callable(upsample) else None
        self.downsample = downsample(dim, norm_layer=norm_layer) if callable(downsample) else None
        self.apply(self._init_weights)        

    def forward(self, x):
        if self.upsample: x = self.upsample(x)
        for blk in self.blocks:
            x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
        if self.downsample: x = self.downsample(x)
        return x

VSSLayerV1 = partial(VSSLayer, vss_module=VSSModuleV1)
VSSLayerV2 = partial(VSSLayer, vss_module=VSSModuleV2)

VSSLayer_down_V1 = partial(VSSLayerV1, upsample=None)
VSSLayer_down_V2 = partial(VSSLayerV2, upsample=None)
VSSLayer_up_V1 = partial(VSSLayerV1, downsample=None)
VSSLayer_up_V2 = partial(VSSLayerV2, downsample=None)

In [14]:
layer = VSSLayer(fmap_size=112, dim=64, depth=2, vss_module=VSSModuleV3).cuda()
a = torch.randn(1, 64, 112, 112).cuda()
print_parameters_and_flops(layer, a, inout=True)

Input: torch.Size([1, 64, 112, 112]),	Output: torch.Size([1, 64, 112, 112])
--------------------------------------------------------------------------------
Parameters: 0.235252 M,	FLOPs: 3.010205 G


## Networks

In [15]:
from blocks.patch import PatchEmbed2D, PatchMerging2D, PatchExpand2D, Final_PatchExpand2D


class VMACO(BaseModule):
    def __init__(self, patch_size=4, in_chans=3, num_classes=1000, depths=[2, 2, 9, 2], depths_decoder=[2, 9, 2, 2],
                 dims=[96, 192, 384, 768], dims_decoder=[768, 384, 192, 96], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, patch_norm=True, use_checkpoint=False, 
                 vss_layer=VSSLayerV1,
                 spatial_size=224, **kwargs):
        super().__init__()
        self.num_classes = num_classes
        self.num_layers = len(depths)
        if isinstance(dims, int):
            dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
        self.embed_dim = dims[0]
        self.num_features = dims[-1]
        self.dims = dims
        self.patch_embed = PatchEmbed2D(patch_size, in_chans, embed_dim=self.embed_dim, norm_layer=norm_layer if patch_norm else None)

        # WASTED absolute position embedding ======================
        self.ape = False
        # self.ape = False
        # drop_rate = 0.0
        if self.ape:
            self.patches_resolution = self.patch_embed.patches_resolution
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, *self.patches_resolution, self.embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_decoder))][::-1]

        fmas = [56, 28, 14, 7]
        self.layers_down = nn.ModuleList([
            vss_layer(
                dim=dims[i_layer],
                depth=depths[i_layer],
                d_state=d_state,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                norm_layer=norm_layer,
                downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None,
                use_checkpoint=use_checkpoint,
                fmap_size=fmas[i_layer],

            ) for i_layer in range(self.num_layers)
        ])

        fmas = [56, 28, 14, 7][::-1]
        self.layers_up = nn.ModuleList([
            vss_layer(
                dim=dims_decoder[i_layer],
                depth=depths_decoder[i_layer],
                d_state=d_state,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr_decoder[sum(depths_decoder[:i_layer]):sum(depths_decoder[:i_layer + 1])],
                norm_layer=norm_layer,
                upsample=PatchExpand2D if (i_layer != 0) else None,
                use_checkpoint=use_checkpoint,
                fmap_size=fmas[i_layer],
            ) for i_layer in range(self.num_layers)
        ])

        self.final_up = Final_PatchExpand2D(dim=dims_decoder[-1], dim_scale=4, norm_layer=norm_layer)
        self.final_conv = nn.Conv2d(dims_decoder[-1]//4, num_classes, 1)
        # self.norm = norm_layer(self.num_features)
        # self.avgpool = nn.AdaptiveAvgPool1d(1)
        # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.apply(self._init_weights)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward_features(self, x):
        skip_list = []
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)
        for layer in self.layers_down:
            skip_list.append(x)
            x = layer(x)
        return x, skip_list

    def forward_features_up(self, x, skip_list):
        for inx, layer_up in enumerate(self.layers_up):
            x = layer_up(x) if inx == 0 else layer_up(x+skip_list[-inx])
        return x

    def forward_final(self, x):
        x = self.final_up(x)
        x = self.final_conv(x)
        return x

    def forward_backbone(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)
        for layer in self.layers_down:
            x = layer(x)
        return x

    def forward(self, x):
        x, skip_list = self.forward_features(x)
        x = self.forward_features_up(x, skip_list)
        x = self.forward_final(x)
        return x

In [16]:
a = torch.randn(1, 3, 224, 224).cuda()

print("VMACO V1")
model_v1 = VMACO(vss_layer=VSSLayerV1).cuda()
print_parameters_and_flops(model_v1, a, inout=True)

print("\nVMACO V2")
model_v2 = VMACO(vss_layer=VSSLayerV2).cuda()
print_parameters_and_flops(model_v2, a, inout=True)

print("\nVMACO V3")
model_v3 = VMACO(vss_layer=partial(VSSLayer, vss_module=VSSModuleV3)).cuda()
print_parameters_and_flops(model_v3, a, inout=True)

VMACO V1
Input: torch.Size([1, 3, 224, 224]),	Output: torch.Size([1, 1000, 224, 224])
--------------------------------------------------------------------------------
Parameters: 54.917176 M,	FLOPs: 10.533784 G

VMACO V2
Input: torch.Size([1, 3, 224, 224]),	Output: torch.Size([1, 1000, 224, 224])
--------------------------------------------------------------------------------
Parameters: 54.917176 M,	FLOPs: 10.533784 G

VMACO V3
Input: torch.Size([1, 3, 224, 224]),	Output: torch.Size([1, 1000, 224, 224])
--------------------------------------------------------------------------------
Parameters: 118.183300 M,	FLOPs: 21.659133 G


In [17]:
# import time
# import math
# from functools import partial
# from typing import Optional, Callable

# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torch.utils.checkpoint as checkpoint
# from einops import rearrange, repeat
# from timm.models.layers import DropPath, to_2tuple, trunc_normal_
# import io
# from contextlib import redirect_stderr
# from fvcore.nn import FlopCountAnalysis


# def count_parameters(model):
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)
  
# def count_flops(model, x):
#   with redirect_stderr(io.StringIO()):
#     flops = FlopCountAnalysis(model, (x,))
#     # flops.count(ignore_modules=[torch.nn.ReLU, torch.nn.ReLU6, torch.nn.SiLU, torch.nn.Identity])
#     flops_amount = flops.total()
#   return flops_amount

# def count_parameters_and_flops(model, x):
#     params = count_parameters(model)
#     flops_amount = count_flops(model, x)
#     return params, flops_amount
  
# def print_parameters_and_flops(model, x, inout=False):
#     params, flops_amount = count_parameters_and_flops(model, x)
#     if inout:
#       output = model(x)
#       print(f"Input: {x.shape},\tOutput: {output.shape}\n{80*'-'}")
#     print(f"Parameters: {params/1e6:.6f} M,\tFLOPs: {flops_amount/1e9:.6f} G")


from models.vmaco import VMACO, VSSLayerV1, VSSLayerV2

a = torch.randn(1, 3, 224, 224).cuda()

print("VMACO V1")
model_v1 = VMACO(vss_layer=VSSLayerV1).cuda()
print_parameters_and_flops(model_v1, a, inout=True)

print("\nVMACO V2")
model_v2 = VMACO(vss_layer=VSSLayerV2).cuda()
print_parameters_and_flops(model_v2, a, inout=True)

VMACO V1


Input: torch.Size([1, 3, 224, 224]),	Output: torch.Size([1, 1000, 224, 224])
--------------------------------------------------------------------------------
Parameters: 54.917176 M,	FLOPs: 10.533784 G

VMACO V2
Input: torch.Size([1, 3, 224, 224]),	Output: torch.Size([1, 1000, 224, 224])
--------------------------------------------------------------------------------
Parameters: 54.917176 M,	FLOPs: 10.533784 G
