# ðŸŒ‹ Vesuvius Challenge: SOTA Multi-Cube Pipeline (0.65+)

### ðŸš¨ VERSION 3.0 - UNIVERSAL FIX
> [!IMPORTANT]
> **If you still see the error**: It means your browser/Kaggle is caching the old code. 
> **PLEASE**: Press `Ctrl + F5` to force-refresh this page, then **Kernel -> Restart & Run All**.

In [None]:
!pip install -q monai tifffile

import os
import glob
import random
import zipfile
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
from monai.networks.nets import SwinUNETR
from tqdm.auto import tqdm
import tifffile as tiff
from sklearn.model_selection import train_test_split

print(f"DEBUG: Code Version 3.0 Loaded (Positional Args Mode)")

# --- CONFIGURATION ---
class CFG:
    input_dir = '/kaggle/input/vesuvius-challenge-surface-detection'
    test_dir = os.path.join(input_dir, 'test_images')
    train_images = os.path.join(input_dir, 'train_images')
    train_labels = os.path.join(input_dir, 'train_labels')
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    patch_size = (128, 128, 128)
    stride = 64
    batch_size = 1 
    lr = 2e-4
    epochs = 15
    
    use_distance_map = True
    tta = True
    best_weights = "best_sota_model.pth"

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything()

## 1. Utilities

