In [1]:
import math
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import os
import time
import tqdm
from torch.utils.data import Dataset, DataLoader

In [2]:
class DFCAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.dfc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels, kernel_size=1, bias=False),
            nn.Hardswish(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=1, bias=False),
            nn.Sigmoid().add_(1e-6)  # Added epsilon for numerical stability
        )
        
    def forward(self, x):
        return x * self.dfc(x)

class GhostModuleV2(nn.Module):
    def __init__(self, in_channels, out_channels,
                 ratio=2, kernel_size=1, dw_size=3,
                 stride=1, use_attn=True):
        super().__init__()
        self.out_channels = out_channels
        init_ch = math.ceil(out_channels / ratio)
        new_ch = init_ch * (ratio - 1)

        # Primary convolution with dynamic padding
        self.primary_conv = nn.Sequential(
            nn.Conv2d(in_channels, init_ch,
                     kernel_size, stride,
                     (kernel_size - 1) // 2,  # More precise padding
                     bias=False),
            nn.BatchNorm2d(init_ch),
            nn.Hardswish(inplace=True),
        )

        # Cheap operation with dynamic padding
        self.cheap_operation = nn.Sequential(
            nn.Conv2d(init_ch, new_ch, dw_size, 1,
                     (dw_size - 1) // 2,  # More precise padding
                     groups=init_ch, bias=False),
            nn.BatchNorm2d(new_ch),
            nn.Hardswish(inplace=True),
        )

        # Attention with stride warning
        self.use_attn = use_attn
        if use_attn:
            if stride > 1:
                print(f"Warning: Stride {stride}>1 may reduce attention effectiveness")
            self.attn = DFCAttention(out_channels)

    def forward(self, x):
        x1 = self.primary_conv(x)
        x2 = self.cheap_operation(x1)
        out = T.cat([x1, x2], dim=1)[:, :self.out_channels]
        return self.attn(out) if self.use_attn else out

In [3]:
class VaryingWindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_size, context_ratio=2):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.context_ratio = context_ratio
        
        # Projections for Q/K/V
        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        
        # Output projection
        self.to_out = nn.Linear(dim, dim, bias=False)
        
    def forward(self, x):
        B, N, C = x.shape
        H = W = int(N ** 0.5)
        P = self.window_size
        R = self.context_ratio
        
        # Reshape to (B, H, W, C)
        x_spatial = x.view(B, H, W, C)
        
        # 1. Create P×P query windows
        q_windows = x_spatial.unfold(1, P, P).unfold(2, P, P)
        q_windows = q_windows.contiguous().view(B, -1, P*P, C)
        
        # 2. Create (R*P)×(R*P) context windows
        pad = (R * P - P) // 2
        x_pad = F.pad(x_spatial, (0, 0, pad, pad, pad, pad))
        ctx_windows = x_pad.unfold(1, R*P, P).unfold(2, R*P, P)
        ctx_windows = ctx_windows.contiguous().view(B, -1, (R*P)**2, C)
        
        # 3. Concatenate queries and context
        seq = T.cat([q_windows, ctx_windows], dim=2)
        
        # 4. Compute Q, K, V
        qkv = self.to_qkv(seq)
        q, k, v = qkv.chunk(3, dim=-1)  # Split into q, k, v
        
        # 5. Scaled dot-product attention
        attn = (q @ k.transpose(-2, -1)) * (C ** -0.5)
        attn = attn.softmax(dim=-1)
        out = attn @ v
        
        # 6. Keep only query window outputs
        out = out[:, :, :P*P, :]
        
        # 7. Reconstruct spatial layout
        out = out.view(B, H//P, W//P, P, P, C)
        out = out.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, C)
        out = out.view(B, N, C)
        
        return self.to_out(out)


class GhostMLP(nn.Module):
    """
    Ghost MLP with two GhostModuleV2 layers and GELU activation.
    Expands C→4C then projects back 4C→C.
    """
    def __init__(self, dim, mlp_ratio=4):
        super().__init__()
        hidden_dim = dim * mlp_ratio
        self.fc1 = GhostModuleV2(dim, hidden_dim, use_attn=False)
        self.act = nn.GELU()
        self.fc2 = GhostModuleV2(hidden_dim, dim, use_attn=False)

    def forward(self, x: T.Tensor) -> T.Tensor:
        """
        x: (B, N, C)
        returns: (B, N, C)
        """
        B, N, C = x.shape
        # reshape to (B, C, H, W) for GhostModuleV2, then back
        H = W = int(N ** 0.5)
        x_spatial = x.view(B, H, W, C).permute(0,3,1,2)
        x_spatial = self.act(self.fc1(x_spatial))
        x_spatial = self.fc2(x_spatial)
        x_flat = x_spatial.permute(0,2,3,1).view(B, N, C)
        return x_flat




