Reference
- [Github](https://github.com/berniwal/swin-transformer-pytorch/blob/master/swin_transformer_pytorch/swin_transformer.py)
- [Youtube](https://www.youtube.com/playlist?list=PL9iXGo3xD8jokWaLB8ZHUkjjv5Y_vPQnZ)
# To Do
- ✅ Delete Shifted Swin Transformer from Swin Block
- ✅ Dual Switch
- ✅ CNN
- ✅ Edit Function to support CNN
- Hybrid Network Backbone
- ✅ Hyneter Module
- Extract Feature from model [Pytorch | Feature extraction for model inspection](https://docs.pytorch.org/vision/main/feature_extraction.html)
- Modify Masked R-CNN backbone [TorchVision Object Detection Finetuning Tutorial](https://docs.pytorch.org/tutorials/intermediate/torchvision_tutorial.html#modifying-the-model-to-add-a-different-backbone)

## Import Library

In [41]:
import torch
from torch import nn, einsum
import numpy as np
from einops import rearrange, repeat

## Residul

In [42]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

## Pre Normalization

In [43]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

## Feed Forward

In [44]:
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, x):
        return self.net(x)


def create_mask(window_size, displacement, upper_lower, left_right):
    mask = torch.zeros(window_size ** 2, window_size ** 2)

    if upper_lower:
        mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')
        mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')

    if left_right:
        mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
        mask[:, -displacement:, :, :-displacement] = float('-inf')
        mask[:, :-displacement, :, -displacement:] = float('-inf')
        mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')

    return mask


def get_relative_distances(window_size):
    indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
    distances = indices[None, :, :] - indices[:, None, :]
    return distances

## Window Attention

In [45]:
class WindowAttention(nn.Module):
    def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
        super().__init__()
        inner_dim = head_dim * heads

        self.heads = heads
        self.scale = head_dim ** -0.5
        self.window_size = window_size
        self.relative_pos_embedding = relative_pos_embedding
        self.shifted = shifted

        if self.shifted:
            displacement = window_size // 2
            self.cyclic_shift = CyclicShift(-displacement)
            self.cyclic_back_shift = CyclicShift(displacement)
            self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                             upper_lower=True, left_right=False), requires_grad=False)
            self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                            upper_lower=False, left_right=True), requires_grad=False)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        if self.relative_pos_embedding:
            self.relative_indices = get_relative_distances(window_size) + window_size - 1
            self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))
        else:
            self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))

        self.to_out = nn.Linear(inner_dim, dim)

       
    def forward(self, x):
        if self.shifted:
            x = self.cyclic_shift(x)

        b, n_h, n_w, _, h = *x.shape, self.heads

        qkv = self.to_qkv(x).chunk(3, dim=-1)
        nw_h = n_h // self.window_size
        nw_w = n_w // self.window_size

        q, k, v = map(
            lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
                                h=h, w_h=self.window_size, w_w=self.window_size), qkv)

        dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale

        if self.relative_pos_embedding:
            dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
        else:
            dots += self.pos_embedding

        if self.shifted:
            dots[:, :, -nw_w:] += self.upper_lower_mask
            dots[:, :, nw_w - 1::nw_w] += self.left_right_mask

        attn = dots.softmax(dim=-1)

        out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
        out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
                        h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
        out = self.to_out(out)

        if self.shifted:
            out = self.cyclic_back_shift(out)
        return out


## Transformer Block

In [46]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding):
        super().__init__()
        self.attention_block = Residual(PreNorm(dim, WindowAttention(dim=dim,
                                                                     heads=heads,
                                                                     head_dim=head_dim,
                                                                     shifted=shifted,
                                                                     window_size=window_size,
                                                                     relative_pos_embedding=relative_pos_embedding)))
        self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim)))

    def forward(self, x):
        x = self.attention_block(x)
        x = self.mlp_block(x)
        return x

## Conv Layer

Branch 5 = 1x1->3x3->3x3

In [47]:
class CNN(nn.Module):
    def __init__(self, in_channels, embed_dim, stride1=1):
        super().__init__()
        padding_3x3 = 1
        padding_5x5 = 2

        self.out_channels = embed_dim//3

        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1, stride=stride1, bias=False),
            nn.BatchNorm2d(self.out_channels),
            nn.GELU()
        )

        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1), # 1x1 to potentially reduce channels
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=stride1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
        )

        # Branch 3: 1x1 Convolution followed by two 3x3 Convolutions (effective 5x5 receptive field)
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1), # 1x1 to potentially reduce channels
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1), # First 3x3
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=stride1), # Second 3x3
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3(x)
        branch5x5 = self.branch5x5(x)
        x = torch.cat((branch1x1, branch3x3, branch5x5), dim=1)

        return x

Branch 5x5 = 1x1->5x5

In [48]:
class CNN(nn.Module):
    def __init__(self, in_channels, embed_dim, stride1=1):
        super().__init__()
        padding_3x3 = 1
        padding_5x5 = 2

        self.out_channels = embed_dim//3

        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1, stride=stride1, bias=False),
            nn.BatchNorm2d(self.out_channels),
            nn.GELU()
        )

        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1), # 1x1 to potentially reduce channels
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=stride1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
        )

        # Branch 3: 1x1 Convolution followed by two 3x3 Convolutions (effective 5x5 receptive field)
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1), # 1x1 to potentially reduce channels
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=5, padding=2, stride=stride1), # Second 3x3
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3(x)
        branch5x5 = self.branch5x5(x)
        x = torch.cat((branch1x1, branch3x3, branch5x5), dim=1)

        return x

Branch 5 = 5x5

In [80]:
class CNN(nn.Module):
    def __init__(self, in_channels, embed_dim, stride1=1):
        super().__init__()
        padding_3x3 = 1
        padding_5x5 = 2

        self.out_channels = embed_dim//3

        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1, stride=stride1, bias=False),
            nn.BatchNorm2d(self.out_channels),
            nn.GELU()
        )

        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=3, padding=1, stride=stride1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
        )

        # Branch 3: 1x1 Convolution followed by two 3x3 Convolutions (effective 5x5 receptive field)
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=5, padding=2, stride=stride1), # Second 3x3
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3(x)
        branch5x5 = self.branch5x5(x)
        x = torch.cat((branch1x1, branch3x3, branch5x5), dim=1)

        return x

branch 3x3, 5x5, 7x7

In [84]:
class CNN(nn.Module):
    def __init__(self, in_channels, embed_dim, stride1=1):
        super().__init__()
        padding_3x3 = 1
        padding_5x5 = 2

        self.out_channels = embed_dim//3

        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1, stride=stride1, bias=False),
            nn.BatchNorm2d(self.out_channels),
            nn.GELU()
        )

        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=3, padding=1, stride=stride1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
        )

        # Branch 3: 1x1 Convolution followed by two 3x3 Convolutions (effective 5x5 receptive field)
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=5, padding=2, stride=stride1), # Second 3x3
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
        )

        # Branch 3: 1x1 Convolution followed by two 3x3 Convolutions (effective 5x5 receptive field)
        self.branch7x7 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=7, padding=3, stride=stride1), # Second 3x3
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
        )



    def forward(self, x):
        branch3x3 = self.branch3x3(x)
        branch5x5 = self.branch5x5(x)
        branch7x7 = self.branch7x7(x)
        x = torch.cat((branch3x3, branch5x5, branch7x7), dim=1)

        return x

## Multigranularity CNN

Branch 5x5 = 1x1->3x3->3x3

