In [1]:
import math
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import os
import time
import tqdm

In [2]:
# DFCAttention and GhostModuleV2 as you defined them:
class DFCAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.dfc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels, kernel_size=1, bias=False),
            nn.Hardswish(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=1, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        return x * self.dfc(x)

class GhostModuleV2(nn.Module):
    def __init__(self, in_channels, out_channels,
                 ratio=2, kernel_size=1, dw_size=3,
                 stride=1, use_attn=True):
        super().__init__()
        self.out_channels = out_channels
        init_ch = math.ceil(out_channels / ratio)
        new_ch  = init_ch * (ratio - 1)

        self.primary_conv = nn.Sequential(
            nn.Conv2d(in_channels, init_ch,
                      kernel_size, stride, kernel_size//2, bias=False),
            nn.BatchNorm2d(init_ch),
            nn.Hardswish(inplace=True),
        )
        self.cheap_operation = nn.Sequential(
            nn.Conv2d(init_ch, new_ch, dw_size, 1,
                      dw_size//2, groups=init_ch, bias=False),
            nn.BatchNorm2d(new_ch),
            nn.Hardswish(inplace=True),
        )
        self.use_attn = use_attn
        if use_attn:
            self.attn = DFCAttention(out_channels)

    def forward(self, x):
        x1 = self.primary_conv(x)
        x2 = self.cheap_operation(x1)
        out = torch.cat([x1, x2], dim=1)[:, :self.out_channels]
        if self.use_attn:
            out = self.attn(out)
        return out

In [3]:
class VaryingWindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_size, context_ratio=2):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.context_ratio = context_ratio
        
        # Projections for Q/K/V
        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        
        # Output projection
        self.to_out = nn.Linear(dim, dim, bias=False)
        
    def forward(self, x):
        B, N, C = x.shape
        H = W = int(N ** 0.5)
        P = self.window_size
        R = self.context_ratio
        
        # Reshape to (B, H, W, C)
        x_spatial = x.view(B, H, W, C)
        
        # 1. Create P×P query windows
        q_windows = x_spatial.unfold(1, P, P).unfold(2, P, P)
        q_windows = q_windows.contiguous().view(B, -1, P*P, C)
        
        # 2. Create (R*P)×(R*P) context windows
        pad = (R * P - P) // 2
        x_pad = F.pad(x_spatial, (0, 0, pad, pad, pad, pad))
        ctx_windows = x_pad.unfold(1, R*P, P).unfold(2, R*P, P)
        ctx_windows = ctx_windows.contiguous().view(B, -1, (R*P)**2, C)
        
        # 3. Concatenate queries and context
        seq = T.cat([q_windows, ctx_windows], dim=2)
        
        # 4. Compute Q, K, V
        qkv = self.to_qkv(seq)
        q, k, v = qkv.chunk(3, dim=-1)  # Split into q, k, v
        
        # 5. Scaled dot-product attention
        attn = (q @ k.transpose(-2, -1)) * (C ** -0.5)
        attn = attn.softmax(dim=-1)
        out = attn @ v
        
        # 6. Keep only query window outputs
        out = out[:, :, :P*P, :]
        
        # 7. Reconstruct spatial layout
        out = out.view(B, H//P, W//P, P, P, C)
        out = out.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, C)
        out = out.view(B, N, C)
        
        return self.to_out(out)


class GhostMLP(nn.Module):
    """
    Ghost MLP with two GhostModuleV2 layers and GELU activation.
    Expands C→4C then projects back 4C→C.
    """
    def __init__(self, dim, mlp_ratio=4):
        super().__init__()
        hidden_dim = dim * mlp_ratio
        self.fc1 = GhostModuleV2(dim, hidden_dim, use_attn=False)
        self.act = nn.GELU()
        self.fc2 = GhostModuleV2(hidden_dim, dim, use_attn=False)

    def forward(self, x: T.Tensor) -> T.Tensor:
        """
        x: (B, N, C)
        returns: (B, N, C)
        """
        B, N, C = x.shape
        # reshape to (B, C, H, W) for GhostModuleV2, then back
        H = W = int(N ** 0.5)
        x_spatial = x.view(B, H, W, C).permute(0,3,1,2)
        x_spatial = self.act(self.fc1(x_spatial))
        x_spatial = self.fc2(x_spatial)
        x_flat = x_spatial.permute(0,2,3,1).view(B, N, C)
        return x_flat




In [4]:
class HybridConvNeXtBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size,
                 context_ratio=2, mlp_ratio=4):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, 7, padding=3,
                                groups=dim, bias=False)
        self.norm1  = nn.LayerNorm(dim)
        self.attn   = VaryingWindowAttention(dim, num_heads,
                                             window_size, context_ratio)
        self.norm2  = nn.LayerNorm(dim)
        self.mlp    = GhostMLP(dim, mlp_ratio)

    def forward(self, x):
        x = self.dwconv(x)
        B,C,H,W = x.shape
        seq = x.permute(0,2,3,1).reshape(B, H*W, C)
        seq = seq + self.attn(self.norm1(seq))
        seq = seq + self.mlp(self.norm2(seq))
        return seq.view(B, H, W, C).permute(0,3,1,2)



