In [None]:
# --- 0. Pre-flight Check: Verify Data Integrity ---
import os
import re

print("--- Verifying available Z-slices for each fragment ---")

def get_available_slices(fragment_id):
    try:
        folder_path = f'train/{fragment_id}/surface_volume'
        files = os.listdir(folder_path)
        slice_numbers = [int(re.search(r'(\d+)', f).group(1)) for f in files if f.endswith('.tif')]
        if not slice_numbers:
            return None, None
        return min(slice_numbers), max(slice_numbers)
    except FileNotFoundError:
        return None, None

fragments_to_check = ['1', '2', '3']
all_slices_min = []
all_slices_max = []

for frag_id in fragments_to_check:
    min_slice, max_slice = get_available_slices(frag_id)
    if min_slice is not None:
        print(f"Fragment {frag_id}: Slices from {min_slice} to {max_slice}")
        all_slices_min.append(min_slice)
        all_slices_max.append(max_slice)
    else:
        print(f"Fragment {frag_id}: Could not find slice data.")

# Determine the common range
if all_slices_min and all_slices_max:
    common_min = max(all_slices_min)
    common_max = min(all_slices_max)
    print(f"\nCommon available Z-slice range: {common_min} to {common_max}")
    print(f"Recommended CFG.Z_START = {common_min}")
    # The end of range in Python is exclusive, so we need to add 1 to the max slice number
    print(f"Recommended CFG.Z_END = {common_max + 1}")
else:
    print("\nCould not determine a common slice range.")

In [None]:
import os
import gc
import random
import time
import glob
import numpy as np
import pandas as pd
import cv2
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp

# --- Configuration ---
class CFG:
    # General
    DEBUG = False
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    USE_AMP = True
    
    # Data
    DATA_PATH = 'train/'
    TRAIN_FRAGMENTS = [1]
    VAL_FRAGMENTS = [2]
    
    # Slices (as per expert advice)
    Z_START = 20
    Z_END = 45
    IN_CHANNELS = Z_END - Z_START
    
    # Tiling (as per expert advice)
    TILE_SIZE = 256
    STRIDE = TILE_SIZE // 2

    # Model
    MODEL_NAME = 'Unet'
    ENCODER_NAME = 'efficientnet-b4'
    ENCODER_WEIGHTS = 'imagenet'
    
    # Training (as per expert advice)
    BATCH_SIZE = 8
    EPOCHS = 12
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-6
    EARLY_STOPPING_PATIENCE = 3
    
    # Pre-caching (as per expert advice)
    CACHE_DIR = './cache_tif_only_v2/'
    REBUILD_CACHE = True

if CFG.DEBUG:
    CFG.TRAIN_FRAGMENTS = [1]
    CFG.VAL_FRAGMENTS = [1]
    CFG.EPOCHS = 2
    CFG.REBUILD_CACHE = True