In [51]:
class MultiGranularitySummingBlock(nn.Module):
    def __init__(self, in_channels: int, embed_dim: int, stride1: int = 1):
        super().__init__()
        padding_3x3 = 1
        padding_5x5 = 2

        self.out_channels = embed_dim//3

        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1, stride=stride1, bias=False),
            nn.BatchNorm2d(self.out_channels),
            nn.GELU()
        )

        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1), # 1x1 to potentially reduce channels
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=stride1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
        )

        # Branch 3: 1x1 Convolution followed by two 3x3 Convolutions (effective 5x5 receptive field)
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1), # 1x1 to potentially reduce channels
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1), # First 3x3
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=stride1), # Second 3x3
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3(x)
        branch5x5 = self.branch5x5(x)
        x = torch.cat((branch1x1, branch3x3, branch5x5), dim=1)

        return x

Branch5 = 1x1 -> 5x5

In [52]:
class MultiGranularitySummingBlock(nn.Module):
    def __init__(self, in_channels: int, embed_dim: int, stride1: int = 1):
        super().__init__()
        padding_3x3 = 1
        padding_5x5 = 2

        self.out_channels = embed_dim//3

        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1, stride=stride1, bias=False),
            nn.BatchNorm2d(self.out_channels),
            nn.GELU()
        )

        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1), # 1x1 to potentially reduce channels
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=stride1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
        )

        # Branch 3: 1x1 Convolution followed by two 3x3 Convolutions (effective 5x5 receptive field)
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1), # 1x1 to potentially reduce channels
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=5, padding=2), # First 3x3
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3(x)
        branch5x5 = self.branch5x5(x)
        x = torch.cat((branch1x1, branch3x3, branch5x5), dim=1)

        return x

Branch 5 = 5x5

In [53]:
class MultiGranularitySummingBlock(nn.Module):
    def __init__(self, in_channels: int, embed_dim: int, stride1: int = 1):
        super().__init__()
        padding_3x3 = 1
        padding_5x5 = 2

        self.out_channels = embed_dim//3

        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1, stride=stride1, bias=False),
            nn.BatchNorm2d(self.out_channels),
            nn.GELU()
        )

        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=3, stride=stride1, padding=padding_3x3, bias=False),
            nn.BatchNorm2d(self.out_channels),
            nn.GELU()
        )

        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=5, stride=stride1, padding=padding_5x5, bias=False),
            nn.BatchNorm2d(self.out_channels),
            nn.GELU()
        )

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3(x)
        branch5x5 = self.branch5x5(x)
        x = torch.cat((branch1x1, branch3x3, branch5x5), dim=1)

        return x

Conv = 3x3, 5x5, 7x7

In [54]:
class MultiGranularitySummingBlock(nn.Module):
    def __init__(self, in_channels, embed_dim, stride1=1):
        super().__init__()
        padding_3x3 = 1
        padding_5x5 = 2

        self.out_channels = embed_dim//3

        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1, stride=stride1, bias=False),
            nn.BatchNorm2d(self.out_channels),
            nn.GELU()
        )

        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1), # 1x1 to potentially reduce channels
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=stride1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
        )

        # Branch 3: 1x1 Convolution followed by two 3x3 Convolutions (effective 5x5 receptive field)
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1), # 1x1 to potentially reduce channels
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1), # First 3x3
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=stride1), # Second 3x3
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
        )

        # Branch 3: 1x1 Convolution followed by two 3x3 Convolutions (effective 5x5 receptive field)
        self.branch7x7 = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1), # 1x1 to potentially reduce channels
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1), # First 3x3
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=stride1), # Second 3x3
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=stride1), # Second 3x3
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
        )



    def forward(self, x):
        branch3x3 = self.branch3x3(x)
        branch5x5 = self.branch5x5(x)
        branch7x7 = self.branch7x7(x)
        x = torch.cat((branch3x3, branch5x5, branch7x7), dim=1)

        return x

## Patch Merging

In [55]:
class PatchMerging(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()
        self.downscaling_factor = downscaling_factor
        self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)
        self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)

    def forward(self, x):
        b, c, h, w = x.shape
        print(b,c,h,w)
        new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
        x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
        x = self.linear(x)
        return x

# Dual Switch Just swap

In [56]:
import torch
import torch.nn as nn

