In [None]:
import kagglehub
from google.colab import drive


In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import math
import pandas as pd


In [None]:
def setup_and_download():
    # Mount Drive
    try:
        drive.mount('/content/drive')
    except Exception:
        pass

    # Login to Kaggle
    kagglehub.login()

    # Download competition data
    kaggle_competition_dl_f_2025_path = kagglehub.competition_download('kaggle-competition-dl-f-2025')

    print('Data source import complete.')
    print(kaggle_competition_dl_f_2025_path)

    # List files
    for dirname, _, filenames in os.walk('/root/.cache/kagglehub/competitions'):
        for filename in filenames:
            print(os.path.join(dirname, filename))
            
    return 

setup_and_download()


In [None]:
class MultiSpectralDataset(Dataset):
    def __init__(self, X, Y, indices, augment=False, mean=None, std=None, compute_stats=False):
        """
        X, Y          : mmap arrays or numpy arrays
        indices       : list of sample indices
        mean, std     : optional precomputed normalization stats
        compute_stats : if True -> compute mean/std from X[indices]
        """
        self.X = X
        self.Y = Y
        self.indices = indices
        self.augment = augment

        # -------------------------------------------------------
        # Compute normalization stats (ONLY ON TRAINING SET)
        # -------------------------------------------------------
        if compute_stats:
            C = X.shape[1]
            mean = np.zeros(C, dtype=np.float64)
            M2   = np.zeros(C, dtype=np.float64)
            count = 0

            for idx in indices:
                x = X[idx].astype(np.float32)     # convert to float32
                pixels = x.reshape(C, -1)
                count_new = pixels.shape[1]

                # incremental mean/std update (Welford algorithm)
                delta = pixels.mean(axis=1) - mean
                mean += delta * (count_new / (count + count_new))

                M2 += ((pixels - mean[:, None])**2).sum(axis=1)

                count += count_new

            self.mean = mean.astype(np.float32)
            self.std = np.sqrt(M2 / count).astype(np.float32) + 1e-6

        else:
            self.mean = mean       # use externally provided stats
            self.std  = std
        # -------------------------------------------------------

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        idx = self.indices[i]

        #  Load image 
        img = torch.tensor(self.X[idx], dtype=torch.float32)

        # Apply normalization if available 
        if self.mean is not None and self.std is not None:
            mean = torch.tensor(self.mean, dtype=torch.float32)[:, None, None]
            std  = torch.tensor(self.std, dtype=torch.float32)[:, None, None]
            img = (img - mean) / std

        # TEST MODE 
        if self.Y is None:
            return img

        # TRAIN / VAL MODE 
        mask = torch.tensor(self.Y[idx], dtype=torch.float32).unsqueeze(0)

        return img, mask
    

# -----------------------------
# Basic building blocks
# -----------------------------

class ConvBnAct(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, padding=1, stride=1, dilation=1, act='leaky', bias=False):
        super().__init__()
        
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias)
        self.bn = nn.BatchNorm2d(out_ch)
        
        if act == 'leaky':
            self.act = nn.LeakyReLU(0.2, inplace=True)
        elif act == 'relu':
            self.act = nn.ReLU(inplace=True)
        else:
            self.act = nn.Identity()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


class ResidualConv(nn.Module):
    """Two convs with residual (projection when channels differ)."""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = ConvBnAct(in_ch, out_ch)
        self.conv2 = ConvBnAct(out_ch, out_ch)
        self.need_proj = (in_ch != out_ch)
        if self.need_proj:
            self.proj = nn.Conv2d(in_ch, out_ch, 1, bias=False)
            self.bn_proj = nn.BatchNorm2d(out_ch)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.conv2(out)
        if self.need_proj:
            identity = self.bn_proj(self.proj(identity))
        return out + identity


class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        r = max(1, channels // reduction)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, r, 1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(r, channels, 1, bias=True),
            nn.Sigmoid()
        )
    def forward(self, x):
        return x * self.fc(self.pool(x))


