In [None]:
import os
import io
import glob
import datetime
import zipfile
import random
import requests
import numpy as np
import nibabel as nib
import tensorflow as tf
import matplotlib.pyplot as plt
import torch
import torchio as tio
from scipy.ndimage import gaussian_filter
from sklearn.model_selection import train_test_split
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

In [None]:
def load_nifti(file_path):
    """
    Load a NIfTI file and normalize it to [-1,1].
    """
    img = nib.load(file_path)
    data = img.get_fdata()
    data = (data/4095)
    return data.astype(np.float32)

def extract_patches_corrected(volume, patch_size=(64, 64, 64), stride=32):
    """
    Extracts 3D patches from the given MRI volume while ensuring correct shape.

    Args:
        volume (numpy array): Input MRI volume of shape (H, W, D).
        patch_size (tuple): Size of the 3D patch (default: 64x64x64).
        stride (int): Stride for patch extraction.

    Returns:
        numpy array: Extracted patches with shape (num_patches, 64, 64, 64, 1).
    """

    # Ensure the input volume shape is (H, W, D) and add channel dimension (C, H, W, D)
    volume = np.expand_dims(volume, axis=0)  # (1, H, W, D) for TorchIO

    # Convert the NumPy array to a TorchIO ScalarImage
    image = tio.ScalarImage(tensor=volume)

    # Create a subject containing the image
    subject = tio.Subject(mri=image)

    # Define the grid sampler for patch extraction
    patch_sampler = tio.GridSampler(
        subject,
        patch_size=patch_size,
        patch_overlap=(stride, stride, stride)
    )

    # Extract patches and convert to NumPy array
    patches = np.array([patch['mri'][tio.DATA].numpy() for patch in patch_sampler])

    # Remove extra singleton dimensions
    #patches = patches.squeeze(axis=(1, -2, -1))  # Remove unnecessary dimensions

    # Add the final channel dimension to match TensorFlow expectations
    patches = np.expand_dims(patches, axis=-1)

    return patches

def load_and_extract_patches(synthetic_files, ground_truth_files, patch_size=(64, 64, 64), stride=32):
    """
    Load NIfTI files and extract patches for synthetic and ground truth volumes.

    Args:
        synthetic_files (list): List of paths to synthetic NIfTI files.
        ground_truth_files (list): List of paths to ground truth NIfTI files.
        patch_size (tuple): Dimensions of the 3D patch.
        stride (int): Step size for sliding window.

    Returns:
        tuple: Synthetic and ground truth patches as NumPy arrays.
    """
    synthetic_patches = []
    ground_truth_patches = []

    for synthetic_path, ground_truth_path in zip(synthetic_files, ground_truth_files):
        print(f"Processing: {synthetic_path} and {ground_truth_path}")

        # Load and normalize volumes
        synthetic_volume = load_nifti(synthetic_path)
        ground_truth_volume = load_nifti(ground_truth_path)

        # Extract patches
        synthetic_patches.append(extract_patches_corrected(synthetic_volume, patch_size, stride))
        ground_truth_patches.append(extract_patches_corrected(ground_truth_volume, patch_size, stride))

    # Concatenate patches from all files
    synthetic_patches = np.concatenate(synthetic_patches, axis=0)
    ground_truth_patches = np.concatenate(ground_truth_patches, axis=0)

    return synthetic_patches, ground_truth_patches

synthetic_files = [
    "/notebooks/data/ALLT1w/sub-01_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-02_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-03_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-04_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-05_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-06_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-07_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-09_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-10_ses-1_T1w_defaced_registered.nii.gz"
]

ground_truth_files = [
    "/notebooks/data/ALLT1w/sub-01_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-02_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-03_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-04_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-05_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-06_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-07_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-09_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-10_ses-2_T1w_defaced_registered.nii.gz"
]

# Extract patches
patch_size = (64, 64, 64)
stride = 32

synthetic_patches, ground_truth_patches = load_and_extract_patches(
    synthetic_files, ground_truth_files, patch_size, stride
)

synthetic_patches = np.squeeze(synthetic_patches, axis=1)
ground_truth_patches = np.squeeze(ground_truth_patches, axis=1)

# Print the shapes of the extracted patches
print("Synthetic patches shape:", synthetic_patches.shape)
print("Ground truth patches shape:", ground_truth_patches.shape)