In [5]:
class ASPPNeXtEncoder(nn.Module):
    def __init__(self, in_ch, base_dim=64,
                 num_heads=(4,8,16,32),
                 window_sizes=(8,4,2,1),
                 mlp_ratio=4):
        super().__init__()
        # Stage 1 patch‐embedding (no further downsampling here)
        self.stem = nn.Sequential(
            nn.Conv2d(in_ch, base_dim, 4, stride=4, bias=False),
            nn.LayerNorm([base_dim, None, None])
        )
        self.stages = nn.ModuleList()
        for i in range(4):
            dim = base_dim * (2**i)
            # downsample before stages 2–4 (i=1,2,3)
            down = None if i == 0 else nn.Sequential(
                GhostModuleV2(in_channels=dim//2,
                              out_channels=dim,
                              ratio=2,
                              kernel_size=2,
                              stride=2,
                              use_attn=True),
                nn.LayerNorm([dim, None, None])
            )
            block = HybridConvNeXtBlock(
                dim=dim,
                num_heads=num_heads[i],
                window_size=window_sizes[i],
                context_ratio=2,
                mlp_ratio=mlp_ratio
            )
            self.stages.append(nn.ModuleDict({'down': down, 'block': block}))

    def forward(self, x):
        # x: (B, in_ch, H, W)
        x = self.stem(x)            # Stage 1 input → (B, base_dim, H/4, W/4)
        skips = []
        for stage in self.stages:
            if stage['down'] is not None:
                x = stage['down'](x) # downsample before stages 2–4
            x = stage['block'](x)   # Hybrid block
            skips.append(x)         # capture skip for decoder
        # skips[0] = stage 1 output (no downsample)
        # skips[1] = stage 2 output (1/8 input res)
        # skips[2] = stage 3 output (1/16 input res)
        # skips[3] = stage 4 output (1/32 input res)
        return x, skips

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# Ensure DFCAttention and GhostModuleV2 are defined in your notebook
# from ghost_module_v2 import GhostModuleV2

class PreDAAFGhostV2(nn.Module):
    """
    Pre-DAAF bottleneck using GhostModuleV2 with DFC attention (V2).
    Applies:
      1) GhostModuleV2 1×1 (C→C//reduction) with stride=1
      2) 3×3 depthwise convolution on the reduced channels
      3) GhostModuleV2 1×1 (C//reduction→C) with stride=1
    """
    def __init__(self, channels: int, reduction: int = 4, ratio: int = 2):
        """
        Args:
          channels:    number of input/output channels (C)
          reduction:   bottleneck factor (default 4 → reduces to C/4)
          ratio:       Ghost ratio for GhostModuleV2 (default 2)
        """
        super().__init__()
        mid_channels = channels // reduction

        # 1×1 Ghost down (intrinsic→mid_channels)
        self.reduce = GhostModuleV2(
            in_channels=channels,
            out_channels=mid_channels,
            ratio=ratio,
            kernel_size=1,
            stride=1,
            use_attn=True
        )

        # 3×3 depthwise convolution
        self.dwconv = nn.Sequential(
            nn.Conv2d(mid_channels, mid_channels,
                      kernel_size=3, stride=1, padding=1,
                      groups=mid_channels, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.Hardswish(inplace=True)
        )

        # 1×1 Ghost up (mid_channels→channels)
        self.expand = GhostModuleV2(
            in_channels=mid_channels,
            out_channels=channels,
            ratio=ratio,
            kernel_size=1,
            stride=1,
            use_attn=True
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, C, H, W)
        returns: (B, C, H, W)
        """
        x = self.reduce(x)    # → (B, C//reduction, H, W)
        x = self.dwconv(x)    # → (B, C//reduction, H, W)
        x = self.expand(x)    # → (B, C, H, W)
        return x

# Example usage in the ASPPNeXt pipeline:

# Instantiate once (channels must match final encoder output channels)
# e.g., if encoder final stage outputs C=512 feature maps:
pre_daaf = PreDAAFGhostV2(channels=512, reduction=4, ratio=2)

# In forward pass for RGB and Depth encoders:
# f_rgb_out, rgb_skips   = rgb_encoder(rgb_input)
# f_depth_out, depth_skips = depth_encoder(depth_input)
# f_rgb_pre   = pre_daaf(f_rgb_out)
# f_depth_pre = pre_daaf(f_depth_out)
#
# Then feed (f_rgb_pre, f_depth_pre) into your DAAF fusion block.