# -----------------------------
# MiniInception with residual
# -----------------------------
class MiniInceptionRes(nn.Module):
    """
    Mini-inception with three stages, residual connection.
    Each split uses one 3x3 and one dilated 3x3 conv.
    """
    def __init__(self, in_ch, out_ch):
        super().__init__()
        assert out_ch % 2 == 0, "out_channels must be divisible by 2"
        half = out_ch // 2
        
        # stage 1
        self.c1l = ConvBnAct(in_ch, half, padding=1, dilation=1, act='leaky')
        self.c1r = ConvBnAct(in_ch, half, padding=2, dilation=2, act='leaky')
        
        # stage 2
        self.c2l = ConvBnAct(out_ch, half, padding=1, dilation=1, act='leaky')
        self.c2r = ConvBnAct(out_ch, half, padding=2, dilation=2, act='leaky')
        
        # stage 3
        self.c3l = ConvBnAct(out_ch, half, padding=1, dilation=1, act='leaky')
        self.c3r = ConvBnAct(out_ch, half, padding=2, dilation=2, act='leaky')

        self.need_proj = (in_ch != out_ch)
        if self.need_proj:
            self.proj = nn.Conv2d(in_ch, out_ch, 1, bias=False)
            self.bn_proj = nn.BatchNorm2d(out_ch)
        self.act = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        y = torch.cat((self.c1l(x), self.c1r(x)), dim=1)  # out_ch
        y = torch.cat((self.c2l(y), self.c2r(y)), dim=1)
        y = torch.cat((self.c3l(y), self.c3r(y)), dim=1)
        ident = x
        if self.need_proj:
            ident = self.bn_proj(self.proj(ident))
        return self.act(y + ident)


