In [5]:
from PIL import Image

img = Image.open(r"C:.\Nutrition5K\Nutrition5K\train\color\dish_0000\rgb.png")
print(img.size)  # 输出 (W, H)

(640, 480)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBNReLU(nn.Module):
    """Basic conv block: Conv -> BN -> ReLU"""
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, k, stride=s, padding=p, bias=False)
        self.bn   = nn.BatchNorm2d(out_ch, eps=1e-5, momentum=0.1)
        self.act  = nn.ReLU(inplace=True)
    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

class TinyCNN(nn.Module):
    """
    Ultra-small CNN for RGB-D regression.
    - Early fusion: set in_ch=4 to accept RGB(3)+Depth(1)
    - By default uses standard GAP (no mask).
    - You can enable masked-GAP later via use_mask=True and passing a mask tensor to forward().
    """
    def __init__(self, in_ch=4, widths=(24, 48, 96, 128), dropout=0.3, out_dim=1, use_mask: bool=False):
        super().__init__()
        self.use_mask = use_mask
        w1, w2, w3, w4 = widths

        # Stem -> H/2
        self.stem = nn.Sequential(
            ConvBNReLU(in_ch, w1, k=3, s=2, p=1),
            ConvBNReLU(w1, w1, k=3, s=1, p=1),
        )
        # Stage1 -> H/4
        self.stage1 = nn.Sequential(
            ConvBNReLU(w1, w2, k=3, s=2, p=1),
            ConvBNReLU(w2, w2, k=3, s=1, p=1),
        )
        # Stage2 -> H/8
        self.stage2 = nn.Sequential(
            ConvBNReLU(w2, w3, k=3, s=2, p=1),
            ConvBNReLU(w3, w3, k=3, s=1, p=1),
        )
        # Stage3 -> H/16
        self.stage3 = nn.Sequential(
            ConvBNReLU(w3, w4, k=3, s=2, p=1),
            ConvBNReLU(w4, w4, k=3, s=1, p=1),
        )

        # Head: GAP/Masked-GAP -> Dropout -> MLP(128) -> Linear(out_dim)
        self.dropout = nn.Dropout(dropout)
        self.fc1     = nn.Linear(w4, 128)
        self.fc2     = nn.Linear(128, out_dim)  # regression head (no activation)

    def forward_features(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)   # [B, C, H/16, W/16]
        return x

    @staticmethod
    def masked_gap(feat: torch.Tensor, mask: torch.Tensor | None):
        """
        Optional masked global average pooling.
        If mask is None, this falls back to standard GAP outside this method.
        """
        if mask is None:
            # Should not be called when mask=None; kept for completeness.
            return F.adaptive_avg_pool2d(feat, 1).flatten(1)
        m = F.interpolate(mask, size=feat.shape[-2:], mode="nearest").clamp(0, 1)  # [B,1,Hf,Wf]
        num = (feat * m).sum(dim=(2,3))           # [B,C]
        den = (m.sum(dim=(2,3)) + 1e-6)           # [B,1]
        return num / den

    def forward(self, x, mask: torch.Tensor | None = None):
        """
        x:    [B, in_ch, H, W]  (e.g., in_ch=4 for RGB-D)
        mask: [B, 1, H, W] or None. Ignored unless self.use_mask is True.
        """
        f = self.forward_features(x)
        if self.use_mask:
            # Use masked-GAP ONLY when explicitly enabled and mask is provided.
            pooled = self.masked_gap(f, mask)
        else:
            # Standard GAP (default baseline, no mask involved).
            pooled = F.adaptive_avg_pool2d(f, 1).flatten(1)

        z = self.dropout(pooled)
        z = F.relu(self.fc1(z), inplace=True)
        out = self.fc2(z)  # [B, out_dim]
        return out

# ------- tiny self-test -------
if __name__ == "__main__":
    B, H, W = 8, 640, 480
    # Default: no mask path (use_mask=False)
    model = TinyCNN(in_ch=4, widths=(24,48,96,128), dropout=0.3, out_dim=1, use_mask=False)
    x = torch.randn(B, 4, H, W)
    y = model(x)
    print("Output (no mask) shape:", y.shape)

    # Later: enable masked-GAP easily
    model_mask = TinyCNN(in_ch=4, widths=(24,48,96,128), dropout=0.3, out_dim=1, use_mask=True)
    m = torch.randint(0, 2, (B, 1, H, W)).float()
    y2 = model_mask(x, mask=m)
    print("Output (with mask) shape:", y2.shape)
