In [1]:
import math
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import os
import time
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
import tensorrt as trt
import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit
import cv2

In [2]:
device = T.device("cuda" if T.cuda.is_available() else "cpu")

In [3]:
MODELS_DIR = os.path.join(os.getcwd(), "ASPPNeXt Models")
os.makedirs(MODELS_DIR, exist_ok=True)

In [4]:
class DFCAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.dfc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels, kernel_size=1, bias=False),
            nn.Hardswish(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=1, bias=False),
            nn.Sigmoid()  # Remove .add_(1e-6)
        )
        
    def forward(self, x):
        attention = self.dfc(x) + 1e-6
        return x * attention

class GhostModuleV2(nn.Module):
    def __init__(self, in_channels, out_channels,
                 ratio=2, kernel_size=1, dw_size=3,
                 stride=1, use_attn=True):
        super().__init__()
        self.out_channels = out_channels
        init_ch = math.ceil(out_channels / ratio)
        new_ch = init_ch * (ratio - 1)

        # Primary convolution with dynamic padding
        self.primary_conv = nn.Sequential(
            nn.Conv2d(in_channels, init_ch,
                     kernel_size, stride,
                     (kernel_size - 1) // 2,  # More precise padding
                     bias=False),
            nn.BatchNorm2d(init_ch),
            nn.Hardswish(inplace=True),
        )

        # Cheap operation with dynamic padding
        self.cheap_operation = nn.Sequential(
            nn.Conv2d(init_ch, new_ch, dw_size, 1,
                     (dw_size - 1) // 2,  # More precise padding
                     groups=init_ch, bias=False),
            nn.BatchNorm2d(new_ch),
            nn.Hardswish(inplace=True),
        )

        # Attention with stride warning
        self.use_attn = use_attn
        if use_attn:
#            if stride > 1:
#                print(f"Warning: Stride {stride}>1 may reduce attention effectiveness")
            self.attn = DFCAttention(out_channels)

    def forward(self, x):
        x1 = self.primary_conv(x)
        x2 = self.cheap_operation(x1)
        out = T.cat([x1, x2], dim=1)[:, :self.out_channels]
        return self.attn(out) if self.use_attn else out

In [5]:
class VaryingWindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_size, context_ratio=2):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.context_ratio = context_ratio

        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim, bias=False)

    def forward(self, x, H, W):
        B, N, C = x.shape
        P = self.window_size
        R = self.context_ratio

        x_spatial = x.view(B, H, W, C)

        q_windows = x_spatial.unfold(1, P, P).unfold(2, P, P)
        q_windows = q_windows.contiguous().view(B, -1, P * P, C)

        pad = (R * P - P) // 2
        x_pad = F.pad(x_spatial, (0, 0, pad, pad, pad, pad))
        ctx_windows = x_pad.unfold(1, R * P, P).unfold(2, R * P, P)
        ctx_windows = ctx_windows.contiguous().view(B, -1, (R * P) ** 2, C)

        if q_windows.size(1) != ctx_windows.size(1):
            raise ValueError(f"Query and context windows mismatch: {q_windows.shape[1]} vs {ctx_windows.shape[1]}")

        seq = T.cat([q_windows, ctx_windows], dim=2)

        qkv = self.to_qkv(seq)
        q, k, v = qkv.chunk(3, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * (C ** -0.5)
        attn = attn.softmax(dim=-1)
        out = attn @ v

        out = out[:, :, :P * P, :]
        out = out.view(B, H // P, W // P, P, P, C)
        out = out.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, C)
        out = out.view(B, N, C)

        return self.to_out(out)
        

class GhostMLP(nn.Module):
    def __init__(self, dim, mlp_ratio=4):
        super().__init__()
        hidden_dim = dim * mlp_ratio
        self.fc1 = GhostModuleV2(dim, hidden_dim, use_attn=False)
        self.act = nn.GELU()
        self.fc2 = GhostModuleV2(hidden_dim, dim, use_attn=False)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x_spatial = x.view(B, H, W, C).permute(0, 3, 1, 2)
        x_spatial = self.act(self.fc1(x_spatial))
        x_spatial = self.fc2(x_spatial)
        x_flat = x_spatial.permute(0, 2, 3, 1).view(B, N, C)
        return x_flat


In [6]:
class LayerNorm2d(nn.Module):
    """
    Channel-wise LayerNorm for 4D tensors (B, C, H, W), ConvNeXt style.
    """
    def __init__(self, num_channels, eps=1e-6):
        super().__init__()
        self.norm = nn.LayerNorm(num_channels, eps=eps)
    def forward(self, x):
        # x: (B, C, H, W) → (B, H, W, C)
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        # → (B, C, H, W)
        return x.permute(0, 3, 1, 2)

In [7]:
class HybridConvNeXtBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size,
                 context_ratio=2, mlp_ratio=4):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, 7, padding=3,
                                groups=dim, bias=False)
        self.norm1 = nn.LayerNorm(dim)
        self.attn = VaryingWindowAttention(dim, num_heads,
                                           window_size, context_ratio)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = GhostMLP(dim, mlp_ratio)

    def forward(self, x):
        x = self.dwconv(x)
        B, C, H, W = x.shape
        seq = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
        seq = seq + self.attn(self.norm1(seq), H, W)
        seq = seq + self.mlp(self.norm2(seq), H, W)
        return seq.view(B, H, W, C).permute(0, 3, 1, 2)


In [8]:
class ASPPNeXtEncoder(nn.Module):
    """
    ASPPNeXtEncoder with patch embedding, hierarchical Ghost blocks, and LayerNorm2d.
    """
    def __init__(self, in_ch, base_dim=64,
                 num_heads=(4, 8, 16, 32),
                 window_sizes=(16, 8, 4, 2),  # ✅ Updated here
                 mlp_ratio=4):
        super().__init__()
        
        # 1. Stem layer (Patch Embedding)
        self.stem = nn.Sequential(
            nn.Conv2d(in_ch, base_dim, kernel_size=4, stride=4, bias=False),  # Down to 1/4
            LayerNorm2d(base_dim)
        )
        
        # 2. Four stages of encoder
        self.stages = nn.ModuleList()
        for i in range(4):
            dim = base_dim * (2 ** i)
            down = None
            
            # Downsample before stages 2-4
            if i > 0:
                down = nn.Sequential(
                    GhostModuleV2(
                        in_channels=dim // 2,
                        out_channels=dim,
                        ratio=2,
                        kernel_size=2,
                        stride=2,
                        use_attn=True
                    ),
                    LayerNorm2d(dim)
                )
                
            block = HybridConvNeXtBlock(
                dim=dim,
                num_heads=num_heads[i],
                window_size=window_sizes[i],  # ✅ Applies updated values
                context_ratio=2,
                mlp_ratio=mlp_ratio
            )
            
            self.stages.append(nn.ModuleDict({'down': down, 'block': block}))

    def forward(self, x):
        # x: (B, C, H, W)
        x = self.stem(x)  # → (B, base_dim, H/4, W/4)
        skips = []
        for stage in self.stages:
            if stage['down'] is not None:
                x = stage['down'](x)
            x = stage['block'](x)
            skips.append(x)
        return x, skips