In [4]:
class HybridConvNeXtBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size,
                 context_ratio=2, mlp_ratio=4):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, 7, padding=3,
                                groups=dim, bias=False)
        self.norm1  = nn.LayerNorm(dim)
        self.attn   = VaryingWindowAttention(dim, num_heads,
                                             window_size, context_ratio)
        self.norm2  = nn.LayerNorm(dim)
        self.mlp    = GhostMLP(dim, mlp_ratio)

    def forward(self, x):
        x = self.dwconv(x)
        B,C,H,W = x.shape
        seq = x.permute(0,2,3,1).reshape(B, H*W, C)
        seq = seq + self.attn(self.norm1(seq))
        seq = seq + self.mlp(self.norm2(seq))
        return seq.view(B, H, W, C).permute(0,3,1,2)



In [5]:
class ASPPNeXtEncoder(nn.Module):
    def __init__(self, in_ch, base_dim=64,
                 num_heads=(4,8,16,32),
                 window_sizes=(8,4,2,1),
                 mlp_ratio=4):
        super().__init__()
        # Stage 1 patch‐embedding (no further downsampling here)
        self.stem = nn.Sequential(
            nn.Conv2d(in_ch, base_dim, 4, stride=4, bias=False),
            nn.LayerNorm([base_dim, None, None])
        )
        self.stages = nn.ModuleList()
        for i in range(4):
            dim = base_dim * (2**i)
            # downsample before stages 2–4 (i=1,2,3)
            down = None if i == 0 else nn.Sequential(
                GhostModuleV2(in_channels=dim//2,
                              out_channels=dim,
                              ratio=2,
                              kernel_size=2,
                              stride=2,
                              use_attn=True),
                nn.LayerNorm([dim, None, None])
            )
            block = HybridConvNeXtBlock(
                dim=dim,
                num_heads=num_heads[i],
                window_size=window_sizes[i],
                context_ratio=2,
                mlp_ratio=mlp_ratio
            )
            self.stages.append(nn.ModuleDict({'down': down, 'block': block}))

    def forward(self, x):
        # x: (B, in_ch, H, W)
        x = self.stem(x)            # Stage 1 input → (B, base_dim, H/4, W/4)
        skips = []
        for stage in self.stages:
            if stage['down'] is not None:
                x = stage['down'](x) # downsample before stages 2–4
            x = stage['block'](x)   # Hybrid block
            skips.append(x)         # capture skip for decoder
        # skips[0] = stage 1 output (no downsample)
        # skips[1] = stage 2 output (1/8 input res)
        # skips[2] = stage 3 output (1/16 input res)
        # skips[3] = stage 4 output (1/32 input res)
        return x, skips

