In [1]:
import os
import numpy as np
import nibabel as nib
import pandas as pd
from glob import glob
from tqdm import tqdm
from scipy.ndimage import zoom, rotate
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt


In [2]:
ROOT_DIR = "/home/raghuram/ARPL/MR-Image-Reconstruction-Using-Deep-Learning/Task02_Heart" 


In [3]:
def load_nifti(path):
    """
    Load a .nii.gz file using nibabel and return:
      - data (as a np.ndarray)
      - affine (4x4 transform matrix)
      - header (NIfTI header)
    """
    nifti = nib.load(path)
    data = nifti.get_fdata(dtype=np.float32)  # or np.float64 if you prefer
    affine = nifti.affine
    header = nifti.header
    return data, affine, header


In [4]:
def get_spacing_from_affine(affine):
    """
    Given a 4x4 affine matrix from nibabel, extract voxel spacing
    along each dimension: (z, y, x).
    """
    # Typically, the diagonal or row vectors of the affine (in absolute value)
    # correspond to voxel size. A simple approach:
    sx = np.sqrt(affine[0, 0]**2 + affine[0, 1]**2 + affine[0, 2]**2)
    sy = np.sqrt(affine[1, 0]**2 + affine[1, 1]**2 + affine[1, 2]**2)
    sz = np.sqrt(affine[2, 0]**2 + affine[2, 1]**2 + affine[2, 2]**2)
    return np.array([sz, sy, sx])


In [5]:
# Jupyter Cell 4

def get_spacing_from_affine(affine):
    """
    Given a 4x4 affine matrix from nibabel, extract the voxel spacing
    along each dimension (z, y, x).
    """
    sx = np.sqrt(affine[0, 0]**2 + affine[0, 1]**2 + affine[0, 2]**2)
    sy = np.sqrt(affine[1, 0]**2 + affine[1, 1]**2 + affine[1, 2]**2)
    sz = np.sqrt(affine[2, 0]**2 + affine[2, 1]**2 + affine[2, 2]**2)
    return np.array([sz, sy, sx])


In [6]:
# Jupyter Cell 5

def resample_volume(volume, current_spacing, new_spacing, is_label=False, order=1):
    """
    Resample a 3D volume (or label) to new_spacing using scipy.ndimage.zoom.
      - volume: np.ndarray of shape (z, y, x).
      - current_spacing: array-like, e.g. [sz, sy, sx].
      - new_spacing: array-like, e.g. [nz, ny, nx].
      - is_label: If True, use nearest neighbor for labels (order=0).
      - order: interpolation order for images (1=linear by default).

    Returns the resampled volume as np.ndarray.
    """
    if is_label:
        order = 0  # nearest-neighbor for segmentations

    zoom_factors = current_spacing / new_spacing
    resampled = zoom(volume, zoom=zoom_factors, order=order)
    return resampled


In [7]:
# Jupyter Cell 6

def min_max_scale_intensity(img, min_val=-57, max_val=164, clip=True):
    """
    Scale intensity to [0, 1] range, assuming raw intensities
    lie roughly in [min_val, max_val]. Optionally clip to [0,1].
    """
    img = img.astype(np.float32)
    img = (img - min_val) / (max_val - min_val)
    if clip:
        img = np.clip(img, 0.0, 1.0)
    return img


In [8]:
# Jupyter Cell 8

def random_rotate90_3d(img, prob=0.5, max_k=3):
    """
    Randomly rotate the 3D image by 90, 180, or 270 degrees (k=1..max_k)
    around a random plane of axes, with probability prob.
    """
    if np.random.rand() < prob:
        k = np.random.randint(1, max_k + 1)  # 1, 2, or 3
        axis_pairs = [(0,1), (1,2), (0,2)]    # e.g., rotate in XY, YZ, or XZ plane
        axes = axis_pairs[np.random.randint(len(axis_pairs))]
        img = np.rot90(img, k=k, axes=axes)
    return img


In [9]:
# Jupyter Cell 9