def set_seed(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
print(f"Device: {CFG.DEVICE}")
print(f"Input Channels: {CFG.IN_CHANNELS}")
print(f"Rebuilding Cache: {CFG.REBUILD_CACHE}")

In [1]:
# --- 1. Configuration ---
import os
import gc
import random
import numpy as np
import pandas as pd
import cv2
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
import time
import shutil

class CFG:
    # General
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    SEED = 42
    
    # Data Paths & Cache
    DATA_PATH = './'
    CACHE_DIR = 'cache_tif_only_v5_b0' # Changed version to avoid conflicts
    
    # Fragments
    TRAIN_FRAGMENTS = [1]
    VALID_FRAGMENTS = [2]
    
    # Slices (as per expert advice)
    Z_START = 20
    Z_END = 45
    IN_CHANNELS = Z_END - Z_START

    # Tiling & Sampling
    TILE_SIZE = 256
    STRIDE = TILE_SIZE // 2
    NEGATIVE_INK_RATIO = 1.0

    # Model
    ENCODER_NAME = 'efficientnet-b0'
    ENCODER_WEIGHTS = 'imagenet'
    MODEL_SAVE_PATH = 'best_tif_only_model_v5_b0.pth' # Changed version

    # Training
    EPOCHS = 15
    BATCH_SIZE = 16
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-6
    USE_AMP = True

def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(CFG.SEED)
print(f"Using device: {CFG.DEVICE}")
print(f"Input channels: {CFG.IN_CHANNELS}")

# --- 2. Data Preparation (with Caching) ---

def get_train_valid_tiles(fragment_id):
    images = []
    for i in range(CFG.Z_START, CFG.Z_END):
        image = cv2.imread(f"{CFG.DATA_PATH}train/{fragment_id}/surface_volume/{i:02}.tif", 0)
        images.append(image)
    images = np.stack(images, axis=-1)

    mask = cv2.imread(f"{CFG.DATA_PATH}train/{fragment_id}/inklabels.png", 0)
    mask = (mask > 0).astype(np.uint8)
    
    roi = cv2.imread(f"{CFG.DATA_PATH}train/{fragment_id}/mask.png", 0)
    roi = (roi > 0).astype(np.uint8)

    coords = []
    positive_coords = []
    negative_coords = []

    for y in range(0, images.shape[0] - CFG.TILE_SIZE + 1, CFG.STRIDE):
        for x in range(0, images.shape[1] - CFG.TILE_SIZE + 1, CFG.STRIDE):
            if roi[y:y+CFG.TILE_SIZE, x:x+CFG.TILE_SIZE].mean() > 0.99:
                tile_mask = mask[y:y+CFG.TILE_SIZE, x:x+CFG.TILE_SIZE]
                if tile_mask.sum() > 0:
                    positive_coords.append((y, x))
                else:
                    negative_coords.append((y, x))

    num_positive = len(positive_coords)
    num_negative_to_sample = int(num_positive * CFG.NEGATIVE_INK_RATIO)
    
    print(f"Found {num_positive} positive tiles and sampling {min(num_negative_to_sample, len(negative_coords))} negative tiles.")

    random.shuffle(negative_coords)
    coords.extend(positive_coords)
    coords.extend(negative_coords[:min(num_negative_to_sample, len(negative_coords))])
    
    os.makedirs(CFG.CACHE_DIR, exist_ok=True)
    for i, (y, x) in enumerate(tqdm(coords, desc=f'Saving Tiles for {fragment_id}')):
        tile_img = images[y:y+CFG.TILE_SIZE, x:x+CFG.TILE_SIZE]
        tile_mask = mask[y:y+CFG.TILE_SIZE, x:x+CFG.TILE_SIZE]
        np.save(f'{CFG.CACHE_DIR}/{fragment_id}_{i}_img.npy', tile_img)
        np.save(f'{CFG.CACHE_DIR}/{fragment_id}_{i}_mask.npy', tile_mask)
        
    print(f"Saved {len(coords)} total tiles for fragment {fragment_id}.")
    return [f'{fragment_id}_{i}' for i in range(len(coords))]

class VesuviusDataset(Dataset):
    def __init__(self, tile_ids, transform=None):
        self.tile_ids = tile_ids
        self.transform = transform

    def __len__(self):
        return len(self.tile_ids)

    def __getitem__(self, idx):
        tile_id = self.tile_ids[idx]
        image = np.load(f'{CFG.CACHE_DIR}/{tile_id}_img.npy')
        mask = np.load(f'{CFG.CACHE_DIR}/{tile_id}_mask.npy')

        p_low, p_high = np.percentile(image, [1, 99])
        image = np.clip((image.astype(np.float32) - p_low) / (p_high - p_low + 1e-6), 0, 1)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        
        return image, mask.unsqueeze(0)

def get_transforms():
    train_transform = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        ToTensorV2(transpose_mask=True),
    ])
    valid_transform = A.Compose([
        ToTensorV2(transpose_mask=True),
    ])
    return train_transform, valid_transform

# --- 3. Loss, Model, and Training Loop ---

class BCETverskyLoss(nn.Module):
    def __init__(self, bce_weight=0.5, tversky_alpha=0.5, tversky_beta=0.5):
        super(BCETverskyLoss, self).__init__()
        self.bce_weight = bce_weight
        self.tversky_alpha = tversky_alpha
        self.tversky_beta = tversky_beta
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, inputs, targets):
        bce_loss = self.bce(inputs, targets)
        probs = torch.sigmoid(inputs)
        tp = (probs * targets).sum()
        fp = (probs * (1 - targets)).sum()
        fn = ((1 - probs) * targets).sum()
        tversky_index = tp / (tp + self.tversky_alpha * fp + self.tversky_beta * fn + 1e-6)
        tversky_loss = 1 - tversky_index
        return self.bce_weight * bce_loss + (1 - self.bce_weight) * tversky_loss

def fbeta_score(y_pred, y_true, beta=0.5, thr=0.5):
    y_pred = (y_pred > thr).float()
    tp = (y_true * y_pred).sum()
    fp = ((1 - y_true) * y_pred).sum()
    fn = (y_true * (1 - y_pred)).sum()
    precision = tp / (tp + fp + 1e-6)
    recall = tp / (tp + fn + 1e-6)
    fbeta = (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall + 1e-6)
    return fbeta

