In [1]:
import math
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import os
import time
import tqdm

In [2]:
class DFCAttention(nn.Module):
    """Depthwise-Fully-Connected Attention from GhostNetV2"""
    def __init__(self, channels):
        super().__init__()
        self.dfc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # Global Avg Pool
            nn.Conv2d(channels, channels, kernel_size=1, bias=False),  # FC1
            nn.Hardswish(inplace=True),  # Replaced ReLU with Hardswish
            nn.Conv2d(channels, channels, kernel_size=1, bias=False),  # FC2
            nn.Sigmoid()  # Sigmoid for attention weights
        )

    def forward(self, x):
        return x * self.dfc(x)  # Channel-wise multiplication

class GhostModuleV2(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        ratio: int = 2,
        kernel_size: int = 1,
        dw_size: int = 3,
        stride: int = 1,
        use_attn: bool = True  # Enable/disable DFC Attention
    ):
        super().__init__()
        self.out_channels = out_channels
        init_channels = math.ceil(out_channels / ratio)
        new_channels = init_channels * (ratio - 1)

        # Primary 1×1 convolution (intrinsic features)
        self.primary_conv = nn.Sequential(
            nn.Conv2d(
                in_channels, init_channels,
                kernel_size=kernel_size, stride=stride,
                padding=kernel_size // 2, bias=False
            ),
            nn.BatchNorm2d(init_channels),
            nn.Hardswish(inplace=True),  # ReLU → Hardswish (V2 change)
        )

        # Cheap depthwise convolution (ghost features)
        self.cheap_operation = nn.Sequential(
            nn.Conv2d(
                init_channels, new_channels,
                kernel_size=dw_size, stride=1,
                padding=dw_size // 2,
                groups=init_channels, bias=False
            ),
            nn.BatchNorm2d(new_channels),
            nn.Hardswish(inplace=True),  # ReLU → Hardswish (V2 change)
        )

        # DFC Attention (V2 addition)
        self.use_attn = use_attn
        if use_attn:
            self.attn = DFCAttention(out_channels)

    def forward(self, x):
        # Intrinsic + Ghost features
        x1 = self.primary_conv(x)
        x2 = self.cheap_operation(x1)
        out = T.cat([x1, x2], dim=1)
        out = out[:, :self.out_channels, :, :]  # Trim to out_channels

        # Apply DFC Attention (V2)
        if self.use_attn:
            out = self.attn(out)
        return out

In [3]:
class VaryingWindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_size, context_ratio=2):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.context_ratio = context_ratio
        
        # Projections for Q/K/V
        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        
        # Output projection
        self.to_out = nn.Linear(dim, dim, bias=False)
        
    def forward(self, x):
        B, N, C = x.shape
        H = W = int(N ** 0.5)
        P = self.window_size
        R = self.context_ratio
        
        # Reshape to (B, H, W, C)
        x_spatial = x.view(B, H, W, C)
        
        # 1. Create P×P query windows
        q_windows = x_spatial.unfold(1, P, P).unfold(2, P, P)
        q_windows = q_windows.contiguous().view(B, -1, P*P, C)
        
        # 2. Create (R*P)×(R*P) context windows
        pad = (R * P - P) // 2
        x_pad = F.pad(x_spatial, (0, 0, pad, pad, pad, pad))
        ctx_windows = x_pad.unfold(1, R*P, P).unfold(2, R*P, P)
        ctx_windows = ctx_windows.contiguous().view(B, -1, (R*P)**2, C)
        
        # 3. Concatenate queries and context
        seq = T.cat([q_windows, ctx_windows], dim=2)
        
        # 4. Compute Q, K, V
        qkv = self.to_qkv(seq)
        q, k, v = qkv.chunk(3, dim=-1)  # Split into q, k, v
        
        # 5. Scaled dot-product attention
        attn = (q @ k.transpose(-2, -1)) * (C ** -0.5)
        attn = attn.softmax(dim=-1)
        out = attn @ v
        
        # 6. Keep only query window outputs
        out = out[:, :, :P*P, :]
        
        # 7. Reconstruct spatial layout
        out = out.view(B, H//P, W//P, P, P, C)
        out = out.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, C)
        out = out.view(B, N, C)
        
        return self.to_out(out)


class GhostMLP(nn.Module):
    """
    Ghost MLP with two GhostModuleV2 layers and GELU activation.
    Expands C→4C then projects back 4C→C.
    """
    def __init__(self, dim, mlp_ratio=4):
        super().__init__()
        hidden_dim = dim * mlp_ratio
        self.fc1 = GhostModuleV2(dim, hidden_dim, use_attn=False)
        self.act = nn.GELU()
        self.fc2 = GhostModuleV2(hidden_dim, dim, use_attn=False)

    def forward(self, x: T.Tensor) -> T.Tensor:
        """
        x: (B, N, C)
        returns: (B, N, C)
        """
        B, N, C = x.shape
        # reshape to (B, C, H, W) for GhostModuleV2, then back
        H = W = int(N ** 0.5)
        x_spatial = x.view(B, H, W, C).permute(0,3,1,2)
        x_spatial = self.act(self.fc1(x_spatial))
        x_spatial = self.fc2(x_spatial)
        x_flat = x_spatial.permute(0,2,3,1).view(B, N, C)
        return x_flat




In [4]:
class HybridConvNeXtBlock(nn.Module):
    def __init__(self,
                 dim: int,
                 num_heads: int,
                 window_size: int,
                 context_ratio: int = 2,
                 mlp_ratio: int = 4):
        """
        Args:
          dim: number of input channels
          num_heads: heads for attention
          window_size: P (query window)
          context_ratio: R (context window = R·P)
          mlp_ratio: expansion factor for Ghost MLP
        """
        super().__init__()
        # 0) GhostModule downsampling: dim→2*dim, stride=2
        self.down = GhostModuleV2(
            in_channels=dim,
            out_channels=dim * 2,
            ratio=2,
            kernel_size=3,
            stride=2,
            use_attn=False
        )
        # 1) 7×7 depthwise convolution on downsampled features
        self.dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=7,
                                padding=3, groups=dim * 2, bias=False)
        # 2) LayerNorms for sequence data
        self.norm1 = nn.LayerNorm(dim * 2)
        self.norm2 = nn.LayerNorm(dim * 2)
        # 3) Varying Window Attention
        self.attn = VaryingWindowAttention(
            dim=dim * 2,
            num_heads=num_heads,
            window_size=window_size,
            context_ratio=context_ratio
        )
        # 4) Ghost MLP expands (2C→8C) then projects back (8C→2C)
        self.mlp = GhostMLP(dim * 2, mlp_ratio)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, C, H, W)
        returns: (B, 2C, H/2, W/2)
        """
        # 0. Ghost downsample
        x = self.down(x)  # → (B, 2C, H/2, W/2)

        # 1. Depthwise conv
        x = self.dwconv(x)  # (B, 2C, H/2, W/2)

        # 2. Flatten to sequence for attention: (B, N, 2C)
        B, C2, H2, W2 = x.shape
        x_seq = x.permute(0, 2, 3, 1).reshape(B, H2*W2, C2)

        # 3. Attention + residual
        x_seq = x_seq + self.attn(self.norm1(x_seq))

        # 4. MLP + residual
        x_seq = x_seq + self.mlp(self.norm2(x_seq))

        # 5. Reshape back to (B, 2C, H/2, W/2)
        x_out = x_seq.view(B, H2, W2, C2).permute(0, 3, 1, 2)
        return x_out


NameError: name 'torch' is not defined