In [None]:
def detect_umbilicus(volume):
    coords = np.argwhere(volume > (0.2 if volume.dtype == np.float32 else 0))
    if len(coords) == 0: return tuple(s // 2 for s in volume.shape)
    return tuple(np.mean(coords, axis=0).astype(int))

def get_radial_dist_map(shape, center):
    z, y, x = np.ogrid[:shape[0], :shape[1], :shape[2]]
    dist = np.sqrt((z-center[0])**2 + (y-center[1])**2 + (x-center[2])**2)
    return (dist / (np.max(dist) + 1e-8)).astype(np.float32)

def compute_dice(pred, target):
    pred = (torch.sigmoid(pred) > 0.5).float()
    intersection = (pred * target).sum()
    return (2. * intersection + 1e-6) / (pred.sum() + target.sum() + 1e-6)

## 2. Dataset

In [None]:
class VesuviusCubeDataset(Dataset):
    def __init__(self, img_paths, label_dir, is_train=True):
        self.img_paths = img_paths
        self.label_dir = label_dir
        self.is_train = is_train
        
    def __len__(self): return len(self.img_paths)
    
    def __getitem__(self, idx):
        path = self.img_paths[idx]
        vid = os.path.basename(path).split('.')[0]
        vol = (tiff.imread(path) / 255.0).astype(np.float32)
        center = detect_umbilicus(vol)
        dist = get_radial_dist_map(vol.shape, center)
        img_patch = np.stack([vol, dist], axis=0)
        
        if self.is_train:
            lab_path = os.path.join(self.label_dir, f"{vid}.tif")
            lab = (tiff.imread(lab_path) > 0).astype(np.float32)
            d, h, w = vol.shape
            z, y, x = [random.randint(0, s - 128) for s in [d, h, w]]
            return torch.from_numpy(img_patch[:, z:z+128, y:y+128, x:x+128]), \
                   torch.from_numpy(lab[z:z+128, y:y+128, x:x+128][None, :])
        
        return torch.from_numpy(img_patch), vid

## 3. Loss & Training

In [None]:
def medial_surface_recall(pred, target):
    def get_2d_skel(x, kernel=(1, 3, 3)):
        x1 = -F.max_pool3d(-F.max_pool3d(x, kernel, 1, (0, 1, 1)), kernel, 1, (0, 1, 1))
        return F.relu(x - x1)
    skel_z = get_2d_skel(pred, (1, 3, 3))
    skel_y = get_2d_skel(pred.transpose(2, 3), (1, 3, 3)).transpose(2, 3)
    skel_x = get_2d_skel(pred.transpose(2, 4), (1, 3, 3)).transpose(2, 4)
    combined_skel = (skel_z + skel_y + skel_x) / 3.0
    recall = (torch.sum(combined_skel * target) + 1e-6) / (torch.sum(combined_skel) + 1e-6)
    return 1.0 - recall

class VesuviusSotaLoss(nn.Module):
    def __init__(self, w_skel=0.5):
        super().__init__()
        self.bce, self.w_skel = nn.BCEWithLogitsLoss(), w_skel
    def forward(self, pred, target):
        bce = self.bce(pred, target)
        p = torch.sigmoid(pred)
        srec = medial_surface_recall(p, target)
        return (1.0 - self.w_skel) * bce + self.w_skel * srec

def train():
    print("\n--- [DEBUG] VERSION 3.0: UNIVERSAL MODE ---")
    all_img_paths = glob.glob(os.path.join(CFG.train_images, "*.tif"))
    train_paths, val_paths = train_test_split(all_img_paths, test_size=0.2, random_state=42)
    
    train_loader = DataLoader(VesuviusCubeDataset(train_paths, CFG.train_labels), batch_size=CFG.batch_size, shuffle=True)
    val_loader = DataLoader(VesuviusCubeDataset(val_paths, CFG.train_labels), batch_size=CFG.batch_size, shuffle=False)
    
    # UNIVERSAL INITIALIZATION (POSITIONAL ONLY)
    # Arguments: img_size, in_channels, out_channels, feature_size
    model = SwinUNETR(CFG.patch_size, 2, 1, 48, use_checkpoint=True).to(CFG.device)
    
    optimizer = optim.Adam(model.parameters(), lr=CFG.lr)
    criterion, scaler = VesuviusSotaLoss(), GradScaler()
    
    best_dice = 0
    for epoch in range(CFG.epochs):
        model.train()
        for img, lab in tqdm(train_loader, desc=f"Epoch {epoch+1} [TRAIN]"):
            img, lab = img.to(CFG.device), lab.to(CFG.device)
            optimizer.zero_grad()
            with autocast(): loss = criterion(model(img), lab)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        
        model.eval()
        val_dice = 0
        with torch.no_grad():
            for img, lab in tqdm(val_loader, desc=f"Epoch {epoch+1} [VAL]"):
                img, lab = img.to(CFG.device), lab.to(CFG.device)
                val_dice += compute_dice(model(img), lab).item()
        
        avg_val_dice = val_dice / len(val_loader)
        print(f"Epoch {epoch+1} | Val Dice={avg_val_dice:.4f}")
        
        if avg_val_dice > best_dice:
            best_dice = avg_val_dice
            torch.save(model.state_dict(), CFG.best_weights)
    return CFG.best_weights

def submit(weights):
    model = SwinUNETR(CFG.patch_size, 2, 1, 48, use_checkpoint=True).to(CFG.device)
    model.load_state_dict(torch.load(weights))
    model.eval()
    
    test_paths = glob.glob(os.path.join(CFG.test_dir, "*.tif"))
    with zipfile.ZipFile('submission.zip', 'w') as out_zip:
        for path in test_paths:
            vid, vol = os.path.basename(path).split('.')[0], (tiff.imread(path)/255.0).astype(np.float32)
            dist = get_radial_dist_map(vol.shape, detect_umbilicus(vol))
            full_pred, counts = np.zeros_like(vol), np.zeros_like(vol)
            pd = CFG.patch_size[0]
            for z in tqdm(range(0, vol.shape[0]-pd+1, CFG.stride), desc=f"Cube {vid}"):
                for y in range(0, vol.shape[1]-pd+1, CFG.stride):
                    for x in range(0, vol.shape[2]-pd+1, CFG.stride):
                        patch = torch.from_numpy(np.stack([vol[z:z+pd, y:y+pd, x:x+pd], dist[z:z+pd, y:y+pd, x:x+pd]]))[None,:].to(CFG.device)
                        with torch.no_grad(): full_pred[z:z+pd, y:y+pd, x:x+pd] += torch.sigmoid(model(patch)).squeeze().cpu().numpy()
                        counts[z:z+pd, y:y+pd, x:x+pd] += 1.0
            mask = (full_pred/(counts+1e-8) > 0.51).astype(np.uint8)
            tiff.imwrite(f"{vid}.tif", mask, compression='lzw')
            out_zip.write(f"{vid}.tif")
            os.remove(f"{vid}.tif")

In [None]:
if __name__ == "__main__":
    best_model_path = train()
    submit(best_model_path)
    print("\n--- SUCCESS: SUBMISSION.ZIP CREATED ---")