class DualSwitch_SwapOnly(nn.Module):
    def __init__(self):
        super(DualSwitch_SwapOnly, self).__init__()

    def _switch_adjacent(self, input_tensor: torch.Tensor, dim: int) -> torch.Tensor:
        """
        Switches adjacent elements along a specified dimension.
        If the dimension size is odd, the last element remains untouched.
        Returns a new tensor.
        """
        size = input_tensor.shape[dim]
        output_tensor = input_tensor.clone() # Start with a copy of the input

        # Determine the largest even size that can be fully swapped
        swappable_size = (size // 2) * 2

        if swappable_size > 0:
            # Create slices for even and odd indices within the swappable part
            slices_even_part = [slice(None)] * input_tensor.ndim
            slices_odd_part = [slice(None)] * input_tensor.ndim
            
            slices_even_part[dim] = slice(0, swappable_size, 2)  # 0, 2, 4, ...
            slices_odd_part[dim] = slice(1, swappable_size + 1, 2) # 1, 3, 5, ...

            # Perform the swap on the output_tensor
            output_tensor[slices_even_part] = input_tensor[slices_odd_part]
            output_tensor[slices_odd_part] = input_tensor[slices_even_part]
            
        return output_tensor

    def _switch_interlaced(self, input_tensor: torch.Tensor, dim: int) -> torch.Tensor:
        """
        Switches interlaced blocks of 2 elements along a specified dimension.
        If the dimension size is not a multiple of 4, the trailing elements remain untouched.
        Returns a new tensor.
        """
        size = input_tensor.shape[dim]
        
        indices = torch.arange(size, device=input_tensor.device)
        new_indices = indices.clone() # Initialize with identity permutation

        # Determine the largest size that is a multiple of 4 and can be fully swapped
        swappable_size = (size // 4) * 4

        for i in range(0, swappable_size, 4):
            # Swap blocks of 2: [i, i+1] goes to [i+2, i+3] positions
            new_indices[i:i+2] = indices[i+2:i+4]
            # And [i+2, i+3] goes to [i, i+1] positions
            new_indices[i+2:i+4] = indices[i:i+2]
        
        # Apply the permutation using advanced indexing, which creates a new tensor
        all_slices = [slice(None)] * input_tensor.ndim
        all_slices[dim] = new_indices # Apply the reordered indices to the specified dimension
        return input_tensor[all_slices]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Performs a sequence of column and row switching operations on the feature map.

        Following the convention:
        - Columns refer to the Height (H) dimension (dim=2).
        - Rows refer to the Width (W) dimension (dim=3).
        
        The operations are:
        1. Adjacent column switching (on H).
        2. Adjacent row switching (on W).
        3. Interlaced column switching (on H, swaps blocks of 2).
        4. Interlaced row switching (on W, swaps blocks of 2).

        Args:
            x (torch.Tensor): Input feature map of shape (B, C, H, W).
                              No internal dimension checks are performed;
                              trailing elements in odd/non-multiple-of-4 dimensions
                              will be left untouched by the respective operations.

        Returns:
            torch.Tensor: The feature map after all switching operations.
                          The output shape is identical to the input shape.
        """

        # Step 1: Adjacent columns switch (Columns is H, so dim=2)
        x = self._switch_adjacent(x, dim=2)
        
        # Step 2: Adjacent rows switch (Rows is W, so dim=3)
        x = self._switch_adjacent(x, dim=3)
        
        # Step 3: Interlaced columns switch (Columns is H, so dim=2)
        x = self._switch_interlaced(x, dim=2)

        # Step 4: Interlaced rows switch (Rows is W, so dim=3)
        x = self._switch_interlaced(x, dim=3)
        
        return x

# Dual Switch

In [57]:
class DualSwitching(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()
        self.downscaling_factor = downscaling_factor
        self.multi_granularity_summing_block = MultiGranularitySummingBlock(in_channels, out_channels)

    def forward(self, x):
        b, c, h, w = x.shape
        print(b,c,h,w)
        new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
        x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)

# -------------Dual Switching---------------------
        # Swap Adjacent Column
        print("\n--- Performing Adjacent Column Swap (0<->1, 2<->3, ...) ---")
        if new_w < 2:
            print("Not enough columns for adjacent column swap. Skipping.")
        else:
            x_adj_col_swapped = torch.empty_like(x)
            for j in range(0, new_w - 1, 2):
                x_adj_col_swapped[:, :, j, :] = x[:, :, j+1, :]
                x_adj_col_swapped[:, :, j+1, :] = x[:, :, j, :]
            if new_w % 2 != 0:
                x_adj_col_swapped[:, :, new_w - 1, :] = x[:, :, new_w - 1, :]
            x = x_adj_col_swapped
        print(f"Shape after adjacent column swap: {x.shape}")

        # Swap Ajacent Rows
        print("\n--- Performing Adjacent Row Swap (0<->1, 2<->3, ...) ---")
        if new_h < 2:
            print("Not enough rows for adjacent row swap. Skipping.")
        else:
            x_adj_row_swapped = torch.empty_like(x)
            for i in range(0, new_h - 1, 2):
                x_adj_row_swapped[:, i, :, :] = x[:, i+1, :, :]
                x_adj_row_swapped[:, i+1, :, :] = x[:, i, :, :]
            if new_h % 2 != 0:
                x_adj_row_swapped[:, new_h - 1, :, :] = x[:, new_h - 1, :, :]
            x = x_adj_row_swapped
        print(f"Shape after adjacent row swap: {x.shape}")

        #Swap Interlaced Column
        print("\n--- Performing Interlaced Column Swap (1<->2, 3<->4, ...) ---")

        if new_w < 3:
            print("Not enough columns for interlaced column swap (need at least 3). Skipping.")
        else:
            x_int_col_swapped = torch.empty_like(x)

            x_int_col_swapped[:, :, 0, :] = x[:, :, 0, :].clone()

            for j_start in range(1, new_w - 1, 2):
                x_int_col_swapped[:, :, j_start, :] = x[:, :, j_start + 1, :]
                x_int_col_swapped[:, :, j_start + 1, :] = x[:, :, j_start, :]

            x_int_col_swapped = x.clone()

            for j_start in range(1, new_w - 1, 2):
                temp_col_j = x[:, :, j_start, :].clone() #
                x_int_col_swapped[:, :, j_start, :] = x[:, :, j_start + 1, :]
                x_int_col_swapped[:, :, j_start + 1, :] = temp_col_j
            x = x_int_col_swapped
        print(f"Shape after interlaced column swap: {x.shape}")

        #Swap Interlaced Rows
        print("\n--- Performing Interlaced Row Swap (1<->2, 3<->4, ...) ---")
        if new_h < 3: # Need at least 3 rows (indices 0, 1, 2)
            print("Not enough rows for interlaced row swap (need at least 3). Skipping.")
        else:
            x_int_row_swapped = x.clone() # Start with the current state of x

            for i_start in range(1, new_h - 1, 2): # Loop for 1, 3, 5, ...
                temp_row_i = x[:, i_start, :, :].clone() # Backup original row i_start
                x_int_row_swapped[:, i_start, :, :] = x[:, i_start + 1, :, :]
                x_int_row_swapped[:, i_start + 1, :, :] = temp_row_i
            x = x_int_row_swapped
        print(f"Shape after interlaced row swap: {x.shape}")


# -----------------End Dual Switching-----------

        x = self.linear(x)
        print(f"Final output shape: {x.shape}")
        return x

# Module

## Stage Module (del module 2)

In [58]:
class StageModule(nn.Module):
    def __init__(self, in_channels, hidden_dimension, TB_layers, downscaling_factor, num_heads, head_dim, window_size,
                 relative_pos_embedding):
        super().__init__()

        self.patch_partition = PatchMerging(in_channels=in_channels, out_channels=hidden_dimension,
                                            downscaling_factor=downscaling_factor)

        self.layers = nn.ModuleList([])


        for _ in range(TB_layers):
            self.layers.append(
                TransformerBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
               )

    def forward(self, x):
        x = self.patch_partition(x)
        
        for regular_block in self.layers:
            x = regular_block(x)
        return x.permute(0, 3, 1, 2)

## Hyneter Module (del module 2)

### Hyneter_no_CNN

In [59]:
class HyneterModule_noCNN(nn.Module):
    def __init__(self, in_channels, hidden_dimension, Conv_layers,TB_layers, downscaling_factor, num_heads, head_dim, window_size,
                 relative_pos_embedding):
        super().__init__()
        self.patch_partition = PatchMerging(in_channels=in_channels, out_channels=hidden_dimension,
                                            downscaling_factor=downscaling_factor)



        self.TB_layers = nn.ModuleList([])
        for _ in range(TB_layers):
            self.TB_layers.append(
                TransformerBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
               )

    def forward(self, x):
        print("HyneterModule_noHNB forward pass")

        print(f"Input shape before Patch Partition: {x.shape}")
        x = self.patch_partition(x)
        print(f"Input shape after Patch Partition: {x.shape}")
        

        print(f"Input shape after Conv layers: {x.shape}")
        for block in self.TB_layers:
            x = block(x)
        print(f"Input shape after Transformer Blocks: {x.shape}")
        return x.permute(0, 3, 1, 2)

### HyneterModule_noHNB

In [60]:
class HyneterModule_noHNB(nn.Module):
    def __init__(self, in_channels, hidden_dimension, Conv_layers,TB_layers, downscaling_factor, num_heads, head_dim, window_size,
                 relative_pos_embedding):
        super().__init__()
        self.patch_partition = PatchMerging(in_channels=in_channels, out_channels=hidden_dimension,
                                            downscaling_factor=downscaling_factor)

        self.Conv_layers = nn.ModuleList([])
        for _ in range(Conv_layers):
            self.Conv_layers.append(
                CNN(in_channels=in_channels,embed_dim=hidden_dimension)
            )



        self.TB_layers = nn.ModuleList([])
        for _ in range(TB_layers):
            self.TB_layers.append(
                TransformerBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
               )

    def forward(self, x):
        print("HyneterModule_noHNB forward pass")

        print(f"Input shape before Patch Partition: {x.shape}")
        x = self.patch_partition(x)
        print(f"Input shape after Patch Partition: {x.shape}")
        
        x = x.permute(0, 3, 1, 2)
        print("CNN input shape:", x.shape)
        for block in self.Conv_layers:
            x = block(x)
        print(f"CNN Output shape: {x.shape}")
        x = x.permute(0, 2, 3, 1)  # Change to (batch_size, height, width, channels)


        print(f"Input shape after Conv layers: {x.shape}")
        for block in self.TB_layers:
            x = block(x)
        print(f"Input shape after Transformer Blocks: {x.shape}")
        return x.permute(0, 3, 1, 2)

### 1B

In [None]:
class HyneterModule(nn.Module):
    def __init__(self, in_channels, hidden_dimension, Conv_layers,TB_layers, downscaling_factor, num_heads, head_dim, window_size,
                 relative_pos_embedding):
        super().__init__()
        self.mg = MultiGranularitySummingBlock(in_channels=in_channels, embed_dim=hidden_dimension, stride1=1)

        self.Conv_layers = nn.ModuleList([])
        for _ in range(Conv_layers):
            self.Conv_layers.append(
                CNN(in_channels=hidden_dimension,embed_dim=hidden_dimension)
            )

        

        self.TB_layers = nn.ModuleList([])
        for _ in range(TB_layers):
            self.TB_layers.append(
                TransformerBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
               )          

    def forward(self, x):
        print("\nHyneterModule: HNB: Patch -> Conv -> TB")
        print("HyneterModule forward pass")
        
        
        print(f"\nInput shape before Multigranularity CNN: {x.shape}") # Batch size, Channels, Height, Width
        x = self.mg(x)
        print(f"Input shape after Multigranularity CNN: {x.shape}") # Batch size, Channels, Height, Width

        x_conv_path = x
        x_tb_path = x

        x_tb_path = x_tb_path.permute(0, 2, 3, 1) # Change to (batch_size, Height, Width, Channels)

        print("X(CNN) shape before Conv layers:", x_conv_path.shape) # Batch size, Channels, Height, Width
        for block in self.Conv_layers:
            x_conv_path = block(x_conv_path)
        print(f"X(CNN) shape after Conv layers: {x_conv_path.shape}") # Batch size, Channels, Height, Width


        print("X(TB) shape before Transformer Blocks:", x_tb_path.shape) # Batch size, Height, Width, Channels
        for block in self.TB_layers:
            x_tb_path = block(x_tb_path)
        print(f"Input shape after Transformer Blocks: {x_tb_path.shape}") # Batch size, Height, Width, Channels
        x_tb_path = x_tb_path.permute(0, 3, 1, 2) # Change to (batch_size, channels, height, width)


        print("########### going to calculate Z ###########")
        print("X(CNN) shape:", x_conv_path.shape) # B, C, H, W
        print("X(TB) shape:", x_tb_path.shape) # B, C, H, W


        Z = x_conv_path * x_tb_path
        print(f"Z shape", Z.shape) # B, C, H, W
        print("Z before tanh:", Z)
        Z = torch.tanh(Z)
        print("Z after tanh:", Z)
        print("Z is tanh(Dot product between x and S):", Z.shape) # B, C, H, W

        output = x_tb_path+Z
        print("output shape after adding Z:", x.shape) # B, C, H, W
        print("output value:", output)
        return output

In [None]:
class HyneterModule_DualSwitch(nn.Module):
    def __init__(self, in_channels, hidden_dimension, Conv_layers,TB_layers, downscaling_factor, num_heads, head_dim, window_size,
                 relative_pos_embedding):
        super().__init__()
        self.mg = MultiGranularitySummingBlock(in_channels=in_channels, embed_dim=hidden_dimension, stride1=1)

        self.Conv_layers = nn.ModuleList([])
        for _ in range(Conv_layers):
            self.Conv_layers.append(
                CNN(in_channels=hidden_dimension,embed_dim=hidden_dimension)
            )


        self.DualSwitching = DualSwitch_SwapOnly()

        self.TB_layers = nn.ModuleList([])
        for _ in range(TB_layers):
            self.TB_layers.append(
                TransformerBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
               )
            

    def forward(self, x):
        print("\nHyneterModule: HNB: Patch -> Conv -> TB")
        print("HyneterModule forward pass")
        
        
        print(f"\nInput shape before Multigranularity CNN: {x.shape}") # Batch size, Channels, Height, Width
        x = self.mg(x)
        print(f"Input shape after Multigranularity CNN: {x.shape}") # Batch size, Channels, Height, Width

        x_conv_path = x
        x_tb_path = x


        print("X(CNN) shape before Conv layers:", x_conv_path.shape) # Batch size, Channels, Height, Width
        for block in self.Conv_layers:
            x_conv_path = block(x_conv_path)
        print(f"X(CNN) shape after Conv layers: {x_conv_path.shape}") # Batch size, Channels, Height, Width
    
        x_tb_path = self.DualSwitching(x_tb_path)

        x_tb_path = x_tb_path.permute(0, 2, 3, 1) # Change to (batch_size, Height, Width, Channels)
        print("X(TB) shape before Transformer Blocks:", x_tb_path.shape) # Batch size, Height, Width, Channels
        for block in self.TB_layers:
            x_tb_path = block(x_tb_path)
        print(f"Input shape after Transformer Blocks: {x_tb_path.shape}") # Batch size, Height, Width, Channels
        x_tb_path = x_tb_path.permute(0, 3, 1, 2) # Change to (batch_size, channels, height, width)


        print("########### going to calculate Z ###########")
        print("X(CNN) shape:", x_conv_path.shape) # B, C, H, W
        print("X(TB) shape:", x_tb_path.shape) # B, C, H, W


        Z = x_conv_path * x_tb_path
        print(f"Z shape", Z.shape) # B, C, H, W
        print("Z before tanh:", Z)
        Z = torch.tanh(Z)
        print("Z after tanh:", Z)
        print("Z is tanh(Dot product between x and S):", Z.shape) # B, C, H, W

        output = x_tb_path+Z
        print("output shape after adding Z:", x.shape) # B, C, H, W
        print("output value:", output)
        return output

## Stage Module (Dual Switch)

In [63]:
class StageModule(nn.Module):
    def __init__(self, in_channels, hidden_dimension, TB_layers, downscaling_factor, num_heads, head_dim, window_size,
                 relative_pos_embedding):
        super().__init__()
        assert TB_layers % 2 == 0, 'Stage TB_layers need to be divisible by 2 for regular and shifted block.'

        self.patch_partition = DualSwitching(in_channels=in_channels, out_channels=hidden_dimension,
                                            downscaling_factor=downscaling_factor)

        self.layers = nn.ModuleList([])


        for _ in range(TB_layers):
            self.layers.append(
                TransformerBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
               )

    def forward(self, x):
        x = self.patch_partition(x)
        
        for regular_block in self.layers:
            x = regular_block(x)
        return x.permute(0, 3, 1, 2)

## Swin Transformer (No FPN)

In [64]:
class SwinTransformer(nn.Module):
    def __init__(self, *, hidden_dim, TB_layers, heads, channels=3, num_classes=1000, head_dim=32, window_size=7,
                 downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True):
        super().__init__()

        self.out_channels = hidden_dim * 8


        self.stage1 = StageModule(in_channels=channels, hidden_dimension=hidden_dim, TB_layers=TB_layers[0],
                                  downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage2 = StageModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, TB_layers=TB_layers[1],
                                  downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage3 = StageModule(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, TB_layers=TB_layers[2],
                                  downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage4 = StageModule(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, TB_layers=TB_layers[3],
                                  downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim * 8),
            nn.Linear(hidden_dim * 8, num_classes)
        )

    def forward(self, img):
        x = self.stage1(img)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = x.mean(dim=[2, 3])
        return self.mlp_head(x)


def swin_t(hidden_dim=96, TB_layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, TB_layers=TB_layers, heads=heads, **kwargs)


def swin_s(hidden_dim=96, TB_layers=(2, 2, 18, 2), heads=(3, 6, 12, 24), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, TB_layers=TB_layers, heads=heads, **kwargs)


def swin_b(hidden_dim=128, TB_layers=(2, 2, 18, 2), heads=(4, 8, 16, 32), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, TB_layers=TB_layers, heads=heads, **kwargs)


def swin_l(hidden_dim=192, TB_layers=(2, 2, 18, 2), heads=(6, 12, 24, 48), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, TB_layers=TB_layers, heads=heads, **kwargs)



# Swin for FPN

## Module

In [65]:
import torch
import torch.nn as nn
from torchvision.models.detection import MaskRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.transforms import functional as F
from collections import OrderedDict


class SwinForFPN(nn.Module):
    def __init__(self, *, hidden_dim, TB_layers, heads, channels=3, head_dim=32, window_size=7,
                 downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True):
        super().__init__()


        self.out_channels = hidden_dim * 8

        self.stage1 = StageModule(in_channels=channels, hidden_dimension=hidden_dim, TB_layers=TB_layers[0],
                                  downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage2 = StageModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, TB_layers=TB_layers[1],
                                  downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage3 = StageModule(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, TB_layers=TB_layers[2],
                                  downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage4 = StageModule(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, TB_layers=TB_layers[3],
                                  downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)



    def forward(self, img):
        # c2 = self.stage1(img)
        # c3 = self.stage2(x)
        # c4 = self.stage3(x)
        # c5 = self.stage4(x)
        # # Return the feature maps for FPN
        # return [c2, c3, c4, c5]

        # for Use with Mask R-CNN from torchvision.models.detection
        out = OrderedDict()
        print(f"Input image shape: {img.shape}")
        c2 = self.stage1(img)
        out['0'] = c2
        print(f"Output of stage1 (c2) shape: {c2.shape}")
        c3 = self.stage2(c2)
        out['1'] = c3
        print(f"Output of stage2 (c3) shape: {c3.shape}")
        c4 = self.stage3(c3)
        out['2'] = c4
        print(f"Output of stage3 (c4) shape: {c4.shape}")
        c5 = self.stage4(c4)
        out['3'] = c5
        print(f"Output of stage4 (c5) shape: {c5.shape}")
        # Return the feature maps for FPN
        return out
        


def swin_t_fpn(hidden_dim=96, TB_layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
    return SwinForFPN(hidden_dim=hidden_dim, TB_layers=TB_layers, heads=heads, **kwargs)

def swin_s_fpn(hidden_dim=96, TB_layers=(2, 2, 18, 2), heads=(3, 6, 12, 24), **kwargs):
    return SwinForFPN(hidden_dim=hidden_dim, TB_layers=TB_layers, heads=heads, **kwargs)

def swin_b_fpn(hidden_dim=128, TB_layers=(2, 2, 18, 2), heads=(4, 8, 16, 32), **kwargs):
    return SwinForFPN(hidden_dim=hidden_dim, TB_layers=TB_layers, heads=heads, **kwargs)

def swin_l_fpn(hidden_dim=192, TB_layers=(2, 2, 18, 2), heads=(6, 12, 24, 48), **kwargs):
    return SwinForFPN(hidden_dim=hidden_dim, TB_layers=TB_layers, heads=heads, **kwargs)



## Swin + FPN

In [66]:
from torchvision.ops import FeaturePyramidNetwork, MultiScaleRoIAlign


class SwinFPNBackbone(nn.Module):
    def __init__(self, backbone, hidden_dim=96,):
        super().__init__()
        self.swin_backbone = backbone
        
        self.fpn = FeaturePyramidNetwork(
            in_channels_list=[hidden_dim, hidden_dim * 2, hidden_dim * 4, hidden_dim * 8],
            out_channels=256,
        )

        self.out_channels = 256

    def forward(self, x):
        features = self.swin_backbone(x)

        fpn_features = self.fpn(features)

        return fpn_features

## Mask R-CNN + Swin Transformer

In [67]:
from torchvision.ops import MultiScaleRoIAlign


def get_swin_mask_rcnn_model(num_class=91):
    backbone = SwinFPNBackbone(swin_t_fpn())

    fpn_output_keys = ['0', '1', '2', '3']  # Corresponds to P2, P3, P4, P5 in FPN
    NUM_FPN_OUTPUT_LEVELS = len(fpn_output_keys)

    anchor_sizes = (
        (32,),
        (64,),
        (128,),
        (256,),
        (512,)
        )
    aspect_ratios = ((0.5, 1.0, 2.0),) * NUM_FPN_OUTPUT_LEVELS
    anchor_generator = AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)


    # ROI Poolers must use the FPN output channels (out_channels, which is 256)
    # featmap_names should match the FPN's output names (0, 1, 2, 3 for P2, P3, P4, P5)
    box_roi_pool = MultiScaleRoIAlign(featmap_names=fpn_output_keys, output_size=7, sampling_ratio=2)
    mask_roi_pool = MultiScaleRoIAlign(featmap_names=fpn_output_keys, output_size=14, sampling_ratio=2)


    model = MaskRCNN(
        backbone=backbone,
        num_classes=num_class,
        rpn_anchor_generator=anchor_generator,
        box_roi_pool=box_roi_pool,
        mask_roi_pool=mask_roi_pool,
        min_size=224,
        max_size=224
    )


    return model

# Hyneter

In [68]:
class Hyneter(nn.Module):
    def __init__(self, *, hidden_dim, Conv_layers, TB_layers, heads, channels=3, num_classes=1000, head_dim=32, window_size=7,
                 downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True):
        super().__init__()

        self.out_channels = hidden_dim * 8

        self.stage1 = HyneterModule(in_channels=channels, hidden_dimension=hidden_dim, Conv_layers=Conv_layers[0], TB_layers=TB_layers[0],
                                  downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage2 = HyneterModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, Conv_layers=Conv_layers[1], TB_layers=TB_layers[1],
                                  downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage3 = HyneterModule_DualSwitch(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, Conv_layers=Conv_layers[2], TB_layers=TB_layers[2],
                                  downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage4 = HyneterModule_DualSwitch(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, Conv_layers=Conv_layers[3], TB_layers=TB_layers[3],
                                  downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)


        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim * 8),
            nn.Linear(hidden_dim * 8, num_classes)
        )

    def forward(self, img):
        x = self.stage1(img)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = x.mean(dim=[2, 3])
        return self.mlp_head(x)



def hyneter_base(hidden_dim=96, Conv_layers=(2, 2, 2, 2), TB_layers=(2, 2, 2, 2), heads=(4, 8, 16, 32), **kwargs):
    return Hyneter(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)


def hyneter_plus(hidden_dim=96, Conv_layers=(2, 2, 3, 2), TB_layers=(2, 2, 6, 2), heads=(4, 8, 16, 32), **kwargs):
    return Hyneter(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)


def hyneter_max(hidden_dim=96, Conv_layers=(2, 2, 6, 2), TB_layers=(2, 2, 18, 2), heads=(4, 8, 16, 32), **kwargs):
    return Hyneter(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)



# Hyneter with FPN

## Module

### Hyneter No CNN

In [69]:
import torch
import torch.nn as nn
from torchvision.models.detection import MaskRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.transforms import functional as F
from collections import OrderedDict


class HyneterForFPN(nn.Module):
    def __init__(self, *, hidden_dim, Conv_layers, TB_layers, heads, channels=3, num_classes=1000, head_dim=32, window_size=7,
                 downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True):
        super().__init__()


        self.out_channels = hidden_dim * 8

        self.stage1 = HyneterModule_noCNN(in_channels=channels, hidden_dimension=hidden_dim, Conv_layers=Conv_layers[0], TB_layers=TB_layers[0],
                                  downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage2 = HyneterModule_noCNN(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, Conv_layers=Conv_layers[1], TB_layers=TB_layers[1],
                                  downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage3 = HyneterModule_noCNN(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, Conv_layers=Conv_layers[2], TB_layers=TB_layers[2],
                                  downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage4 = HyneterModule_noCNN(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, Conv_layers=Conv_layers[3], TB_layers=TB_layers[3],
                                  downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim * 8),
            nn.Linear(hidden_dim * 8, num_classes)
        )



    def forward(self, img):
 
        out = OrderedDict()
        print(f"Input image shape: {img.shape}")
        c2 = self.stage1(img)
        out['0'] = c2
        print(f"#Output of stage1 (c2) shape: {c2.shape}")
        c3 = self.stage2(c2)
        out['1'] = c3
        print(f"#Output of stage2 (c3) shape: {c3.shape}")
        c4 = self.stage3(c3)
        out['2'] = c4
        print(f"#Output of stage3 (c4) shape: {c4.shape}")
        c5 = self.stage4(c4)
        out['3'] = c5
        print(f"#Output of stage4 (c5) shape: {c5.shape}")
        # Return the feature maps for FPN
        return out
        


def hyneter_base_fpn(hidden_dim=96, Conv_layers=(2, 2, 2, 2), TB_layers=(2, 2, 2, 2), heads=(3, 6, 12, 24), **kwargs):
    return HyneterForFPN(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)


def hyneter_plus_fpn(hidden_dim=96, Conv_layers=(2, 2, 3, 2), TB_layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
    return HyneterForFPN(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)


def hyneter_max_fpn(hidden_dim=96, Conv_layers=(2, 2, 6, 2), TB_layers=(2, 2, 18, 2), heads=(3, 6, 12, 24), **kwargs):
    return HyneterForFPN(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)

def hyneter_swin_size_fpn(hidden_dim=128, Conv_layers=(2, 2, 6, 2), TB_layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
    return HyneterForFPN(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)


### Hyneter 

In [70]:
import torch
import torch.nn as nn
from torchvision.models.detection import MaskRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.transforms import functional as F
from collections import OrderedDict


class HyneterForFPN(nn.Module):
    def __init__(self, *, hidden_dim, Conv_layers, TB_layers, heads, channels=3, num_classes=1000, head_dim=32, window_size=7,
                 downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True):
        super().__init__()


        self.out_channels = hidden_dim * 8

        self.stage1 = HyneterModule_noHNB(in_channels=channels, hidden_dimension=hidden_dim, Conv_layers=Conv_layers[0], TB_layers=TB_layers[0],
                                  downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage2 = HyneterModule_noHNB(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, Conv_layers=Conv_layers[1], TB_layers=TB_layers[1],
                                  downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage3 = HyneterModule_noHNB(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, Conv_layers=Conv_layers[2], TB_layers=TB_layers[2],
                                  downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage4 = HyneterModule_noHNB(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, Conv_layers=Conv_layers[3], TB_layers=TB_layers[3],
                                  downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)



    def forward(self, img):
 
        out = OrderedDict()
        print(f"Input image shape: {img.shape}")
        c2 = self.stage1(img)
        out['0'] = c2
        print(f"#Output of stage1 (c2) shape: {c2.shape}")
        c3 = self.stage2(c2)
        out['1'] = c3
        print(f"#Output of stage2 (c3) shape: {c3.shape}")
        c4 = self.stage3(c3)
        out['2'] = c4
        print(f"#Output of stage3 (c4) shape: {c4.shape}")
        c5 = self.stage4(c4)
        out['3'] = c5
        print(f"#Output of stage4 (c5) shape: {c5.shape}")
        # Return the feature maps for FPN
        return out
        


def hyneter_base_fpn(hidden_dim=96, Conv_layers=(2, 2, 2, 2), TB_layers=(2, 2, 2, 2), heads=(3, 6, 12, 24), **kwargs):
    return HyneterForFPN(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)


def hyneter_plus_fpn(hidden_dim=96, Conv_layers=(2, 2, 3, 2), TB_layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
    return HyneterForFPN(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)


def hyneter_max_fpn(hidden_dim=96, Conv_layers=(2, 2, 6, 2), TB_layers=(2, 2, 18, 2), heads=(3, 6, 12, 24), **kwargs):
    return HyneterForFPN(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)



### Hyneter with HNB

In [71]:
import torch
import torch.nn as nn
from torchvision.models.detection import MaskRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.transforms import functional as F
from collections import OrderedDict


class HyneterForFPN(nn.Module):
    def __init__(self, *, hidden_dim, Conv_layers, TB_layers, heads, channels=3, num_classes=1000, head_dim=32, window_size=7,
                 downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True):
        super().__init__()


        self.out_channels = hidden_dim * 8

        self.stage1 = HyneterModule(in_channels=channels, hidden_dimension=hidden_dim, Conv_layers=Conv_layers[0], TB_layers=TB_layers[0],
                                  downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage2 = HyneterModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, Conv_layers=Conv_layers[1], TB_layers=TB_layers[1],
                                  downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage3 = HyneterModule(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, Conv_layers=Conv_layers[2], TB_layers=TB_layers[2],
                                  downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage4 = HyneterModule(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, Conv_layers=Conv_layers[3], TB_layers=TB_layers[3],
                                  downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)



    def forward(self, img):
 
        out = OrderedDict()
        print(f"#Input image shape: {img.shape}")
        c2 = self.stage1(img)
        out['0'] = c2
        print(f"#Output of stage1 (c2) shape: {c2.shape}")
        c3 = self.stage2(c2)
        out['1'] = c3
        print(f"#Output of stage2 (c3) shape: {c3.shape}")
        c4 = self.stage3(c3)
        out['2'] = c4
        print(f"#Output of stage3 (c4) shape: {c4.shape}")
        c5 = self.stage4(c4)
        out['3'] = c5
        print(f"#Output of stage4 (c5) shape: {c5.shape}")
        # Return the feature maps for FPN
        return out
        


def hyneter_base_fpn(hidden_dim=96, Conv_layers=(2, 2, 2, 2), TB_layers=(2, 2, 2, 2), heads=(3, 6, 12, 24), **kwargs):
    return HyneterForFPN(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)


def hyneter_plus_fpn(hidden_dim=96, Conv_layers=(2, 2, 3, 2), TB_layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
    return HyneterForFPN(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)


def hyneter_max_fpn(hidden_dim=128, Conv_layers=(2, 2, 6, 2), TB_layers=(2, 2, 18, 2), heads=(3, 6, 12, 24), **kwargs):
    return HyneterForFPN(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)


### Hyneter Classification

In [72]:
import torch
import torch.nn as nn
from torchvision.models.detection import MaskRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.transforms import functional as F
from collections import OrderedDict


class HyneterForFPN(nn.Module):
    def __init__(self, *, hidden_dim, Conv_layers, TB_layers, heads, channels=3, num_classes=1000, head_dim=32, window_size=7,
                 downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True):
        super().__init__()


        self.out_channels = hidden_dim * 8

        self.stage1 = HyneterModule(in_channels=channels, hidden_dimension=hidden_dim, Conv_layers=Conv_layers[0], TB_layers=TB_layers[0],
                                  downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage2 = HyneterModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, Conv_layers=Conv_layers[1], TB_layers=TB_layers[1],
                                  downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage3 = HyneterModule_DualSwitch(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, Conv_layers=Conv_layers[2], TB_layers=TB_layers[2],
                                  downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage4 = HyneterModule_DualSwitch(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, Conv_layers=Conv_layers[3], TB_layers=TB_layers[3],
                                  downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim * 8),
            nn.Linear(hidden_dim * 8, num_classes)
        )


    def forward(self, img):
        x = self.stage1(img)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = x.mean(dim=[2, 3])
        return self.mlp_head(x)
        


def hyneter_base(hidden_dim=96, Conv_layers=(2, 2, 2, 2), TB_layers=(2, 2, 2, 2), heads=(3, 6, 12, 24), **kwargs):
    return Hyneter(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)


def hyneter_plus_fpn(hidden_dim=96, Conv_layers=(2, 2, 3, 2), TB_layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
    return Hyneter(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)


def hyneter_max_fpn(hidden_dim=128, Conv_layers=(2, 2, 6, 2), TB_layers=(2, 2, 18, 2), heads=(4, 8, 16, 32), **kwargs):
    return Hyneter(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)


### Hyneter With HNB + DS

In [73]:
import torch
import torch.nn as nn
from torchvision.models.detection import MaskRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.transforms import functional as F
from collections import OrderedDict


class HyneterForFPN(nn.Module):
    def __init__(self, *, hidden_dim, Conv_layers, TB_layers, heads, channels=3, num_classes=1000, head_dim=32, window_size=7,
                 downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True):
        super().__init__()


        self.out_channels = hidden_dim * 8

        self.stage1 = HyneterModule(in_channels=channels, hidden_dimension=hidden_dim, Conv_layers=Conv_layers[0], TB_layers=TB_layers[0],
                                  downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage2 = HyneterModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, Conv_layers=Conv_layers[1], TB_layers=TB_layers[1],
                                  downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage3 = HyneterModule_DualSwitch(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, Conv_layers=Conv_layers[2], TB_layers=TB_layers[2],
                                  downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage4 = HyneterModule_DualSwitch(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, Conv_layers=Conv_layers[3], TB_layers=TB_layers[3],
                                  downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim * 8),
            nn.Linear(hidden_dim * 8, num_classes)
        )


    def forward(self, img):
 
        out = OrderedDict()
        print("HyneterForFPN forward pass")
        print(f"#Input image shape: {img.shape}")
        c2 = self.stage1(img)
        out['0'] = c2
        print(f"#Output of stage1 (c2) shape: {c2.shape}")
        c3 = self.stage2(c2)
        out['1'] = c3
        print(f"#Output of stage2 (c3) shape: {c3.shape}")
        c4 = self.stage3(c3)
        out['2'] = c4
        print(f"#Output of stage3 (c4) shape: {c4.shape}")
        c5 = self.stage4(c4)
        out['3'] = c5
        print(f"#Output of stage4 (c5) shape: {c5.shape}")
        # Return the feature maps for FPN
        return out
        


def hyneter_base_fpn(hidden_dim=96, Conv_layers=(2, 2, 2, 2), TB_layers=(2, 2, 2, 2), heads=(3, 6, 12, 24), **kwargs):
    return HyneterForFPN(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)


def hyneter_plus_fpn(hidden_dim=96, Conv_layers=(2, 2, 3, 2), TB_layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
    return HyneterForFPN(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)


def hyneter_max_fpn(hidden_dim=128, Conv_layers=(2, 2, 6, 2), TB_layers=(2, 2, 18, 2), heads=(4, 8, 16, 32), **kwargs):
    return HyneterForFPN(hidden_dim=hidden_dim, Conv_layers=Conv_layers, TB_layers=TB_layers, heads=heads, **kwargs)


## Hyneter + FPN

In [74]:
from torchvision.ops import FeaturePyramidNetwork, MultiScaleRoIAlign


class HyneterFPNBackbone(nn.Module):
    def __init__(self, backbone, hidden_dim=96,):
        super().__init__()
        self.hyneter_backbone = backbone
        
        self.fpn = FeaturePyramidNetwork(
            in_channels_list=[hidden_dim, hidden_dim * 2, hidden_dim * 4, hidden_dim * 8],
            out_channels=256,
        )

        self.out_channels = 256

    def forward(self, x):
        features = self.hyneter_backbone(x)
        print(f"Features from Hyneter Backbone: {features.keys()}")
        print(f"Input image shape: {x.shape}")

        fpn_features = self.fpn(features)
        print(f"FPN Features: {fpn_features.keys()}")
        print(f"FPN Features shape: {[fpn_features[k].shape for k in fpn_features.keys()]}")

        return fpn_features

## Hyneter + FPN + Mask R-CNN

In [75]:
print(torch.__version__)

2.5.1+cu121


In [76]:
import torchvision
from torchvision.ops import MultiScaleRoIAlign


def get_hyneter_mask_rcnn_model(backbone = hyneter_base_fpn(),num_class=91):
    backbone = HyneterFPNBackbone(backbone)

    fpn_output_keys = ['0', '1', '2', '3']  # Corresponds to P2, P3, P4, P5 in FPN
    NUM_FPN_OUTPUT_LEVELS = len(fpn_output_keys)

    anchor_sizes = (
        (32,),
        (64,),
        (128,),
        (256,),
        (512,)
        )
    aspect_ratios = ((0.5, 1.0, 2.0),) * NUM_FPN_OUTPUT_LEVELS
    anchor_generator = AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)


    # ROI Poolers must use the FPN output channels (out_channels, which is 256)
    # featmap_names should match the FPN's output names (0, 1, 2, 3 for P2, P3, P4, P5)
    box_roi_pool = MultiScaleRoIAlign(featmap_names=fpn_output_keys, output_size=7, sampling_ratio=2)
    mask_roi_pool = MultiScaleRoIAlign(featmap_names=fpn_output_keys, output_size=14, sampling_ratio=2)


    model = MaskRCNN(
        backbone=backbone,
        num_classes=num_class,
        rpn_anchor_generator=anchor_generator,
        box_roi_pool=box_roi_pool,
        mask_roi_pool=mask_roi_pool,
        min_size=224,
        max_size=224
    )


    return model

# Test Model Size & Architecture

In [85]:
model = swin_t_fpn(hidden_dim=96, TB_layers=(2, 2, 6, 2), heads=(3, 6, 12, 24))
model = HyneterForFPN(hidden_dim=128, Conv_layers=(2, 2, 6, 2),TB_layers=(2, 2, 18, 2), heads=(3, 6, 12, 24))
model = HyneterForFPN(hidden_dim=128, Conv_layers=(2, 2, 6, 2),TB_layers=(2, 2, 18, 2), heads=(2, 4, 8, 16))
model = HyneterForFPN(hidden_dim=96, Conv_layers=(2, 2, 3, 2),TB_layers=(2, 2, 6, 2), heads=(3, 6, 12, 24))
model = HyneterForFPN(hidden_dim=96, Conv_layers=(2, 2, 2, 2),TB_layers=(2, 2, 2, 2), heads=(3, 6, 12, 24))


def count_parameters(model):
    """
    Counts the total number of parameters in a PyTorch model.
    """
    return sum(p.numel() for p in model.parameters())

def count_trainable_parameters(model):
    """
    Counts only the trainable parameters in a PyTorch model.
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


total_params = count_parameters(model)
trainable_params = count_trainable_parameters(model)

print(f"Total parameters: {total_params / 1e6:.2f} M")
print(f"Trainable parameters: {trainable_params / 1e6:.2f} M")

# For a more detailed breakdown (like `model.summary()` in Keras):
from prettytable import PrettyTable

def model_summary_pytorch(model):
    table = PrettyTable(["Module", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    # print(f"\nTotal Trainable Parameters: {total_params / 1e6:.2f} M")

model_summary_pytorch(model) # Uncomment to see detailed summary

Total parameters: 68.71 M
Trainable parameters: 68.71 M
+--------------------------------------------------------+------------+
|                         Module                         | Parameters |
+--------------------------------------------------------+------------+
|              stage1.mg.branch1x1.0.weight              |     96     |
|              stage1.mg.branch1x1.1.weight              |     32     |
|               stage1.mg.branch1x1.1.bias               |     32     |
|              stage1.mg.branch3x3.0.weight              |     96     |
|               stage1.mg.branch3x3.0.bias               |     32     |
|              stage1.mg.branch3x3.1.weight              |     32     |
|               stage1.mg.branch3x3.1.bias               |     32     |
|              stage1.mg.branch3x3.3.weight              |    9216    |
|               stage1.mg.branch3x3.3.bias               |     32     |
|              stage1.mg.branch3x3.4.weight              |     32     |
|       

# Test on x = torch.randn(1, 3, 224, 224).to(device)

In [86]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


model.to(device)
model.eval()

x = torch.randn(1, 3, 224, 224).to(device)  # Example input tensor

with torch.no_grad():
    output = model(x)
    
print(output)  # Should print the output of the Mask R-CNN model

Using device: cuda
HyneterForFPN forward pass
#Input image shape: torch.Size([1, 3, 224, 224])

HyneterModule: HNB: Patch -> Conv -> TB
HyneterModule forward pass

Input shape before Multigranularity CNN: torch.Size([1, 3, 224, 224])
Input shape after Multigranularity CNN: torch.Size([1, 96, 224, 224])
X(CNN) shape before Conv layers: torch.Size([1, 96, 224, 224])
X(CNN) shape after Conv layers: torch.Size([1, 96, 224, 224])
X(TB) shape before Transformer Blocks: torch.Size([1, 224, 224, 96])
Input shape after Transformer Blocks: torch.Size([1, 224, 224, 96])
########### going to calculate Z ###########
X(CNN) shape: torch.Size([1, 96, 224, 224])
X(TB) shape: torch.Size([1, 96, 224, 224])
Z shape torch.Size([1, 96, 224, 224])
Z before tanh: tensor([[[[-0.0000e+00,  0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00, -0.0000e+00],
          [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
           -0.0000e+00, -0.0000e+00],
          [-0.0000e+00,  0.0000e+00

# Image Classification : ImageNet1K
 
Train Like CSwin

- 300 epochs
- 224 x 224
- AdamW with Weight decay of 0.05
- batch size 1024
- Lr 0.001
- Cosine Lr scheduler 20 epoch

In [None]:
from torchvision.datasets import ImageNet
from torchvision.datasets.folder import default_loader
from torchvision import transforms
import random
from torchvision.transforms import v2
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
import timm
import time
import os

In [None]:
NUM_EPOCHS = 300
IMAGE_SIZE = 224
BATCH_SIZE = 1024
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.05
WARMUP_EPOCHS = 20
NUM_WORKERS = 8  # Number of workers for DataLoader
NUM_CLASSES = 1000  # Number of classes in ImageNet1K
DATA_DIR = '/imagenet'  # Update with your ImageNet dataset path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# --- Augmentation-specific hyperparameters ---
RAND_AUGMENT_NUM_OPS = 2
RAND_AUGMENT_MAGNITUDE = 9
MIXUP_ALPHA = 0.8
CUTMIX_ALPHA = 1.0
RANDOM_ERASING_PROB = 0.25

model = HyneterForFPN(hidden_dim=96, Conv_layers=(2, 2, 2, 2), TB_layers=(2, 2, 2, 2), heads=(3, 6, 12, 24), num_classes=NUM_CLASSES)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Using device: {device}")

criterion  = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=WARMUP_EPOCHS)

In [None]:
# --- Data Transformations with timm's built-in options ---
train_transform = timm.data.create_transform(
    input_size=IMAGE_SIZE,
    is_training=True,
    color_jitter=0.4,
    auto_augment=f'rand-n{RAND_AUGMENT_NUM_OPS}-m{RAND_AUGMENT_MAGNITUDE}',
    interpolation='bicubic',
    mean=(0.485, 0.456, 0.406),
    std=(0.229, 0.224, 0.225),
    re_prob=RANDOM_ERASING_PROB,
    re_mode='pixel',
    re_count=1,
)

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
train_dataset = ImageNet(root=DATA_DIR, split='train', transform=train_transform)
val_dataset = ImageNet(root=DATA_DIR, split='val', transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

In [None]:
from timm.data import Mixup

mixup_fn = None
if MIXUP_ALPHA > 0 or CUTMIX_ALPHA > 0:
    mixup_fn = Mixup(
        mixup_alpha=MIXUP_ALPHA,
        cutmix_alpha=CUTMIX_ALPHA,
        num_classes=NUM_CLASSES,
        prob=1.0,
        switch_prob=0.5,
        mode='batch',
        label_smoothing=0.1,
    )

In [None]:
CHECKPOINT_DIR = './checkpoints' # Directory to save checkpoints
os.makedirs(CHECKPOINT_DIR, exist_ok=True) # Create the directory if it doesn't exist

In [None]:
best_accuracy = 0.0

for epoch in range(NUM_EPOCHS):
    epoch_start_time = time.time()
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS} started at: {time.ctime(epoch_start_time)}")

    model.train()
    running_loss = 0.0


    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        if mixup_fn is not None:
            inputs, labels = mixup_fn(inputs, labels)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

        if (batch_idx + 1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Batch [{batch_idx+1}/{len(train_loader)}], "
                  f"Loss: {loss.item():.4f}")

    epoch_train_loss = running_loss / len(train_dataset) # Approximate for mixed samples
    print(f"Epoch {epoch+1} Train Loss: {epoch_train_loss:.4f}")

    scheduler.step()

    # --- Validation Phase ---
    model.eval()
    val_running_loss = 0.0
    val_correct_predictions = 0
    val_total_samples = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            val_total_samples += labels.size(0)
            val_correct_predictions += (predicted == labels).sum().item()

    val_loss = val_running_loss / len(val_dataset)
    val_accuracy = val_correct_predictions / val_total_samples
    print(f"Epoch {epoch+1} Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

    # Record epoch end time and calculate duration
    epoch_end_time = time.time()
    epoch_duration = epoch_end_time - epoch_start_time
    print(f"Epoch {epoch+1} ended at: {time.ctime(epoch_end_time)}")
    print(f"Epoch {epoch+1} duration: {epoch_duration:.2f} seconds")

    # Save checkpoint
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f'epoch_{epoch+1:03d}.pth')
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(), # Save scheduler state too
        'best_accuracy': best_accuracy, # Or current accuracy if you want to track more
        'val_accuracy': val_accuracy,
        'val_loss': val_loss,
        'train_loss': epoch_train_loss,
    }, checkpoint_path)
    print(f"Saved checkpoint to {checkpoint_path}")





    # Save the best model
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        torch.save(model.state_dict(), f'custom_imagenet_best_model.pth')
        print(f"Saved best model with accuracy: {best_accuracy:.4f}")

print("\nTraining finished :D (I wish)")

# Trainning

Optimizer : AdamW (lr = 0.00001, Weight decay = 0.05)
with MMdetection

In [None]:
import torchvision.datasets as dset
import utils

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader



In [None]:
model = get_hyneter_mask_rcnn_model(hyneter_base_fpn())

In [None]:
model = hyneter_base()

# Test

In [None]:
model = get_hyneter_mask_rcnn_model(hyneter_base_fpn())

In [None]:
model = hyneter_base_fpn()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


model.to(device)
model.eval()

x = torch.randn(1, 3, 224, 224).to(device)  # Example input tensor

with torch.no_grad():
    output = model(x)
    
print(output)  # Should print the output of the Mask R-CNN model

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
    transforms.ToTensor(),  # Convert to tensor
])

In [None]:
train_dataset = dset.CocoDetection(root='data/train2017', annFile='data/annotations/instances_train2017.json', transform=transform)
val_dataset  = dset.CocoDetection(root='data/val2017', annFile='data/annotations/instances_val2017.json', transform=transform)

In [None]:
def collate_fn(batch):
    images = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    return images, targets


In [None]:
if __name__ == "__main__":

    train_loader = DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        )
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [None]:
num_classes = 91  # 80 COCO classes + 1 background
params = [p for p in model.parameters() if p.requires_grad]
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001, weight_decay=0.05)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

In [None]:
model = hyneter_base()