def random_zoom_3d(img, prob=0.2, min_zoom=0.9, max_zoom=1.1, order=1, is_label=False):
    """
    Randomly zoom a 3D volume in/out with a factor in [min_zoom, max_zoom],
    with probability prob. If is_label=True, uses nearest neighbor.
    """
    if np.random.rand() < prob:
        zf = np.random.uniform(min_zoom, max_zoom, size=3)
        if is_label:
            order = 0  # nearest for labels
        img = zoom(img, zoom=zf, order=order)
    return img


In [10]:
# Jupyter Cell 10

# 1) Gather file paths
images = sorted(glob(os.path.join(ROOT_DIR, "imagesTr", "*.nii.gz")))
labels = sorted(glob(os.path.join(ROOT_DIR, "labelsTr", "*.nii.gz")))

# 2) Set a desired target spacing (z, y, x)
target_spacing = np.array([1.25, 1.25, 1.25])

# 3) Prepare a list for stats
stats_list = []

# 4) Loop over all images/labels
for idx, (img_path, lbl_path) in enumerate(tqdm(zip(images, labels), 
                                                total=len(images),
                                                desc="Processing")):
    # --- Load
    img, aff_img, _ = load_nifti(img_path)
    lbl, aff_lbl, _ = load_nifti(lbl_path)
    
    # --- Original spacings
    orig_spacing_img = get_spacing_from_affine(aff_img)
    orig_spacing_lbl = get_spacing_from_affine(aff_lbl)  # often the same as img
    
    # --- Resample
    img_rs = resample_volume(img, orig_spacing_img, target_spacing, is_label=False, order=1)
    lbl_rs = resample_volume(lbl, orig_spacing_lbl, target_spacing, is_label=True)
    
    # --- Intensity normalization (image only)
    img_scaled = min_max_scale_intensity(img_rs, min_val=-57, max_val=164, clip=True)
    
    # --- Data augmentation
    # Must apply the same random transforms to both image & label

    # Random flip
    if np.random.rand() < 0.5:
        flip_axis = np.random.choice([0,1,2])
        img_scaled = np.flip(img_scaled, axis=flip_axis)
        lbl_rs = np.flip(lbl_rs, axis=flip_axis)

    # Random rotate 90
    if np.random.rand() < 0.5:
        k = np.random.randint(1, 4)  # 1, 2, or 3
        axis_pairs = [(0,1), (1,2), (0,2)]
        ax = axis_pairs[np.random.randint(len(axis_pairs))]
        img_scaled = np.rot90(img_scaled, k=k, axes=ax)
        lbl_rs = np.rot90(lbl_rs, k=k, axes=ax)

    # Random zoom
    if np.random.rand() < 0.2:
        zf = np.random.uniform(0.9, 1.1, size=3)
        img_scaled = zoom(img_scaled, zoom=zf, order=1)
        lbl_rs = zoom(lbl_rs, zoom=zf, order=0)

    # --- Collect stats
    img_min, img_max = float(img_scaled.min()), float(img_scaled.max())
    lbl_min, lbl_max = float(lbl_rs.min()), float(lbl_rs.max())

    stats_list.append({
        "index": idx,
        "image_path": img_path,
        "label_path": lbl_path,
        "shape_img": img_scaled.shape,
        "shape_lbl": lbl_rs.shape,
        "img_min": img_min,
        "img_max": img_max,
        "lbl_min": lbl_min,
        "lbl_max": lbl_max
    })

# 5) Convert stats to a DataFrame
df_stats = pd.DataFrame(stats_list)
df_stats.to_csv("heart_dataset_stats.csv", index=False)
df_stats.head(10)


Processing: 100%|██████████| 20/20 [00:23<00:00,  1.20s/it]