def get_dataloaders():
    train_tiles, valid_tiles = [], []
    print("--- Preparing Training Data ---")
    for fragment_id in CFG.TRAIN_FRAGMENTS:
        print(f"Processing fragment {fragment_id}...")
        train_tiles.extend(get_train_valid_tiles(fragment_id))
    print("--- Preparing Validation Data ---")
    for fragment_id in CFG.VALID_FRAGMENTS:
        print(f"Processing fragment {fragment_id}...")
        valid_tiles.extend(get_train_valid_tiles(fragment_id))

    train_transform, valid_transform = get_transforms()
    train_dataset = VesuviusDataset(train_tiles, transform=train_transform)
    valid_dataset = VesuviusDataset(valid_tiles, transform=valid_transform)

    train_loader = DataLoader(train_dataset, batch_size=CFG.BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    valid_loader = DataLoader(valid_dataset, batch_size=CFG.BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
    return train_loader, valid_loader

def train_one_epoch(model, optimizer, criterion, loader, scaler):
    model.train()
    total_loss = 0
    for images, masks in tqdm(loader, desc="Training"):
        images, masks = images.to(CFG.DEVICE), masks.to(CFG.DEVICE)
        optimizer.zero_grad()
        with autocast(enabled=CFG.USE_AMP):
            outputs = model(images)
            loss = criterion(outputs, masks)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
    return total_loss / len(loader)

def validate(model, criterion, loader):
    model.eval()
    total_loss = 0
    all_preds, all_targets = [], []
    with torch.no_grad():
        for images, masks in tqdm(loader, desc="Validating"):
            images, masks = images.to(CFG.DEVICE), masks.to(CFG.DEVICE)
            with autocast(enabled=CFG.USE_AMP):
                outputs = model(images)
                loss = criterion(outputs, masks)
            total_loss += loss.item()
            all_preds.append(torch.sigmoid(outputs).cpu())
            all_targets.append(masks.cpu())
    avg_loss = total_loss / len(loader)
    all_preds = torch.cat(all_preds)
    all_targets = torch.cat(all_targets)
    best_thr, best_score = 0, 0
    for thr in np.arange(0.1, 0.9, 0.05):
        score = fbeta_score(all_preds, all_targets, beta=0.5, thr=thr)
        if score > best_score:
            best_score, best_thr = score, thr
    return avg_loss, best_score, best_thr

def run_training():
    # 1. Instantiate model FIRST to isolate memory issues
    print("Instantiating model and moving to GPU...")
    try:
        model = smp.Unet(CFG.ENCODER_NAME, encoder_weights=CFG.ENCODER_WEIGHTS, in_channels=CFG.IN_CHANNELS, classes=1, activation=None).to(CFG.DEVICE)
        print("Model loaded on GPU successfully.")
    except Exception as e:
        print(f"Failed to load model onto GPU: {e}")
        os.system('nvidia-smi')
        return

    # 2. Prepare data and cache
    if os.path.exists(CFG.CACHE_DIR):
        print(f"Removing old cache directory: {CFG.CACHE_DIR}")
        shutil.rmtree(CFG.CACHE_DIR)
    
    train_loader, valid_loader = get_dataloaders()
    print("Data loaded and cached.")
    
    # 3. Setup for training
    optimizer = optim.AdamW(model.parameters(), lr=CFG.LEARNING_RATE, weight_decay=CFG.WEIGHT_DECAY)
    criterion = BCETverskyLoss(bce_weight=0.5, tversky_alpha=0.7, tversky_beta=0.3)
    scaler = GradScaler(enabled=CFG.USE_AMP)
    best_val_score = 0

    # 4. Run training loop
    for epoch in range(CFG.EPOCHS):
        start_time = time.time()
        train_loss = train_one_epoch(model, optimizer, criterion, train_loader, scaler)
        val_loss, val_score, val_thr = validate(model, criterion, valid_loader)
        end_time = time.time()
        print(f"Epoch {epoch+1}/{CFG.EPOCHS} | Time: {end_time-start_time:.2f}s | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val F0.5: {val_score:.4f} @ Thr: {val_thr:.2f}")
        if val_score > best_val_score:
            best_val_score = val_score
            torch.save(model.state_dict(), CFG.MODEL_SAVE_PATH)
            print(f"   -> New best model saved with F0.5 score: {best_val_score:.4f}")
    
    print(f"\nTraining complete. Best validation F0.5 score: {best_val_score:.4f}")
    del model, train_loader, valid_loader
    gc.collect()
    torch.cuda.empty_cache()

# --- 4. Execute Training ---
run_training()

Using device: cuda
Input channels: 25
Instantiating model and moving to GPU...
Failed to load model onto GPU: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Fri Sep 26 04:19:14 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.144.06             Driver Version: 550.144.06     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |


|   0  NVIDIA A10-24Q                 On  |   00000002:00:00.0 Off |                    0 |
| N/A   N/A    P0             N/A /  N/A  |   21842MiB /  24512MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
+-----------------------------------------------------------------------------------------+


In [None]:
# --- 3. Run Training ---

if __name__ == '__main__':
    if CFG.REBUILD_CACHE and os.path.exists(CFG.CACHE_DIR):
        print(f"Removing existing cache directory: {CFG.CACHE_DIR}")
        import shutil
        shutil.rmtree(CFG.CACHE_DIR)
    
    run_training()