In [None]:
# Visualize one input and ground truth patch
plt.subplot(1, 2, 1)
plt.imshow(synthetic_patches[1000, :, :, :, 0][20], cmap='gray')
plt.title("Synthetic Patch (Slice 32)")

plt.subplot(1, 2, 2)
plt.imshow(ground_truth_patches[1000, :, :, :, 0][20], cmap='gray')
plt.title("Ground Truth Patch (Slice 32)")

plt.show()

In [None]:
def residual_block(x, filters):
    """ Residual Block to enhance fine details in feature maps. """
    res = layers.Conv3D(filters, (3, 3, 3), padding='same')(x)
    res = layers.PReLU()(res)
    res = layers.Conv3D(filters, (3, 3, 3), padding='same')(res)

    # Ensure `x` has the same number of channels as `res`
    if x.shape[-1] != filters:
        x = layers.Conv3D(filters, (1, 1, 1), padding='same', activation='linear')(x)

    return layers.Add()([x, res])  # Residual connection to refine features

def multi_scale_fusion(x, filters):
    """ Multi-scale feature extraction to retain details at different receptive fields. """
    s1 = layers.Conv3D(filters, (1, 1, 1), padding='same')(x)
    s3 = layers.Conv3D(filters, (3, 3, 3), padding='same')(x)
    s5 = layers.Conv3D(filters, (5, 5, 5), padding='same')(x)
    return layers.concatenate([s1, s3, s5], axis=-1)

def build_optimized_unet_autoencoder(input_shape=(64, 64, 64, 1)):
    inputs = layers.Input(shape=input_shape)

    # Encoder
    c1 = layers.Conv3D(32, (5, 5, 5), padding='same')(inputs)
    c1 = layers.PReLU()(c1)
    c1 = layers.Conv3D(32, (3, 3, 3), padding='same')(c1)
    c1 = layers.PReLU()(c1)
    c1 = layers.LayerNormalization()(c1)
    p1 = layers.MaxPooling3D((2, 2, 2))(c1)  # 32x32x32

    c2 = layers.Conv3D(64, (3, 3, 3), padding='same')(p1)
    c2 = layers.PReLU()(c2)
    c2 = layers.Conv3D(64, (3, 3, 3), padding='same')(c2)
    c2 = layers.PReLU()(c2)
    c2 = layers.LayerNormalization()(c2)
    p2 = layers.MaxPooling3D((2, 2, 2))(c2)  # 16x16x16

    c3 = layers.Conv3D(128, (3, 3, 3), padding='same')(p2)
    c3 = layers.PReLU()(c3)
    c3 = layers.Conv3D(128, (3, 3, 3), padding='same')(c3)
    c3 = layers.PReLU()(c3)
    c3 = layers.LayerNormalization()(c3)
    p3 = layers.MaxPooling3D((2, 2, 2))(c3)  # 8x8x8

    c4 = layers.Conv3D(256, (3, 3, 3), padding='same')(p3)
    c4 = layers.PReLU()(c4)
    c4 = layers.Conv3D(256, (3, 3, 3), padding='same')(c4)
    c4 = layers.PReLU()(c4)
    c4 = layers.LayerNormalization()(c4)

    # Decoder with Residual Blocks and Multi-Scale Fusion
    u3 = layers.Conv3DTranspose(128, (3, 3, 3), strides=(2, 2, 2), padding='same')(c4)  # 16x16x16
    u3 = layers.concatenate([u3, c3])  # Now both are (16,16,16,128)
    u3 = residual_block(u3, 128)
    u3 = multi_scale_fusion(u3, 128)
    u3 = layers.LayerNormalization()(u3)

    u2 = layers.Conv3DTranspose(64, (3, 3, 3), strides=(2, 2, 2), padding='same')(u3)  # 32x32x32
    u2 = layers.concatenate([u2, c2])  # Now both are (32,32,32,64)
    u2 = residual_block(u2, 64)
    u2 = multi_scale_fusion(u2, 64)
    u2 = layers.LayerNormalization()(u2)

    u1 = layers.Conv3DTranspose(32, (3, 3, 3), strides=(2, 2, 2), padding='same')(u2)  # 64x64x64
    u1 = layers.concatenate([u1, c1])  # Now both are (64,64,64,32)
    u1 = residual_block(u1, 32)
    u1 = multi_scale_fusion(u1, 32)
    u1 = layers.LayerNormalization()(u1)

    outputs = layers.Conv3D(1, (1, 1, 1), activation='sigmoid', padding='same')(u1)  # Sigmoid for better contrast

    model = Model(inputs, outputs)
    return model

