<a href="https://colab.research.google.com/github/sajidcsecu/radioGenomic/blob/main/3DGPU(SS_3DCapsNet).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# This is the Code for the Segmentation on Rider Dataset (LUNG1). The Code is worked on the 3D volume over GPU using self-supervised 3D Capsule Net. The balanced sampler, preprocessed data (uniform volume spacing and clipping [-1000, 700]) and the strong augmentation is used in the code...

# (1) Import Required Libraries

In [1]:
!pip install SimpleITK
!pip install pydicom===2.4.3
!pip install pydicom-seg
!pip install numpy==1.23.5
!pip install monai
!pip install torch==1.13.1
!pip install nibabel>=5.0.0

Collecting numpy==1.23.5
  Using cached numpy-1.23.5.tar.gz (10.7 MB)
  Installing build dependencies ... [?25l[?25hdone
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mGetting requirements to build wheel[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Getting requirements to build wheel ... [?25l[?25herror
[1;31merror[0m: [1msubprocess-exited-with-error[0m

[31m×[0m [32mGetting requirements to build wheel[0m did not run successfully.
[31m│[0m exit code: [1;36m1[0m
[31m╰─>[0m See above for output.

[1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
Collecting monai
  Downloading monai-1.5.0-py3-none-any.whl.metadata (13 kB)
Collecting torch<2.7.0,>=2.4.1 (from monai)
  Downloading torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl.metadata (28


# (2) Import required Libraries

In [2]:
import os
import time
import csv
import shutil
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
from monai.networks.nets import UNet  # keep available if needed elsewhere
from monai.networks.layers import Norm
from glob import glob
from monai.transforms import (
            Compose, LoadImaged, EnsureChannelFirstD, Spacingd, Orientationd,
            ScaleIntensityRanged, CropForegroundd, Resized, ToTensord,
            RandFlipd, RandAffined, RandGaussianNoised, RandScaleIntensityd
        )
from monai.data import Dataset, DataLoader
from sklearn.metrics import jaccard_score, f1_score, recall_score, precision_score, accuracy_score
import matplotlib.pyplot as plt
from monai.metrics import DiceMetric

# (3) Mount Google Drive

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## (4). Loss Function

In [4]:
class DiceBCELoss3D(nn.Module):
    def __init__(self, smooth=1e-6, epsilon=1e-8):
        super().__init__()
        self.smooth = smooth
        self.epsilon = epsilon
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, preds, targets):
        preds = preds.flatten()
        targets = targets.flatten()
        preds_sigmoid = torch.sigmoid(preds)
        intersection = (preds_sigmoid * targets).sum()
        dice_loss = 1 - (2. * intersection + self.smooth) / (
            preds_sigmoid.sum() + targets.sum() + self.smooth + self.epsilon)
        bce_loss = self.bce(preds, targets)
        return dice_loss + bce_loss

# (5). Early Stopping

In [5]:
class EarlyStopping:
    def __init__(self, patience=10, verbose=True, min_delta=0, path='checkpoint.pt',
                 start_val_loss_min=None, start_patience_counter=0):
        self.patience = patience
        self.verbose = verbose
        self.min_delta = min_delta
        self.path = path
        self.val_loss_min = start_val_loss_min if start_val_loss_min is not None else np.inf
        self.counter = start_patience_counter
        self.early_stop = False

    def __call__(self, val_loss, model, epoch=None, optimizer=None):
        improved = False
        if val_loss < self.val_loss_min - self.min_delta:
            self.val_loss_min = val_loss
            self.counter = 0
            improved = True
            if self.verbose:
                print(f"✅ Validation loss improved. Saving model...")
        else:
            self.counter += 1
            if self.verbose:
                print(f"⏳ EarlyStopping counter: {self.counter} out of {self.patience}")

        # Always save a full checkpoint
        self.save_checkpoint(model, epoch, optimizer)

        if self.counter >= self.patience:
            self.early_stop = True

        return self.early_stop

    def save_checkpoint(self, model, epoch=None, optimizer=None):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict() if optimizer else None,
            'val_loss': self.val_loss_min,
            'patience_counter': self.counter
        }
        torch.save(checkpoint, self.path)

## (6)Squash Function

In [6]:
def squash(s, dim=-1, eps=1e-8):
    norm = torch.norm(s, dim=dim, keepdim=True)
    scale = (norm ** 2) / (1.0 + norm ** 2)
    return scale * s / (norm + eps)

# (7). Routing By Agreement

In [7]:
class CapsuleLayer3D(nn.Module):
    """
    Capsule layer with dynamic routing.

    Args:
        num_input_caps (int): Number of input capsules (I).
        dim_input_caps (int): Dimensionality of each input capsule (D).
        num_output_caps (int): Number of output capsules (O).
        dim_output_caps (int): Dimensionality of each output capsule (H).
        routing_iters (int): Number of routing iterations.

    Inputs:
        u: Tensor of shape [B, I, D] (input capsules).

    Outputs:
        v: Tensor of shape [B, O, H] (output capsules after routing).
        c: Tensor of shape [B, I, O] (final coupling coefficients).
    """
    def __init__(self, num_input_caps, dim_input_caps,
                 num_output_caps, dim_output_caps,
                 routing_iters=3):
        super().__init__()
        self.num_input_caps = num_input_caps
        self.dim_input_caps = dim_input_caps
        self.num_output_caps = num_output_caps
        self.dim_output_caps = dim_output_caps
        self.routing_iters = routing_iters

        # Weight matrix W: transforms each input capsule into each output capsule space
        # Shape: [I, O, H, D]
        self.W = nn.Parameter(
            0.01 * torch.randn(num_input_caps, num_output_caps, dim_output_caps, dim_input_caps)
        )

    def forward(self, u):
        """
        Forward pass through the capsule layer.

        Args:
            u: [B, I, D] input capsules.

        Returns:
            v: [B, O, H] output capsules.
            c: [B, I, O] coupling coefficients.
        """
        B, I, D = u.shape
        O, H = self.num_output_caps, self.dim_output_caps

        # (1) Transform input capsules into output capsule space
        # Expand u to [B, I, O, D] so it can interact with W
        u_expand = u.unsqueeze(2).expand(-1, -1, O, -1)   # [B, I, O, D]
        # Predicted capsules u_hat = u · W
        u_hat = torch.einsum("biod,iodh->bioh", u_expand, self.W)  # [B, I, O, H]

        # (2) Routing logits (initially zeros)
        b = u.new_zeros(B, I, O)  # [B, I, O]

        # (3) Dynamic routing
        for r in range(self.routing_iters):
            # Coupling coefficients (softmax over output capsules)
            c = F.softmax(b, dim=-1)  # [B, I, O]

            # Weighted sum of predictions
            s = (c.unsqueeze(-1) * u_hat).sum(dim=1)  # [B, O, H]

            # Apply squash non-linearity
            v = squash(s, dim=-1)

            if r < self.routing_iters - 1:
                # Update routing logits with agreement
                b = b + (u_hat * v.unsqueeze(1)).sum(dim=-1)  # [B, I, O]

        return v, c

# (8). Converting CT Features into primary capsule

In [8]:
class PrimaryCaps3DStem(nn.Module):
    """
    Conv3D stem that produces primary capsules from an input patch.
    - in_channels: 1 (CT)
    - conv_channels: internal feature width
    - num_capsules: how many capsule types per spatial location
    - capsule_dim: dimension of each primary capsule vector
    - downsample: number of stride-2 convs to reduce spatial resolution before forming capsules
    """
    def __init__(self, in_channels=1, conv_channels=32, num_capsules=4, capsule_dim=8, downsample=2):
        super().__init__()
        layers = []
        c = conv_channels
        layers.append(nn.Conv3d(in_channels, c, kernel_size=3, padding=1, stride=1))
        layers.append(nn.InstanceNorm3d(c)); layers.append(nn.ReLU(inplace=True))
        for _ in range(downsample):
            layers.append(nn.Conv3d(c, c, kernel_size=3, padding=1, stride=2))
            layers.append(nn.InstanceNorm3d(c)); layers.append(nn.ReLU(inplace=True))
        # additional conv(s)
        layers.append(nn.Conv3d(c, c, kernel_size=3, padding=1, stride=1))
        layers.append(nn.InstanceNorm3d(c)); layers.append(nn.ReLU(inplace=True))
        self.stem = nn.Sequential(*layers)

        # final conv produces num_capsules * capsule_dim channels
        self.to_caps = nn.Conv3d(c, num_capsules * capsule_dim, kernel_size=1)
        self.num_capsules = num_capsules
        self.capsule_dim = capsule_dim

    def forward(self, x):
        # x: [B, 1, D, H, W]
        f = self.stem(x)  # f: [B, C, D', H', W']  <-- return this for reconstruction
        caps = self.to_caps(f)  # [B, num_caps*capsdim, D', H', W']
        B, Ch, Dp, Hp, Wp = caps.shape
        caps = caps.view(B, self.num_capsules, self.capsule_dim, Dp, Hp, Wp)
        # permute to [B, num_capsules, D', H', W', cap_dim]
        caps = caps.permute(0, 1, 3, 4, 5, 2).contiguous()
        B, NC, d, h, w, dim = caps.shape
        I = NC * d * h * w
        caps = caps.view(B, I, dim)  # [B, I, capsule_dim]
        spatial_info = (Dp, Hp, Wp, NC) # required to reshape couplings back
        return caps, spatial_info, f     # <-- return f

## (9) Decoder

In [9]:
class SegCaps3D(nn.Module):
    """
    SegCaps-like 3D segmentation model:
      Input patch -> PrimaryCaps3DStem -> CapsuleLayer3D (with W param)
      -> voxelwise decision map -> decoder upsampling -> segmentation logits.
      Also provides an optional reconstruction head for self-supervised pretraining.
    """
    def __init__(self,
                 in_channels=1,
                 conv_channels=32,
                 num_capsules=4,
                 capsule_dim=8,
                 region_caps=2,
                 region_dim=8,
                 routing_iters=3,
                 upsample_mode='trilinear'):
        super().__init__()

        # Stem: extract low-level conv features + primary capsules
        self.stem = PrimaryCaps3DStem(
            in_channels, conv_channels, num_capsules, capsule_dim, downsample=2
        )

        # Capsule layer (lazy init)
        self.capsule_layer = None
        self.region_caps = region_caps
        self.region_dim = region_dim
        self.routing_iters = routing_iters
        self.upsample_mode = upsample_mode

        # Decoder: upsample capsule map to voxel resolution
        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose3d(1, 8, kernel_size=2, stride=2),
            nn.InstanceNorm3d(8), nn.ReLU(inplace=True),
            nn.ConvTranspose3d(8, 8, kernel_size=2, stride=2),
            nn.InstanceNorm3d(8), nn.ReLU(inplace=True),
            nn.Conv3d(8, 1, kernel_size=1)
        )

        # Reconstruction head (optional, for SS pretraining)
        self.recon_head = nn.Sequential(
            nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1),
            nn.InstanceNorm3d(conv_channels), nn.ReLU(inplace=True),
            nn.Conv3d(conv_channels, 1, kernel_size=1)
        )

    def _init_capsule_layer(self, num_input_caps, dim_input_caps):
        """Lazy init capsule layer when input size is known."""
        if self.capsule_layer is None:
            self.capsule_layer = CapsuleLayer3D(
                num_input_caps, dim_input_caps,
                self.region_caps, self.region_dim,
                routing_iters=self.routing_iters
            ).to(next(self.parameters()).device)

    def forward(self, x):
        """
        Args:
            x: input tensor [B, C, D, H, W]
        Returns:
            logits: segmentation logits [B, 1, D, H, W]
            recon:  reconstruction of input [B, 1, D, H, W]
        """
        B = x.shape[0]

        # --- Stem features ---
        primary_caps, spatial_info, feat_map = self.stem(x)
        Dp, Hp, Wp, NC = spatial_info
        I = primary_caps.shape[1]  # number of input capsules
        dim = primary_caps.shape[2]

        # --- Capsules ---
        self._init_capsule_layer(I, dim)
        v, c = self.capsule_layer(primary_caps)

        # Pick tumor capsule (index 0)
        tumor_index = 0
        c_tumor = c[..., tumor_index]  # [B, I]

        # Reshape to decision map
        c_map = c_tumor.view(B, NC, Dp, Hp, Wp)
        c_map_avg = c_map.mean(dim=1, keepdim=True)  # [B,1,Dp,Hp,Wp]

        # --- Decoder upsampling ---
        logits = self.decoder_conv(c_map_avg)  # [B,1,D,H,W]

        # --- Reconstruction head ---
        recon_coarse = self.recon_head(feat_map)  # [B,1,Dp,Hp,Wp]
        recon = F.interpolate(
            recon_coarse, size=x.shape[2:], mode=self.upsample_mode, align_corners=False
        )  # [B,1,D,H,W]

        return logits, recon



# (10) Test

In [10]:
class SegCapsTest:
    def __init__(self, test_result_path: str, metrics_csv: str, device: torch.device):
        self.test_result_path = test_result_path
        self.metrics_csv = metrics_csv
        self.device = device

        os.makedirs(self.test_result_path, exist_ok=True)
        self._init_metrics_csv()

    def _init_metrics_csv(self):
        if not os.path.exists(self.metrics_csv):
            with open(self.metrics_csv, 'w', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(["SampleID", "Jaccard", "F1", "Recall", "Precision", "Accuracy", "Time"])

    def calculate_metrics(self, y_true: np.ndarray, y_pred: np.ndarray):
        y_true = y_true.astype(bool).flatten()
        y_pred = y_pred.astype(bool).flatten()
        return [
            jaccard_score(y_true, y_pred, zero_division=0),
            f1_score(y_true, y_pred, zero_division=0),
            recall_score(y_true, y_pred, zero_division=0),
            precision_score(y_true, y_pred, zero_division=0),
            accuracy_score(y_true, y_pred)
        ]

    def save_result_slices(self, image: np.ndarray, pred_mask: np.ndarray, true_mask: np.ndarray, sample_id: str):
        sample_dir = os.path.join(self.test_result_path, sample_id)
        os.makedirs(sample_dir, exist_ok=True)

        for i in range(image.shape[0]):
            try:
                fig, ax = plt.subplots(1, 3, figsize=(12, 4))
                ax[0].imshow(image[i], cmap='gray')
                ax[0].set_title('Image')

                ax[1].imshow(true_mask[i], cmap='gray')
                ax[1].set_title('Ground Truth')

                ax[2].imshow(pred_mask[i], cmap='gray')
                ax[2].set_title('Prediction')

                for a in ax: a.axis('off')
                plt.tight_layout()
                plt.savefig(os.path.join(sample_dir, f'slice_{i:03d}.png'))
                plt.close()
            except Exception as e:
                print(f"⚠️ Could not save slice {i} for {sample_id}: {e}")

    def append_metrics_to_csv(self, sample_id: str, metrics: list, elapsed_time: float):
        with open(self.metrics_csv, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([sample_id] + [f"{m:.4f}" for m in metrics] + [f"{elapsed_time:.4f}"])

    def test(self, model: nn.Module, test_loader: DataLoader):
        model.eval()
        total_metrics = np.zeros(5)
        total_times = []

        roi_size = (96, 96, 96)
        sw_batch_size = 1

        with torch.no_grad():
            for batch_idx, batch in enumerate(test_loader):
                image, label = batch["vol"].to(self.device), batch["seg"].to(self.device)
                start_time = time.time()

                pred = sliding_window_inference(
                    inputs=image,
                    roi_size=roi_size,
                    sw_batch_size=sw_batch_size,
                    predictor=model
                )
                pred = torch.sigmoid(pred) > 0.5  # Binary thresholding

                elapsed = time.time() - start_time
                total_times.append(elapsed)

                # Convert to NumPy
                image_np = image[0, 0].cpu().numpy()
                label_np = label[0, 0].cpu().numpy()
                pred_np = pred[0, 0].cpu().numpy()

                # Metrics
                metrics = self.calculate_metrics(label_np, pred_np)
                total_metrics += np.array(metrics)

                # Sample ID
                sample_id = os.path.basename(batch["vol_meta_dict"]["filename_or_obj"][0]).replace(".nii.gz", "")
                self.save_result_slices(image_np, pred_np, label_np, sample_id)
                self.append_metrics_to_csv(sample_id, metrics, elapsed)

        # Print summary
        num_samples = len(test_loader)
        print("\n📊 Average Test Metrics:")
        print(f"Jaccard:  {total_metrics[0]/num_samples:.4f}")
        print(f"F1:       {total_metrics[1]/num_samples:.4f}")
        print(f"Recall:   {total_metrics[2]/num_samples:.4f}")
        print(f"Precision:{total_metrics[3]/num_samples:.4f}")
        print(f"Accuracy: {total_metrics[4]/num_samples:.4f}")
        print(f"⚡ FPS:    {1 / np.mean(total_times):.2f}")

## (11) Training

In [11]:
# ---------- Trainer adapted for SegCaps ----------
class SegCapsTrain:
    def __init__(self, model_file, loss_result_path, lr, num_epochs, device, self_supervised=False, ss_weight=0.1):
        self.model_file = model_file
        self.loss_result_path = loss_result_path
        self.lr = lr
        self.num_epochs = num_epochs
        self.device = device
        self.self_supervised = self_supervised
        self.ss_weight = ss_weight
        self.seeding(42)

    def seeding(self, seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    def epoch_time(self, start_time, end_time):
        elapsed = end_time - start_time
        return int(elapsed / 60), int(elapsed % 60)

    def train_one_epoch(self, model, loader, optimizer, loss_fn):
        model.train()
        epoch_loss = 0.0
        scaler = torch.amp.GradScaler()
        device_type = 'cuda' if self.device.type == 'cuda' else 'cpu'

        for batch in loader:
            inputs, labels = batch["vol"].to(self.device), batch["seg"].to(self.device)
            optimizer.zero_grad()
            with torch.amp.autocast(device_type=device_type):
                logits, recon = model(inputs)  # logits: [B,1,D,H,W]
                seg_loss = loss_fn(logits, labels)
                if self.self_supervised:
                    # reconstruction MSE on voxel intensities (compare input and recon)
                    recon_loss = F.mse_loss(recon, inputs)
                    loss = seg_loss + self.ss_weight * recon_loss
                else:
                    loss = seg_loss
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            epoch_loss += loss.item()
        return epoch_loss / len(loader)

    def evaluate(self, model, loader, loss_fn):
        model.eval()
        epoch_loss = 0.0
        with torch.no_grad():
            for batch in loader:
                inputs, labels = batch["vol"].to(self.device), batch["seg"].to(self.device)
                logits, recon = model(inputs)
                seg_loss = loss_fn(logits, labels)
                if self.self_supervised:
                    recon_loss = F.mse_loss(recon, inputs)
                    loss = seg_loss + self.ss_weight * recon_loss
                else:
                    loss = seg_loss
                epoch_loss += loss.item()
        return epoch_loss / len(loader)

    def execute(self, train_loader, valid_loader):
        model = SegCaps3D(
            in_channels=1,
            conv_channels=32,
            num_capsules=4,
            capsule_dim=8,
            region_caps=2,
            region_dim=8,
            routing_iters=3
        ).to(self.device)

        optimizer = torch.optim.AdamW(model.parameters(), lr=self.lr, weight_decay=1e-5)
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
        loss_fn = DiceBCELoss3D()

        start_epoch = 1
        start_val_loss_min = None
        start_patience_counter = 0
        history = {"train_loss": [], "valid_loss": []}

        # resume checkpoint if exists
        if os.path.exists(self.model_file):
            checkpoint = torch.load(self.model_file, map_location=self.device)
            model.load_state_dict(checkpoint['model_state_dict'], strict=False)
            if checkpoint.get('optimizer_state_dict'):
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            start_epoch = checkpoint.get('epoch', 1) + 1
            start_val_loss_min = checkpoint.get('val_loss', None)
            start_patience_counter = checkpoint.get('patience_counter', 0)

        if os.path.exists(self.loss_result_path):
            with open(self.loss_result_path, 'r') as f:
                reader = csv.reader(f)
                next(reader)
                rows = list(reader)
                if rows:
                    last_epoch = int(rows[-1][0])
                    start_epoch = last_epoch + 1
                    history['train_loss'] = [float(r[1]) for r in rows]
                    history['valid_loss'] = [float(r[2]) for r in rows]
                    if start_val_loss_min is None:
                        start_val_loss_min = min(history['valid_loss'])
            backup_path = self.loss_result_path.replace(".csv", "_backup.csv")
            shutil.copy(self.loss_result_path, backup_path)

        early_stopping = EarlyStopping(
            patience=10,
            min_delta=0.0005,
            path=self.model_file,
            start_val_loss_min=start_val_loss_min,
            start_patience_counter=start_patience_counter
        )

        if not os.path.exists(self.loss_result_path):
            with open(self.loss_result_path, "w", newline="") as f:
                csv.writer(f).writerow(["Epoch", "Train Loss", "Valid Loss"])

        for epoch in range(start_epoch, self.num_epochs + 1):
            start_time = time.time()
            train_loss = self.train_one_epoch(model, train_loader, optimizer, loss_fn)
            valid_loss = self.evaluate(model, valid_loader, loss_fn)
            scheduler.step()
            epoch_mins, epoch_secs = self.epoch_time(start_time, time.time())
            print(f"Epoch {epoch:03d} | Time: {epoch_mins}m {epoch_secs}s | Train: {train_loss:.6f} | Val: {valid_loss:.6f}")

            history['train_loss'].append(train_loss); history['valid_loss'].append(valid_loss)
            with open(self.loss_result_path, "a", newline="") as f:
                csv.writer(f).writerow([epoch, train_loss, valid_loss])

            if early_stopping(valid_loss, model, epoch, optimizer):
                print("🛑 Early stopping triggered.")
                break

            torch.cuda.empty_cache()

## (12) Pipeline

In [None]:
# ---------- Pipeline that mirrors your UnetPipeline but for SegCaps ----------
class SegCapsPipeline:
    def __init__(self, config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.setup_paths()
        print("📦 Loading datasets...")
        self.train_loader, self.valid_loader, self.test_loader = self.prepare_loaders()

    def setup_paths(self):
        os.chdir(self.config['target_dir'])
        self.output_dir = os.path.join(".", "results", self.config['output_folder_name'])
        os.makedirs(self.output_dir, exist_ok=True)

        self.loss_result_file = os.path.join(self.output_dir, "train_and_valid_loss_results.csv")
        self.model_file = os.path.join(self.output_dir, "model.pth")
        self.test_metrics_file = os.path.join(self.output_dir, "test_metrics.csv")
        self.test_result_path = os.path.join(self.output_dir, "test_outputs")
        os.makedirs(self.test_result_path, exist_ok=True)

        self.dataset_dir = os.path.join("./datasets", f"Datasets_{self.config['transformation']}")

    def prepare_loaders(self):
        pixdim = (1, 1, 1)
        a_min, a_max = -1000, 700
        spatial_size = (96, 96, 96)

        def get_files(split):
            ct = sorted(glob(os.path.join(self.dataset_dir, split, "ct", "*.nii.gz")))
            seg = sorted(glob(os.path.join(self.dataset_dir, split, "segment", "*.nii.gz")))
            return [{"vol": c, "seg": s} for c, s in zip(ct, seg)]

        train_transforms = Compose([
            LoadImaged(keys=["vol", "seg"]),
            EnsureChannelFirstD(keys=["vol", "seg"]),
            Spacingd(keys=["vol", "seg"], pixdim=pixdim, mode=("bilinear", "nearest")),
            Orientationd(keys=["vol", "seg"], axcodes="RAS"),
            ScaleIntensityRanged(keys=["vol"], a_min=a_min, a_max=a_max, b_min=0.0, b_max=1.0, clip=True),
            CropForegroundd(keys=["vol", "seg"], source_key="vol"),
            Resized(keys=["vol", "seg"], spatial_size=spatial_size),
            RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=0),
            RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=1),
            RandAffined(keys=["vol", "seg"], prob=0.3, rotate_range=(0.1,0.1,0.1), scale_range=(0.1,0.1,0.1), mode=("bilinear","nearest")),
            RandGaussianNoised(keys=["vol"], prob=0.2, mean=0.0, std=0.1),
            RandScaleIntensityd(keys=["vol"], factors=0.1, prob=0.5),
            ToTensord(keys=["vol", "seg"])
        ])

        base_transforms = Compose([
            LoadImaged(keys=["vol", "seg"]),
            EnsureChannelFirstD(keys=["vol", "seg"]),
            Spacingd(keys=["vol", "seg"], pixdim=pixdim, mode=("bilinear", "nearest")),
            Orientationd(keys=["vol", "seg"], axcodes="RAS"),
            ScaleIntensityRanged(keys=["vol"], a_min=a_min, a_max=a_max, b_min=0.0, b_max=1.0, clip=True),
            CropForegroundd(keys=["vol", "seg"], source_key="vol"),
            Resized(keys=["vol", "seg"], spatial_size=spatial_size),
            ToTensord(keys=["vol", "seg"])
        ])

        train_loader = DataLoader(Dataset(get_files("train"), train_transforms), batch_size=self.config['batch_size'], shuffle=True)
        valid_loader = DataLoader(Dataset(get_files("valid"), base_transforms), batch_size=self.config['batch_size'])
        test_loader = DataLoader(Dataset(get_files("test"), base_transforms), batch_size=1)

        return train_loader, valid_loader, test_loader

    def train(self):
        trainer = SegCapsTrain(
            model_file=self.model_file,
            loss_result_path=self.loss_result_file,
            lr=self.config['learning_rate'],
            num_epochs=self.config['num_epochs'],
            device=self.device,
            self_supervised=self.config.get('self_supervised_pretrain', False),
            ss_weight=self.config.get('ss_weight', 0.1)
        )
        trainer.execute(self.train_loader, self.valid_loader)

    def test(self):
        # Load model
        model = SegCaps3D(
            in_channels=1,
            conv_channels=32,
            num_capsules=4,
            capsule_dim=8,
            region_caps=2,
            region_dim=8,
            routing_iters=3
        ).to(self.device)
        checkpoint = torch.load(self.model_file, map_location=self.device)
        model.load_state_dict(checkpoint['model_state_dict'],strict=False)

        # Run test using SegCapsTest
        tester = SegCapsTest(self.test_result_path, self.test_metrics_file, self.device)
        tester.test(model, self.test_loader)

    def run(self):
        self.train()
        self.test()

# ---- Example main (adapt config as you had before) ----
def main():
    config = {
        'target_dir': "/content/drive/MyDrive/PhDwork/Segmentation",
        'output_folder_name': "Results_SegCaps_Augmented",
        'transformation': "OriginalCT_Nifti_Empty_NonEmpty_slices_In_Train",
        'batch_size': 2,
        'num_epochs': 100,
        'learning_rate': 1e-4,
        'self_supervised_pretrain': False,
        'ss_weight': 0.1
    }
    pipeline = SegCapsPipeline(config)
    pipeline.run()

if __name__ == "__main__":
    main()

📦 Loading datasets...
Epoch 052 | Time: 80m 2s | Train: 1.058874 | Val: 1.058432
✅ Validation loss improved. Saving model...
Epoch 053 | Time: 77m 23s | Train: 1.056699 | Val: 1.055222
✅ Validation loss improved. Saving model...
Epoch 054 | Time: 77m 46s | Train: 1.053670 | Val: 1.052430
✅ Validation loss improved. Saving model...
Epoch 055 | Time: 78m 11s | Train: 1.051080 | Val: 1.050078
✅ Validation loss improved. Saving model...
Epoch 056 | Time: 78m 22s | Train: 1.048905 | Val: 1.048107
✅ Validation loss improved. Saving model...
Epoch 057 | Time: 78m 48s | Train: 1.047247 | Val: 1.046712
✅ Validation loss improved. Saving model...


#(8) Mask Generation

In [None]:
import os
import torch
import nibabel as nib
import numpy as np
from pathlib import Path
from monai.networks.nets import UNet
from monai.transforms import (
    Compose,
    Resized,
    CopyItemsd,
    Invertd,
    LoadImaged,
    EnsureChannelFirstd,
    Spacingd,
    Orientationd,
    ScaleIntensityRanged,
    CropForegroundd,
    EnsureTyped,
    SaveImaged,
    ToTensord,
)
from monai.data import Dataset, DataLoader, decollate_batch
from monai.inferers import sliding_window_inference
from monai.utils import set_determinism
from monai.networks.layers import Norm
# from monai.transforms.utils import SaveTransform



class UNetInferencePipeline:
    def __init__(self, model_path, input_ct_dir, input_seg_dir, output_dir, device="cuda:0"):
        self.device = device if torch.cuda.is_available() else "cpu"
        self.input_ct_dir = input_ct_dir
        self.input_seg_dir = input_seg_dir
        self.output_dir = output_dir
        self.ct_out_dir = os.path.join(output_dir, "ct")
        self.seg_out_dir = os.path.join(output_dir, "segment")
        os.makedirs(self.ct_out_dir, exist_ok=True)
        os.makedirs(self.seg_out_dir, exist_ok=True)
        self.model_path = model_path
        self.model = self._load_model()
        set_determinism(seed=42)
        self.forward_transforms = self._get_forward_transforms()
        self.inverse_transforms = None
        self.dataloader = self._prepare_dataloader()

    def _load_model(self):
        if not os.path.exists(self.model_path):
            raise FileNotFoundError(f"Model file not found at: {self.model_path}")

        model = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=1,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH
        ).to(self.device)

        state_dict = torch.load(self.model_path, map_location=self.device)
        model.load_state_dict(state_dict.get('model_state_dict', state_dict))

        print(f"✅ Model loaded successfully from {self.model_path}")
        return model



    def _get_forward_transforms(self):
        return Compose([
            LoadImaged(keys=["vol"]),
            EnsureChannelFirstd(keys=["vol"]),
            CopyItemsd(keys=["vol"], names=["vol_meta_dict"]),
            Spacingd(keys=["vol"], pixdim=(1.0, 1.0, 1.0), mode="bilinear"),
            Orientationd(keys=["vol"], axcodes="RAS"),
            ScaleIntensityRanged(keys=["vol"], a_min=-1000, a_max=700, b_min=0.0, b_max=1.0, clip=True),
            CropForegroundd(keys=["vol"], source_key="vol"),
            Resized(keys=["vol"], spatial_size=(96, 96, 96)),
            EnsureTyped(keys=["vol"]),
        ])

    def _get_inverse_transforms(self):
        return Compose([
            Invertd(
                keys=["seg"],
                transform=self.forward_transforms,
                orig_keys=["vol"],
                meta_keys=["vol_meta_dict"],
                nearest_interp=True,
                to_tensor=False,
            ),
            EnsureTyped(keys=["seg"])
        ])

    def _prepare_dataloader(self):
        data = []
        for f in os.listdir(self.input_ct_dir):
            if f.endswith(('.nii', '.nii.gz')):
                ct_path = os.path.join(self.input_ct_dir, f)
                data.append({"vol": ct_path})
        print(f"🔍 Found {len(data)} NIfTI files for inference.")
        return DataLoader(Dataset(data=data, transform=self.forward_transforms), batch_size=1, num_workers=0)

    def infer(self):
        self.model.eval()
        with torch.no_grad():
            for i, batch in enumerate(self.dataloader):
                batch = decollate_batch(batch)[0]
                vol_meta = batch["vol_meta_dict"]
                ct = batch["vol"]

                if ct.dim() == 4:
                    ct = ct.unsqueeze(0)
                ct = ct.to(self.device)

                filename = os.path.basename(vol_meta.meta["filename_or_obj"])
                orig_vol = nib.load(vol_meta.meta["filename_or_obj"]).get_fdata()
                print(f"🔍 Inference on [{i+1}] {filename} | shape = {ct.shape}")
                print(f"🔍 Original volume shape = {orig_vol.shape}")
                pred = self.model(ct)
                pred = (torch.sigmoid(pred) > 0.5).float()

                print(f"✅ Predicted mask shape: {pred.shape}")

                batch["seg"] = pred.cpu().squeeze(0)
                print(f"✅ Batch shape: {batch['seg'].shape}")

                if self.inverse_transforms is None:
                    self.inverse_transforms = self._get_inverse_transforms()

                inverted = self.inverse_transforms(batch)
                inv_seg = inverted["seg"].squeeze(0).numpy()
                inv_seg = (inv_seg > 0.5).astype(np.uint8)
                print(f"✅ Inverted mask shape: {inv_seg.shape}")

                self._save_nifti(inv_seg, vol_meta, self.seg_out_dir, filename, is_segmentation=True)


    def _save_nifti(self, array, meta_tensor, out_dir, filename, is_segmentation=False):
        os.makedirs(out_dir, exist_ok=True)
        affine = meta_tensor.meta.get("original_affine", meta_tensor.meta.get("affine", np.eye(4)))
        dtype = np.uint8 if is_segmentation else np.float32
        nib_img = nib.Nifti1Image(array.astype(dtype), affine)
        nib.save(nib_img, os.path.join(out_dir, filename))
        print(f"✅ Saved: {os.path.join(out_dir, filename)}")


if __name__ == "__main__":
    ROOT_DIR = "/content/drive/MyDrive/PhDwork/Segmentation"
    MODEL_PATH = os.path.join(ROOT_DIR, "results", "Results_MONAI_Augmented", "model.pth")
    INPUT_CT_FOLDER = os.path.join(ROOT_DIR, "datasets", "Datasets_OriginalCT_Nifti_Empty_NonEmpty_slices_In_Train", "Lung3", "ct")
    INPUT_SEG_FOLDER = os.path.join(ROOT_DIR, "datasets", "Datasets_OriginalCT_Nifti_Empty_NonEmpty_slices_In_Train", "Lung3", "segment")
    OUTPUT_FOLDER = os.path.join(ROOT_DIR, "datasets", "Datasets_OriginalCT_Nifti_Empty_NonEmpty_slices_In_Train", "Lung3_Predicted")

    os.makedirs(OUTPUT_FOLDER, exist_ok=True)
    os.chdir(ROOT_DIR)

    try:
        pipeline = UNetInferencePipeline(MODEL_PATH, INPUT_CT_FOLDER, INPUT_SEG_FOLDER, OUTPUT_FOLDER)
        pipeline.infer()
        print("🎉 Inference completed successfully for all patients!")
    except FileNotFoundError as e:
        print(f"Error: {e}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")


✅ Model loaded successfully from /content/drive/MyDrive/PhDwork/Segmentation/results/Results_MONAI_Augmented/model.pth
🔍 Found 89 NIfTI files for inference.
🔍 Inference on [1] LUNG3-01.nii.gz | shape = torch.Size([1, 1, 96, 96, 96])
🔍 Original volume shape = (59, 512, 512)
✅ Predicted mask shape: torch.Size([1, 1, 96, 96, 96])
✅ Batch shape: torch.Size([1, 96, 96, 96])
✅ Inverted mask shape: (59, 512, 512)
✅ Saved: /content/drive/MyDrive/PhDwork/Segmentation/datasets/Datasets_OriginalCT_Nifti_Empty_NonEmpty_slices_In_Train/Lung3_Predicted/segment/LUNG3-01.nii.gz
🔍 Inference on [2] LUNG3-02.nii.gz | shape = torch.Size([1, 1, 96, 96, 96])
🔍 Original volume shape = (57, 512, 512)
✅ Predicted mask shape: torch.Size([1, 1, 96, 96, 96])
✅ Batch shape: torch.Size([1, 96, 96, 96])
✅ Inverted mask shape: (57, 512, 512)
✅ Saved: /content/drive/MyDrive/PhDwork/Segmentation/datasets/Datasets_OriginalCT_Nifti_Empty_NonEmpty_slices_In_Train/Lung3_Predicted/segment/LUNG3-02.nii.gz
🔍 Inference on [3] 

In [None]:
if __name__ == "__main__":
    ROOT_DIR = "/content/drive/MyDrive/PhDwork/Segmentation"
    MODEL_PATH = os.path.join(ROOT_DIR, "results", "Results_MONAI_Augmented", "model.pth")
    INPUT_CT_FOLDER = os.path.join(ROOT_DIR, "datasets", "Datasets_OriginalCT_Nifti_Empty_NonEmpty_slices_In_Train", "test", "ct")
    INPUT_SEG_FOLDER = os.path.join(ROOT_DIR, "datasets", "Datasets_OriginalCT_Nifti_Empty_NonEmpty_slices_In_Train", "test", "segment")
    OUTPUT_FOLDER = os.path.join(ROOT_DIR, "datasets", "Datasets_OriginalCT_Nifti_Empty_NonEmpty_slices_In_Train", "test_Predicted")

    os.makedirs(OUTPUT_FOLDER, exist_ok=True)
    os.chdir(ROOT_DIR)

    try:
        pipeline = UNetInferencePipeline(MODEL_PATH, INPUT_CT_FOLDER, INPUT_SEG_FOLDER, OUTPUT_FOLDER)
        pipeline.infer()
        print("🎉 Inference completed successfully for all patients!")
    except FileNotFoundError as e:
        print(f"Error: {e}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")

✅ Model loaded successfully from /content/drive/MyDrive/PhDwork/Segmentation/results/Results_MONAI_Augmented/model.pth
🔍 Found 38 NIfTI files for inference.
🔍 Inference on [1] LUNG1-001.nii.gz | shape = torch.Size([1, 1, 96, 96, 96])
🔍 Original volume shape = (134, 512, 512)
✅ Predicted mask shape: torch.Size([1, 1, 96, 96, 96])
✅ Batch shape: torch.Size([1, 96, 96, 96])
✅ Inverted mask shape: (134, 512, 512)
✅ Saved: /content/drive/MyDrive/PhDwork/Segmentation/datasets/Datasets_OriginalCT_Nifti_Empty_NonEmpty_slices_In_Train/test_Predicted/segment/LUNG1-001.nii.gz
🔍 Inference on [2] LUNG1-025.nii.gz | shape = torch.Size([1, 1, 96, 96, 96])
🔍 Original volume shape = (106, 512, 512)
✅ Predicted mask shape: torch.Size([1, 1, 96, 96, 96])
✅ Batch shape: torch.Size([1, 96, 96, 96])
✅ Inverted mask shape: (106, 512, 512)
✅ Saved: /content/drive/MyDrive/PhDwork/Segmentation/datasets/Datasets_OriginalCT_Nifti_Empty_NonEmpty_slices_In_Train/test_Predicted/segment/LUNG1-025.nii.gz
🔍 Inference o

In [None]:
if __name__ == "__main__":
    ROOT_DIR = "/content/drive/MyDrive/PhDwork/Segmentation"
    MODEL_PATH = os.path.join(ROOT_DIR, "results", "Results_MONAI_Augmented", "model.pth")
    INPUT_CT_FOLDER = os.path.join(ROOT_DIR, "datasets", "Datasets_OriginalCT_Nifti_Empty_NonEmpty_slices_In_Train", "valid", "ct")
    INPUT_SEG_FOLDER = os.path.join(ROOT_DIR, "datasets", "Datasets_OriginalCT_Nifti_Empty_NonEmpty_slices_In_Train", "valid", "segment")
    OUTPUT_FOLDER = os.path.join(ROOT_DIR, "datasets", "Datasets_OriginalCT_Nifti_Empty_NonEmpty_slices_In_Train", "valid_Predicted")

    os.makedirs(OUTPUT_FOLDER, exist_ok=True)
    os.chdir(ROOT_DIR)

    try:
        pipeline = UNetInferencePipeline(MODEL_PATH, INPUT_CT_FOLDER, INPUT_SEG_FOLDER, OUTPUT_FOLDER)
        pipeline.infer()
        print("🎉 Inference completed successfully for all patients!")
    except FileNotFoundError as e:
        print(f"Error: {e}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")

✅ Model loaded successfully from /content/drive/MyDrive/PhDwork/Segmentation/results/Results_MONAI_Augmented/model.pth
🔍 Found 43 NIfTI files for inference.
🔍 Inference on [1] LUNG1-010.nii.gz | shape = torch.Size([1, 1, 96, 96, 96])
🔍 Original volume shape = (91, 512, 512)
✅ Predicted mask shape: torch.Size([1, 1, 96, 96, 96])
✅ Batch shape: torch.Size([1, 96, 96, 96])
✅ Inverted mask shape: (91, 512, 512)
✅ Saved: /content/drive/MyDrive/PhDwork/Segmentation/datasets/Datasets_OriginalCT_Nifti_Empty_NonEmpty_slices_In_Train/valid_Predicted/segment/LUNG1-010.nii.gz
🔍 Inference on [2] LUNG1-031.nii.gz | shape = torch.Size([1, 1, 96, 96, 96])
🔍 Original volume shape = (153, 512, 512)
✅ Predicted mask shape: torch.Size([1, 1, 96, 96, 96])
✅ Batch shape: torch.Size([1, 96, 96, 96])
✅ Inverted mask shape: (153, 512, 512)
✅ Saved: /content/drive/MyDrive/PhDwork/Segmentation/datasets/Datasets_OriginalCT_Nifti_Empty_NonEmpty_slices_In_Train/valid_Predicted/segment/LUNG1-031.nii.gz
🔍 Inference o

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from typing import List
import os
import csv


class LossPlotter:
    def __init__(self, csv_path: str):
        self.csv_path = Path(csv_path)
        self.data = self._load_data()

    def _load_data(self):
        if not self.csv_path.exists():
            raise FileNotFoundError(f"CSV file not found: {self.csv_path}")
        df = pd.read_csv(self.csv_path, index_col=0)  # Read row labels as index
        return df  # Make rows into columns

    def plot(self, title: str = "Training and Validation Loss", save_path= None):
        plt.figure(figsize=(8, 5))
        plt.plot(self.data.index, self.data['Train Loss'], label='Train Loss', color='blue')
        plt.plot(self.data.index, self.data['Valid Loss'], label='Valid Loss', color='orange')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title(title)
        plt.legend()
        plt.grid(True)
        plt.tight_layout()

        if save_path:
            save_path = Path(save_path)
            save_path.parent.mkdir(parents=True, exist_ok=True)
            plt.savefig(save_path, format='pdf')
            print(f"[INFO] Loss plot saved to {save_path}")
        else:
            plt.show()

        plt.close()

if __name__ == "__main__":
    target_dir = "/content/drive/MyDrive/PhDwork/Segmentation"
    os.chdir(target_dir)
    loss_result_file = os.path.join(".","results",f"Results_PreProcessedCT_Fifty_Fifty_DiceLoss_And_Strong_Augmentation","train_and_valid_loss_results.csv")
    plotter = LossPlotter(loss_result_file)
    plotter.plot()


In [None]:
import h5py
os.chdir("/content/drive/MyDrive/PhDwork/Segmentation")
print(f"📁 Current Directory: {os.getcwd()}")
with h5py.File('./datasets/Datasets_PreprocessedCT_clipping_uniformSpacing_With_Empty_NonEmpty_slices_In_Train/train_dataset.hdf5', 'r') as f:
    print(list(f.keys()))