Unnamed: 0,index,image_path,label_path,shape_img,shape_lbl,img_min,img_max,lbl_min,lbl_max
0,0,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(351, 326, 129)","(351, 326, 129)",0.257919,1.0,0.0,1.0
1,1,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(110, 320, 351)","(110, 320, 351)",0.257919,1.0,0.0,1.0
2,2,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(351, 320, 120)","(351, 320, 120)",0.257919,1.0,0.0,1.0
3,3,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(130, 320, 351)","(130, 320, 351)",0.257919,1.0,0.0,1.0
4,4,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(351, 320, 100)","(351, 320, 100)",0.257919,1.0,0.0,1.0
5,5,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(351, 120, 320)","(351, 120, 320)",0.257919,1.0,0.0,1.0
6,6,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(351, 120, 320)","(351, 120, 320)",0.257919,1.0,0.0,1.0
7,7,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(351, 320, 120)","(351, 320, 120)",0.257919,1.0,0.0,1.0
8,8,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(351, 320, 90)","(351, 320, 90)",0.257919,1.0,0.0,1.0
9,9,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,/home/raghuram/ARPL/MR-Image-Reconstruction-Us...,"(383, 334, 131)","(383, 334, 131)",0.257919,1.0,0.0,1.0


In [None]:
import os
from glob import glob
import nibabel as nib
import torch
from torch.utils.data import Dataset, DataLoader

class HeartDataset(Dataset):
    def __init__(self, image_paths, label_paths, target_spacing=None):
        """
        Args:
            image_paths  (list): list of paths to the NIfTI images.
            label_paths  (list): list of paths to the corresponding NIfTI labels.
            target_spacing (np.array): optional, if you plan to resample images.
        """
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.target_spacing = target_spacing  # not used here, but can be used for resampling

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        lbl_path = self.label_paths[idx]

        # Load NIfTI images (shape: [Z, Y, X] typically)
        img_nifti = nib.load(img_path)
        lbl_nifti = nib.load(lbl_path)

        img = img_nifti.get_fdata(dtype=np.float32)
        lbl = lbl_nifti.get_fdata(dtype=np.float32)

        # If you have specific transforms (normalization, cropping, etc.), apply them here
        # e.g. min-max normalization, etc.

        # Expand dims to have shape (1, Z, Y, X) for PyTorch
        img = np.expand_dims(img, axis=0)
        lbl = np.expand_dims(lbl, axis=0)

        return torch.from_numpy(img), torch.from_numpy(lbl)


In [12]:
# Create dataset
dataset = HeartDataset(
    image_paths=images,
    label_paths=labels,
    target_spacing=np.array([1.25, 1.25, 1.25]),)

# Create DataLoader
loader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0)

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv3D(nn.Module):
    """
    
    A helper module that performs a 3D convolution -> ReLU -> 3D convolution -> ReLU
    """
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class Down3D(nn.Module):
    """
    Downscaling (via MaxPool3D) followed by a DoubleConv3D.
    """
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv = DoubleConv3D(in_ch, out_ch)

    def forward(self, x):
        x = self.pool(x)
        x = self.conv(x)
        return x