# Hybrid Loss Function with Higher SSIM Weight
def hybrid_loss(y_true, y_pred):
    mse_loss = tf.keras.losses.MeanSquaredError()(y_true, y_pred)
    ssim_component = 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, max_val=1.0))
    return mse_loss + 0.7 * ssim_component  # Increased SSIM weight for better structure preservation


# Build Model
autoencoder = build_optimized_unet_autoencoder()
autoencoder.compile(optimizer=Adam(learning_rate=2e-5), loss=hybrid_loss, metrics=['mse'])

# Summary
autoencoder.summary()

In [None]:
# Define callbacks
checkpoint = ModelCheckpoint("best_7T.keras", save_best_only=True, monitor="val_loss", verbose=1)
early_stopping = EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True)

autoencoder.compile(optimizer=Adam(learning_rate=2e-5), loss=hybrid_loss, metrics=['mse'])

# Train the model with fine-tuning
autoencoder.fit(
    synthetic_patches,  # 3T MRI patches
    ground_truth_patches,  # 7T MRI patches
    epochs=40,
    batch_size=8,
    validation_split=0.1,
    callbacks=[checkpoint, early_stopping]
)

In [None]:
# Predict and visualize one patch
testt = tf.expand_dims(synthetic_patches, axis=-1)
predicted_patches = autoencoder.predict(testt[:150])

plt.subplot(1, 3, 1)
plt.imshow(synthetic_patches[100, :, :, :, 0][32], cmap='gray')
plt.title("Synthetic Patch (Slice 32)")

plt.subplot(1, 3, 2)
plt.imshow(ground_truth_patches[100, :, :, :, 0][32], cmap='gray')
plt.title("Ground Truth Patch (Slice 32)")

plt.subplot(1, 3, 3)
plt.imshow(predicted_patches[100, :, :, :, 0][32], cmap='gray')
plt.title("Predicted Patch (Slice 32)")

plt.show()

In [None]:
def compute_psnr_ssim_nmse(ground_truth_volume, refined_volume):
    """
    Computes PSNR, SSIM, and NMSE for 3D MRI volumes slice-by-slice and returns their mean.
    """
    # Ensure both volumes have the same shape
    assert ground_truth_volume.shape == refined_volume.shape, "Volume shapes do not match!"

    for j in range(3):

      psnr_values = []
      ssim_values = []
      nmse_values = []

      num_slices = ground_truth_volume.shape[j]  # Assuming slices along the Z-axis

      for i in range(num_slices):

          if j == 0:
            gt_slice = ground_truth_volume[i, :, :]
            pred_slice = refined_volume[i, :, :]

          if j == 1:
            gt_slice = ground_truth_volume[:, i, :]
            pred_slice = refined_volume[:, i, :]

          if j == 2:
            gt_slice = ground_truth_volume[:, :, i]
            pred_slice = refined_volume[:, :, i]

          # Compute MSE
          mse = np.mean((gt_slice - pred_slice) ** 2)

          # Compute PSNR (handle zero-MSE case)
          psnr_value = psnr(gt_slice, pred_slice, data_range=4095)

          # Compute SSIM
          ssim_value = ssim(gt_slice, pred_slice, data_range=4095, gaussian_weights=True)

          # Compute NMSE (Normalized Mean Squared Error)
          norm_factor = np.mean(gt_slice ** 2)  # Normalize by the mean squared value of the ground truth
          nmse_value = mse / norm_factor if norm_factor > 0 else 0  # Avoid division by zero

          psnr_values.append(psnr_value)
          ssim_values.append(ssim_value)
          nmse_values.append(nmse_value)

      # Compute mean values across all slices, ignoring infinite PSNR values
      mean_psnr = np.mean([p for p in psnr_values if np.isfinite(p)]) if psnr_values else 0
      mean_ssim = np.mean(ssim_values) if ssim_values else 0
      mean_nmse = np.mean(nmse_values) if nmse_values else 0

      print(f'Axis {j}: Mean PSNR: {mean_psnr:.2f} dB, Mean SSIM: {mean_ssim:.4f}, Mean NMSE: {mean_nmse:.6f}')

    return mean_psnr, mean_ssim, mean_nmse