# -----------------------------
# Corrected MFNet (channel-safe)
# -----------------------------
class MFNet(nn.Module):
    """
    Fully corrected MFNet:
      - Clean channel bookkeeping to avoid concat mismatches
      - Residual mini-inceptions
      - Optional SE after fusion
      - Deep supervision heads
    """
    def __init__(self, in_ch=16, n_class=1, use_se=True, deep_supervision=True):
        super().__init__()
        self.n_class = n_class
        self.use_se = use_se
        self.deep_supervision = deep_supervision

        # choose widths (bigger than original MFNet)
        # level indices: 1..5 (1 shallow -> 5 deepest)
        rgb_ch = [32, 96, 160, 256, 320]    # channels at levels 1..5 for RGB
        inf_ch = [32, 64, 96, 128, 160]     # channels at levels 1..5 for INF

        # RGB branch
        self.conv1_rgb   = ResidualConv(3, rgb_ch[0])
        self.conv2_1_rgb = ResidualConv(rgb_ch[0], rgb_ch[1])
        self.conv2_2_rgb = ResidualConv(rgb_ch[1], rgb_ch[1])
        self.conv3_1_rgb = ResidualConv(rgb_ch[1], rgb_ch[2])
        self.conv3_2_rgb = ResidualConv(rgb_ch[2], rgb_ch[2])
        self.conv4_rgb   = MiniInceptionRes(rgb_ch[2], rgb_ch[3])
        self.conv5_rgb   = MiniInceptionRes(rgb_ch[3], rgb_ch[4])

        # INF branch (if present)
        self.inf_in_ch = max(0, in_ch - 3)
        if self.inf_in_ch > 0:
            self.conv1_inf   = ResidualConv(self.inf_in_ch, inf_ch[0])
            self.conv2_1_inf = ResidualConv(inf_ch[0], inf_ch[1])
            self.conv2_2_inf = ResidualConv(inf_ch[1], inf_ch[1])
            self.conv3_1_inf = ResidualConv(inf_ch[1], inf_ch[2])
            self.conv3_2_inf = ResidualConv(inf_ch[2], inf_ch[2])
            self.conv4_inf   = MiniInceptionRes(inf_ch[2], inf_ch[3])
            self.conv5_inf   = MiniInceptionRes(inf_ch[3], inf_ch[4])
        else:
            # placeholders
            self.conv1_inf = None

        # Precompute channel sizes so we never miscalculate:
        self.rgb_ch = rgb_ch
        self.inf_ch = inf_ch

        # deepest fused channels
        deepest_rgb = rgb_ch[4]   # 320
        deepest_inf = inf_ch[4] if self.inf_in_ch > 0 else 0  # 160 or 0
        self.fused_deep_ch = deepest_rgb + deepest_inf  # e.g. 480

        # skip channel counts
        skip4_ch = rgb_ch[3] + (inf_ch[3] if self.inf_in_ch > 0 else 0)  # 256 + 128 = 384
        skip3_ch = rgb_ch[2] + (inf_ch[2] if self.inf_in_ch > 0 else 0)  # 160 + 96  = 256
        skip2_ch = rgb_ch[1] + (inf_ch[1] if self.inf_in_ch > 0 else 0)  # 96  + 64  = 160
        skip1_ch = rgb_ch[0] + (inf_ch[0] if self.inf_in_ch > 0 else 0)  # 32  + 32  = 64

        # Decoding projections (concat upsampled fused/prev + skip) -> next-level channels
        # decode4: (fused_deep) + skip4 -> project to skip3_ch
        dec4_in_ch = self.fused_deep_ch + skip4_ch     # e.g. 480 + 384 = 864
        dec4_out_ch = skip3_ch                         # 256

        dec3_in_ch = dec4_out_ch + skip3_ch            # 256 + 256 = 512
        dec3_out_ch = skip2_ch                         # 160

        dec2_in_ch = dec3_out_ch + skip2_ch            # 160 +160 = 320
        dec2_out_ch = skip1_ch                         # 64

        dec1_in_ch = dec2_out_ch                       # 64
        dec1_out_ch = dec1_in_ch                       # keep same, head maps to logits

        # Residual projection convs
        self.decode4_proj = ResidualConv(dec4_in_ch, dec4_out_ch)
        self.decode3_proj = ResidualConv(dec3_in_ch, dec3_out_ch)
        self.decode2_proj = ResidualConv(dec2_in_ch, dec2_out_ch)
        self.decode1_proj = ResidualConv(dec1_in_ch, dec1_out_ch)

        # Heads
        self.head = nn.Conv2d(dec1_out_ch, self.n_class, kernel_size=1)
        if self.deep_supervision:
            self.head_ds3 = nn.Conv2d(dec3_out_ch, self.n_class, kernel_size=1)
            self.head_ds4 = nn.Conv2d(dec4_out_ch, self.n_class, kernel_size=1)

        # Optional SE after fusion
        if self.use_se:
            self.se = SEBlock(self.fused_deep_ch, reduction=8)

        self._init_weights()

    def forward(self, x):
        # x: B, C, H, W
        assert x.shape[1] >= 3, "input must contain at least 3 channels for RGB"
        x_rgb = x[:, :3, :, :]
        x_inf = x[:, 3:, :, :] if self.inf_in_ch > 0 else None

        #  RGB encode (store skips) 
        x_rgb = self.conv1_rgb(x_rgb)               # level1
        x_rgb = F.max_pool2d(x_rgb, 2)
        x_rgb = self.conv2_1_rgb(x_rgb)
        x_rgb_p2 = self.conv2_2_rgb(x_rgb)          # skip level2
        x_rgb = F.max_pool2d(x_rgb_p2, 2)
        x_rgb = self.conv3_1_rgb(x_rgb)
        x_rgb_p3 = self.conv3_2_rgb(x_rgb)          # skip level3
        x_rgb = F.max_pool2d(x_rgb_p3, 2)
        x_rgb_p4 = self.conv4_rgb(x_rgb)            # skip level4
        x_rgb = F.max_pool2d(x_rgb_p4, 2)
        x_rgb = self.conv5_rgb(x_rgb)               # deepest rgb

        #  INF encode 
        if x_inf is not None:
            x_inf = self.conv1_inf(x_inf)
            x_inf = F.max_pool2d(x_inf, 2)
            x_inf = self.conv2_1_inf(x_inf)
            x_inf_p2 = self.conv2_2_inf(x_inf)
            x_inf = F.max_pool2d(x_inf_p2, 2)
            x_inf = self.conv3_1_inf(x_inf)
            x_inf_p3 = self.conv3_2_inf(x_inf)
            x_inf = F.max_pool2d(x_inf_p3, 2)
            x_inf_p4 = self.conv4_inf(x_inf)
            x_inf = F.max_pool2d(x_inf_p4, 2)
            x_inf = self.conv5_inf(x_inf)
        else:
            
            # create zero placeholders with correct channel counts so concat works seamlessly
            B, _, Hd, Wd = x_rgb.shape
            device = x_rgb.device
            x_inf = torch.zeros(B, 0, Hd, Wd, device=device)  # deepest INF channels = 0
            x_inf_p4 = None
            x_inf_p3 = None
            x_inf_p2 = None

        #  fusion at deepest level 
        if x_inf.shape[1] == 0:
            fused = x_rgb
        else:
            fused = torch.cat([x_rgb, x_inf], dim=1)

        if self.use_se:
            fused = self.se(fused)

        #  decode level 4 
        x = F.interpolate(fused, scale_factor=2.0, mode='nearest')  # up -> level4 spatial
        # build skip4 (rgb_p4 + inf_p4 if available)
        if x_inf is not None and x_inf_p4 is not None:
            skip4 = torch.cat([x_rgb_p4, x_inf_p4], dim=1)
        else:
            skip4 = x_rgb_p4

        # concat upsampled fused + skip4
        x = torch.cat([x, skip4], dim=1)    # channels = fused_deep_ch + skip4_ch
        x = self.decode4_proj(x)            # out channels = dec4_out_ch
        ds4 = self.head_ds4(x) if self.deep_supervision else None

        #  decode level 3 
        x = F.interpolate(x, scale_factor=2.0, mode='nearest')
        if x_inf is not None and x_inf_p3 is not None:
            skip3 = torch.cat([x_rgb_p3, x_inf_p3], dim=1)
        else:
            skip3 = x_rgb_p3
        x = torch.cat([x, skip3], dim=1)    # channels = dec4_out_ch + skip3_ch
        x = self.decode3_proj(x)            # out channels = dec3_out_ch
        ds3 = self.head_ds3(x) if self.deep_supervision else None

        #  decode level 2 
        x = F.interpolate(x, scale_factor=2.0, mode='nearest')
        if x_inf is not None and x_inf_p2 is not None:
            skip2 = torch.cat([x_rgb_p2, x_inf_p2], dim=1)
        else:
            skip2 = x_rgb_p2
        x = torch.cat([x, skip2], dim=1)
        x = self.decode2_proj(x)            # out channels = dec2_out_ch

        # final upsample to original resolution 
        x = F.interpolate(x, scale_factor=2.0, mode='nearest')
        x = self.decode1_proj(x)
        main_logits = self.head(x)

        if self.deep_supervision:
            # ensure ds shapes match main_logits spatial size
            ds3_up = F.interpolate(ds3, size=main_logits.shape[2:], mode='bilinear', align_corners=False)
            ds4_up = F.interpolate(ds4, size=main_logits.shape[2:], mode='bilinear', align_corners=False)
            # return (main, ds3_up, ds4_up)
            return main_logits, ds3_up, ds4_up
        else:
            return main_logits

    def _init_weights(self):
        
        # Kaiming init for convs, BN ones and zeros
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
                if getattr(m, 'bias', None) is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm2d,)):
                if getattr(m, 'weight', None) is not None:
                    nn.init.ones_(m.weight)
                if getattr(m, 'bias', None) is not None:
                    nn.init.zeros_(m.bias)

    
    