In [9]:
class PreDAAFGhostV2(nn.Module):
    def __init__(self, channels: int, reduction: int = 4, ratio: int = 2):
        super().__init__()
        mid_channels = channels // reduction

        # 1. Reduce channels (C → C//reduction)
        self.reduce = GhostModuleV2(
            in_channels=channels,
            out_channels=mid_channels,
            ratio=ratio,
            kernel_size=1,
            stride=1,
            use_attn=True
        )

        # 2. Lightweight depthwise conv
        self.dwconv = nn.Sequential(
            nn.Conv2d(mid_channels, mid_channels,
                      kernel_size=3, stride=1, padding=1,
                      groups=mid_channels, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.Hardswish(inplace=True)
        )

        # 3. Expand back to original channels (C//reduction → C)
        self.expand = GhostModuleV2(
            in_channels=mid_channels,
            out_channels=channels,
            ratio=ratio,
            kernel_size=1,
            stride=1,
            use_attn=True
        )

    def forward(self, x: T.Tensor) -> T.Tensor:
        """
        Input:  (B, C, H, W)
        Output: (B, C, H, W)
        """
        x = self.reduce(x)
        x = self.dwconv(x)
        x = self.expand(x)
        return x


In [10]:
class RDSCBLocal(nn.Module):
    """
    Residual Depthwise Separable Convolutions Branch (local branch)
    Applies GhostModuleV2 with kernel sizes 1,3,5,7, then fuses.
    """
    def __init__(self, channels: int):
        super().__init__()
        # Four multi‐scale GhostConv branches
        self.conv1 = GhostModuleV2(channels, channels, kernel_size=1, use_attn=False)
        self.conv3 = GhostModuleV2(channels, channels, kernel_size=3, use_attn=False)
        self.conv5 = GhostModuleV2(channels, channels, kernel_size=5, use_attn=False)
        self.conv7 = GhostModuleV2(channels, channels, kernel_size=7, use_attn=False)
        # Fusion 1×1 GhostConv
        self.fuse = GhostModuleV2(channels * 4, channels, kernel_size=1, use_attn=False)
        self.act = nn.LeakyReLU(inplace=True)

    def forward(self, x: T.Tensor) -> T.Tensor:
        # x: (B, C, H, W)
        b1 = self.act(self.conv1(x))
        b3 = self.act(self.conv3(x))
        b5 = self.act(self.conv5(x))
        b7 = self.act(self.conv7(x))
        # Concatenate multi‐scale features
        cat = T.cat([b1, b3, b5, b7], dim=1)  # (B, 4C, H, W)
        # Fuse back to C channels
        out = self.act(self.fuse(cat))            # (B, C, H, W)
        return out


class LIA(nn.Module):
    """
    Local Interaction Attention: cross-modal local feature fusion mechanism
    as conventionally used in DAAF. Uses both spatial average pooling and 
    standard deviation pooling to capture local contrast and variance.
    """
    def __init__(self, channels: int):
        super().__init__()
        # Global pooling operations
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.std_pool = lambda x: T.sqrt(T.var(x, dim=(2,3), keepdim=True) + 1e-8)
        
        # Learnable coefficients for combining avg and std features
        self.alpha = nn.Parameter(T.ones(1))
        self.beta = nn.Parameter(T.ones(1))
        
        # MLP for generating attention weights
        self.mlp = nn.Sequential(
            nn.Conv2d(channels, channels // 4, kernel_size=1, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(channels // 4, channels, kernel_size=1, bias=False),
            nn.Sigmoid()
        )
        
        # Final refinement convolution - note: use_attn=False as discussed
        self.refine = GhostModuleV2(channels, channels, kernel_size=3, use_attn=False)
        self.act = nn.LeakyReLU(inplace=True)

    def forward(self, x: T.Tensor) -> T.Tensor:
        """
        x: (B, C, H, W) - concatenated or summed local features from RGB+Depth
        returns: (B, C, H, W) - refined local features with attention
        """
        # Generate modality-aware saliency maps
        avg_feat = self.avg_pool(x)      # (B, C, 1, 1)
        std_feat = self.std_pool(x)      # (B, C, 1, 1)
        
        # Combine avg and std with learnable coefficients
        combined_feat = self.alpha * avg_feat + self.beta * std_feat
        
        # Generate attention weights
        attention = self.mlp(combined_feat)  # (B, C, 1, 1)
        
        # Apply attention and refine
        attended = x * attention
        refined = self.act(self.refine(attended))
        
        return refined


class InteractiveTransformerBlock(nn.Module):
    """
    Interactive Transformer Block (ITB): Cross‐modal interactive self‐attention + Ghost MLP FFN.
    Produces two outputs: global features for RGB and Depth.
    """
    def __init__(self, channels: int, num_heads: int):
        super().__init__()
        self.channels = channels
        self.num_heads = num_heads
        
        # Layer normalization before attention
        self.norm1 = nn.LayerNorm(channels)
        
        # Cross-modal multi-head attention
        self.attn = nn.MultiheadAttention(
            embed_dim=channels,
            num_heads=num_heads,
            batch_first=True,
            dropout=0.1
        )
        
        # Layer normalization after residual connection
        self.norm2 = nn.LayerNorm(channels)
        
        # Ghost-based FFN: C→4C→C with GELU activation
        # Note: use_attn=False to avoid redundant attention in DAAF context
        self.ffn = nn.ModuleDict({
            'expand': GhostModuleV2(channels, channels * 4, kernel_size=1, use_attn=False),
            'contract': GhostModuleV2(channels * 4, channels, kernel_size=1, use_attn=False)
        })
        self.ffn_act = nn.GELU()
        self.ffn_dropout = nn.Dropout(0.1)

    def _apply_ffn(self, x: T.Tensor) -> T.Tensor:
        """Apply Ghost-based FFN to sequence data"""
        B, N, C = x.shape
        H = W = int(N ** 0.5)
        
        # Reshape to 4D for GhostModuleV2
        x_4d = x.view(B, H, W, C).permute(0, 3, 1, 2)  # (B, C, H, W)
        
        # Apply FFN
        x_4d = self.ffn['expand'](x_4d)
        x_4d = self.ffn_act(x_4d)
        x_4d = self.ffn_dropout(x_4d)
        x_4d = self.ffn['contract'](x_4d)
        
        # Reshape back to sequence
        x_out = x_4d.permute(0, 2, 3, 1).view(B, N, C)
        return x_out

    def forward(self, f_rgb: T.Tensor, f_depth: T.Tensor) -> tuple:
        """
        f_rgb, f_depth: (B, C, H, W)
        returns: (g_rgb, g_depth) each (B, C, H, W)
        """
        B, C, H, W = f_rgb.shape
        N = H * W
        
        # Flatten spatial dimensions for attention
        rgb_seq = f_rgb.permute(0, 2, 3, 1).view(B, N, C)      # (B, N, C)
        depth_seq = f_depth.permute(0, 2, 3, 1).view(B, N, C)  # (B, N, C)
        
        # Normalize before attention
        rgb_norm = self.norm1(rgb_seq)
        depth_norm = self.norm1(depth_seq)
        
        # Cross-modal attention: RGB queries attend to Depth keys/values
        g_rgb_seq, _ = self.attn(
            query=rgb_norm,
            key=depth_norm,
            value=depth_norm
        )
        
        # Cross-modal attention: Depth queries attend to RGB keys/values  
        g_depth_seq, _ = self.attn(
            query=depth_norm,
            key=rgb_norm,
            value=rgb_norm
        )
        
        # First residual connection + normalization
        g_rgb_seq = self.norm2(rgb_seq + g_rgb_seq)
        g_depth_seq = self.norm2(depth_seq + g_depth_seq)
        
        # FFN with second residual connection
        g_rgb_seq = g_rgb_seq + self._apply_ffn(g_rgb_seq)
        g_depth_seq = g_depth_seq + self._apply_ffn(g_depth_seq)
        
        # Reshape back to spatial format
        g_rgb = g_rgb_seq.view(B, H, W, C).permute(0, 3, 1, 2)     # (B, C, H, W)
        g_depth = g_depth_seq.view(B, H, W, C).permute(0, 3, 1, 2) # (B, C, H, W)
        
        return g_rgb, g_depth


In [11]:
class DAAFBlock(nn.Module):
    def __init__(self, channels: int, num_heads: int):
        super().__init__()
        # 1. Local branch per modality
        self.local_branch = RDSCBLocal(channels)
        # Fuse local RGB/depth outputs back to C channels
        self.local_fuse = GhostModuleV2(
            in_channels=channels * 2,
            out_channels=channels,
            kernel_size=1,
            use_attn=False
        )
        self.lia = LIA(channels)

        # 2. Global branch
        self.itb = InteractiveTransformerBlock(channels, num_heads)

        # 3. Global fusion conv: (local + global_rgb + global_depth) → channels
        self.global_fuse = GhostModuleV2(
            in_channels=channels * 3,
            out_channels=channels,
            kernel_size=3,
            use_attn=False
        )
        self.act = nn.LeakyReLU(inplace=True)

        # 4. Reconstruction head: three cascaded GhostConv(3×3)
        self.reconstruction = nn.Sequential(
            GhostModuleV2(channels, channels, kernel_size=3, use_attn=False),
            nn.LeakyReLU(inplace=True),
            GhostModuleV2(channels, channels, kernel_size=3, use_attn=False),
            nn.LeakyReLU(inplace=True),
            GhostModuleV2(channels, channels, kernel_size=3, use_attn=False),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self,
                f_rgb_pre: T.Tensor,
                f_depth_pre: T.Tensor) -> T.Tensor:
        # 1. Local branch on each modality
        lr = self.lia(self.local_branch(f_rgb_pre))
        ld = self.lia(self.local_branch(f_depth_pre))
        # Fuse local RGB + Depth features
        local = T.cat([lr, ld], dim=1)   # (B, 2C, H, W)
        local = self.act(self.local_fuse(local))  # (B, C, H, W)

        # 2. Global interactive attention
        g_rgb, g_depth = self.itb(f_rgb_pre, f_depth_pre)

        # 3. Global fusion
        cat = T.cat([local, g_rgb, g_depth], dim=1)  # (B, 3C, H, W)
        fused = self.act(self.global_fuse(cat))         # (B, C, H, W)

        # 4. Reconstruction head
        out = self.reconstruction(fused)  # (B, C, H, W)
        return out


In [12]:
class PostDAAFGhostV2(nn.Module):
    """
    Post-DAAF GhostModule bottleneck using GhostModuleV2 with DFC attention.
    Mirrors the Pre-DAAF structure:
      1) GhostModuleV2 1×1 (C→C//reduction)
      2) 3×3 depthwise conv
      3) GhostModuleV2 1×1 (C//reduction→C)
    """
    def __init__(self, channels: int, reduction: int = 4, ratio: int = 2):
        super().__init__()
        mid_ch = channels // reduction

        # 1×1 GhostModuleV2 reduce channels
        self.reduce = GhostModuleV2(
            in_channels=channels,
            out_channels=mid_ch,
            ratio=ratio,
            kernel_size=1,
            stride=1,
            use_attn=True
        )
        # 3×3 depthwise conv on reduced features
        self.dwconv = nn.Sequential(
            nn.Conv2d(
                mid_ch, mid_ch,
                kernel_size=3, stride=1,
                padding=1, groups=mid_ch,
                bias=False
            ),
            nn.BatchNorm2d(mid_ch),
            nn.Hardswish(inplace=True)
        )
        # 1×1 GhostModuleV2 expand back to original channels
        self.expand = GhostModuleV2(
            in_channels=mid_ch,
            out_channels=channels,
            ratio=ratio,
            kernel_size=1,
            stride=1,
            use_attn=True
        )

    def forward(self, x: T.Tensor) -> T.Tensor:
        """
        x: (B, C, H, W) — fused feature map from DAAFBlock
        returns: (B, C, H, W) — channel-refined output
        """
        x = self.reduce(x)    # → (B, C//reduction, H, W)
        x = self.dwconv(x)    # → (B, C//reduction, H, W)
        x = self.expand(x)    # → (B, C, H, W)
        return x

# Example integration:

# After you obtain `out` from your DAAFBlock:
# daaf = DAAFBlock(channels=..., num_heads=...)
# fused = daaf(f_rgb_pre, f_depth_pre)
#
# post_daaf = PostDAAFGhostV2(channels=fused.shape[1], reduction=4, ratio=2)
# refined = post_daaf(fused)


In [13]:
class DilatedGhostBlock(nn.Module):
    """
    GhostModuleV2 with configurable dilation in the depthwise (cheap) path.
    Keeps the model lightweight for edge deployment.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, use_attn=False):
        super().__init__()

        self.block = GhostModuleV2(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=1,
            ratio=2,
            use_attn=use_attn
        )

        # Set dilation and padding for the depthwise part (cheap_operation[0])
        self.block.cheap_operation[0].dilation = (dilation, dilation)
        self.block.cheap_operation[0].padding = (dilation, dilation)

        self.act = nn.LeakyReLU(inplace=True)

    def forward(self, x):
        return self.act(self.block(x))


In [14]:
class GhostASPPFELAN(nn.Module):
    """
    Ghost ASPPFELAN Block (dilation-based):
      - Branch 1: GhostConv(3×3, dilation=1) → LeakyReLU
      - Branch 2: GhostConv(3×3, dilation=3) → LeakyReLU
      - Branch 3: GhostConv(3×3, dilation=5) → LeakyReLU
      - Branch 4: GhostConv(3×3, dilation=7) → LeakyReLU
      - Concatenate → GhostConv(1×1) → LeakyReLU
    """
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

        # Helper to create branch with specific dilation
        def make_branch(dilation):
            branch = GhostModuleV2(in_channels, out_channels, kernel_size=3, use_attn=False)
            # Apply dilation and padding to depthwise (cheap) path
            branch.cheap_operation[0].dilation = (dilation, dilation)
            branch.cheap_operation[0].padding = (dilation, dilation)
            return nn.Sequential(branch, nn.LeakyReLU(inplace=True))

        self.branch1 = make_branch(1)
        self.branch2 = make_branch(3)
        self.branch3 = make_branch(5)
        self.branch4 = make_branch(7)

        self.fuse = nn.Sequential(
            GhostModuleV2(out_channels * 4, out_channels, kernel_size=1, use_attn=False),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x: T.Tensor) -> T.Tensor:
        b1 = self.branch1(x)
        b2 = self.branch2(x)
        b3 = self.branch3(x)
        b4 = self.branch4(x)

        cat = T.cat([b1, b2, b3, b4], dim=1)  # (B, 4*out_ch, H, W)
        out = self.fuse(cat)
        return out


In [15]:
import torch as T
import torch.nn as nn
import torch.nn.functional as F

class CoordAttention(nn.Module):
    """
    Coordinate Attention block using GhostModuleV2 instead of standard 1x1 convs.
    Applies separate attention along height and width axes.
    Input:  (B, C, H, W)
    Output: (B, C, H, W)
    """
    def __init__(self, channels: int, reduction: int = 4):
        super().__init__()
        mid_ch = max(8, channels // reduction)  # Prevent too small intermediate channels

        # Shared intermediate transformation
        self.conv1 = GhostModuleV2(
            in_channels=channels,
            out_channels=mid_ch,
            kernel_size=1,
            use_attn=False
        )
        self.act = nn.LeakyReLU(inplace=True)

        # Independent convs to recover attention maps
        self.conv_h = GhostModuleV2(
            in_channels=mid_ch,
            out_channels=channels,
            kernel_size=1,
            use_attn=False
        )
        self.conv_w = GhostModuleV2(
            in_channels=mid_ch,
            out_channels=channels,
            kernel_size=1,
            use_attn=False
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: T.Tensor) -> T.Tensor:
        B, C, H, W = x.shape

        # Pool height and width separately
        z_h = F.adaptive_avg_pool2d(x, (H, 1))         # (B, C, H, 1)
        z_w = F.adaptive_avg_pool2d(x, (1, W))         # (B, C, 1, W)
        z_w = z_w.permute(0, 1, 3, 2)                  # (B, C, W, 1)

        # Concatenate along spatial dimension
        z = T.cat([z_h, z_w], dim=2)                   # (B, C, H+W, 1)

        # Pass through shared conv
        z = self.conv1(z)                              # (B, mid_ch, H+W, 1)
        z = self.act(z)

        # Split and apply separate attention along H and W
        z_h, z_w = T.split(z, [H, W], dim=2)           # z_h: (B, mid_ch, H, 1), z_w: (B, mid_ch, W, 1)
        a_h = self.sigmoid(self.conv_h(z_h))           # (B, C, H, 1)
        a_w = self.sigmoid(self.conv_w(z_w.permute(0, 1, 3, 2)))  # (B, C, 1, W)

        # Apply attention maps
        out = x * a_h * a_w                            # (B, C, H, W)
        return out


In [16]:
class DySample(nn.Module):
    """
    Dynamic content-aware upsampling (×scale) via point sampling.
    Now optionally uses GhostModuleV2 or standard Conv2d for feature transform.
    """
    def __init__(self,
                 in_channels: int,
                 scale: int = 2,
                 use_feat_transform: bool = False,
                 use_ghost: bool = True):
        super().__init__()
        self.scale = scale
        self.use_feat_transform = use_feat_transform

        # Optional feature transform layer
        if self.use_feat_transform:
            if use_ghost:
                self.transform = GhostModuleV2(
                    in_channels, in_channels,
                    kernel_size=1, use_attn=False
                )
            else:
                self.transform = nn.Conv2d(
                    in_channels, in_channels,
                    kernel_size=1, stride=1, padding=0, bias=False
                )
        else:
            self.transform = None

        # Offset prediction layer
        self.offset_conv = nn.Conv2d(
            in_channels,
            2 * scale * scale,
            kernel_size=1,
            bias=True
        )

    def forward(self, x: T.Tensor) -> T.Tensor:
        B, C, H, W = x.shape

        # Optional transform
        if self.transform is not None:
            x = self.transform(x)

        # Bilinear upsample
        H2, W2 = H * self.scale, W * self.scale
        x_interp = F.interpolate(
            x, size=(H2, W2),
            mode='bilinear',
            align_corners=False
        )

        # Predict offsets
        offsets = self.offset_conv(x)

        # Reshape and permute offsets
        offsets = offsets.view(B, 2, self.scale, self.scale, H, W)
        offsets = offsets.permute(0, 1, 4, 2, 5, 3)
        offsets = offsets.reshape(B, 2, H2, W2)

        # Build base grid
        device = x.device
        ys = T.linspace(-1, 1, H2, device=device)
        xs = T.linspace(-1, 1, W2, device=device)
        grid_y, grid_x = T.meshgrid(ys, xs, indexing='ij')
        base_grid = T.stack((grid_x, grid_y), dim=-1).unsqueeze(0).expand(B, -1, -1, -1)

        # Normalize offsets
        norm_factor = T.tensor(
            [2.0 / max(W2 - 1, 1), 2.0 / max(H2 - 1, 1)],
            device=device
        ).view(1, 2, 1, 1)
        offsets_norm = offsets * norm_factor

        # Apply sampling
        sampling_grid = base_grid + offsets_norm.permute(0, 2, 3, 1)
        x_up = F.grid_sample(
            x_interp, sampling_grid,
            mode='bilinear',
            padding_mode='border',
            align_corners=True
        )

        return x_up


In [17]:
class ASPPNeXtDecoder(nn.Module):
    def __init__(self,
                 feature_channels: int,
                 skip_channels: list,
                 decoder_channels: list,
                 coord_reduction: int = 4,
                 use_feat_transform: bool = False):
        super().__init__()
        assert len(skip_channels) == 3 and len(decoder_channels) == 3

        self.stages = nn.ModuleList()
        in_ch = feature_channels  # Start with E4 output (e.g., 512 or 1024)

        for i in range(3):
            skip_ch = skip_channels[i]  # Use in natural E4→E3→E2 order
            out_ch = decoder_channels[i]
            concat_ch = in_ch + skip_ch

            block = nn.ModuleDict({
                'fuse_skip': GhostModuleV2(concat_ch, out_ch, kernel_size=1, use_attn=False),
                'aspp': GhostASPPFELAN(in_channels=out_ch, out_channels=out_ch),
                'coord': CoordAttention(channels=out_ch, reduction=coord_reduction),
                'upsample': DySample(in_channels=out_ch, scale=2, use_feat_transform=use_feat_transform)
            })

            self.stages.append(block)
            in_ch = out_ch  # Output becomes input for next stage

    def forward(self, x, skips):
        # skips must be [E4, E3, E2] — highest to lowest
        for i, stage in enumerate(self.stages):
            skip = skips[i]  # Correct: access E4, then E3, then E2
            x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
            x = T.cat([x, skip], dim=1)
            x = stage['fuse_skip'](x)
            x = stage['aspp'](x)
            x = stage['coord'](x)
            x = stage['upsample'](x)
        return x


In [18]:
class ASPPNeXtOutputLayer(nn.Module):
    """
    Separate output layer that takes decoder features at 1/4 resolution
    and produces final segmentation masks at full input resolution.
    
    Supports multiple upsampling strategies:
    - Learnable transpose convolution (ConvTranspose2d)
    - Bilinear interpolation + refinement conv
    - DySample-based learnable upsampling
    """
    def __init__(self,
                 in_channels: int,
                 num_classes: int,
                 upsampling_method: str = 'transpose',
                 use_ghost_conv: bool = False,
                 refinement_layers: int = 1):
        """
        Args:
          in_channels:        channels from final decoder stage
          num_classes:        number of segmentation classes
          upsampling_method:  'transpose', 'bilinear', 'dysample'
          use_ghost_conv:     whether to use GhostModuleV2 for final conv
          refinement_layers:  number of refinement convs after upsampling
        """
        super().__init__()
        self.upsampling_method = upsampling_method
        self.num_classes = num_classes
        
        # Upsampling strategies
        if upsampling_method == 'transpose':
            # Learnable 4× upsampling via transpose convolution
            self.upsample = nn.ConvTranspose2d(
                in_channels=in_channels,
                out_channels=in_channels,
                kernel_size=4,
                stride=4,
                padding=0,
                bias=False
            )
            
        elif upsampling_method == 'bilinear':
            # Fixed bilinear interpolation
            self.upsample = lambda x: F.interpolate(
                x, scale_factor=4, 
                mode='bilinear', 
                align_corners=False
            )
            
        elif upsampling_method == 'dysample':
            # Content-aware DySample upsampling
            self.upsample = DySample(
                in_channels=in_channels,
                scale=4,  # 4× upsampling in one step
                use_feat_transform=True
            )
        else:
            raise ValueError(f"Unknown upsampling method: {upsampling_method}")
        
        # Optional refinement layers
        refinement = []
        for i in range(refinement_layers):
            if use_ghost_conv:
                refinement.append(GhostModuleV2(
                    in_channels, in_channels,
                    kernel_size=3, use_attn=False
                ))
            else:
                refinement.append(nn.Conv2d(
                    in_channels, in_channels,
                    kernel_size=3, padding=1, bias=False
                ))
            refinement.append(nn.BatchNorm2d(in_channels))
            refinement.append(nn.ReLU(inplace=True))
            
        self.refinement = nn.Sequential(*refinement)
        
        # Final classification layer
        if use_ghost_conv:
            self.classifier = GhostModuleV2(
                in_channels, num_classes,
                kernel_size=1, use_attn=False
            )
        else:
            self.classifier = nn.Conv2d(
                in_channels, num_classes,
                kernel_size=1, bias=True
            )

    def forward(self, x: T.Tensor, target_size: tuple = None):
        """
        Args:
          x: (B, C, H/4, W/4) decoder output at 1/4 resolution
          target_size: (H, W) target output size, if None uses 4×input size
        Returns:
          logits: (B, num_classes, H, W) at full resolution
        """
        # Upsample to full resolution
        if self.upsampling_method == 'bilinear':
            if target_size is not None:
                x = F.interpolate(x, size=target_size, 
                                mode='bilinear', align_corners=False)
            else:
                x = self.upsample(x)
        else:
            x = self.upsample(x)
            
        # Apply refinement layers
        x = self.refinement(x)
        
        # Final classification
        logits = self.classifier(x)
        
        # Ensure exact target size if specified
        if target_size is not None and logits.shape[-2:] != target_size:
            logits = F.interpolate(logits, size=target_size,
                                 mode='bilinear', align_corners=False)
            
        return logits

# Example usage:
def create_asppnext_with_separate_output(feature_channels=512,
                                       skip_channels=[128, 256, 512],
                                       decoder_channels=[256, 128, 64],
                                       num_classes=21,
                                       upsampling_method='transpose'):
    """
    Factory function to create decoder + output layer
    """
    decoder = ASPPNeXtDecoder(
        feature_channels=feature_channels,
        skip_channels=skip_channels,
        decoder_channels=decoder_channels
    )
    
    output_layer = ASPPNeXtOutputLayer(
        in_channels=decoder_channels[-1],  # Final decoder stage channels
        num_classes=num_classes,
        upsampling_method=upsampling_method,
        use_ghost_conv=False,  # Set to True to experiment with GhostModuleV2
        refinement_layers=2
    )
    
    return decoder, output_layer

# Complete forward pass example:
# decoder, output_layer = create_asppnext_with_separate_output()
# 
# # From your pipeline:
# decoder_features = decoder(post_daaf_output, encoder_skips)  # (B, 64, 128, 96)
# final_logits = output_layer(decoder_features, target_size=(384, 512))  # (B, 21, 384, 512)


In [19]:
class ASPPNeXtModel(nn.Module):
    def __init__(self,
                 in_ch_rgb: int = 3,
                 in_ch_depth: int = 1,
                 base_dim: int = 64,
                 num_heads: tuple = (4, 8, 16, 32),
                 window_sizes: tuple = (8, 4, 2, 1),
                 mlp_ratio: int = 4,
                 skip_channels: list = None,
                 decoder_channels: list = None,
                 num_classes: int = 21,
                 coord_reduction: int = 4,
                 output_upsample: str = 'transpose',
                 use_feat_transform: bool = False):
        super().__init__()

        # 1) Dual Encoders (RGB and Depth)
        self.rgb_encoder = ASPPNeXtEncoder(
            in_ch=in_ch_rgb,
            base_dim=base_dim,
            num_heads=num_heads,
            window_sizes=window_sizes,
            mlp_ratio=mlp_ratio
        )

        self.depth_encoder = ASPPNeXtEncoder(
            in_ch=in_ch_depth,
            base_dim=base_dim,
            num_heads=num_heads,
            window_sizes=window_sizes,
            mlp_ratio=mlp_ratio
        )

        enc_out_ch = base_dim * 2 ** 3  # Final encoder stage channels (E4 = 512)

        # 2) PreDAAF for RGB and Depth independently
        self.rgb_predaaf = PreDAAFGhostV2(channels=enc_out_ch, reduction=4, ratio=2)
        self.depth_predaaf = PreDAAFGhostV2(channels=enc_out_ch, reduction=4, ratio=2)

        # 3) DAAF Fusion
        self.daaf = DAAFBlock(channels=enc_out_ch, num_heads=num_heads[-1])

        # 4) PostDAAF Enhancement
        self.post_daaf = PostDAAFGhostV2(channels=enc_out_ch, reduction=4, ratio=2)

        # 5) Decoder
        if skip_channels is None:
            # ✅ Fixed order: E4, E3, E2 → 512, 256, 128
            skip_channels = [base_dim * 2 ** i for i in reversed(range(1, 4))]  # [512, 256, 128]
        if decoder_channels is None:
            decoder_channels = [base_dim * 2 ** i for i in reversed(range(3))]  # [256, 128, 64]

        self.decoder = ASPPNeXtDecoder(
            feature_channels=enc_out_ch,
            skip_channels=skip_channels,
            decoder_channels=decoder_channels,
            coord_reduction=coord_reduction,
            use_feat_transform=use_feat_transform
        )

        # 6) Output Layer
        self.output = ASPPNeXtOutputLayer(
            in_channels=decoder_channels[-1],
            num_classes=num_classes,
            upsampling_method=output_upsample,
            use_ghost_conv=False,
            refinement_layers=2
        )

    def forward(self, rgb: T.Tensor, depth: T.Tensor) -> T.Tensor:
        """
        Full forward pass:
        encoder → pre-DAAF (per stream) → DAAF → post-DAAF → decoder → output
        """
        # Step 1: Dual Encoding
        _, rgb_feats = self.rgb_encoder(rgb)
        _, depth_feats = self.depth_encoder(depth)

        # Step 2: PreDAAF on each feature stream
        rgb_refined = self.rgb_predaaf(rgb_feats[-1])
        depth_refined = self.depth_predaaf(depth_feats[-1])

        # Step 3: DAAF Fusion
        fused = self.daaf(rgb_refined, depth_refined)

        # Step 4: Post-DAAF Enhancement
        enhanced = self.post_daaf(fused)

        # Step 5: Decode (using RGB skip connections, [E4, E3, E2])
        decoder_feats = self.decoder(enhanced, rgb_feats[-1:-4:-1])  # ✅ reversed skip list

        # Step 6: Final Classification
        logits = self.output(decoder_feats, target_size=rgb.shape[2:])  # Restore original res

        return logits


In [20]:
class ASPPNeXtDataset(Dataset):
    def __init__(self, image_dir, depth_dir, mask_dir, window_size=8):
        self.window_size = window_size

        self.image_files = sorted(
            [f for f in os.listdir(image_dir) if f.endswith('.jpg')],
            key=lambda x: int(x.split('_')[0])
        )

        if len(self.image_files) == 0:
            raise ValueError(f"No JPG files found in {image_dir}")

        self.depth_files = [f.replace('.jpg', '.png').replace('image', 'image_depth') for f in self.image_files]
        self.mask_files = [f.replace('.jpg', '.png') for f in self.image_files]

        self.image_paths = [os.path.join(image_dir, f) for f in self.image_files]
        self.depth_paths = [os.path.join(depth_dir, f) for f in self.depth_files]
        self.mask_paths = [os.path.join(mask_dir, f) for f in self.mask_files]

        for p in self.image_paths + self.depth_paths + self.mask_paths:
            if not os.path.exists(p):
                raise FileNotFoundError(f"File {p} does not exist")

    def pad_to_square_multiple(self, arr, multiple=8):
        h, w = arr.shape[:2]
        max_dim = max(h, w)
        final_dim = ((max_dim + multiple - 1) // multiple) * multiple
        pad_h = final_dim - h
        pad_w = final_dim - w
    
        pad_top = pad_h // 2
        pad_bottom = pad_h - pad_top
        pad_left = pad_w // 2
        pad_right = pad_w - pad_left
    
        if arr.ndim == 3:
            return np.pad(arr, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='constant')
        else:
            return np.pad(arr, ((pad_top, pad_bottom), (pad_left, pad_right)), mode='constant')


    def __getitem__(self, idx):
        img = cv2.cvtColor(cv2.imread(self.image_paths[idx]), cv2.COLOR_BGR2RGB)
        depth = cv2.imread(self.depth_paths[idx], cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)

        img = self.pad_to_square_multiple(img, self.window_size)
        depth = self.pad_to_square_multiple(depth, self.window_size)
        mask = self.pad_to_square_multiple(mask, self.window_size)

        return (
            T.from_numpy(img).float().permute(2, 0, 1) / 255.0,  # [3, H, W]
            T.from_numpy(depth).float().unsqueeze(0),            # [1, H, W]
            T.from_numpy(mask).long()                            # [H, W]
        )

    def __len__(self):
        return len(self.image_paths)


# ------------------ Dataloader Setup ------------------

current_dir = os.getcwd()

# Base directories
base_img = r"D:\AAU Internship\Code\CWF-788\IMAGE512x384"
base_depth = os.path.join(current_dir, "Depth_Features")
base_mask = os.path.join(current_dir, "Weed_Masks")

for dir_path in [base_img, base_depth, base_mask]:
    if not os.path.isdir(dir_path):
        raise FileNotFoundError(f"Directory {dir_path} does not exist")

# Dataset splits
splits = {
    "train": ("train_new", "train_new", "train"),
    "val": ("validation_new", "validation_new", "val"),
    "test": ("test_new", "test_new", "test"),
}

# Match window size to model's max attention window size
model_window_sizes = (8, 4, 2, 1)
window_size = max(model_window_sizes)  # -> 8

dataloaders = {}
batch_size = 4

for phase, (img_s, depth_s, mask_s) in splits.items():
    img_dir = os.path.join(base_img, img_s)
    depth_dir = os.path.join(base_depth, depth_s)
    mask_dir = os.path.join(base_mask, mask_s)

    for d in [img_dir, depth_dir, mask_dir]:
        if not os.path.isdir(d):
            raise FileNotFoundError(f"Directory {d} does not exist")

    try:
        ds = ASPPNeXtDataset(
            image_dir=img_dir,
            depth_dir=depth_dir,
            mask_dir=mask_dir,
            window_size=window_size
        )
        print(f"{phase} dataset loaded with {len(ds)} samples")

        dataloaders[phase] = DataLoader(
            ds,
            batch_size=batch_size,
            shuffle=(phase == "train"),
            pin_memory=True,
        )
    except Exception as e:
        print(f"Error loading {phase} dataset: {str(e)}")
        raise


train dataset loaded with 1600 samples
val dataset loaded with 352 samples
test dataset loaded with 1200 samples


In [21]:
def compute_metrics(preds, targets, num_classes=2):
    """
    Computes all metrics for a batch
    Returns: dict of {
        'miou', 'weed_iou', 'mPA', 'accuracy', 
        'precision', 'recall', 'f1', 'fnr'
    }
    """
    preds = preds.argmax(1)
    mask = (targets >= 0) & (targets < num_classes)
    hist = T.bincount(
        num_classes * targets[mask] + preds[mask],
        minlength=num_classes**2
    ).reshape(num_classes, num_classes).float()
    
    tp = T.diag(hist)
    fp = hist.sum(0) - tp
    fn = hist.sum(1) - tp
    
    metrics = {
        'miou': T.mean(tp / (tp + fp + fn + 1e-10)).item(),
        'weed_iou': (tp[1] / (tp[1] + fp[1] + fn[1] + 1e-10)).item(),
        'mPA': T.mean(tp / (tp + fp + 1e-10)).item(),
        'accuracy': (tp.sum() / hist.sum()).item(),
        'precision': (tp[1] / (tp[1] + fp[1] + 1e-10)).item(),
        'recall': (tp[1] / (tp[1] + fn[1] + 1e-10)).item(),
        'f1': (2 * tp[1] / (2 * tp[1] + fp[1] + fn[1] + 1e-10)).item(),
        'fnr': (fn[1] / (fn[1] + tp[1] + 1e-10)).item()
    }
    return metrics


In [22]:
model = ASPPNeXtModel(
    in_ch_rgb=3,
    in_ch_depth=1,
    base_dim=64,
    num_heads=(4,8,16,32),
    window_sizes=(16,8,4,2),
    mlp_ratio=4,
    num_classes=2,
    output_upsample='dysample',
    use_feat_transform=True           # <-- ✅ control GhostModule or Conv1x1
).to(device)


In [23]:
optimizer = T.optim.AdamW(
    model.parameters(),
    lr=6e-5,          # Base learning rate
    weight_decay=0.01, # L2 regularization strength
    betas=(0.9, 0.999) # Momentum parameters:
                       # - beta1: gradient moving average decay (0.9)
                       # - beta2: squared gradient moving average decay (0.999)
)

In [24]:
class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.7, beta=0.3, gamma=0.75, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.smooth = smooth
        self.eps = 1e-8  # Numerical stability

    def update_hyperparams(self, epoch):
        """Dynamically adjust hyperparameters every 5 epochs"""
        steps = epoch // 5
        self.alpha = max(0.4, 0.7 - 0.03*steps)  # Linearly decrease
        self.beta = min(0.6, 0.3 + 0.03*steps)   # Complementary to alpha
        self.gamma = min(1.5, 0.5 + 0.1*steps)   # Gradually increase focus

    def forward(self, preds, targets):
        # Convert targets to one-hot
        targets_one_hot = F.one_hot(targets.clamp(0,1), num_classes=2).permute(0,3,1,2).float()
        
        # Softmax probabilities
        probs = F.softmax(preds, dim=1)
        
        # Calculate components
        TP = (probs * targets_one_hot).sum((0,2,3))  # True Positives
        FP = (probs * (1-targets_one_hot)).sum((0,2,3))  # False Positives
        FN = ((1-probs) * targets_one_hot).sum((0,2,3))  # False Negatives
        
        # Tversky index per class
        tversky = (TP + self.smooth) / (TP + self.alpha*FP + self.beta*FN + self.smooth)
        
        # Focal Tversky loss
        return T.mean(T.pow((1 - tversky), self.gamma))

loss_fn = FocalTverskyLoss(alpha=0.7, beta=0.3, gamma=0.75).to(device)

In [25]:
def export_to_onnx(model, filename):
    dummy_rgb   = T.randn(1, 3, 384, 512).to(device)
    dummy_depth = T.randn(1, 1, 384, 512).to(device)

    T.onnx.export(
        model,
        (dummy_rgb, dummy_depth),
        filename,
        export_params=True,
        opset_version=13,  # For TensorRT compatibility
        do_constant_folding=True,
        input_names=['rgb_input', 'depth_input'],
        output_names=['output'],
        dynamic_axes={
            'rgb_input': {0: 'batch_size'},
            'depth_input': {0: 'batch_size'},
            'output': {0: 'batch_size'}
        }
    )
    print(f"Exported {filename} successfully")

In [26]:
def build_engine_from_onnx(onnx_filepath: str,
                           engine_filepath: str,
                           max_workspace_size: int = 1 << 30) -> None:
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network(
        1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    )
    parser = trt.OnnxParser(network, TRT_LOGGER)

    with open(onnx_filepath, 'rb') as model_f:
        if not parser.parse(model_f.read()):
            for i in range(parser.num_errors):
                print(parser.get_error(i))
            raise RuntimeError("Failed to parse ONNX model")

    config = builder.create_builder_config()
    config.max_workspace_size = max_workspace_size
    config.set_flag(trt.BuilderFlag.FP16)

    engine = builder.build_engine(network, config)
    if engine is None:
        raise RuntimeError("Failed to build the TensorRT engine")

    with open(engine_filepath, 'wb') as f:
        f.write(engine.serialize())

    print(f"✅ Serialized TRT engine to {engine_filepath}")


In [27]:
def train_model(num_epochs=50):
    best_mPA = 0.0
    best_accuracy = 0.0

    for epoch in tqdm(range(1, num_epochs+1), desc="Epochs", leave=True):
        loss_fn.update_hyperparams(epoch)

        # -------- Training Phase --------
        model.train()
        train_loader = dataloaders['train']
        train_pbar = tqdm(train_loader,
                          desc=f"  Train (E{epoch})",
                          leave=False,
                          unit="batch")
        for img, depth, mask in train_pbar:
            optimizer.zero_grad()
            outputs = model(img.to(device), depth.to(device))
            loss = loss_fn(outputs, mask.to(device))
            loss.backward()
            T.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_pbar.set_postfix(loss=loss.item())

        T.cuda.empty_cache()

        # -------- Validation Phase --------
        model.eval()
        val_metrics = {'mPA': 0.0, 'accuracy': 0.0}
        val_loader = dataloaders['val']
        val_pbar = tqdm(val_loader,
                        desc=f"  Val   (E{epoch})",
                        leave=False,
                        unit="batch")
        with T.no_grad():
            for img, depth, mask in val_pbar:
                outputs = model(img.to(device), depth.to(device))
                m = compute_metrics(outputs, mask.to(device))
                val_metrics['mPA']      += m['mPA']
                val_metrics['accuracy'] += m['accuracy']

        num_val = len(val_loader)
        val_metrics = {k: v / num_val for k, v in val_metrics.items()}

        print(f"Epoch {epoch:02d} — "
              f"Train Loss: {loss.item():.4f} | "
              f"Val mPA: {val_metrics['mPA']:.4f} | "
              f"Val Acc: {val_metrics['accuracy']:.4f}")

        # -------- Checkpoint & Engine Build --------
        if val_metrics['mPA'] > best_mPA:
            best_mPA = val_metrics['mPA']
            T.save(model.state_dict(), os.path.join(MODELS_DIR,'best_mPA_model.pth'))
            export_to_onnx(model, os.path.join(MODELS_DIR,'best_mPA_model.onnx'))
            build_engine_from_onnx(os.path.join(MODELS_DIR,'best_mPA_model.onnx'),
                                   os.path.join(MODELS_DIR,'best_mPA_model.engine'))

        if val_metrics['accuracy'] > best_accuracy:
            best_accuracy = val_metrics['accuracy']
            T.save(model.state_dict(), os.path.join(MODELS_DIR,'best_accuracy_model.pth'))
            export_to_onnx(model, os.path.join(MODELS_DIR,'best_accuracy_model.onnx'))
            build_engine_from_onnx(os.path.join(MODELS_DIR,'best_accuracy_model.onnx'),
                                   os.path.join(MODELS_DIR,'best_accuracy_model.engine'))

        T.cuda.empty_cache()


In [28]:
if __name__ == "__main__":
    train_model()


Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

  Train (E1):   0%|          | 0/400 [00:00<?, ?batch/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 1.56 GiB. GPU 0 has a total capacity of 4.00 GiB of which 0 bytes is free. Of the allocated memory 7.00 GiB is allocated by PyTorch, and 1.59 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
class HostDeviceMem:
    """Simple holder for host/device buffers."""
    def __init__(self, host_mem, device_mem, name: str):
        self.host = host_mem
        self.device = device_mem
        self.name = name

class TRTInfer:
    """
    Wrapper for TensorRT inference. Allocates proper CUDA buffers
    and manages bindings by name.
    """
    def __init__(self, engine_path: str, batch_size: int = 1):
        self.batch_size = batch_size
        self.logger = trt.Logger(trt.Logger.ERROR)
        self.runtime = trt.Runtime(self.logger)

        with open(engine_path, 'rb') as f:
            engine_data = f.read()
        self.engine = self.runtime.deserialize_cuda_engine(engine_data)
        self.context = self.engine.create_execution_context()

        self.inputs, self.outputs, self.bindings, self.stream = self._allocate_buffers()

    def _allocate_buffers(self):
        inputs, outputs, bindings = [], [], []
        stream = cuda.Stream()

        for binding in self.engine:
            shape = tuple(self.engine.get_binding_shape(binding))
            dtype = trt.nptype(self.engine.get_binding_dtype(binding))
            size = self.batch_size * int(trt.volume(shape))

            host_mem = cuda.pagelocked_empty(size, dtype)
            device_mem = cuda.mem_alloc(host_mem.nbytes)
            bindings.append(int(device_mem))

            if self.engine.binding_is_input(binding):
                inputs.append(HostDeviceMem(host_mem, device_mem, binding))
            else:
                outputs.append(HostDeviceMem(host_mem, device_mem, binding))

        return inputs, outputs, bindings, stream

    def infer(self, input_batch: np.ndarray) -> np.ndarray:
        """
        Run inference.
        input_batch: np.ndarray of shape (batch_size, C, H, W), dtype matching engine.
        Returns the first (or only) output as a numpy array.
        """
        # Copy input to host buffer, then to device
        np.copyto(self.inputs[0].host, input_batch.ravel())
        cuda.memcpy_htod_async(self.inputs[0].device,
                               self.inputs[0].host, self.stream)

        # Set all binding addresses
        for idx, binding_ptr in enumerate(self.bindings):
            name = self.engine.get_binding_name(idx)
            self.context.set_tensor_address(name, binding_ptr)

        # Execute asynchronously
        self.context.execute_async_v3(stream_handle=self.stream.handle)

        # Retrieve outputs
        for out in self.outputs:
            cuda.memcpy_dtoh_async(out.host, out.device, self.stream)
        self.stream.synchronize()

        # Reshape and return
        out_binding = self.outputs[0]
        out_shape = (self.batch_size,) + tuple(
            self.engine.get_binding_shape(out_binding.name)
        )
        return out_binding.host.reshape(out_shape)

In [None]:
class TRTInfer:
    """
    Wrapper for TensorRT inference. Allocates proper CUDA buffers
    and manages bindings by name.
    """
    def __init__(self, engine_path: str, batch_size: int = 1):
        self.batch_size = batch_size
        self.logger = trt.Logger(trt.Logger.ERROR)
        self.runtime = trt.Runtime(self.logger)

        with open(engine_path, 'rb') as f:
            engine_data = f.read()
        self.engine = self.runtime.deserialize_cuda_engine(engine_data)
        self.context = self.engine.create_execution_context()

        self.inputs, self.outputs, self.bindings, self.stream = self._allocate_buffers()

    def _allocate_buffers(self):
        inputs, outputs, bindings = [], [], []
        stream = cuda.Stream()

        for binding in self.engine:
            shape = tuple(self.engine.get_binding_shape(binding))
            dtype = trt.nptype(self.engine.get_binding_dtype(binding))
            size = self.batch_size * int(trt.volume(shape))

            host_mem = cuda.pagelocked_empty(size, dtype)
            device_mem = cuda.mem_alloc(host_mem.nbytes)
            bindings.append(int(device_mem))

            if self.engine.binding_is_input(binding):
                inputs.append(HostDeviceMem(host_mem, device_mem, binding))
            else:
                outputs.append(HostDeviceMem(host_mem, device_mem, binding))

        return inputs, outputs, bindings, stream

    def infer(self, input_batches):
        """
        Run inference.

        Args:
            input_batches: tuple or list of np.ndarray, one per input binding.
                           Each array shape should be (batch_size, C, H, W)
        
        Returns:
            The first (or only) output as a numpy array.
        """
        # Copy each input to host buffer, then to device
        for inp, buf in zip(input_batches, self.inputs):
            np.copyto(buf.host, inp.ravel())
            cuda.memcpy_htod_async(buf.device, buf.host, self.stream)

        # Set all binding addresses
        for idx, binding_ptr in enumerate(self.bindings):
            name = self.engine.get_binding_name(idx)
            self.context.set_tensor_address(name, binding_ptr)

        # Execute asynchronously
        self.context.execute_async_v3(stream_handle=self.stream.handle)

        # Retrieve outputs
        for out in self.outputs:
            cuda.memcpy_dtoh_async(out.host, out.device, self.stream)
        self.stream.synchronize()

        # Reshape and return first output
        out_binding = self.outputs[0]
        out_shape = (self.batch_size,) + tuple(
            self.engine.get_binding_shape(out_binding.name)
        )
        return out_binding.host.reshape(out_shape)


In [None]:
def test_models(dataloaders, device, batch_size=4, engine_dir=MODELS_DIR):
    # Instantiate TensorRT inference wrappers
    trt_inf_mpa = TRTInfer(f'{engine_dir}/best_mPA_model.engine', batch_size=batch_size)
    trt_inf_acc = TRTInfer(f'{engine_dir}/best_accuracy_model.engine', batch_size=batch_size)

    # Prepare metric accumulators
    test_metrics_mpa = {'mPA':0.0, 'accuracy':0.0, 'miou':0.0,
                        'weed_iou':0.0, 'precision':0.0, 'recall':0.0}
    test_metrics_acc = {'mPA':0.0, 'accuracy':0.0, 'miou':0.0,
                        'weed_iou':0.0, 'precision':0.0, 'recall':0.0}

    with T.no_grad():
        for img, depth, mask in dataloaders['test']:
            # Convert to NumPy float32
            in_rgb   = img.cpu().numpy().astype(np.float32)
            in_depth = depth.cpu().numpy().astype(np.float32)

            # Perform inference with both inputs
            out_mpa = trt_inf_mpa.infer((in_rgb, in_depth))  # (B, num_classes, H, W)
            out_acc = trt_inf_acc.infer((in_rgb, in_depth))

            # Convert outputs to PyTorch tensors on GPU
            preds_mpa = T.from_numpy(out_mpa).to(device)
            preds_acc = T.from_numpy(out_acc).to(device)

            # Compute batch metrics
            metrics_mpa = compute_metrics(preds_mpa, mask.to(device))
            metrics_acc = compute_metrics(preds_acc, mask.to(device))

            # Accumulate metrics
            for k in test_metrics_mpa:
                test_metrics_mpa[k] += metrics_mpa[k]
                test_metrics_acc[k] += metrics_acc[k]

    # Average over batches
    num_batches = len(dataloaders['test'])
    for k in test_metrics_mpa:
        test_metrics_mpa[k] /= num_batches
        test_metrics_acc[k] /= num_batches

    # Print results
    print("\nTensorRT mPA-Optimized Model Test Results:")
    for k, v in test_metrics_mpa.items():
        print(f"{k:>10}: {v:.4f}")

    print("\nTensorRT Accuracy-Optimized Model Test Results:")
    for k, v in test_metrics_acc.items():
        print(f"{k:>10}: {v:.4f}")


In [None]:
test_models(dataloaders, device)