In [1]:
import os
import numpy as np
import nibabel as nib
import pandas as pd
from glob import glob
from tqdm import tqdm
from scipy.ndimage import zoom, rotate
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt


In [2]:
ROOT_DIR = "/home/raghuram/ARPL/MR-Image-Reconstruction-Using-Deep-Learning/Task02_Heart" 


In [3]:
def load_nifti(path):
    """
    Load a .nii.gz file using nibabel and return:
      - data (as a np.ndarray)
      - affine (4x4 transform matrix)
      - header (NIfTI header)
    """
    nifti = nib.load(path)
    data = nifti.get_fdata(dtype=np.float32)  # or np.float64 if you prefer
    affine = nifti.affine
    header = nifti.header
    return data, affine, header


In [4]:
def get_spacing_from_affine(affine):
    """
    Given a 4x4 affine matrix from nibabel, extract voxel spacing
    along each dimension: (z, y, x).
    """
    # Typically, the diagonal or row vectors of the affine (in absolute value)
    # correspond to voxel size. A simple approach:
    sx = np.sqrt(affine[0, 0]**2 + affine[0, 1]**2 + affine[0, 2]**2)
    sy = np.sqrt(affine[1, 0]**2 + affine[1, 1]**2 + affine[1, 2]**2)
    sz = np.sqrt(affine[2, 0]**2 + affine[2, 1]**2 + affine[2, 2]**2)
    return np.array([sz, sy, sx])


In [5]:
# Jupyter Cell 4

def get_spacing_from_affine(affine):
    """
    Given a 4x4 affine matrix from nibabel, extract the voxel spacing
    along each dimension (z, y, x).
    """
    sx = np.sqrt(affine[0, 0]**2 + affine[0, 1]**2 + affine[0, 2]**2)
    sy = np.sqrt(affine[1, 0]**2 + affine[1, 1]**2 + affine[1, 2]**2)
    sz = np.sqrt(affine[2, 0]**2 + affine[2, 1]**2 + affine[2, 2]**2)
    return np.array([sz, sy, sx])


In [6]:
# Jupyter Cell 5

def resample_volume(volume, current_spacing, new_spacing, is_label=False, order=1):
    """
    Resample a 3D volume (or label) to new_spacing using scipy.ndimage.zoom.
      - volume: np.ndarray of shape (z, y, x).
      - current_spacing: array-like, e.g. [sz, sy, sx].
      - new_spacing: array-like, e.g. [nz, ny, nx].
      - is_label: If True, use nearest neighbor for labels (order=0).
      - order: interpolation order for images (1=linear by default).

    Returns the resampled volume as np.ndarray.
    """
    if is_label:
        order = 0  # nearest-neighbor for segmentations

    zoom_factors = current_spacing / new_spacing
    resampled = zoom(volume, zoom=zoom_factors, order=order)
    return resampled


In [7]:
# Jupyter Cell 6

def min_max_scale_intensity(img, min_val=-57, max_val=164, clip=True):
    """
    Scale intensity to [0, 1] range, assuming raw intensities
    lie roughly in [min_val, max_val]. Optionally clip to [0,1].
    """
    img = img.astype(np.float32)
    img = (img - min_val) / (max_val - min_val)
    if clip:
        img = np.clip(img, 0.0, 1.0)
    return img


In [8]:
# Jupyter Cell 8

def random_rotate90_3d(img, prob=0.5, max_k=3):
    """
    Randomly rotate the 3D image by 90, 180, or 270 degrees (k=1..max_k)
    around a random plane of axes, with probability prob.
    """
    if np.random.rand() < prob:
        k = np.random.randint(1, max_k + 1)  # 1, 2, or 3
        axis_pairs = [(0,1), (1,2), (0,2)]    # e.g., rotate in XY, YZ, or XZ plane
        axes = axis_pairs[np.random.randint(len(axis_pairs))]
        img = np.rot90(img, k=k, axes=axes)
    return img


In [9]:
# Jupyter Cell 9

def random_zoom_3d(img, prob=0.2, min_zoom=0.9, max_zoom=1.1, order=1, is_label=False):
    """
    Randomly zoom a 3D volume in/out with a factor in [min_zoom, max_zoom],
    with probability prob. If is_label=True, uses nearest neighbor.
    """
    if np.random.rand() < prob:
        zf = np.random.uniform(min_zoom, max_zoom, size=3)
        if is_label:
            order = 0  # nearest for labels
        img = zoom(img, zoom=zf, order=order)
    return img


In [10]:
# Jupyter Cell 10

# 1) Gather file paths
images = sorted(glob(os.path.join(ROOT_DIR, "imagesTr", "*.nii.gz")))
labels = sorted(glob(os.path.join(ROOT_DIR, "labelsTr", "*.nii.gz")))

# 2) Set a desired target spacing (z, y, x)
target_spacing = np.array([1.25, 1.25, 1.25])

# 3) Prepare a list for stats
stats_list = []

# 4) Loop over all images/labels
for idx, (img_path, lbl_path) in enumerate(tqdm(zip(images, labels), 
                                                total=len(images),
                                                desc="Processing")):
    # --- Load
    img, aff_img, _ = load_nifti(img_path)
    lbl, aff_lbl, _ = load_nifti(lbl_path)
    
    # --- Original spacings
    orig_spacing_img = get_spacing_from_affine(aff_img)
    orig_spacing_lbl = get_spacing_from_affine(aff_lbl)  # often the same as img
    
    # --- Resample
    img_rs = resample_volume(img, orig_spacing_img, target_spacing, is_label=False, order=1)
    lbl_rs = resample_volume(lbl, orig_spacing_lbl, target_spacing, is_label=True)
    
    # --- Intensity normalization (image only)
    img_scaled = min_max_scale_intensity(img_rs, min_val=-57, max_val=164, clip=True)
    
    # --- Data augmentation
    # Must apply the same random transforms to both image & label

    # Random flip
    if np.random.rand() < 0.5:
        flip_axis = np.random.choice([0,1,2])
        img_scaled = np.flip(img_scaled, axis=flip_axis)
        lbl_rs = np.flip(lbl_rs, axis=flip_axis)

    # Random rotate 90
    if np.random.rand() < 0.5:
        k = np.random.randint(1, 4)  # 1, 2, or 3
        axis_pairs = [(0,1), (1,2), (0,2)]
        ax = axis_pairs[np.random.randint(len(axis_pairs))]
        img_scaled = np.rot90(img_scaled, k=k, axes=ax)
        lbl_rs = np.rot90(lbl_rs, k=k, axes=ax)

    # Random zoom
    if np.random.rand() < 0.2:
        zf = np.random.uniform(0.9, 1.1, size=3)
        img_scaled = zoom(img_scaled, zoom=zf, order=1)
        lbl_rs = zoom(lbl_rs, zoom=zf, order=0)

    # --- Collect stats
    img_min, img_max = float(img_scaled.min()), float(img_scaled.max())
    lbl_min, lbl_max = float(lbl_rs.min()), float(lbl_rs.max())

    stats_list.append({
        "index": idx,
        "image_path": img_path,
        "label_path": lbl_path,
        "shape_img": img_scaled.shape,
        "shape_lbl": lbl_rs.shape,
        "img_min": img_min,
        "img_max": img_max,
        "lbl_min": lbl_min,
        "lbl_max": lbl_max
    })

# 5) Convert stats to a DataFrame
df_stats = pd.DataFrame(stats_list)
df_stats.to_csv("heart_dataset_stats.csv", index=False)
df_stats.head(10)


Processing: 100%|██████████| 20/20 [00:21<00:00,  1.08s/it]


Unnamed: 0,index,image_path,label_path,shape_img,shape_lbl,img_min,img_max,lbl_min,lbl_max
0,0,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(351, 320, 130)","(351, 320, 130)",0.257919,1.0,0.0,1.0
1,1,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(351, 320, 110)","(351, 320, 110)",0.257919,1.0,0.0,1.0
2,2,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(337, 291, 112)","(337, 291, 112)",0.257919,1.0,0.0,1.0
3,3,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(362, 325, 133)","(362, 325, 133)",0.257919,1.0,0.0,1.0
4,4,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(351, 320, 100)","(351, 320, 100)",0.257919,1.0,0.0,1.0
5,5,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(351, 320, 120)","(351, 320, 120)",0.257919,1.0,0.0,1.0
6,6,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(351, 120, 320)","(351, 120, 320)",0.257919,1.0,0.0,1.0
7,7,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(351, 320, 120)","(351, 320, 120)",0.257919,1.0,0.0,1.0
8,8,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(351, 320, 90)","(351, 320, 90)",0.257919,1.0,0.0,1.0
9,9,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(351, 320, 120)","(351, 320, 120)",0.257919,1.0,0.0,1.0