##############################
# IoU Metric and Dice Loss
##############################
bce_loss = nn.BCEWithLogitsLoss()

def dice_loss(pred, target, smooth=1.0):
    """
    Computes the Dice Loss for binary segmentation.
    Args:
        pred (torch.Tensor): Predicted logits of shape (B, 1, H, W)
        target (torch.Tensor): Ground truth mask of shape (B, 1, H, W)
        smooth (float): Smoothing constant to avoid division by zero
    Returns:
        torch.Tensor: Scalar Dice Loss
    """
    pred = torch.sigmoid(pred)  # Convert logits to probabilities

    # Flatten each image in batch to compute per-sample Dice
    pred = pred.contiguous().view(pred.size(0), -1)
    target = target.contiguous().view(target.size(0), -1)

    intersection = (pred * target).sum(dim=1)
    dice = (2. * intersection + smooth) / (pred.sum(dim=1) + target.sum(dim=1) + smooth)

    # Return mean Dice loss over batch
    return 1 - dice.mean()


def combined_loss(outputs, target, alpha=0.7):
    """
    outputs can be either:
     - single tensor (deep_supervision=False)
     - tuple: (main, ds3, ds4)
    """

    if isinstance(outputs, tuple):
        main, ds3, ds4 = outputs

        loss_main = alpha * bce_loss(main, target) + (1 - alpha) * dice_loss(main, target)
        loss_ds3  = alpha * bce_loss(ds3, target)  + (1 - alpha) * dice_loss(ds3, target)
        loss_ds4  = alpha * bce_loss(ds4, target)  + (1 - alpha) * dice_loss(ds4, target)

        # Weighted deep supervision
        return loss_main + 0.5 * loss_ds3 + 0.25 * loss_ds4

    else:
        # Single output
        return alpha * bce_loss(outputs, target) + (1 - alpha) * dice_loss(outputs, target)


