In [None]:
import os
import time
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from glob import glob
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import torch.nn.functional as F
from PIL import Image
from pathlib import Path


In [None]:
# os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Replaced by more flexible logic below
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' # Kept as per your original setup

# --- GPU Configuration ---
SPECIFIC_GPU_INDEX = 1  # <<< YOU CAN CHANGE THIS GPU INDEX (e.g., 0, 1, 2, ...)
SELECTED_DEVICE = None
if torch.cuda.is_available():
    if SPECIFIC_GPU_INDEX < torch.cuda.device_count():
        SELECTED_DEVICE = torch.device(f'cuda:{SPECIFIC_GPU_INDEX}')
        print(f"Attempting to use CUDA device: cuda:{SPECIFIC_GPU_INDEX}")
    else:
        print(f"Warning: GPU index {SPECIFIC_GPU_INDEX} is out of range (0-{torch.cuda.device_count()-1}).")
        # Fallback logic: try cuda:0, then CPU
        if torch.cuda.device_count() > 0:
            SELECTED_DEVICE = torch.device('cuda:0')
            print("Falling back to GPU 0 (cuda:0).")
        else:
            SELECTED_DEVICE = torch.device('cpu')
            print("No GPUs available. Falling back to CPU.")
else:
    SELECTED_DEVICE = torch.device('cpu')
    print("CUDA not available. Using CPU.")
print(f"Selected device: {SELECTED_DEVICE}")


CONFIG = {
    "version_name": "v1-100epoch-lr1e-3-batch4-18may-13h-21m", # Example, keep your versioning
    "log_base_dir": Path("experiments"), # This will be experiments within unet_train/
    "img_size": 256,
    "batch_size": 4,
    "num_workers": 2,
    "lr": 1e-3,
    "bce_weight": 0.5,
    "num_epochs": 100,
    "device": SELECTED_DEVICE, # Use the dynamically selected device
    # data_root should point from unet_train/unet_train.ipynb to glomeruli_segmentation/datasets/
    "data_root": Path("../datasets"), # Relative path from unet_train/ to datasets/
}

# Verify paths (optional, for debugging)
print(f"Data root resolved to: {CONFIG['data_root'].resolve()}")
print(f"Log base directory: {CONFIG['log_base_dir'].resolve()}")

In [None]:
class GlomeruliDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = Path(image_dir)
        self.mask_dir = Path(mask_dir)
        self.filenames = sorted(os.listdir(image_dir))
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_name = self.filenames[idx]
        image = np.array(Image.open(self.image_dir / img_name).convert("RGB"))
        mask = np.array(Image.open(self.mask_dir / img_name).convert("L"))
        mask = (mask != 0).astype("float32")

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"].unsqueeze(0)
        return image, mask