In [11]:
import os
import nibabel as nib
import numpy as np
import torch
from torch.utils.data import Dataset

def center_crop_3d(volume, crop_shape):
    """
    Center-crop a 3D volume (Z, Y, X) to the desired crop_shape.
    If the volume is smaller than crop_shape along any dimension,
    it will just return the original volume for that dimension.
    """
    z, y, x = volume.shape
    cz, cy, cx = crop_shape

    start_z = max((z - cz) // 2, 0)
    start_y = max((y - cy) // 2, 0)
    start_x = max((x - cx) // 2, 0)

    end_z = start_z + cz if (start_z + cz) <= z else z
    end_y = start_y + cy if (start_y + cy) <= y else y
    end_x = start_x + cx if (start_x + cx) <= x else x

    return volume[start_z:end_z, start_y:end_y, start_x:end_x]


class ModifiedHeartDataset(Dataset):
    """
    Loads 3D NIfTI images and labels, applies a center-crop to reduce size
    (and thus memory usage), and returns them as torch tensors.
    """

    def __init__(self, image_paths, label_paths, crop_shape=(128, 128, 128)):
        """
        Args:
            image_paths (list): list of paths to image .nii(.gz) files.
            label_paths (list): list of paths to label .nii(.gz) files.
            crop_shape  (tuple): desired 3D shape after center-cropping (Z, Y, X).
        """
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.crop_shape = crop_shape

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        lbl_path = self.label_paths[idx]

        # Load volumes (shape: [Z, Y, X])
        img_nifti = nib.load(img_path)
        lbl_nifti = nib.load(lbl_path)

        img = img_nifti.get_fdata(dtype=np.float32)
        lbl = lbl_nifti.get_fdata(dtype=np.float32)

        # Center-crop
        img_cropped = center_crop_3d(img, self.crop_shape)
        lbl_cropped = center_crop_3d(lbl, self.crop_shape)

        # Expand dims to (1, Z, Y, X)
        img_cropped = np.expand_dims(img_cropped, axis=0)
        lbl_cropped = np.expand_dims(lbl_cropped, axis=0)

        # Convert to PyTorch tensors
        img_tensor = torch.from_numpy(img_cropped)
        lbl_tensor = torch.from_numpy(lbl_cropped)

        return img_tensor, lbl_tensor


In [13]:
# # Create dataset
# dataset = HeartDataset(
#     image_paths=images,
#     label_paths=labels,
#     target_spacing=np.array([1.25, 1.25, 1.25]),)

# # Create DataLoader
# loader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0)

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv3D(nn.Module):
    """
    A helper module that performs a 3D convolution -> ReLU -> 3D convolution -> ReLU
    """
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class Down3D(nn.Module):
    """
    Downscaling (via MaxPool3D) followed by a DoubleConv3D.
    """
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv = DoubleConv3D(in_ch, out_ch)

    def forward(self, x):
        x = self.pool(x)
        x = self.conv(x)
        return x


class Up3D(nn.Module):
    """
    Upscaling then a DoubleConv3D. We can do either:
      - Transposed Conv if bilinear=False
      - nn.Upsample (trilinear) if bilinear=True
    """
    def __init__(self, in_ch, out_ch, bilinear=True):
        super().__init__()
        self.bilinear = bilinear
        # If using bilinear upsampling, keep #params lower:
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose3d(in_ch // 2, in_ch // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv3D(in_ch, out_ch)

    def forward(self, x1, x2):
        # x1 is decoder feature, x2 is skip connection from encoder
        x1 = self.up(x1)

        # Match sizes by padding
        diffZ = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2,
                        diffZ // 2, diffZ - diffZ // 2])

        # Concatenate
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class UNet3D(nn.Module):
    """
    A 3D U-Net with 4 down-sampling / up-sampling levels.
    """
    def __init__(self, in_channels, out_channels, base_filters=32, bilinear=True):
        super().__init__()
        self.bilinear = bilinear

        # Encoder
        self.inc = DoubleConv3D(in_channels, base_filters)
        self.down1 = Down3D(base_filters, base_filters * 2)
        self.down2 = Down3D(base_filters * 2, base_filters * 4)
        self.down3 = Down3D(base_filters * 4, base_filters * 8)
        factor = 2 if bilinear else 1
        self.down4 = Down3D(base_filters * 8, base_filters * 16 // factor)

        # Decoder
        self.up1 = Up3D(base_filters * 16, base_filters * 8 // factor, bilinear)
        self.up2 = Up3D(base_filters * 8, base_filters * 4 // factor, bilinear)
        self.up3 = Up3D(base_filters * 4, base_filters * 2 // factor, bilinear)
        self.up4 = Up3D(base_filters * 2, base_filters, bilinear)

        # Final 1x1 convolution
        self.outc = nn.Conv3d(base_filters, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        # Decoder
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


In [15]:
def random_crop_3d(img, lbl, crop_size=(64, 128, 128)):
    """
    Randomly crop a 3D patch of size crop_size from img and lbl.
    Assume img, lbl shape: (Z, Y, X).
    """
    z, y, x = img.shape
    cz, cy, cx = crop_size
    
    # pick random start
    z0 = np.random.randint(0, z - cz) if z > cz else 0
    y0 = np.random.randint(0, y - cy) if y > cy else 0
    x0 = np.random.randint(0, x - cx) if x > cx else 0
    
    img_patch = img[z0:z0+cz, y0:y0+cy, x0:x0+cx]
    lbl_patch = lbl[z0:z0+cz, y0:y0+cy, x0:x0+cx]
    return img_patch, lbl_patch


In [16]:
import os
import numpy as np
import nibabel as nib
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from glob import glob

In [17]:
import os
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.amp import autocast, GradScaler
from glob import glob
from tqdm import tqdm

def main():
    # Example root directory containing imagesTr/labelsTr
    images = sorted(glob(os.path.join(ROOT_DIR, "imagesTr", "*.nii.gz")))
    labels = sorted(glob(os.path.join(ROOT_DIR, "labelsTr", "*.nii.gz")))

    # 1) Instantiate dataset & data loader
    train_dataset = ModifiedHeartDataset(
        image_paths=images,
        label_paths=labels,
        crop_shape=(128, 128, 128)
    )
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)

    # 2) Create the 3D UNet model
    model = UNet3D(in_channels=1, out_channels=2, base_filters=32, bilinear=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # 3) Define optimizer & loss function
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    # 4) Mixed-precision (AMP) setup
    #    - Do NOT pass "cuda" as a positional arg. Just use GradScaler() or specify enabled=True.
    scaler = GradScaler()

    # (Optional) directory for checkpoints
    os.makedirs("checkpoints", exist_ok=True)

    # 5) Training loop (example: 100 epochs)
    epochs = 100
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)

        for step, (images_batch, labels_batch) in enumerate(pbar):
            images_batch = images_batch.to(device, dtype=torch.float32)
            labels_batch = labels_batch.to(device, dtype=torch.long)
            labels_batch = labels_batch.squeeze(1)  # shape: (B, Z, Y, X)

            optimizer.zero_grad()

            # Use the NEW recommended autocast signature:
            with autocast(device_type="cuda", dtype=torch.float16):
                outputs = model(images_batch)
                loss = criterion(outputs, labels_batch)

            # Scale gradients
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()
            if step % 5 == 0:
                pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        epoch_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs} finished. Average loss: {epoch_loss:.4f}")

        # Save checkpoints every 50 epochs
        if (epoch + 1) % 50 == 0:
            ckpt_path = os.path.join("checkpoints", f"checkpoint_epoch_{epoch+1}.pth")
            torch.save(model.state_dict(), ckpt_path)
            print(f"Checkpoint saved to {ckpt_path}")
        torch.cuda.empty_cache()

if __name__ == "__main__":
    main()


                                                                         

Epoch 1/100 finished. Average loss: 0.7787


                                                                         

Epoch 2/100 finished. Average loss: 0.1701


                                                                         

Epoch 3/100 finished. Average loss: 0.1021


                                                                         

Epoch 4/100 finished. Average loss: 0.0920


                                                                         

Epoch 5/100 finished. Average loss: 0.0510


                                                                         

Epoch 6/100 finished. Average loss: 0.0508


                                                                         

Epoch 7/100 finished. Average loss: 0.0451


                                                                         

Epoch 8/100 finished. Average loss: 0.0438


                                                                         

Epoch 9/100 finished. Average loss: 0.0444


                                                                          

Epoch 10/100 finished. Average loss: 0.0402


                                                                          

Epoch 11/100 finished. Average loss: 0.0442


                                                                          

Epoch 12/100 finished. Average loss: 0.0450


                                                                          

Epoch 13/100 finished. Average loss: 0.0505


                                                                          

Epoch 14/100 finished. Average loss: 0.0373


                                                                          

Epoch 15/100 finished. Average loss: 0.0390


                                                                          

Epoch 16/100 finished. Average loss: 0.0355


                                                                          

Epoch 17/100 finished. Average loss: 0.0324


                                                                          

Epoch 18/100 finished. Average loss: 0.0270


                                                                          

Epoch 19/100 finished. Average loss: 0.0307


                                                                          

Epoch 20/100 finished. Average loss: 0.0282


                                                                          

Epoch 21/100 finished. Average loss: 0.0294


                                                                          

Epoch 22/100 finished. Average loss: 0.0295


                                                                          

Epoch 23/100 finished. Average loss: 0.0281


                                                                          

Epoch 24/100 finished. Average loss: 0.0249


                                                                          

Epoch 25/100 finished. Average loss: 0.0209


                                                                          

Epoch 26/100 finished. Average loss: 0.0199


                                                                          

Epoch 27/100 finished. Average loss: 0.0168


                                                                          

Epoch 28/100 finished. Average loss: 0.0159


                                                                          

Epoch 29/100 finished. Average loss: 0.0145


                                                                          

Epoch 30/100 finished. Average loss: 0.0132


                                                                          

Epoch 31/100 finished. Average loss: 0.0146


                                                                          

Epoch 32/100 finished. Average loss: 0.0137


                                                                          

Epoch 33/100 finished. Average loss: 0.0138


                                                                          

Epoch 34/100 finished. Average loss: 0.0122


                                                                          

Epoch 35/100 finished. Average loss: 0.0115


                                                                          

Epoch 36/100 finished. Average loss: 0.0110


                                                                          

Epoch 37/100 finished. Average loss: 0.0100


                                                                          

Epoch 38/100 finished. Average loss: 0.0099


                                                                          

Epoch 39/100 finished. Average loss: 0.0105


                                                                          

Epoch 40/100 finished. Average loss: 0.0106


                                                                          

Epoch 41/100 finished. Average loss: 0.0102


                                                                          

Epoch 42/100 finished. Average loss: 0.0098


                                                                          

Epoch 43/100 finished. Average loss: 0.0098


                                                                          

Epoch 44/100 finished. Average loss: 0.0105


                                                                          

Epoch 45/100 finished. Average loss: 0.0097


                                                                          

Epoch 46/100 finished. Average loss: 0.0092


                                                                          

Epoch 47/100 finished. Average loss: 0.0089


                                                                          

Epoch 48/100 finished. Average loss: 0.0088


                                                                          

Epoch 49/100 finished. Average loss: 0.0084


                                                                          

Epoch 50/100 finished. Average loss: 0.0084
Checkpoint saved to checkpoints/checkpoint_epoch_50.pth


                                                                          

Epoch 51/100 finished. Average loss: 0.0096


                                                                          

Epoch 52/100 finished. Average loss: 0.0104


                                                                          

Epoch 53/100 finished. Average loss: 0.0090


                                                                          

Epoch 54/100 finished. Average loss: 0.0090


                                                                          

Epoch 55/100 finished. Average loss: 0.0081


                                                                          

Epoch 56/100 finished. Average loss: 0.0077


                                                                          

Epoch 57/100 finished. Average loss: 0.0081


                                                                          

Epoch 58/100 finished. Average loss: 0.0077


                                                                          

Epoch 59/100 finished. Average loss: 0.0073


                                                                          

Epoch 60/100 finished. Average loss: 0.0070


                                                                          

Epoch 61/100 finished. Average loss: 0.0072


                                                                          

Epoch 62/100 finished. Average loss: 0.0070


                                                                          

Epoch 63/100 finished. Average loss: 0.0066


                                                                          

Epoch 64/100 finished. Average loss: 0.0064


                                                                          

Epoch 65/100 finished. Average loss: 0.0066


                                                                          

Epoch 66/100 finished. Average loss: 0.0064


                                                                          

Epoch 67/100 finished. Average loss: 0.0064


                                                                          

Epoch 68/100 finished. Average loss: 0.0076


                                                                          

Epoch 69/100 finished. Average loss: 0.0083


                                                                          

Epoch 70/100 finished. Average loss: 0.0073


                                                                          

Epoch 71/100 finished. Average loss: 0.0065


                                                                          

Epoch 72/100 finished. Average loss: 0.0061


                                                                          

Epoch 73/100 finished. Average loss: 0.0061


                                                                          

Epoch 74/100 finished. Average loss: 0.0060


                                                                          

Epoch 75/100 finished. Average loss: 0.0061


                                                                          

Epoch 76/100 finished. Average loss: 0.0060


                                                                          

Epoch 77/100 finished. Average loss: 0.0063


                                                                          

Epoch 78/100 finished. Average loss: 0.0069


                                                                          

Epoch 79/100 finished. Average loss: 0.0080


                                                                          

Epoch 80/100 finished. Average loss: 0.0085


                                                                          

Epoch 81/100 finished. Average loss: 0.0081


                                                                          

Epoch 82/100 finished. Average loss: 0.0076


                                                                          

Epoch 83/100 finished. Average loss: 0.0066


                                                                          

Epoch 84/100 finished. Average loss: 0.0067


                                                                          

Epoch 85/100 finished. Average loss: 0.0063


                                                                          

Epoch 86/100 finished. Average loss: 0.0068


                                                                          

Epoch 87/100 finished. Average loss: 0.0066


                                                                          

Epoch 88/100 finished. Average loss: 0.0067


                                                                          

Epoch 89/100 finished. Average loss: 0.0061


                                                                          

Epoch 90/100 finished. Average loss: 0.0060


                                                                          

Epoch 91/100 finished. Average loss: 0.0056


                                                                          

Epoch 92/100 finished. Average loss: 0.0056


                                                                          

Epoch 93/100 finished. Average loss: 0.0058


                                                                          

Epoch 94/100 finished. Average loss: 0.0052


                                                                          

Epoch 95/100 finished. Average loss: 0.0052


                                                                          

Epoch 96/100 finished. Average loss: 0.0051


                                                                          

Epoch 97/100 finished. Average loss: 0.0051


                                                                          

Epoch 98/100 finished. Average loss: 0.0054


                                                                          

Epoch 99/100 finished. Average loss: 0.0060


                                                                           

Epoch 100/100 finished. Average loss: 0.0056
Checkpoint saved to checkpoints/checkpoint_epoch_100.pth




In [18]:
def show_3d_slices(img_3d, title="", figsize=(12, 4)):
    """
    Show 3 slices from a 3D volume:
      1) A slice in the axial plane (z fixed)
      2) A slice in the coronal plane (y fixed)
      3) A slice in the sagittal plane (x fixed)
    img_3d: np.ndarray with shape (Z, Y, X)
    title: optional string for figure title
    """
    z, y, x = img_3d.shape

    # pick the middle indices in each dimension
    z_mid = z // 2
    y_mid = y // 2
    x_mid = x // 2

    fig, axes = plt.subplots(1, 3, figsize=figsize)

    # Axial plane: fix z index
    axes[0].imshow(img_3d[z_mid, :, :], cmap="gray")
    axes[0].set_title(f"Axial (z={z_mid})")

    # Coronal plane: fix y index
    # coronal slice is shape [z, x], so we do img_3d[:, y_mid, :]
    axes[1].imshow(img_3d[:, y_mid, :], cmap="gray")
    axes[1].set_title(f"Coronal (y={y_mid})")

    # Sagittal plane: fix x index
    # sagittal slice is shape [z, y], so we do img_3d[:, :, x_mid]
    # We might transpose so that it doesn't appear rotated, but that's optional
    axes[2].imshow(img_3d[:, :, x_mid].T, cmap="gray", origin="lower")
    axes[2].set_title(f"Sagittal (x={x_mid})")

    fig.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Let's assume you already have:
#   dataset = HeartDataset(...)       # your custom dataset
#   OR a DataLoader: loader = DataLoader(dataset, ...)
# For simplicity, we'll just pull an item directly from the dataset.

# 1) Get the first sample from the dataset
img_tensor, lbl_tensor = dataset[0]  # shape: (1, Z, Y, X) each

# 2) Convert to NumPy (removing the channel dimension)
img_3d = img_tensor.squeeze(0).numpy()  # now shape = (Z, Y, X)

# We'll pick the middle slice (axial view) along Z
z_mid = img_3d.shape[0] // 2
slice_2d = img_3d[z_mid]  # shape = (Y, X)

# 3) Window/Level using percentile-based clipping to improve contrast
p1, p99 = np.percentile(slice_2d, [1, 99])         # 1st & 99th percentile
slice_clipped = np.clip(slice_2d, p1, p99)         # clamp intensities
slice_normalized = (slice_clipped - p1) / (p99 - p1)  # scale to [0..1]

# 4) Display this slice
plt.figure(figsize=(6,6))
plt.imshow(slice_normalized, cmap='gray')
plt.title(f"Middle Slice (z={z_mid}) with Percentile Windowing")
plt.axis('off')
plt.show()


In [None]:
import os
import numpy as np
import nibabel as nib
import torch
from torch.utils.data import Dataset
from scipy.ndimage import zoom
import nibabel as nib
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm  # Progress tracking

class SlicesDataset(Dataset):
    def __init__(self, image_paths, label_paths=None, axis=0, transform=None):
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.axis = axis
        self.transform = transform
        self.slice_index_mapping = []  # Stores (file_idx, slice_idx)

        print("Indexing slices from NIfTI files...")
        
        # Create index mapping without loading files into memory
        for file_idx, img_path in enumerate(tqdm(self.image_paths, desc="Indexing", unit="file")):
            img_nifti = nib.load(img_path)
            img_shape = img_nifti.shape  # Get shape without loading into memory
            num_slices = img_shape[self.axis]

            for slice_idx in range(num_slices):
                self.slice_index_mapping.append((file_idx, slice_idx))  # Store file & slice index

        print(f"Total slices indexed: {len(self.slice_index_mapping)}")

        # Optional: Cache the last loaded file to avoid redundant loading
        self.last_loaded_file = None
        self.last_loaded_data = None

    def __len__(self):
        return len(self.slice_index_mapping)

    def __getitem__(self, idx):
        file_idx, slice_idx = self.slice_index_mapping[idx]
        img_path = self.image_paths[file_idx]
        label_path = self.label_paths[file_idx] if self.label_paths else None

        # Load image if it's not already cached
        if self.last_loaded_file != img_path:
            img_nifti = nib.load(img_path)
            self.last_loaded_data = img_nifti.get_fdata(dtype=np.float32)
            self.last_loaded_file = img_path  # Update cache

        img_3d = self.last_loaded_data
        img_slice = img_3d.take(slice_idx, axis=self.axis)

        # Normalize image
        img_slice = (img_slice - img_slice.min()) / (img_slice.max() - img_slice.min())

        # Load label (if available)
        label_slice = None
        if label_path:
            if self.last_loaded_file != label_path:
                label_nifti = nib.load(label_path)
                label_3d = label_nifti.get_fdata(dtype=np.float32).astype(np.int32)
                self.last_loaded_data = label_3d
                self.last_loaded_file = label_path  # Update cache
            label_slice = self.last_loaded_data.take(slice_idx, axis=self.axis)

        # Convert to tensors
        img_tensor = torch.from_numpy(img_slice[np.newaxis, ...]).float()
        label_tensor = torch.from_numpy(label_slice).long() if label_slice is not None else None

        if self.transform:
            img_tensor = self.transform(img_tensor)

        return img_tensor, label_tensor

# ------------ Utility functions ------------
def load_nifti(path):
    nifti = nib.load(path)
    data = nifti.get_fdata(dtype=np.float32)
    affine = nifti.affine
    header = nifti.header
    return data, affine, header

def get_spacing_from_affine(affine):
    sx = np.sqrt(affine[0, 0]**2 + affine[0, 1]**2 + affine[0, 2]**2)
    sy = np.sqrt(affine[1, 0]**2 + affine[1, 1]**2 + affine[1, 2]**2)
    sz = np.sqrt(affine[2, 0]**2 + affine[2, 1]**2 + affine[2, 2]**2)
    return np.array([sz, sy, sx])

def resample_volume(volume, current_spacing, new_spacing, is_label=False, order=1):
    if is_label:
        order = 0  # nearest-neighbor for labels
    zoom_factors = current_spacing / new_spacing
    return zoom(volume, zoom=zoom_factors, order=order)

def min_max_scale_intensity(img, min_val=-57, max_val=164, clip=True):
    """
    Scale intensity to [0, 1] range, assuming raw intensities
    are roughly in [min_val, max_val]. Optionally clip to [0..1].
    """
    img = (img - min_val) / (max_val - min_val)
    if clip:
        img = np.clip(img, 0.0, 1.0)
    return img

# -------------- Dataset class --------------
class HeartAugCompareDataset(Dataset):
    """
    Returns both the pre-augmentation version and post-augmentation version.
    For each sample, you get (img_pre, lbl_pre, img_aug, lbl_aug).
    """
    def __init__(
        self,
        image_paths,
        label_paths,
        target_spacing=np.array([1.25, 1.25, 1.25]),
        do_augment=True
    ):
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.target_spacing = target_spacing
        self.do_augment = do_augment

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # 1) Load image and label
        img_path = self.image_paths[idx]
        lbl_path = self.label_paths[idx]
        img, aff_img, _ = load_nifti(img_path)
        lbl, aff_lbl, _ = load_nifti(lbl_path)
        
        # 2) Resample
        spacing_img = get_spacing_from_affine(aff_img)
        spacing_lbl = get_spacing_from_affine(aff_lbl)
        
        img_rs = resample_volume(img, spacing_img, self.target_spacing, is_label=False)
        lbl_rs = resample_volume(lbl, spacing_lbl, self.target_spacing, is_label=True)

        # 3) Intensity normalization (for the image)
        img_rs = img_rs.astype(np.float32)  # ensure float32
        img_norm = min_max_scale_intensity(img_rs, min_val=-57, max_val=164, clip=True)
        
        # ---- Save a copy *before* augmentation ----
        img_pre = img_norm.copy()
        lbl_pre = lbl_rs.copy()

        # 4) (Optional) augmentation
        img_aug = img_norm
        lbl_aug = lbl_rs

        if self.do_augment:
            # Random flip
            if np.random.rand() < 0.5:
                axis = np.random.choice([0,1,2])
                img_aug = np.flip(img_aug, axis=axis).copy()
                lbl_aug = np.flip(lbl_aug, axis=axis).copy()

            # Random rotate 90
            if np.random.rand() < 0.5:
                k = np.random.randint(1, 4)
                ax_pairs = [(0,1), (1,2), (0,2)]
                axes = ax_pairs[np.random.randint(len(ax_pairs))]
                img_aug = np.rot90(img_aug, k=k, axes=axes).copy()
                lbl_aug = np.rot90(lbl_aug, k=k, axes=axes).copy()

            # Random zoom
            if np.random.rand() < 0.2:
                zf = np.random.uniform(0.9, 1.1, size=3)
                img_aug = zoom(img_aug, zoom=zf, order=1)
                lbl_aug = zoom(lbl_aug, zoom=zf, order=0)

        # 5) Convert everything to torch tensors
        # shape (1, Z, Y, X)
        img_pre_tensor = torch.from_numpy(img_pre).unsqueeze(0).float()
        lbl_pre_tensor = torch.from_numpy(lbl_pre).unsqueeze(0).long()

        img_aug_tensor = torch.from_numpy(img_aug).unsqueeze(0).float()
        lbl_aug_tensor = torch.from_numpy(lbl_aug).unsqueeze(0).long()

        # Return the "before" + "after" versions
        return (img_pre_tensor, lbl_pre_tensor, img_aug_tensor, lbl_aug_tensor)


In [None]:
from glob import glob
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

root_dir = "/home/raghuram/ARPL/MR-Image-Reconstruction-Using-Deep-Learning/Task02_Heart"
images = sorted(glob(root_dir + "/imagesTr/*.nii.gz"))
labels = sorted(glob(root_dir + "/labelsTr/*.nii.gz"))

# Create the dataset
dataset = HeartAugCompareDataset(
    image_paths=images,
    label_paths=labels,
    target_spacing=np.array([1.25, 1.25, 1.25]),
    do_augment=True  # set True if you want random flips/rotations
)

loader = DataLoader(dataset, batch_size=1, shuffle=False)

# Fetch one batch
img_pre, lbl_pre, img_aug, lbl_aug = next(iter(loader))
# shapes: (1, 1, Z, Y, X) each

# Convert to numpy, remove the batch & channel dimension
img_pre_3d = img_pre[0,0].cpu().numpy()  # shape (Z, Y, X)
img_aug_3d = img_aug[0,0].cpu().numpy()  # shape (Z, Y, X)

# Let's pick the middle slice along z
z_mid = img_pre_3d.shape[0] // 2
slice_pre = img_pre_3d[z_mid]
slice_aug = img_aug_3d[z_mid]

# Plot side-by-side
plt.figure(figsize=(10,5))

plt.subplot(1,2,1)
plt.imshow(slice_pre, cmap='gray')
plt.title("Before Augmentation")
plt.axis('off')

plt.subplot(1,2,2)
plt.imshow(slice_aug, cmap='gray')
plt.title("After Augmentation")
plt.axis('off')

plt.tight_layout()
plt.show()


In [23]:
# test_dataset.py
import nibabel as nib
import numpy as np
import torch
from torch.utils.data import Dataset

def center_crop_3d(volume, crop_shape):
    z, y, x = volume.shape
    cz, cy, cx = crop_shape

    start_z = max((z - cz) // 2, 0)
    start_y = max((y - cy) // 2, 0)
    start_x = max((x - cx) // 2, 0)

    end_z = min(start_z + cz, z)
    end_y = min(start_y + cy, y)
    end_x = min(start_x + cx, x)

    return volume[start_z:end_z, start_y:end_y, start_x:end_x]


class TestOnlyDataset(Dataset):
    """
    Loads 3D NIfTI images, optionally center-crops them,
    and returns them as PyTorch tensors (no labels).
    """

    def __init__(self, image_paths, crop_shape=(128, 128, 128)):
        self.image_paths = image_paths
        self.crop_shape = crop_shape

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img_nifti = nib.load(img_path)
        img = img_nifti.get_fdata(dtype=np.float32)  # (Z, Y, X)

        # center-crop (if needed)
        if self.crop_shape is not None:
            img = center_crop_3d(img, self.crop_shape)

        # expand dims => (1, Z, Y, X)
        img = np.expand_dims(img, axis=0)

        img_tensor = torch.from_numpy(img)  # shape (1, Z, Y, X)
        return img_tensor, img_path  # We return the image path too for naming


In [None]:
import os
import nibabel as nib
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from glob import glob
from matplotlib.backends.backend_pdf import PdfPages

# =============================================================================
# 1. Sliding-window Inference Function
# =============================================================================
def sliding_window_inference(volume, model, patch_size, stride, device, use_amp=True):
    """
    Perform sliding-window (patch-based) inference on a 5D tensor.
    
    Args:
      volume (torch.Tensor): Input volume of shape (1, C, D, H, W).
      model (torch.nn.Module): The segmentation model.
      patch_size (tuple): The patch size as (pD, pH, pW).
      stride (tuple): The stride (step) for sliding the window (sD, sH, sW).
      device (torch.device): The device for inference.
      use_amp (bool): Whether to use AMP (mixed precision).
    
    Returns:
      aggregated_logits (torch.Tensor): Aggregated output logits of shape 
          (1, out_channels, D, H, W), averaged over overlapping patches.
    """
    _, C, D, H, W = volume.shape
    pD, pH, pW = patch_size
    sD, sH, sW = stride

    # Run a dummy patch through the model to get number of output channels.
    with torch.no_grad():
        dummy_patch = volume[:, :, :pD, :pH, :pW].to(device)
        if use_amp:
            with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
                dummy_out = model(dummy_patch)
        else:
            dummy_out = model(dummy_patch)
        out_channels = dummy_out.shape[1]

    # Create tensors to accumulate logits and a counter for overlap.
    aggregated_logits = torch.zeros((1, out_channels, D, H, W), device=device)
    count_map = torch.zeros((1, 1, D, H, W), device=device)

    # Loop over the volume with the given stride.
    for d in range(0, D, sD):
        for h in range(0, H, sH):
            for w in range(0, W, sW):
                d_start = d
                h_start = h
                w_start = w
                d_end = min(d_start + pD, D)
                h_end = min(h_start + pH, H)
                w_end = min(w_start + pW, W)
                
                # Extract the patch.
                patch = volume[:, :, d_start:d_end, h_start:h_end, w_start:w_end]
                # Determine needed padding (pad only on the end sides).
                pad_d = pD - patch.shape[2]
                pad_h = pH - patch.shape[3]
                pad_w = pW - patch.shape[4]
                if pad_d > 0 or pad_h > 0 or pad_w > 0:
                    # F.pad expects pad in the order: (w_left, w_right, h_left, h_right, d_left, d_right)
                    patch = F.pad(patch, (0, pad_w, 0, pad_h, 0, pad_d))2024 21st International Conference on Ubiquitous Robots (UR), 176-183	
                with torch.no_grad():
                    if use_amp:
                        with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
                            patch_logits = model(patch)
                    else:
                        patch_logits = model(patch)
                
                # Remove any extra padded predictions.
                actual_d = d_end - d_start
                actual_h = h_end - h_start
                actual_w = w_end - w_start
                patch_logits = patch_logits[:, :, :actual_d, :actual_h, :actual_w]
                
                # Accumulate logits and update count map.
                aggregated_logits[:, :, d_start:d_end, h_start:h_end, w_start:w_end] += patch_logits
                count_map[:, :, d_start:d_end, h_start:h_end, w_start:w_end] += 1

    # Average overlapping regions.
    aggregated_logits = aggregated_logits / count_map
    return aggregated_logits

def infer_full_volume(volume_np, model, patch_size=(64,64,64), stride=(32,32,32),
                      device=torch.device("cuda"), use_amp=True):
    """
    Given a full volume (numpy array of shape (D, H, W)), perform patch-based inference
    using the model and return the predicted segmentation mask.
    """
    # Convert to tensor with shape (1, 1, D, H, W)
    volume_tensor = torch.from_numpy(volume_np).unsqueeze(0).unsqueeze(0)
    volume_tensor = volume_tensor.to(device, dtype=torch.float32)
    
    aggregated_logits = sliding_window_inference(volume_tensor, model, patch_size, stride, device, use_amp)
    # Argmax to get final segmentation mask (shape: (1, D, H, W))
    pred_mask = torch.argmax(aggregated_logits, dim=1)
    return pred_mask.cpu().numpy().squeeze(0)  # shape: (D, H, W)

# =============================================================================
# 2. Inference and Visualization Function
# =============================================================================
def visualize_inference():
    """
    Loads full test volumes from a test directory, runs patch-based inference using
    mixed precision, applies the predicted mask to create a masked image, and saves
    a PDF file with side-by-side visualization (center axial slice) of the original 
    and masked images.
    """
    # ----- Setup paths and device -----
    test_folder = os.path.join(ROOT_DIR, "imagesTs")
    test_images = sorted(glob(os.path.join(test_folder, "*.nii.gz")))
    output_pdf = "inference_results.pdf"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # ----- Load the trained model -----
    # (Assuming UNet3D is already defined and imported)
    model = UNet3D(in_channels=1, out_channels=2, base_filters=32, bilinear=True)
    CHECKPOINT_PATH = "checkpoints/checkpoint_epoch_100.pth"  # Adjust path as needed
    # You may consider using weights_only=True if available and safe.
    model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location="cpu"))
    model.to(device)
    model.eval()
    
    # ----- Set patch-based inference parameters -----
    patch_size = (64, 64, 64)  # Smaller patches to fit in memory
    stride = (32, 32, 32)       # Overlap between patches
    
    # ----- Prepare PDF for visualization -----
    with PdfPages(output_pdf) as pdf:
        for idx, img_path in enumerate(test_images):
            base_name = os.path.basename(img_path)
            print(f"Processing {base_name} with patch-based inference...")
            
            # Load the full volume (shape: D x H x W)
            img_nifti = nib.load(img_path)
            img_np = img_nifti.get_fdata(dtype=np.float32)
            
            # Run patch-based inference to obtain predicted segmentation mask.
            pred_mask_np = infer_full_volume(img_np, model, patch_size, stride, device, use_amp=True)
            
            # Create a masked image (zeroing out background)
            masked_np = img_np * pred_mask_np
            
            # For visualization, we take the center axial slice.
            D, H, W = img_np.shape
            z_mid = D // 2
            original_slice = img_np[z_mid, :, :]
            masked_slice = masked_np[z_mid, :, :]
            
            # Create a figure with 2 subplots: original and masked.
            fig, axs = plt.subplots(1, 2, figsize=(10, 5))
            fig.suptitle(f"{base_name} - Center Axial Slice (z={z_mid})", fontsize=16)
            axs[0].imshow(original_slice, cmap='gray')
            axs[0].set_title("Original")
            axs[1].imshow(masked_slice, cmap='gray')
            axs[1].set_title("Masked")
            for ax in axs:
                ax.set_xticks([])
                ax.set_yticks([])
            plt.tight_layout()
            pdf.savefig(fig)
            plt.close(fig)
            print(f"Finished processing {base_name}")
            
    print(f"Inference results saved to: {output_pdf}")

# =============================================================================
# 3. Main Guard
# =============================================================================
if __name__ == "__main__":
    visualize_inference()


  model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location="cpu"))


Processing la_001.nii.gz with patch-based inference...
Finished processing la_001.nii.gz
Processing la_002.nii.gz with patch-based inference...
Finished processing la_002.nii.gz
Processing la_006.nii.gz with patch-based inference...
Finished processing la_006.nii.gz
Processing la_008.nii.gz with patch-based inference...
Finished processing la_008.nii.gz
Processing la_012.nii.gz with patch-based inference...
Finished processing la_012.nii.gz
Processing la_013.nii.gz with patch-based inference...
Finished processing la_013.nii.gz
Processing la_015.nii.gz with patch-based inference...
Finished processing la_015.nii.gz
Processing la_025.nii.gz with patch-based inference...
Finished processing la_025.nii.gz
Processing la_027.nii.gz with patch-based inference...
Finished processing la_027.nii.gz
Processing la_028.nii.gz with patch-based inference...
Finished processing la_028.nii.gz
Inference results saved to: inference_results.pdf


In [18]:
import nibabel as nib
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm  # Progress tracking
import torchvision.transforms.functional as TF
TARGET_SIZE = (256, 256)  # Standardize all slices to this size

class SlicesDataset(Dataset):
    def __init__(self, image_paths, label_paths=None, axis=0, transform=None):
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.axis = axis
        self.transform = transform
        self.slice_index_mapping = []  # Stores (file_idx, slice_idx)

        print("Indexing slices from NIfTI files...")
        
        # Create index mapping without loading files into memory
        for file_idx, img_path in enumerate(tqdm(self.image_paths, desc="Indexing", unit="file")):
            img_nifti = nib.load(img_path)
            img_shape = img_nifti.shape  # Get shape without loading into memory
            num_slices = img_shape[self.axis]

            for slice_idx in range(num_slices):
                self.slice_index_mapping.append((file_idx, slice_idx))  # Store file & slice index

        print(f"Total slices indexed: {len(self.slice_index_mapping)}")

        # Optional: Cache the last loaded file to avoid redundant loading
        self.last_loaded_file = None
        self.last_loaded_data = None

    def __len__(self):
        return len(self.slice_index_mapping)

    

    def __getitem__(self, idx):
        file_idx, slice_idx = self.slice_index_mapping[idx]
        img_path = self.image_paths[file_idx]
        label_path = self.label_paths[file_idx] if self.label_paths else None

        # Load image
        if self.last_loaded_file != img_path:
            img_nifti = nib.load(img_path)
            self.last_loaded_data = img_nifti.get_fdata(dtype=np.float32)
            self.last_loaded_file = img_path

        img_3d = self.last_loaded_data
        img_slice = img_3d.take(slice_idx, axis=self.axis)

        # Normalize safely
        if img_slice.max() - img_slice.min() > 0:
            img_slice = (img_slice - img_slice.min()) / (img_slice.max() - img_slice.min())
        else:
            img_slice = np.zeros_like(img_slice)  # Avoid NaNs

        # Convert to tensor and resize
        img_tensor = torch.from_numpy(img_slice[np.newaxis, ...]).float()  # Shape: (1, H, W)
        img_tensor = TF.resize(img_tensor, TARGET_SIZE)  # Resize to fixed shape

        # Process label (if available)
        label_tensor = None
        if label_path:
            if self.last_loaded_file != label_path:
                label_nifti = nib.load(label_path)
                label_3d = label_nifti.get_fdata(dtype=np.float32).astype(np.int32)
                self.last_loaded_data = label_3d
                self.last_loaded_file = label_path
            label_slice = self.last_loaded_data.take(slice_idx, axis=self.axis)

            # Convert label to tensor and resize
            label_tensor = torch.from_numpy(label_slice).long()
            label_tensor = TF.resize(label_tensor.unsqueeze(0), TARGET_SIZE).squeeze(0)

        return img_tensor, label_tensor


In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv2D(nn.Module):
    """
    A helper module that performs:
    (conv2d -> ReLU -> conv2d -> ReLU)
    """
    def __init__(self, in_channels, out_channels):
        super(DoubleConv2D, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class Down2D(nn.Module):
    """
    Downscaling with maxpool then DoubleConv2D.
    """
    def __init__(self, in_channels, out_channels):
        super(Down2D, self).__init__()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv = DoubleConv2D(in_channels, out_channels)

    def forward(self, x):
        x = self.pool(x)
        x = self.conv(x)
        return x

class Up2D(nn.Module):
    """
    Upscaling then DoubleConv2D.
    If `bilinear` is True, use nn.Upsample for upscaling.
    Otherwise, use ConvTranspose2d.
    """
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up2D, self).__init__()
        self.bilinear = bilinear
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv2D(in_channels, out_channels)

    def forward(self, x1, x2):
        # x1: decoder feature map
        # x2: skip connection from encoder
        x1 = self.up(x1)
        # Match x1 size to x2 (in case of odd dimensions)
        diffY = x2.size(2) - x1.size(2)
        diffX = x2.size(3) - x1.size(3)

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        # Concatenate along the channel dimension
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class UNet2D(nn.Module):
    """
    A 2D U-Net implementation.
    """
    def __init__(self, in_channels, out_channels, base_filters=32, bilinear=True):
        super(UNet2D, self).__init__()
        self.bilinear = bilinear

        # Encoder
        self.inc = DoubleConv2D(in_channels, base_filters)
        self.down1 = Down2D(base_filters, base_filters * 2)
        self.down2 = Down2D(base_filters * 2, base_filters * 4)
        self.down3 = Down2D(base_filters * 4, base_filters * 8)
        factor = 2 if bilinear else 1
        self.down4 = Down2D(base_filters * 8, base_filters * 16 // factor)

        # Decoder
        self.up1 = Up2D(base_filters * 16, base_filters * 8 // factor, bilinear)
        self.up2 = Up2D(base_filters * 8, base_filters * 4 // factor, bilinear)
        self.up3 = Up2D(base_filters * 4, base_filters * 2 // factor, bilinear)
        self.up4 = Up2D(base_filters * 2, base_filters, bilinear)

        # Final layer
        self.outc = nn.Conv2d(base_filters, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        # Decoder
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)

        return logits


In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader

# Function to compute Dice Similarity Coefficient (DSC)
def dice_score(pred, target, smooth=1e-5):
    pred = (pred > 0.5).float()  # Convert probabilities to binary mask
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    dice = (2.0 * intersection + smooth) / (union + smooth)
    return dice

# Ensure model is on the correct device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Instantiate dataset & dataloader
train_dataset = SlicesDataset(
    image_paths=train_image_paths,
    label_paths=train_label_paths,
    axis=0  # Axial slicing
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)

# Initialize UNet model
model = UNet2D(in_channels=1, out_channels=2, base_filters=32, bilinear=True).to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Mixed Precision (AMP) for faster training
scaler = torch.cuda.amp.GradScaler()

# Training loop
epochs = 10  # Adjust as needed
for epoch in range(epochs):
    model.train()
    epoch_loss = 0.0
    epoch_dice = 0.0
    batch_count = 0

    print(f"\nStarting Epoch {epoch + 1}/{epochs}")
    print(f"Number of batches in train_loader: {len(train_loader)}")

    with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}", unit="batch") as pbar:
        for slices, labels in train_loader:
            # Move data to GPU (if available)
            slices = slices.to(device, dtype=torch.float32)  # Shape: (B, 1, H, W)
            labels = labels.to(device, dtype=torch.long)  # Shape: (B, H, W)

            optimizer.zero_grad()

            # Enable mixed precision (AMP)
            with torch.cuda.amp.autocast():
                outputs = model(slices)  # Shape: (B, C, H, W)
                loss = criterion(outputs, labels)

            # Backpropagation with AMP
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # Compute Dice score
            with torch.no_grad():
                preds = torch.argmax(outputs, dim=1)  # Convert logits to predicted labels
                dice = dice_score(preds, labels)

            # Update metrics
            epoch_loss += loss.item()
            epoch_dice += dice.item()
            batch_count += 1

            # Update tqdm progress bar
            pbar.set_postfix(loss=f"{loss.item():.4f}", dice=f"{dice.item():.4f}")
            pbar.update(1)

    # Compute average loss and Dice score
    avg_loss = epoch_loss / batch_count
    avg_dice = epoch_dice / batch_count
    print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}, Average Dice: {avg_dice:.4f}")

    # Save checkpoint every 50 epochs
    if (epoch + 1) % 50 == 0:
        checkpoint_path = f"checkpoints/model_epoch_{epoch+1}.pth"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved: {checkpoint_path}")


Using device: cuda
Indexing slices from NIfTI files...


Indexing: 100%|██████████| 20/20 [00:00<00:00, 2038.89file/s]
  scaler = torch.cuda.amp.GradScaler()


Total slices indexed: 6400

Starting Epoch 1/10
Number of batches in train_loader: 800


  with torch.cuda.amp.autocast():
Epoch 1/10: 100%|██████████| 800/800 [15:08<00:00,  1.14s/batch, dice=0.0000, loss=0.0046]


Epoch 1 completed. Average Loss: 0.0471, Average Dice: 0.2939

Starting Epoch 2/10
Number of batches in train_loader: 800


Epoch 2/10: 100%|██████████| 800/800 [15:22<00:00,  1.15s/batch, dice=0.0000, loss=0.0101]


Epoch 2 completed. Average Loss: 0.0114, Average Dice: 0.3150

Starting Epoch 3/10
Number of batches in train_loader: 800


Epoch 3/10: 100%|██████████| 800/800 [15:30<00:00,  1.16s/batch, dice=0.4279, loss=0.0024]


Epoch 3 completed. Average Loss: 0.0058, Average Dice: 0.6203

Starting Epoch 4/10
Number of batches in train_loader: 800


Epoch 4/10: 100%|██████████| 800/800 [15:11<00:00,  1.14s/batch, dice=1.0000, loss=0.0000]


Epoch 4 completed. Average Loss: 0.0036, Average Dice: 0.7354

Starting Epoch 5/10
Number of batches in train_loader: 800


Epoch 5/10: 100%|██████████| 800/800 [15:25<00:00,  1.16s/batch, dice=1.0000, loss=0.0004]


Epoch 5 completed. Average Loss: 0.0027, Average Dice: 0.8091

Starting Epoch 6/10
Number of batches in train_loader: 800


Epoch 6/10: 100%|██████████| 800/800 [15:20<00:00,  1.15s/batch, dice=1.0000, loss=0.0000]


Epoch 6 completed. Average Loss: 0.0021, Average Dice: 0.8358

Starting Epoch 7/10
Number of batches in train_loader: 800


Epoch 7/10: 100%|██████████| 800/800 [15:23<00:00,  1.15s/batch, dice=0.8798, loss=0.0030]


Epoch 7 completed. Average Loss: 0.0018, Average Dice: 0.8596

Starting Epoch 8/10
Number of batches in train_loader: 800


Epoch 8/10: 100%|██████████| 800/800 [15:22<00:00,  1.15s/batch, dice=0.9453, loss=0.0015]


Epoch 8 completed. Average Loss: 0.0016, Average Dice: 0.8828

Starting Epoch 9/10
Number of batches in train_loader: 800


Epoch 9/10: 100%|██████████| 800/800 [15:24<00:00,  1.16s/batch, dice=0.8670, loss=0.0036]


Epoch 9 completed. Average Loss: 0.0015, Average Dice: 0.8809

Starting Epoch 10/10
Number of batches in train_loader: 800


Epoch 10/10: 100%|██████████| 800/800 [15:23<00:00,  1.15s/batch, dice=0.9294, loss=0.0014]

Epoch 10 completed. Average Loss: 0.0014, Average Dice: 0.8886





In [23]:
# Save final model after training
final_model_path = os.path.join("/home/raghuram/ARPL/MR-Image-Reconstruction-Using-Deep-Learning/Data/checkpoints", "final_model.pth")
torch.save(model.state_dict(), final_model_path)
print(f"Final trained model saved: {final_model_path}")

Final trained model saved: /home/raghuram/ARPL/MR-Image-Reconstruction-Using-Deep-Learning/Data/checkpoints/final_model.pth


In [24]:
import os
import nibabel as nib
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from glob import glob
from matplotlib.backends.backend_pdf import PdfPages
from torch.utils.data import DataLoader
from torchvision.transforms import functional as TF

# =============================================================================
# 1. 2D Slice-wise Inference Function
# =============================================================================
def infer_2d_slices(volume_np, model, device):
    """
    Performs inference on each axial slice of a 3D volume using a trained 2D UNet model.

    Args:
      volume_np (numpy.ndarray): Input 3D volume (D, H, W).
      model (torch.nn.Module): Trained 2D UNet model.
      device (torch.device): The device for inference.

    Returns:
      pred_mask_np (numpy.ndarray): Predicted segmentation mask (D, H, W).
    """
    D, H, W = volume_np.shape
    pred_mask_np = np.zeros((D, H, W), dtype=np.uint8)

    # Convert to tensor and normalize
    volume_tensor = torch.from_numpy(volume_np).float()
    volume_tensor = (volume_tensor - volume_tensor.min()) / (volume_tensor.max() - volume_tensor.min())  # Normalize

    model.eval()
    with torch.no_grad():
        for z in range(D):
            slice_2d = volume_tensor[z, :, :].unsqueeze(0).unsqueeze(0).to(device)  # Shape: (1, 1, H, W)
            
            # Resize slice if necessary
            target_size = (256, 256)
            slice_2d_resized = TF.resize(slice_2d, target_size)
            
            # Inference
            logits = model(slice_2d_resized)  # Shape: (1, num_classes, H, W)
            pred_mask = torch.argmax(logits, dim=1).cpu().numpy().squeeze(0)  # Shape: (H, W)

            # Resize back to original size
            pred_mask_np[z, :, :] = TF.resize(torch.from_numpy(pred_mask).unsqueeze(0), (H, W)).squeeze(0).numpy()
    
    return pred_mask_np

# =============================================================================
# 2. Inference and Visualization Function
# =============================================================================
def visualize_2d_inference():
    """
    Loads test volumes from a test directory, runs slice-wise inference using a 
    2D UNet model, and saves a PDF showing original and predicted segmentation.
    """
    # ----- Setup paths and device -----
    test_folder = os.path.join(ROOT_DIR, "imagesTs")
    test_images = sorted(glob(os.path.join(test_folder, "*.nii.gz")))
    output_pdf = "inference_results_2d.pdf"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ----- Load the trained 2D model -----
    model = UNet2D(in_channels=1, out_channels=2, base_filters=32, bilinear=True)
    CHECKPOINT_PATH = "checkpoints/final_model.pth"  # Adjust path if needed
    model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location="cpu"))
    model.to(device)
    model.eval()
    
    # ----- Prepare PDF for visualization -----
    with PdfPages(output_pdf) as pdf:
        for idx, img_path in enumerate(test_images):
            base_name = os.path.basename(img_path)
            print(f"Processing {base_name} with slice-wise 2D inference...")

            # Load full volume (D, H, W)
            img_nifti = nib.load(img_path)
            img_np = img_nifti.get_fdata(dtype=np.float32)
            
            # Run slice-wise 2D inference
            pred_mask_np = infer_2d_slices(img_np, model, device)
            
            # Create masked image
            masked_np = img_np * pred_mask_np

            # Visualize center axial slice
            D, H, W = img_np.shape
            z_mid = D // 2
            original_slice = img_np[z_mid, :, :]
            masked_slice = masked_np[z_mid, :, :]

            # Create a figure with 2 subplots: original and masked.
            fig, axs = plt.subplots(1, 2, figsize=(10, 5))
            fig.suptitle(f"{base_name} - Center Axial Slice (z={z_mid})", fontsize=16)
            axs[0].imshow(original_slice, cmap='gray')
            axs[0].set_title("Original")
            axs[1].imshow(masked_slice, cmap='gray')
            axs[1].set_title("Masked")
            for ax in axs:
                ax.set_xticks([])
                ax.set_yticks([])
            plt.tight_layout()
            pdf.savefig(fig)
            plt.close(fig)
            print(f"Finished processing {base_name}")
            
    print(f"Inference results saved to: {output_pdf}")

# =============================================================================
# 3. Main Guard
# =============================================================================
if __name__ == "__main__":
    visualize_2d_inference()


  model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location="cpu"))


Processing la_001.nii.gz with slice-wise 2D inference...
Finished processing la_001.nii.gz
Processing la_002.nii.gz with slice-wise 2D inference...
Finished processing la_002.nii.gz
Processing la_006.nii.gz with slice-wise 2D inference...
Finished processing la_006.nii.gz
Processing la_008.nii.gz with slice-wise 2D inference...
Finished processing la_008.nii.gz
Processing la_012.nii.gz with slice-wise 2D inference...
Finished processing la_012.nii.gz
Processing la_013.nii.gz with slice-wise 2D inference...
Finished processing la_013.nii.gz
Processing la_015.nii.gz with slice-wise 2D inference...
Finished processing la_015.nii.gz
Processing la_025.nii.gz with slice-wise 2D inference...
Finished processing la_025.nii.gz
Processing la_027.nii.gz with slice-wise 2D inference...
Finished processing la_027.nii.gz
Processing la_028.nii.gz with slice-wise 2D inference...
Finished processing la_028.nii.gz
Inference results saved to: inference_results_2d.pdf


In [25]:
import os
import nibabel as nib
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from glob import glob
from matplotlib.backends.backend_pdf import PdfPages
from torch.utils.data import DataLoader
from torchvision.transforms import functional as TF

# =============================================================================
# 1. 2D Slice-wise Inference Function
# =============================================================================
def infer_2d_slices(volume_np, model, device):
    """
    Performs inference on each axial slice of a 3D volume using a trained 2D UNet model.

    Args:
      volume_np (numpy.ndarray): Input 3D volume (D, H, W).
      model (torch.nn.Module): Trained 2D UNet model.
      device (torch.device): The device for inference.

    Returns:
      pred_mask_np (numpy.ndarray): Predicted segmentation mask (D, H, W).
    """
    D, H, W = volume_np.shape
    pred_mask_np = np.zeros((D, H, W), dtype=np.uint8)

    # Convert to tensor and normalize
    volume_tensor = torch.from_numpy(volume_np).float()
    volume_tensor = (volume_tensor - volume_tensor.min()) / (volume_tensor.max() - volume_tensor.min())  # Normalize

    model.eval()
    with torch.no_grad():
        for z in range(D):
            slice_2d = volume_tensor[z, :, :].unsqueeze(0).unsqueeze(0).to(device)  # Shape: (1, 1, H, W)
            
            # Resize slice if necessary
            target_size = (256, 256)
            slice_2d_resized = TF.resize(slice_2d, target_size)
            
            # Inference
            logits = model(slice_2d_resized)  # Shape: (1, num_classes, H, W)
            pred_mask = torch.argmax(logits, dim=1).cpu().numpy().squeeze(0)  # Shape: (H, W)

            # Resize back to original size
            pred_mask_np[z, :, :] = TF.resize(torch.from_numpy(pred_mask).unsqueeze(0), (H, W)).squeeze(0).numpy()
    
    return pred_mask_np

# =============================================================================
# 2. Inference and Visualization Function
# =============================================================================
def visualize_2d_inference():
    """
    Loads test volumes from a test directory, runs slice-wise inference using a 
    2D UNet model, and saves a PDF showing original images with red overlay segmentation.
    """
    # ----- Setup paths and device -----
    test_folder = os.path.join(ROOT_DIR, "imagesTs")
    test_images = sorted(glob(os.path.join(test_folder, "*.nii.gz")))
    output_pdf = "inference_results_2d.pdf"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ----- Load the trained 2D model -----
    model = UNet2D(in_channels=1, out_channels=2, base_filters=32, bilinear=True)
    CHECKPOINT_PATH = "checkpoints/final_model.pth"  # Adjust path if needed
    model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location="cpu"))
    model.to(device)
    model.eval()
    
    # ----- Prepare PDF for visualization -----
    with PdfPages(output_pdf) as pdf:
        for idx, img_path in enumerate(test_images):
            base_name = os.path.basename(img_path)
            print(f"Processing {base_name} with slice-wise 2D inference...")

            # Load full volume (D, H, W)
            img_nifti = nib.load(img_path)
            img_np = img_nifti.get_fdata(dtype=np.float32)
            
            # Run slice-wise 2D inference
            pred_mask_np = infer_2d_slices(img_np, model, device)
            
            # Visualize center axial slice
            D, H, W = img_np.shape
            z_mid = D // 2
            original_slice = img_np[z_mid, :, :]
            mask_slice = pred_mask_np[z_mid, :, :]

            # ----- Create an overlay where the mask is red -----
            img_rgb = np.stack([original_slice] * 3, axis=-1)  # Convert grayscale to RGB
            img_rgb = (img_rgb - img_rgb.min()) / (img_rgb.max() - img_rgb.min())  # Normalize to [0, 1]
            
            mask_rgb = np.zeros_like(img_rgb)
            mask_rgb[:, :, 0] = mask_slice  # Set red channel to mask
            mask_overlay = (img_rgb * 0.5) + (mask_rgb * 0.5)  # Blend 50% red mask

            # Create a figure with 2 subplots: original and masked.
            fig, axs = plt.subplots(1, 2, figsize=(10, 5))
            fig.suptitle(f"{base_name} - Center Axial Slice (z={z_mid})", fontsize=16)
            axs[0].imshow(original_slice, cmap='gray')
            axs[0].set_title("Original")
            axs[1].imshow(mask_overlay)
            axs[1].set_title("Masked (Red Overlay)")
            for ax in axs:
                ax.set_xticks([])
                ax.set_yticks([])
            plt.tight_layout()
            pdf.savefig(fig)
            plt.close(fig)
            print(f"Finished processing {base_name}")
            
    print(f"Inference results saved to: {output_pdf}")

# =============================================================================
# 3. Main Guard
# =============================================================================
if __name__ == "__main__":
    visualize_2d_inference()


  model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location="cpu"))


Processing la_001.nii.gz with slice-wise 2D inference...
Finished processing la_001.nii.gz
Processing la_002.nii.gz with slice-wise 2D inference...
Finished processing la_002.nii.gz
Processing la_006.nii.gz with slice-wise 2D inference...
Finished processing la_006.nii.gz
Processing la_008.nii.gz with slice-wise 2D inference...
Finished processing la_008.nii.gz
Processing la_012.nii.gz with slice-wise 2D inference...
Finished processing la_012.nii.gz
Processing la_013.nii.gz with slice-wise 2D inference...
Finished processing la_013.nii.gz
Processing la_015.nii.gz with slice-wise 2D inference...
Finished processing la_015.nii.gz
Processing la_025.nii.gz with slice-wise 2D inference...
Finished processing la_025.nii.gz
Processing la_027.nii.gz with slice-wise 2D inference...
Finished processing la_027.nii.gz
Processing la_028.nii.gz with slice-wise 2D inference...
Finished processing la_028.nii.gz
Inference results saved to: inference_results_2d.pdf