def iou_score(pred, gt, n_classes=2, eps=1e-10):

    #  Handle MFNet deep supervision: get main output
    if isinstance(pred, tuple):
        pred = pred[0]

    with torch.no_grad():
        # Binary segmentation: pred shape = (B,1,H,W)
        if pred.shape[1] == 1:
            pred = (pred > 0.1).long()  # threshold at 0.1
        else:
            pred = torch.argmax(pred, dim=1, keepdim=True)

        pred = pred.squeeze(1)
        gt = gt.squeeze(1)

        iou_per_class = []

        for cls in range(n_classes):
            pred_cls = (pred == cls)
            gt_cls = (gt == cls)

            intersection = (pred_cls & gt_cls).sum().float()
            union = (pred_cls | gt_cls).sum().float()

            if union == 0:
                iou_per_class.append(float('nan'))
            else:
                iou = (intersection + eps) / (union + eps)
                iou_per_class.append(iou.item())

        return np.nanmean(iou_per_class)    


In [None]:
# Device configuration
print("CUDA available:", torch.cuda.is_available())
print("GPU name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_PATH = "/content/mfnet(big)_checkpoint.pth"
CHECKPOINT_PATH_drive = "/content/drive/MyDrive/mfnet(big)_checkpoint.pth"
BEST_MODEL_PATH = "/content/best_model.pth"
BEST_DICE_MODEL_PATH = "/content/best_dice_model.pth"

####################
# Training Setup
####################
model = MFNet(in_ch=16, n_class=1, use_se=True, deep_supervision=True).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.65, patience=3, min_lr=1e-11) #changing min_lr from 1e-6 to 5e-7

if os.path.exists(CHECKPOINT_PATH):
    try:
        checkpoint_1 = torch.load(BEST_MODEL_PATH, map_location=device)
        checkpoint_2 = torch.load(BEST_DICE_MODEL_PATH, map_location=device)

        model.load_state_dict(checkpoint_1['model_state_dict'])
        optimizer.load_state_dict(checkpoint_2['optimizer_state_dict'])
        
        # loading with same lr again
        # for param_group in optimizer.param_groups:
          # param_group['lr'] = 1e-8
        
        start_epoch = checkpoint_1['epoch'] +1
        print(f"Loaded checkpoint from epoch {checkpoint_1['epoch']}")
    except Exception as e:
        print(f"Failed to load checkpoint \n({e}). \nStarting from scratch.")
else:
    print("No checkpoint found - starting anew.")

#################
# Data Loader
#################
DATA_DIR = "/root/.cache/kagglehub/competitions/kaggle-competition-dl-f-2025"
X_train = np.load('/root/.cache/kagglehub/competitions/kaggle-competition-dl-f-2025/X_train_256.npy', mmap_mode='r')
Y_train = np.load('/root/.cache/kagglehub/competitions/kaggle-competition-dl-f-2025/Y_train_256.npy', mmap_mode='r')

print("Train shape:", X_train.shape)
print("Train mask shape:", Y_train.shape, "\n")

indices = np.arange(len(X_train))
train_idx, val_idx = train_test_split(indices, test_size=0.25, random_state=303)
del indices
train_ds = MultiSpectralDataset(X_train, Y_train, train_idx, augment=True, compute_stats=True)
val_ds   = MultiSpectralDataset(X_train, Y_train, val_idx, augment=False, compute_stats=True)
del X_train, Y_train
BATCH_SIZE = 24
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)