class Up3D(nn.Module):
    """
    Upscaling then a DoubleConv3D. We can do either:
      - Transposed Conv if bilinear=False
      - nn.Upsample (trilinear) if bilinear=True
    """
    def __init__(self, in_ch, out_ch, bilinear=True):
        super().__init__()
        self.bilinear = bilinear
        # If using bilinear upsampling, we keep the number of parameters lower:
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        else:
            # Deconvolution / Transposed Conv
            self.up = nn.ConvTranspose3d(in_ch // 2, in_ch // 2, kernel_size=2, stride=2)

        # A DoubleConv3D after concatenation
        self.conv = DoubleConv3D(in_ch, out_ch)

    def forward(self, x1, x2):
        # x1 is the "decoder" feature map, x2 is the "skip connection" from encoder
        x1 = self.up(x1)

        # We may need to pad x1 to match x2's size
        diffZ = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2,
                        diffZ // 2, diffZ - diffZ // 2])

        # Concatenate along channel dimension
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class UNet3D(nn.Module):
    """
    A 3D UNet with 4 down-sampling / up-sampling levels.
    """
    def __init__(self, in_channels, out_channels, base_filters=32, bilinear=True):
        super().__init__()
        self.bilinear = bilinear

        # Encoder
        self.inc = DoubleConv3D(in_channels, base_filters)
        self.down1 = Down3D(base_filters, base_filters * 2)
        self.down2 = Down3D(base_filters * 2, base_filters * 4)
        self.down3 = Down3D(base_filters * 4, base_filters * 8)
        # If bilinear upsampling, we can reduce the number of filters by a factor of 2 in the bottleneck
        factor = 2 if bilinear else 1
        self.down4 = Down3D(base_filters * 8, base_filters * 16 // factor)

        # Decoder
        self.up1 = Up3D(base_filters * 16, base_filters * 8 // factor, bilinear)
        self.up2 = Up3D(base_filters * 8, base_filters * 4 // factor, bilinear)
        self.up3 = Up3D(base_filters * 4, base_filters * 2 // factor, bilinear)
        self.up4 = Up3D(base_filters * 2, base_filters, bilinear)

        # Final 1x1 convolution
        self.outc = nn.Conv3d(base_filters, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder path
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        # Decoder path
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


In [14]:
def random_crop_3d(img, lbl, crop_size=(64, 128, 128)):
    """
    Randomly crop a 3D patch of size crop_size from img and lbl.
    Assume img, lbl shape: (Z, Y, X).
    """
    z, y, x = img.shape
    cz, cy, cx = crop_size
    
    # pick random start
    z0 = np.random.randint(0, z - cz) if z > cz else 0
    y0 = np.random.randint(0, y - cy) if y > cy else 0
    x0 = np.random.randint(0, x - cx) if x > cx else 0
    
    img_patch = img[z0:z0+cz, y0:y0+cy, x0:x0+cx]
    lbl_patch = lbl[z0:z0+cz, y0:y0+cy, x0:x0+cx]
    return img_patch, lbl_patch


In [15]:
import os
import numpy as np
import nibabel as nib
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from glob import glob

In [16]:
def main():
    # Modify ROOT_DIR to your own directory containing imagesTr/labelsTr
    images = sorted(glob(os.path.join(ROOT_DIR, "imagesTr", "*.nii.gz")))
    labels = sorted(glob(os.path.join(ROOT_DIR, "labelsTr", "*.nii.gz")))

    # Instantiate dataset & dataloader
    train_dataset = HeartDataset(image_paths=images, label_paths=labels)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)

    # Create model
    model = UNet3D(in_channels=1, out_channels=2, base_filters=32, bilinear=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Define optimizer & loss
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    # --- Mixed Precision Components ---
    scaler = GradScaler()  # handles dynamic loss scaling for FP16

    epochs = 2  # Example small number
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for step, (images, labels) in enumerate(train_loader):
            # images shape: (B, 1, Z, Y, X)
            # labels shape: (B, 1, Z, Y, X)
            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)

            # For nn.CrossEntropyLoss, labels should be shape (B, Z, Y, X)
            labels = labels.squeeze(1)  # remove the channel dim => (B, Z, Y, X)

            optimizer.zero_grad()

            # Use autocast for forward + loss computation
            with autocast():
                outputs = model(images)  # shape: (B, 2, Z, Y, X)
                loss = criterion(outputs, labels)

            # Backprop with scaled gradients
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

            if step % 5 == 0:
                print(f"[Epoch {epoch+1}, Step {step}] Loss: {loss.item():.4f}")

        epoch_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1} finished. Average loss: {epoch_loss:.4f}")


if __name__ == "__main__":
    main()

  scaler = GradScaler()  # handles dynamic loss scaling for FP16
  with autocast():


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.93 GiB. GPU 0 has a total capacity of 7.75 GiB of which 2.41 GiB is free. Including non-PyTorch memory, this process has 5.23 GiB memory in use. Of the allocated memory 4.96 GiB is allocated by PyTorch, and 133.42 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
def show_3d_slices(img_3d, title="", figsize=(12, 4)):
    """
    Show 3 slices from a 3D volume:
      1) A slice in the axial plane (z fixed)
      2) A slice in the coronal plane (y fixed)
      3) A slice in the sagittal plane (x fixed)
    img_3d: np.ndarray with shape (Z, Y, X)
    title: optional string for figure title
    """
    z, y, x = img_3d.shape

    # pick the middle indices in each dimension
    z_mid = z // 2
    y_mid = y // 2
    x_mid = x // 2

    fig, axes = plt.subplots(1, 3, figsize=figsize)

    # Axial plane: fix z index
    axes[0].imshow(img_3d[z_mid, :, :], cmap="gray")
    axes[0].set_title(f"Axial (z={z_mid})")

    # Coronal plane: fix y index
    # coronal slice is shape [z, x], so we do img_3d[:, y_mid, :]
    axes[1].imshow(img_3d[:, y_mid, :], cmap="gray")
    axes[1].set_title(f"Coronal (y={y_mid})")

    # Sagittal plane: fix x index
    # sagittal slice is shape [z, y], so we do img_3d[:, :, x_mid]
    # We might transpose so that it doesn't appear rotated, but that's optional
    axes[2].imshow(img_3d[:, :, x_mid].T, cmap="gray", origin="lower")
    axes[2].set_title(f"Sagittal (x={x_mid})")

    fig.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Let's assume you already have:
#   dataset = HeartDataset(...)       # your custom dataset
#   OR a DataLoader: loader = DataLoader(dataset, ...)
# For simplicity, we'll just pull an item directly from the dataset.

# 1) Get the first sample from the dataset
img_tensor, lbl_tensor = dataset[0]  # shape: (1, Z, Y, X) each

# 2) Convert to NumPy (removing the channel dimension)
img_3d = img_tensor.squeeze(0).numpy()  # now shape = (Z, Y, X)

# We'll pick the middle slice (axial view) along Z
z_mid = img_3d.shape[0] // 2
slice_2d = img_3d[z_mid]  # shape = (Y, X)

# 3) Window/Level using percentile-based clipping to improve contrast
p1, p99 = np.percentile(slice_2d, [1, 99])         # 1st & 99th percentile
slice_clipped = np.clip(slice_2d, p1, p99)         # clamp intensities
slice_normalized = (slice_clipped - p1) / (p99 - p1)  # scale to [0..1]

# 4) Display this slice
plt.figure(figsize=(6,6))
plt.imshow(slice_normalized, cmap='gray')
plt.title(f"Middle Slice (z={z_mid}) with Percentile Windowing")
plt.axis('off')
plt.show()


In [None]:
import os
import numpy as np
import nibabel as nib
import torch
from torch.utils.data import Dataset
from scipy.ndimage import zoom

# ------------ Utility functions ------------
def load_nifti(path):
    nifti = nib.load(path)
    data = nifti.get_fdata(dtype=np.float32)
    affine = nifti.affine
    header = nifti.header
    return data, affine, header

def get_spacing_from_affine(affine):
    sx = np.sqrt(affine[0, 0]**2 + affine[0, 1]**2 + affine[0, 2]**2)
    sy = np.sqrt(affine[1, 0]**2 + affine[1, 1]**2 + affine[1, 2]**2)
    sz = np.sqrt(affine[2, 0]**2 + affine[2, 1]**2 + affine[2, 2]**2)
    return np.array([sz, sy, sx])

def resample_volume(volume, current_spacing, new_spacing, is_label=False, order=1):
    if is_label:
        order = 0  # nearest-neighbor for labels
    zoom_factors = current_spacing / new_spacing
    return zoom(volume, zoom=zoom_factors, order=order)

def min_max_scale_intensity(img, min_val=-57, max_val=164, clip=True):
    """
    Scale intensity to [0, 1] range, assuming raw intensities
    are roughly in [min_val, max_val]. Optionally clip to [0..1].
    """
    img = (img - min_val) / (max_val - min_val)
    if clip:
        img = np.clip(img, 0.0, 1.0)
    return img

# -------------- Dataset class --------------
class HeartAugCompareDataset(Dataset):
    """
    Returns both the pre-augmentation version and post-augmentation version.
    For each sample, you get (img_pre, lbl_pre, img_aug, lbl_aug).
    """
    def __init__(
        self,
        image_paths,
        label_paths,
        target_spacing=np.array([1.25, 1.25, 1.25]),
        do_augment=True
    ):
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.target_spacing = target_spacing
        self.do_augment = do_augment

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # 1) Load image and label
        img_path = self.image_paths[idx]
        lbl_path = self.label_paths[idx]
        img, aff_img, _ = load_nifti(img_path)
        lbl, aff_lbl, _ = load_nifti(lbl_path)
        
        # 2) Resample
        spacing_img = get_spacing_from_affine(aff_img)
        spacing_lbl = get_spacing_from_affine(aff_lbl)
        
        img_rs = resample_volume(img, spacing_img, self.target_spacing, is_label=False)
        lbl_rs = resample_volume(lbl, spacing_lbl, self.target_spacing, is_label=True)

        # 3) Intensity normalization (for the image)
        img_rs = img_rs.astype(np.float32)  # ensure float32
        img_norm = min_max_scale_intensity(img_rs, min_val=-57, max_val=164, clip=True)
        
        # ---- Save a copy *before* augmentation ----
        img_pre = img_norm.copy()
        lbl_pre = lbl_rs.copy()

        # 4) (Optional) augmentation
        img_aug = img_norm
        lbl_aug = lbl_rs

        if self.do_augment:
            # Random flip
            if np.random.rand() < 0.5:
                axis = np.random.choice([0,1,2])
                img_aug = np.flip(img_aug, axis=axis).copy()
                lbl_aug = np.flip(lbl_aug, axis=axis).copy()

            # Random rotate 90
            if np.random.rand() < 0.5:
                k = np.random.randint(1, 4)
                ax_pairs = [(0,1), (1,2), (0,2)]
                axes = ax_pairs[np.random.randint(len(ax_pairs))]
                img_aug = np.rot90(img_aug, k=k, axes=axes).copy()
                lbl_aug = np.rot90(lbl_aug, k=k, axes=axes).copy()

            # Random zoom
            if np.random.rand() < 0.2:
                zf = np.random.uniform(0.9, 1.1, size=3)
                img_aug = zoom(img_aug, zoom=zf, order=1)
                lbl_aug = zoom(lbl_aug, zoom=zf, order=0)

        # 5) Convert everything to torch tensors
        # shape (1, Z, Y, X)
        img_pre_tensor = torch.from_numpy(img_pre).unsqueeze(0).float()
        lbl_pre_tensor = torch.from_numpy(lbl_pre).unsqueeze(0).long()

        img_aug_tensor = torch.from_numpy(img_aug).unsqueeze(0).float()
        lbl_aug_tensor = torch.from_numpy(lbl_aug).unsqueeze(0).long()

        # Return the "before" + "after" versions
        return (img_pre_tensor, lbl_pre_tensor, img_aug_tensor, lbl_aug_tensor)


In [None]:
from glob import glob
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

root_dir = "/home/raghuram/ARPL/MR-Image-Reconstruction-Using-Deep-Learning/Task02_Heart"
images = sorted(glob(root_dir + "/imagesTr/*.nii.gz"))
labels = sorted(glob(root_dir + "/labelsTr/*.nii.gz"))

# Create the dataset
dataset = HeartAugCompareDataset(
    image_paths=images,
    label_paths=labels,
    target_spacing=np.array([1.25, 1.25, 1.25]),
    do_augment=True  # set True if you want random flips/rotations
)

loader = DataLoader(dataset, batch_size=1, shuffle=False)

# Fetch one batch
img_pre, lbl_pre, img_aug, lbl_aug = next(iter(loader))
# shapes: (1, 1, Z, Y, X) each

# Convert to numpy, remove the batch & channel dimension
img_pre_3d = img_pre[0,0].cpu().numpy()  # shape (Z, Y, X)
img_aug_3d = img_aug[0,0].cpu().numpy()  # shape (Z, Y, X)

# Let's pick the middle slice along z
z_mid = img_pre_3d.shape[0] // 2
slice_pre = img_pre_3d[z_mid]
slice_aug = img_aug_3d[z_mid]

# Plot side-by-side
plt.figure(figsize=(10,5))

plt.subplot(1,2,1)
plt.imshow(slice_pre, cmap='gray')
plt.title("Before Augmentation")
plt.axis('off')

plt.subplot(1,2,2)
plt.imshow(slice_aug, cmap='gray')
plt.title("After Augmentation")
plt.axis('off')

plt.tight_layout()
plt.show()