In [6]:
class PreDAAFGhostV2(nn.Module):
    def __init__(self, channels: int, reduction: int = 4, ratio: int = 2):
        super().__init__()
        mid_channels = channels // reduction

        self.reduce = GhostModuleV2(
            in_channels=channels,
            out_channels=mid_channels,
            ratio=ratio,
            kernel_size=1,
            stride=1,
            use_attn=True
        )

        self.dwconv = nn.Sequential(
            nn.Conv2d(mid_channels, mid_channels,
                      kernel_size=3, stride=1, padding=1,
                      groups=mid_channels, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.Hardswish(inplace=True)
        )

        self.expand = GhostModuleV2(
            in_channels=mid_channels,
            out_channels=channels,
            ratio=ratio,
            kernel_size=1,
            stride=1,
            use_attn=True
        )

    def forward(self, x: T.Tensor) -> T.Tensor:
        x = self.reduce(x)    # → (B, C//reduction, H, W)
        x = self.dwconv(x)    # → (B, C//reduction, H, W)
        x = self.expand(x)    # → (B, C, H, W)
        return x

# Example usage in the ASPPNeXt pipeline:

# Instantiate once (channels must match final encoder output channels)
# e.g., if encoder final stage outputs C=512 feature maps:
#pre_daaf = PreDAAFGhostV2(channels=512, reduction=4, ratio=2)

# In forward pass for RGB and Depth encoders:
# f_rgb_out, rgb_skips   = rgb_encoder(rgb_input)
# f_depth_out, depth_skips = depth_encoder(depth_input)
# f_rgb_pre   = pre_daaf(f_rgb_out)
# f_depth_pre = pre_daaf(f_depth_out)
#
# Then feed (f_rgb_pre, f_depth_pre) into your DAAF fusion block.


In [7]:
class DAAFBlock(nn.Module):
    def __init__(self, channels: int, num_heads: int):
        super().__init__()
        # 1. Local branch per modality
        self.local_branch = RDSCBLocal(channels)
        # Fuse local RGB/depth outputs back to C channels
        self.local_fuse = GhostModuleV2(
            in_channels=channels * 2,
            out_channels=channels,
            kernel_size=1,
            use_attn=False
        )
        self.lia = LIA(channels)

        # 2. Global branch
        self.itb = InteractiveTransformerBlock(channels, num_heads)

        # 3. Global fusion conv: (local + global_rgb + global_depth) → channels
        self.global_fuse = GhostModuleV2(
            in_channels=channels * 3,
            out_channels=channels,
            kernel_size=3,
            use_attn=False
        )
        self.act = nn.LeakyReLU(inplace=True)

        # 4. Reconstruction head: three cascaded GhostConv(3×3)
        self.reconstruction = nn.Sequential(
            GhostModuleV2(channels, channels, kernel_size=3, use_attn=False),
            nn.LeakyReLU(inplace=True),
            GhostModuleV2(channels, channels, kernel_size=3, use_attn=False),
            nn.LeakyReLU(inplace=True),
            GhostModuleV2(channels, channels, kernel_size=3, use_attn=False),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self,
                f_rgb_pre: T.Tensor,
                f_depth_pre: T.Tensor) -> T.Tensor:
        # 1. Local branch on each modality
        lr = self.lia(self.local_branch(f_rgb_pre))
        ld = self.lia(self.local_branch(f_depth_pre))
        # Fuse local RGB + Depth features
        local = T.cat([lr, ld], dim=1)   # (B, 2C, H, W)
        local = self.act(self.local_fuse(local))  # (B, C, H, W)

        # 2. Global interactive attention
        g_rgb, g_depth = self.itb(f_rgb_pre, f_depth_pre)

        # 3. Global fusion
        cat = T.cat([local, g_rgb, g_depth], dim=1)  # (B, 3C, H, W)
        fused = self.act(self.global_fuse(cat))         # (B, C, H, W)

        # 4. Reconstruction head
        out = self.reconstruction(fused)  # (B, C, H, W)
        return out


In [8]:
class PostDAAFGhostV2(nn.Module):
    """
    Post-DAAF GhostModule bottleneck using GhostModuleV2 with DFC attention.
    Mirrors the Pre-DAAF structure:
      1) GhostModuleV2 1×1 (C→C//reduction)
      2) 3×3 depthwise conv
      3) GhostModuleV2 1×1 (C//reduction→C)
    """
    def __init__(self, channels: int, reduction: int = 4, ratio: int = 2):
        super().__init__()
        mid_ch = channels // reduction

        # 1×1 GhostModuleV2 reduce channels
        self.reduce = GhostModuleV2(
            in_channels=channels,
            out_channels=mid_ch,
            ratio=ratio,
            kernel_size=1,
            stride=1,
            use_attn=True
        )
        # 3×3 depthwise conv on reduced features
        self.dwconv = nn.Sequential(
            nn.Conv2d(
                mid_ch, mid_ch,
                kernel_size=3, stride=1,
                padding=1, groups=mid_ch,
                bias=False
            ),
            nn.BatchNorm2d(mid_ch),
            nn.Hardswish(inplace=True)
        )
        # 1×1 GhostModuleV2 expand back to original channels
        self.expand = GhostModuleV2(
            in_channels=mid_ch,
            out_channels=channels,
            ratio=ratio,
            kernel_size=1,
            stride=1,
            use_attn=True
        )

    def forward(self, x: T.Tensor) -> T.Tensor:
        """
        x: (B, C, H, W) — fused feature map from DAAFBlock
        returns: (B, C, H, W) — channel-refined output
        """
        x = self.reduce(x)    # → (B, C//reduction, H, W)
        x = self.dwconv(x)    # → (B, C//reduction, H, W)
        x = self.expand(x)    # → (B, C, H, W)
        return x

# Example integration:

# After you obtain `out` from your DAAFBlock:
# daaf = DAAFBlock(channels=..., num_heads=...)
# fused = daaf(f_rgb_pre, f_depth_pre)
#
# post_daaf = PostDAAFGhostV2(channels=fused.shape[1], reduction=4, ratio=2)
# refined = post_daaf(fused)


In [9]:
class GhostASPPFELAN(nn.Module):
    """
    Ghost ASPPFELAN Block:
      - Branch 1: GhostConv(1×1) → LeakyReLU
      - Branch 2: GhostConv(3×3, dilation=2) → LeakyReLU
      - Branch 3: GhostConv(3×3, dilation=4) → LeakyReLU
      - Concatenate branches → GhostConv(1×1) → LeakyReLU
    """
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        # Branch 1: 1×1 GhostConv
        self.branch1 = nn.Sequential(
            GhostModuleV2(in_channels, out_channels,
                          kernel_size=1, use_attn=False),
            nn.LeakyReLU(inplace=True)
        )
        # Branch 2: 3×3 GhostConv with dilation=2
        self.branch2 = nn.Sequential(
            GhostModuleV2(in_channels, out_channels,
                          kernel_size=3, use_attn=False),
            nn.LeakyReLU(inplace=True)
        )
        # adjust dilation on the depthwise step inside GhostModuleV2
        # primary_conv uses 1×1 so no change; cheap_operation is depthwise
        self.branch2[0].cheap_operation[0].dilation = (2,2)
        self.branch2[0].cheap_operation[0].padding = (2,2)

        # Branch 3: 3×3 GhostConv with dilation=4
        self.branch3 = nn.Sequential(
            GhostModuleV2(in_channels, out_channels,
                          kernel_size=3, use_attn=False),
            nn.LeakyReLU(inplace=True)
        )
        self.branch3[0].cheap_operation[0].dilation = (4,4)
        self.branch3[0].cheap_operation[0].padding = (4,4)

        # Fusion: concat → 1×1 GhostConv → LeakyReLU
        self.fuse = nn.Sequential(
            GhostModuleV2(out_channels * 3, out_channels,
                          kernel_size=1, use_attn=False),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x: T.Tensor) -> T.Tensor:
        """
        x: (B, in_channels, H, W)
        returns: (B, out_channels, H, W)
        """
        b1 = self.branch1(x)
        b2 = self.branch2(x)
        b3 = self.branch3(x)
        # Concatenate along channel dimension
        cat = T.cat([b1, b2, b3], dim=1)  # (B, 3*out_channels, H, W)
        out = self.fuse(cat)
        return out


In [10]:
class CoordAttention(nn.Module):
    """
    Coordinate Attention block with GhostModuleV2 replacing 1×1 convolutions.
    Input: X ∈ ℝ^(B×C×H×W)
    Output: X' = X × attn_h × attn_w
    """
    def __init__(self, channels: int, reduction: int = 4):
        super().__init__()
        self.channels = channels
        self.mid_channels = channels // reduction

        # 1×1 GhostConv to reduce channels: C→C/r
        self.conv1 = GhostModuleV2(
            in_channels=channels,
            out_channels=self.mid_channels,
            kernel_size=1,
            use_attn=False
        )
        self.act = nn.LeakyReLU(inplace=True)

        # 1×1 GhostConvs to expand back: C/r→C
        self.conv_h = GhostModuleV2(
            in_channels=self.mid_channels,
            out_channels=channels,
            kernel_size=1,
            use_attn=False
        )
        self.conv_w = GhostModuleV2(
            in_channels=self.mid_channels,
            out_channels=channels,
            kernel_size=1,
            use_attn=False
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: T.Tensor) -> T.Tensor:
        B, C, H, W = x.size()

        # 1. Pool along height and width
        z_h = F.adaptive_avg_pool2d(x, (H, 1))  # (B, C, H, 1)
        z_w = F.adaptive_avg_pool2d(x, (1, W))  # (B, C, 1, W)

        # 2. Transpose z_w to match concat shape
        z_w = z_w.permute(0, 1, 3, 2)           # (B, C, W, 1)

        # 3. Concat and reduce channels
        z = T.cat([z_h, z_w], dim=2)        # (B, C, H+W, 1)
        z = z.permute(0, 3, 2, 1)               # (B, 1, H+W, C)
        z = self.conv1(z)                       # (B, 1, H+W, C/r)
        z = self.act(z)

        # 4. Split into height and width contexts
        z = z.permute(0, 3, 2, 1)               # (B, C/r, H+W, 1)
        z_h, z_w = T.split(z, [H, W], dim=2)  # z_h: (B, C/r, H, 1), z_w: (B, C/r, W, 1)

        # 5. Generate attention maps
        a_h = self.sigmoid(self.conv_h(z_h))   # (B, C, H, 1)
        a_w = self.sigmoid(self.conv_w(z_w.permute(0,1,3,2)))  # (B, C, 1, W)

        # 6. Apply attention
        out = x * a_h * a_w
        return out


In [11]:
class DySample(nn.Module):
    """
    Dynamic content-aware upsampling (×scale) via point sampling.
    Based on “DySample: Learning to Upsample by Learning to Sample”[1].
    Input:  X ∈ ℝ^(B×C×H×W)
    Output: X_up ∈ ℝ^(B×C×(H·scale)×(W·scale))
    """
    def __init__(self,
                 in_channels: int,
                 scale: int = 2,
                 use_feat_transform: bool = False):
        super().__init__()
        self.scale = scale
        # Optional 1×1 GhostConv feature transform
        if use_feat_transform:
            self.transform = GhostModuleV2(
                in_channels, in_channels,
                kernel_size=1, use_attn=False
            )
        else:
            self.transform = None
        # Predict pixel‐level offsets: 2 coords × scale² shifts per input pixel
        self.offset_conv = nn.Conv2d(
            in_channels,
            2 * scale * scale,
            kernel_size=1,
            bias=True
        )

    def forward(self, x: T.Tensor) -> T.Tensor:
        B, C, H, W = x.shape

        # 1. Optional feature transform
        if self.transform is not None:
            x = self.transform(x)

        # 2. Initial upsampling via bilinear interpolation
        H2, W2 = H * self.scale, W * self.scale
        x_interp = F.interpolate(
            x, size=(H2, W2),
            mode='bilinear',
            align_corners=False
        )

        # 3. Predict offsets at input resolution
        #    shape: (B, 2*scale*scale, H, W)
        offsets = self.offset_conv(x)

        # 4. Rearrange offsets to per-output-pixel shifts
        #    from (B, 2·s², H, W)
        # to (B, 2, H·s, W·s)
        offsets = offsets.view(B, 2, self.scale, self.scale, H, W)
        # reorder to (B, 2, H, s, W, s)
        offsets = offsets.permute(0, 1, 4, 2, 5, 3)
        # collapse (H, s)→H2 and (W, s)→W2
        offsets = offsets.reshape(B, 2, H2, W2)

        # 5. Build base sampling grid in normalized coords [-1,1]
        # grid has shape (1, H2, W2, 2) with (x,y) in last dim
        device = x.device
        ys = T.linspace(-1, 1, H2, device=device)
        xs = T.linspace(-1, 1, W2, device=device)
        grid_y, grid_x = T.meshgrid(ys, xs, indexing='ij')
        base_grid = T.stack((grid_x, grid_y), dim=-1)  # (H2, W2, 2)
        base_grid = base_grid.unsqueeze(0).expand(B, -1, -1, -1)

        # 6. Convert pixel offsets to normalized offsets
        #    offset_x_pixel → offset_x_norm = offset_x * 2 / (W2 - 1)
        #    offset_y_pixel → offset_y_norm = offset_y * 2 / (H2 - 1)
        norm_factor = T.tensor(
            [2.0 / max(W2 - 1, 1), 2.0 / max(H2 - 1, 1)],
            device=device
        ).view(1, 2, 1, 1)
        offsets_norm = offsets * norm_factor

        # 7. Form final sampling grid and sample
        #    grid_sample expects grid in shape (B, H2, W2, 2)
        sampling_grid = base_grid + offsets_norm.permute(0, 2, 3, 1)
        x_up = F.grid_sample(
            x_interp,
            sampling_grid,
            mode='bilinear',
            padding_mode='border',
            align_corners=True
        )

        return x_up


In [12]:
class ASPPNeXtDecoder(nn.Module):
    """
    ASPPNeXt decoder with 3 stages - WITHOUT final classification layer:
      Stage 1: fuse Post-DAAF + skip E4 → GhostASPPFELAN → CoordAttention → DySample
      Stage 2: fuse Stage1_out + skip E3 → GhostASPPFELAN → CoordAttention → DySample  
      Stage 3: fuse Stage2_out + skip E2 → GhostASPPFELAN → CoordAttention → DySample
      
    Output: Raw feature maps at 1/4 input resolution (128×96 for 512×384 input)
    """
    def __init__(self,
                 feature_channels: int,
                 skip_channels: list,
                 decoder_channels: list,
                 coord_reduction: int = 4):
        """
        Args:
          feature_channels: channels of Post-DAAF output
          skip_channels:    [E2, E3, E4] channels from encoder stages
          decoder_channels: [stage1, stage2, stage3] desired output channels per stage
          coord_reduction:  reduction ratio for CoordAttention
        """
        super().__init__()
        assert len(skip_channels) == 3 and len(decoder_channels) == 3

        self.stages = nn.ModuleList()
        in_ch = feature_channels
        
        for i in range(3):
            skip_ch = skip_channels[2 - i]  # E4→stage0, E3→stage1, E2→stage2
            out_ch  = decoder_channels[i]
            
            block = nn.ModuleDict({
                'fuse_skip': GhostModuleV2(
                    in_channels=in_ch + skip_ch,
                    out_channels=out_ch,
                    kernel_size=1,
                    use_attn=False
                ),
                'aspp': GhostASPPFELAN(
                    in_channels=out_ch,
                    out_channels=out_ch
                ),
                'coord': CoordAttention(
                    channels=out_ch,
                    reduction=coord_reduction
                ),
                'upsample': DySample(
                    in_channels=out_ch,
                    scale=2,
                    use_feat_transform=False
                )
            })
            self.stages.append(block)
            in_ch = out_ch  # next stage input

    def forward(self, post_daaf: T.Tensor, skips: list):
        """
        Args:
          post_daaf: (B, C, H, W) output of Post-DAAF
          skips:     [E1, E2, E3, E4] encoder features
        Returns:
          features: (B, decoder_channels[-1], H*8, W*8) at 1/4 input resolution
        """
        x = post_daaf
        
        # Process each decoder stage
        for stage, skip in zip(self.stages, reversed(skips[1:])):  
            # skips[1:] = [E2,E3,E4]; reversed → [E4,E3,E2]
            
            # 1) Fuse skip connection
            x = T.cat([x, skip], dim=1)
            x = stage['fuse_skip'](x)
            
            # 2) GhostASPPFELAN
            x = stage['aspp'](x)
            
            # 3) CoordAttention  
            x = stage['coord'](x)
            
            # 4) DySample upsampling ×2
            x = stage['upsample'](x)

        return x  # Raw features at 1/4 input resolution


In [13]:
class ASPPNeXtOutputLayer(nn.Module):
    """
    Separate output layer that takes decoder features at 1/4 resolution
    and produces final segmentation masks at full input resolution.
    
    Supports multiple upsampling strategies:
    - Learnable transpose convolution (ConvTranspose2d)
    - Bilinear interpolation + refinement conv
    - DySample-based learnable upsampling
    """
    def __init__(self,
                 in_channels: int,
                 num_classes: int,
                 upsampling_method: str = 'transpose',
                 use_ghost_conv: bool = False,
                 refinement_layers: int = 1):
        """
        Args:
          in_channels:        channels from final decoder stage
          num_classes:        number of segmentation classes
          upsampling_method:  'transpose', 'bilinear', 'dysample'
          use_ghost_conv:     whether to use GhostModuleV2 for final conv
          refinement_layers:  number of refinement convs after upsampling
        """
        super().__init__()
        self.upsampling_method = upsampling_method
        self.num_classes = num_classes
        
        # Upsampling strategies
        if upsampling_method == 'transpose':
            # Learnable 4× upsampling via transpose convolution
            self.upsample = nn.ConvTranspose2d(
                in_channels=in_channels,
                out_channels=in_channels,
                kernel_size=4,
                stride=4,
                padding=0,
                bias=False
            )
            
        elif upsampling_method == 'bilinear':
            # Fixed bilinear interpolation
            self.upsample = lambda x: F.interpolate(
                x, scale_factor=4, 
                mode='bilinear', 
                align_corners=False
            )
            
        elif upsampling_method == 'dysample':
            # Content-aware DySample upsampling
            self.upsample = DySample(
                in_channels=in_channels,
                scale=4,  # 4× upsampling in one step
                use_feat_transform=True
            )
        else:
            raise ValueError(f"Unknown upsampling method: {upsampling_method}")
        
        # Optional refinement layers
        refinement = []
        for i in range(refinement_layers):
            if use_ghost_conv:
                refinement.append(GhostModuleV2(
                    in_channels, in_channels,
                    kernel_size=3, use_attn=False
                ))
            else:
                refinement.append(nn.Conv2d(
                    in_channels, in_channels,
                    kernel_size=3, padding=1, bias=False
                ))
            refinement.append(nn.BatchNorm2d(in_channels))
            refinement.append(nn.ReLU(inplace=True))
            
        self.refinement = nn.Sequential(*refinement)
        
        # Final classification layer
        if use_ghost_conv:
            self.classifier = GhostModuleV2(
                in_channels, num_classes,
                kernel_size=1, use_attn=False
            )
        else:
            self.classifier = nn.Conv2d(
                in_channels, num_classes,
                kernel_size=1, bias=True
            )

    def forward(self, x: T.Tensor, target_size: tuple = None):
        """
        Args:
          x: (B, C, H/4, W/4) decoder output at 1/4 resolution
          target_size: (H, W) target output size, if None uses 4×input size
        Returns:
          logits: (B, num_classes, H, W) at full resolution
        """
        # Upsample to full resolution
        if self.upsampling_method == 'bilinear':
            if target_size is not None:
                x = F.interpolate(x, size=target_size, 
                                mode='bilinear', align_corners=False)
            else:
                x = self.upsample(x)
        else:
            x = self.upsample(x)
            
        # Apply refinement layers
        x = self.refinement(x)
        
        # Final classification
        logits = self.classifier(x)
        
        # Ensure exact target size if specified
        if target_size is not None and logits.shape[-2:] != target_size:
            logits = F.interpolate(logits, size=target_size,
                                 mode='bilinear', align_corners=False)
            
        return logits

# Example usage:
def create_asppnext_with_separate_output(feature_channels=512,
                                       skip_channels=[128, 256, 512],
                                       decoder_channels=[256, 128, 64],
                                       num_classes=21,
                                       upsampling_method='transpose'):
    """
    Factory function to create decoder + output layer
    """
    decoder = ASPPNeXtDecoder(
        feature_channels=feature_channels,
        skip_channels=skip_channels,
        decoder_channels=decoder_channels
    )
    
    output_layer = ASPPNeXtOutputLayer(
        in_channels=decoder_channels[-1],  # Final decoder stage channels
        num_classes=num_classes,
        upsampling_method=upsampling_method,
        use_ghost_conv=False,  # Set to True to experiment with GhostModuleV2
        refinement_layers=2
    )
    
    return decoder, output_layer

# Complete forward pass example:
# decoder, output_layer = create_asppnext_with_separate_output()
# 
# # From your pipeline:
# decoder_features = decoder(post_daaf_output, encoder_skips)  # (B, 64, 128, 96)
# final_logits = output_layer(decoder_features, target_size=(384, 512))  # (B, 21, 384, 512)


In [14]:
class ASPPNeXtModel(nn.Module):
    def __init__(self,
                 in_ch_rgb: int = 3,
                 in_ch_depth: int = 1,
                 base_dim: int   = 64,
                 num_heads: tuple = (4,8,16,32),
                 window_sizes: tuple = (8,4,2,1),
                 mlp_ratio: int = 4,
                 skip_channels: list = None,
                 decoder_channels: list = None,
                 num_classes: int = 21,
                 coord_reduction: int = 4,
                 output_upsample: str = 'transpose'):
        super().__init__()
        # 1) Encoders for RGB and Depth
        self.rgb_encoder   = ASPPNeXtEncoder(in_ch=in_ch_rgb,
                                             base_dim=base_dim,
                                             num_heads=num_heads,
                                             window_sizes=window_sizes,
                                             mlp_ratio=mlp_ratio)
        self.depth_encoder = ASPPNeXtEncoder(in_ch=in_ch_depth,
                                             base_dim=base_dim,
                                             num_heads=num_heads,
                                             window_sizes=window_sizes,
                                             mlp_ratio=mlp_ratio)

        # Final encoder feature channels
        enc_out_ch = base_dim * 2**3  # stage4 channels

        # 2) Pre-DAAF Ghost bottlenecks
        self.pre_daaf = PreDAAFGhostV2(channels=enc_out_ch, reduction=4, ratio=2)

        # 3) DAAF fusion block
        self.daaf = DAAFBlock(channels=enc_out_ch, num_heads=num_heads[-1])

        # 4) Post-DAAF Ghost bottleneck
        self.post_daaf = PostDAAFGhostV2(channels=enc_out_ch, reduction=4, ratio=2)

        # 5) Decoder
        # If not provided, default skip and decoder channels:
        if skip_channels is None:
            skip_channels = [base_dim * 2**i for i in range(1,4)]  # [E2,E3,E4]
        if decoder_channels is None:
            decoder_channels = [base_dim * 2**i for i in reversed(range(3))]  # [C*4,C*2,C]
        self.decoder = ASPPNeXtDecoder(feature_channels=enc_out_ch,
                                       skip_channels=skip_channels,
                                       decoder_channels=decoder_channels,
                                       coord_reduction=coord_reduction)

        # 6) Output layer to restore full resolution
        self.output = ASPPNeXtOutputLayer(in_channels=decoder_channels[-1],
                                          num_classes=num_classes,
                                          upsampling_method=output_upsample,
                                          use_ghost_conv=False,
                                          refinement_layers=2)

    def forward(self, rgb: T.Tensor, depth: T.Tensor):
        """
        Inputs:
          rgb:   (B, 3, H, W)
          depth: (B, 1, H, W)
        Output:
          logits: (B, num_classes, H, W)
        """
        # Encode
        f_rgb,   rgb_skips   = self.rgb_encoder(rgb)
        f_depth, depth_skips = self.depth_encoder(depth)

        # Pre-fusion
        f_rgb_pre, f_depth_pre = self.pre_daaf(f_rgb, f_depth)

        # Cross-modal fusion
        daaf_out = self.daaf(f_rgb_pre, f_depth_pre)

        # Post-fusion refine
        post_daaf_out = self.post_daaf(daaf_out)

        # Decoder (returns features at 1/4 resolution)
        # Pass in skips: we use rgb_skips (or depth_skips—they have same shapes)
        decoder_feats = self.decoder(post_daaf_out, rgb_skips)

        # Final segmentation head (restore to full H×W)
        logits = self.output(decoder_feats, target_size=rgb.shape[-2:])

        return logits


In [15]:
class ASPPNeXtDataset(Dataset):
    def __init__(self, image_dir, depth_dir, mask_dir, transform=None):
        self.image_paths = sorted([os.path.join(image_dir, f)
                                   for f in os.listdir(image_dir)])
        self.depth_paths = sorted([os.path.join(depth_dir, f)
                                   for f in os.listdir(depth_dir)])
        self.mask_paths  = sorted([os.path.join(mask_dir, f)
                                   for f in os.listdir(mask_dir)])
        assert len(self.image_paths)==len(self.depth_paths)==len(self.mask_paths)
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # 1. Load RGB image with cv2 and convert BGR→RGB
        img_bgr = cv2.imread(self.image_paths[idx], cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

        # 2. Load depth (assume single-channel image)
        depth_path = self.depth_paths[idx]
        if depth_path.endswith(".npy"):
            depth_np = np.load(depth_path)
        else:
            depth_np = cv2.imread(depth_path, cv2.IMREAD_GRAYSCALE)
        depth = depth_np.astype(np.float32)

        # 3. Load mask as grayscale
        mask_np = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)
        mask = mask_np.astype(np.int64)

        # 4. Resize if needed (e.g., to 384×512)
        if self.transform:
            img = cv2.resize(img, self.transform["size"], interpolation=cv2.INTER_LINEAR)
            depth = cv2.resize(depth, self.transform["size"], interpolation=cv2.INTER_NEAREST)
            mask = cv2.resize(mask, self.transform["size"], interpolation=cv2.INTER_NEAREST)

        # 5. Normalize & to tensor
        #   img: H×W×3 → 3×H×W, scale [0,255]→[0,1]
        img = img.astype(np.float32) / 255.0
        img = np.transpose(img, (2,0,1))
        img_t = T.from_numpy(img)

        #   depth: H×W → 1×H×W, assume already scaled appropriately
        depth = depth[np.newaxis, ...]
        depth_t = T.from_numpy(depth)

        #   mask: H×W → H×W long tensor
        mask_t = T.from_numpy(mask)

        return img_t, depth_t, mask_t

# Example transforms dict
common_transform = {"size": (512,384)}  # width,height for cv2.resize

# Paths
base_img   = r"D:\AAU Internship\Code\CWF-788\IMAGE512x384"
base_depth = r"./Depth_Features"
base_mask  = r"./Weed_Masks"

splits = {
    "train":  ("train_new",    "train_new",    "train"),
    "val":    ("validation_new","validation_new","val"),
    "test":   ("test_new",     "test_new",     "test"),
}

dataloaders = {}
batch_size = 4
for phase, (img_s, depth_s, mask_s) in splits.items():
    ds = ASPPNeXtDataset(
        image_dir = os.path.join(base_img, img_s),
        depth_dir = os.path.join(base_depth, depth_s),
        mask_dir  = os.path.join(base_mask, mask_s),
        transform = common_transform
    )
    dataloaders[phase] = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=(phase=="train"),
        pin_memory=True
    )


In [16]:
model = ASPPNeXtModel(
    in_ch_rgb=3,
    in_ch_depth=1,
    base_dim=64,
    num_heads=(4,8,16,32),
    window_sizes=(8,4,2,1),
    mlp_ratio=4,
    num_classes=2,
    output_upsample='bilinear'
)

# Create dummy inputs (batch_size=1)
rgb_input = T.randn(1, 3, 384, 512)    # (B, C, H, W) = (1, 3, 384, 512)
depth_input = T.randn(1, 1, 384, 512)   # (1, 1, 384, 512)

# Forward pass
with T.no_grad():
    output = model(rgb_input, depth_input)

# Print dimensions
print("Input shape (RGB):", rgb_input.shape)
print("Input shape (Depth):", depth_input.shape)
print("Output shape:", output.shape)

# Verify resolution match
assert output.shape[-2:] == (384, 512), \
    f"Output size {output.shape[-2:]} ≠ Input size (384, 512)"
print("\n✅ Size verification passed! Output matches input resolution.")

TypeError: empty(): argument 'size' failed to unpack the object at pos 2 with error "type must be tuple of ints,but got NoneType"