In [None]:
def extract_patches(volume, patch_size, stride):
    """
    Extract 3D patches from a volume, ensuring full coverage with padding at edges.
    """
    h, w, d = volume.shape
    ps_h, ps_w, ps_d = patch_size

    patches = []
    positions = []

    for i in range(0, h, stride):
        for j in range(0, w, stride):
            for k in range(0, d, stride):
                # Ensure patches at edges do not exceed image size
                i_end = min(i + ps_h, h)
                j_end = min(j + ps_w, w)
                k_end = min(k + ps_d, d)

                # Extract patch
                patch = volume[i:i_end, j:j_end, k:k_end]

                # If patch is smaller due to edge effects, pad it
                pad_h = ps_h - (i_end - i)
                pad_w = ps_w - (j_end - j)
                pad_d = ps_d - (k_end - k)

                patch = np.pad(patch, ((0, pad_h), (0, pad_w), (0, pad_d)), mode='reflect')

                patches.append(patch)
                positions.append((i, j, k))

    patches = np.array(patches)
    patches = np.expand_dims(patches, axis=-1)  # Add channel dimension
    return patches, positions

def gaussian_weight_map(shape, sigma=0.5):
    """Generate a Gaussian weighting map for smooth patch blending."""
    mask = np.ones(shape, dtype=np.float32)
    weights = gaussian_filter(mask, sigma=sigma)
    return weights / np.max(weights)  # Normalize

def reconstruct_volume(patches, positions, original_shape, patch_size):
    """
    Reconstruct a 3D volume from overlapping patches using Gaussian blending.
    """
    h, w, d = original_shape
    ps_h, ps_w, ps_d = patch_size

    reconstructed = np.zeros(original_shape, dtype=np.float32)
    counts = np.zeros(original_shape, dtype=np.float32)

    weight_map = gaussian_weight_map((ps_h, ps_w, ps_d))  # Generate smooth weight map

    for idx, (i, j, k) in enumerate(positions):
        i_end = min(i + ps_h, h)
        j_end = min(j + ps_w, w)
        k_end = min(k + ps_d, d)

        patch = patches[idx, ..., 0]  # Extract patch
        patch = patch[:i_end-i, :j_end-j, :k_end-k]  # Crop to match original size if needed
        weight = weight_map[:i_end-i, :j_end-j, :k_end-k]  # Crop weight map similarly

        reconstructed[i:i_end, j:j_end, k:k_end] += patch * weight
        counts[i:i_end, j:j_end, k:k_end] += weight

    counts[counts == 0] = 1  # Avoid division by zero
    return reconstructed / counts  # Normalize final volume

def process_nifti_with_autoencoder(input_nifti_path, ground_truth_nifti_path, output_nifti_path, autoencoder, patch_size=(64,64,64), stride=32):
    """
    Processes a single NIfTI file through the autoencoder and saves the reconstructed volume.
    """
    # Load NIfTI file
    nifti_img = nib.load(input_nifti_path)
    volume = nifti_img.get_fdata()

        # Load NIfTI file
    nifti_gt = nib.load(ground_truth_nifti_path)
    volume_gt = nifti_gt.get_fdata()


    # Normalize input to [-1,1]
    volume = (volume/4095)

    # Extract patches
    synthetic_patches, positions = extract_patches(volume, patch_size, stride)

    # Predict refined patches
    refined_patches = autoencoder.predict(synthetic_patches)

    # Reconstruct volume
    refined_volume = reconstruct_volume(refined_patches, positions, volume.shape, patch_size)

    # Denormalize input
    refined_volume = (refined_volume) * 4095

    # Example Usage:
    mean_psnr, mean_ssim, mean_nmse = compute_psnr_ssim_nmse(volume_gt, refined_volume)

    # Save as NIfTI
    refined_nifti = nib.Nifti1Image(refined_volume, nifti_img.affine)
    nib.save(refined_nifti, output_nifti_path)
    print(f"Saved refined MRI: {output_nifti_path}")

# Example Usage
input_nifti_path = "/notebooks/data/ALLT1w/sub-08_ses-1_T1w_defaced_registered.nii.gz"
ground_truth_nifti_path = "/notebooks/data/ALLT1w/sub-08_ses-2_T1w_defaced_registered.nii.gz"
output_nifti_path = "Test_sub08_SAE.nii.gz"

process_nifti_with_autoencoder(input_nifti_path, ground_truth_nifti_path, output_nifti_path, autoencoder, patch_size=(64,64,64), stride=32)