In [None]:
def train_model(start_epoch=0,EPOCHS=300):

    best_val_iou = 0.0
    best_val_dice = 0.0

    for epoch in range(start_epoch, EPOCHS):

        # ------------------------ TRAIN ------------------------
        model.train()
        total_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Train {epoch+1}/{EPOCHS}")

        for imgs, masks in pbar:
            imgs = imgs.to(device)
            masks = masks.to(device)

            logits = model(imgs)

            loss = combined_loss(logits, masks)
            optimizer.zero_grad()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

            total_loss += loss.item()
            pbar.set_postfix(loss=total_loss / (pbar.n + 1))

        avg_train_loss = total_loss / len(train_loader)

        # ------------------------ VALIDATION ------------------------
        model.eval()
        val_iou_total = 0.0
        val_count = 0
        total_dice = 0

        with torch.no_grad():
            for imgs, masks in val_loader:
                imgs = imgs.to(device)
                masks = masks.float()
                masks = masks.to(device)

                logits, _, _ = model(imgs)

                # IoU @ threshold=0.1 (probabilities -> binary mask inside function)
                iou = iou_score(logits, masks)
                dice = dice_loss(logits, masks)            # both on GPU

                total_dice += (1-dice)

                if not math.isnan(iou):
                    val_iou_total += iou
                    val_count += 1

        val_iou = (val_iou_total / val_count) if val_count > 0 else 0.0
        val_dice = (total_dice / val_count) if val_count > 0 else 0.0

        # ------------------------ SCHEDULER (on Dice) ------------------------
        scheduler.step((val_dice))

        # ------------------------ PRINT STATUS ------------------------
        print(
            f"Epoch {epoch+1}: "
            f"TrainLoss={avg_train_loss:.4f}, "
            f"Val IoU={val_iou:.4f}, "
            f"Val Dice ={val_dice:.4f}, "
            f"LR={optimizer.param_groups[0]['lr']:.2e}"
        )

        # ------------------------ SAVE CHECKPOINT EVERY EPOCH ------------------------
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }
        torch.save(checkpoint, CHECKPOINT_PATH)

        # ------------------------ BEST IoU SAVE ------------------------
        if val_iou > best_val_iou:
            best_val_iou = val_iou
            torch.save(checkpoint, BEST_MODEL_PATH)
            print(f"New BEST IoU model saved (Val IoU={best_val_iou:.4f})\n")

        # ------------------------ BEST DICE SAVE ------------------------
        if (val_dice) > best_val_dice:
            best_val_dice = val_dice
            torch.save(checkpoint, BEST_DICE_MODEL_PATH)
            print(f"New BEST Dice model saved (Val Dice={best_val_dice:.4f})\n")

    print(f"Best Validation IoU: {best_val_iou:.4f}")
    print(f"Best Validation Dice: {best_val_dice:.4f}")


train_model(start_epoch=0, EPOCHS=150)  #for better training, train till 300 is suggested


In [None]:

def run_inference(test_loader, model, device):
    
    use_model_for_test = model #inference_model
    use_model_for_test.eval()

    predictions = []
    use_threshold = 0.099 #best_thr
    with torch.no_grad():
        for imgs in tqdm(test_loader, desc="Predicting "):
            imgs = imgs.to(device)
            logits, _, _ = use_model_for_test(imgs)
            probs = torch.sigmoid(logits).cpu().numpy()  # (B,1,H,W)
            preds = (probs > use_threshold).astype(np.uint8)
            for mask in preds:
                predictions.append(mask.squeeze().astype(np.uint8).flatten())


    #################
    # submission.csv
    #################
    
    # as per rules of the comeptition
    submission = pd.DataFrame({
        "id": np.arange(len(predictions)),
        "pixels": [",".join(map(str, p)) for p in predictions]
    })
    SUBMISSION_PATH = "submission.csv"
    submission.to_csv(SUBMISSION_PATH, index=False)
    print("Saved", SUBMISSION_PATH)

    print(torch.cuda.memory_allocated())
    # print(torch.cuda.memory_cached())# if deprecated, use torch.cuda.memory_reserved()
    try:
        print(torch.cuda.memory_cached())# if deprecated, use torch.cuda.max_memory_reserved()
    except Exception as e:
        print(torch.cuda.memory_reserved())
        
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    return


X_test  = np.load('/root/.cache/kagglehub/competitions/kaggle-competition-dl-f-2025/X_test_256.npy', mmap_mode='r')
print("Test shape:", X_test.shape)
indices = np.arange(len(X_test))
test_ds = MultiSpectralDataset(X_test, Y=None,indices = indices,compute_stats=True, augment=False)
del X_test
test_loader = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)

# Initialize model
model = MFNet(in_ch=16, n_class=1, use_se=True, deep_supervision=True).to(device)
checkpoint = torch.load(BEST_MODEL_PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
run_inference(test_loader, model, device)