In [None]:
def get_transforms(img_size):
    train_transform = A.Compose([
        A.Resize(img_size, img_size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ElasticTransform(alpha=1, sigma=50, p=0.5),  # fixed: removed alpha_affine
        A.RandomBrightnessContrast(p=0.2),
        A.Normalize(),
        ToTensorV2(),
    ])
    val_transform = A.Compose([
        A.Resize(img_size, img_size),
        A.Normalize(),
        ToTensorV2(),
    ])
    return train_transform, val_transform


In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super().__init__()
        def conv_block(ic, oc):
            return nn.Sequential(
                nn.Conv2d(ic, oc, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(oc, oc, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
            )
        self.enc1 = conv_block(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = conv_block(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.enc5 = conv_block(512, 1024)
        self.pool5 = nn.MaxPool2d(2)
        self.drop = nn.Dropout2d(0.5)
        self.bottleneck = conv_block(1024, 2048)
        self.up5 = nn.ConvTranspose2d(2048, 1024, kernel_size=2, stride=2)
        self.dec5 = conv_block(2048, 1024)
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = conv_block(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = conv_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = conv_block(128, 64)
        self.final = nn.Conv2d(64, out_channels, kernel_size=1)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))
        e5 = self.enc5(self.pool4(e4))
        b = self.drop(self.bottleneck(self.pool5(e5)))
        d5 = self.dec5(torch.cat([self.up5(b), e5], dim=1))
        d4 = self.dec4(torch.cat([self.up4(d5), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.final(d1)

class BCEDiceLoss(nn.Module):
    def __init__(self, bce_weight=0.5, smooth=1e-7):
        super().__init__()
        self.bce_weight = bce_weight
        self.bce = nn.BCEWithLogitsLoss()
        self.smooth = smooth

    def forward(self, pred, target):
        bce = self.bce(pred, target)
        prob = torch.sigmoid(pred)
        inter = (prob * target).sum(dim=(1, 2, 3))
        union = prob.sum(dim=(1, 2, 3)) + target.sum(dim=(1, 2, 3))
        dice_loss = 1 - ((2 * inter + self.smooth) / (union + self.smooth)).mean()
        return self.bce_weight * bce + (1 - self.bce_weight) * dice_loss


In [None]:
def train():
    cfg = CONFIG
    version_dir = cfg["log_base_dir"] / cfg["version_name"]
    log_dir = version_dir / "logs"
    checkpoint_path = version_dir / "best_model.pth"
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(version_dir, exist_ok=True)

    train_tf, val_tf = get_transforms(cfg["img_size"])
    
    # Updated dataset paths
    train_ds = GlomeruliDataset(
        cfg["data_root"] / "train" / "images", 
        cfg["data_root"] / "train" / "masks", 
        transform=train_tf
    )
    val_ds   = GlomeruliDataset(
        cfg["data_root"] / "val" / "images",   
        cfg["data_root"] / "val" / "masks", 
        transform=val_tf
    )

    train_loader = DataLoader(train_ds, batch_size=cfg["batch_size"], shuffle=True,  num_workers=cfg["num_workers"])
    val_loader   = DataLoader(val_ds,   batch_size=cfg["batch_size"], shuffle=False, num_workers=cfg["num_workers"])

    model = UNet().to(cfg["device"])
    criterion = BCEDiceLoss(cfg["bce_weight"])
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True)

    writer = SummaryWriter(log_dir=str(log_dir))
    best_dice = 0

    for epoch in range(cfg["num_epochs"]):
        model.train()
        running_loss = 0.0
        # Wrap train_loader with tqdm for a progress bar
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg['num_epochs']} [TRAIN]", leave=False)
        for images, masks in train_pbar:
            images, masks = images.to(cfg["device"]), masks.to(cfg["device"])
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            train_pbar.set_postfix(loss=loss.item())
        avg_train_loss = running_loss / len(train_loader)
        writer.add_scalar("Loss/train", avg_train_loss, epoch)

        model.eval()
        val_loss = 0.0
        dice_all, iou_all = [], []
        # Wrap val_loader with tqdm for a progress bar
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{cfg['num_epochs']} [VAL]", leave=False)
        with torch.no_grad():
            for i, (images, masks) in enumerate(val_pbar):
                images, masks = images.to(cfg["device"]), masks.to(cfg["device"])
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()

                preds = torch.sigmoid(outputs)
                pred_bin = (preds > 0.5).float()
                intersection = (pred_bin * masks).sum(dim=(1, 2, 3)) # Keep batch dimension for mean
                # union for IoU
                union_iou = pred_bin.sum(dim=(1, 2, 3)) + masks.sum(dim=(1, 2, 3)) - intersection
                # sum for Dice
                sum_dice = pred_bin.sum(dim=(1,2,3)) + masks.sum(dim=(1,2,3))

                dice = (2. * intersection + 1e-7) / (sum_dice + 1e-7)
                iou = (intersection + 1e-7) / (union_iou + 1e-7)
                
                dice_all.append(dice.mean().item()) # Append mean dice for the batch
                iou_all.append(iou.mean().item())   # Append mean iou for the batch
                val_pbar.set_postfix(loss=loss.item(), dice=dice.mean().item())

                if i == 0: # Log first batch of validation images
                    writer.add_image("val/image", torchvision.utils.make_grid(images), epoch)
                    writer.add_image("val/mask", torchvision.utils.make_grid(masks), epoch)
                    writer.add_image("val/prediction", torchvision.utils.make_grid(pred_bin), epoch) # Log binary prediction for clarity

        avg_val_loss = val_loss / len(val_loader)
        avg_dice = np.mean(dice_all) # More robust way to calculate mean
        avg_iou  = np.mean(iou_all)  # More robust way to calculate mean
        
        writer.add_scalar("Loss/val", avg_val_loss, epoch)
        writer.add_scalar("Dice/val", avg_dice, epoch)
        writer.add_scalar("IoU/val", avg_iou, epoch)
        scheduler.step(avg_dice)

        if avg_dice > best_dice:
            best_dice = avg_dice
            # Save to a temporary file first, then replace, to avoid corrupted file if process is interrupted
            temp_checkpoint_path = str(checkpoint_path) + ".tmp"
            torch.save(model.state_dict(), temp_checkpoint_path)
            os.replace(temp_checkpoint_path, checkpoint_path)
            print(f"Saved new best model with Dice {avg_dice:.4f}") # Emoji removed

        print(f"[{epoch+1}/{cfg['num_epochs']}] Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Dice: {avg_dice:.4f} | IoU: {avg_iou:.4f}")

    writer.close()
    print("Training finished.") # Added a finish message
    
if __name__ == "__main__":
    train()

In [None]:
def test():
    cfg = CONFIG

    # === Define paths based on version ===
    version_dir = cfg["log_base_dir"] / cfg["version_name"]
    checkpoint_path = version_dir / "best_model.pth"
    # Ensure log_dir is defined for test phase if you intend to write separate test logs
    # If reusing training logs, this might not be needed or could point to the same log_dir
    test_log_dir = version_dir / "logs_test" # Or version_dir / "logs" to append
    os.makedirs(test_log_dir, exist_ok=True)


    # === Load test set ===
    # Assuming get_transforms returns (train_transform, val_transform)
    # Using [1] for val_transform as test_transform is good practice
    _, test_transform = get_transforms(cfg["img_size"]) # Get train and val, use val for test
    
    # Updated dataset path
    test_ds = GlomeruliDataset(
        cfg["data_root"] / "test" / "images",
        cfg["data_root"] / "test" / "masks",
        transform=test_transform
    )
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=cfg["num_workers"]) # Added num_workers

    # === Load best model ===
    model = UNet().to(cfg["device"])
    if not checkpoint_path.exists():
        print(f"Error: Model checkpoint not found at {checkpoint_path}")
        return
    model.load_state_dict(torch.load(checkpoint_path, map_location=cfg["device"]))
    model.eval()

    # === TensorBoard Writer ===
    writer = SummaryWriter(log_dir=str(test_log_dir)) # Use test_log_dir
    dice_scores, iou_scores = [], []

    # Wrap test_loader with tqdm for a progress bar
    test_pbar = tqdm(test_loader, desc="Testing", leave=False)
    with torch.no_grad():
        for i, (image, mask) in enumerate(test_pbar): # Changed from image, mask in test_loader
            image, mask = image.to(cfg["device"]), mask.to(cfg["device"])
            pred = model(image)
            pred_bin = (torch.sigmoid(pred) > 0.5).float()

            # Calculations are per image since batch_size=1 for test_loader
            intersection = (pred_bin * mask).sum() # sum over all (C, H, W) dims
            # union for IoU
            union_iou = pred_bin.sum() + mask.sum() - intersection
            # sum for Dice
            sum_dice = pred_bin.sum() + mask.sum()

            dice = (2. * intersection + 1e-7) / (sum_dice + 1e-7)
            iou = (intersection + 1e-7) / (union_iou + 1e-7)

            dice_scores.append(dice.item())
            iou_scores.append(iou.item())
            test_pbar.set_postfix(dice=dice.item(), iou=iou.item())

            # Optionally log some test images and predictions
            if i < 5 : # Log first 5 test images/masks/predictions
                writer.add_image(f"test/image_{i}", image.squeeze(0), global_step=0) # Squeeze batch dim
                writer.add_image(f"test/mask_{i}", mask.squeeze(0), global_step=0)
                writer.add_image(f"test/prediction_{i}", pred_bin.squeeze(0), global_step=0)


    if not dice_scores: # Handle case where test set might be empty
        print("Warning: No data in test set or dice_scores list is empty.")
        mean_dice = 0.0
        mean_iou = 0.0
    else:
        mean_dice = np.mean(dice_scores)
        mean_iou = np.mean(iou_scores)

    print(f"Test Dice: {mean_dice:.4f} | Test IoU: {mean_iou:.4f}")
    writer.add_scalar("Test/Dice_Overall", mean_dice, global_step=0) # Use a distinct name or step
    writer.add_scalar("Test/IoU_Overall", mean_iou, global_step=0)
    writer.close()
    print("Testing finished.")

if __name__ == "__main__":
    # Decide whether to run train or test, or both sequentially
    # For now, assuming you might call them separately or have a different trigger
    # If you run this script directly, it will attempt to call test() if this block is active
    # train() # Uncomment to run training
    test()  # Uncomment to run testing after potential training, or standalone