In [None]:
import os
import io
import glob
import datetime
import zipfile
import random
import requests
import numpy as np
import nibabel as nib
import tensorflow as tf
import matplotlib.pyplot as plt
import torch
import torchio as tio
from scipy.ndimage import gaussian_filter
from sklearn.model_selection import train_test_split
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

In [None]:
def load_nifti(file_path):
    """
    Load a NIfTI file and normalize it to [0,1] using min-max normalization.
    """
    img = nib.load(file_path)
    data = img.get_fdata()
    min_val = np.min(data)
    max_val = np.max(data)
    data = (data - min_val) / (max_val - min_val)  # Min-max normalization
    return data.astype(np.float32)

def extract_patches_corrected(volume, patch_size=(64, 64, 64), stride=32):
    """
    Extracts 3D patches from the given MRI volume while ensuring correct shape.
    """
    volume = np.expand_dims(volume, axis=0)  # (1, H, W, D) for TorchIO
    image = tio.ScalarImage(tensor=volume)
    subject = tio.Subject(mri=image)

    patch_sampler = tio.GridSampler(
        subject,
        patch_size=patch_size,
        patch_overlap=(stride, stride, stride)
    )

    patches = np.array([patch['mri'][tio.DATA].numpy() for patch in patch_sampler])

    # Fix shape: Remove the first dimension (1, H, W, D) → (H, W, D)
    patches = np.squeeze(patches, axis=1)  # Removes the extra dimension

    # Ensure final shape is (N, 64, 64, 64, 1) - with channel dimension
    patches = np.expand_dims(patches, axis=-1)

    return patches

def save_patches_to_memmap(file_path, patches, total_patches, patch_size):
    """
    Saves extracted patches to a memory-mapped file and appends patches.
    """
    num_patches = patches.shape[0]
    shape = (total_patches + num_patches,) + patch_size  # New shape with total patches

    # Open the existing memmap file or create a new one if it doesn't exist
    if not os.path.exists(file_path):
        # Create a new memmap file with the full expected size
        memmap_array = np.memmap(file_path, dtype=np.float32, mode='w+', shape=shape)
    else:
        # Open the existing file for appending
        memmap_array = np.memmap(file_path, dtype=np.float32, mode='r+', shape=shape)

    # Remove the channel dimension when saving
    memmap_array[total_patches:total_patches + num_patches] = patches.squeeze(axis=-1)  # Remove the last dimension (64,64,64)
    del memmap_array  # Ensure data is saved to disk

    return total_patches + num_patches  # Return the updated number of patches

def load_and_extract_patches_memmap(synthetic_files, ground_truth_files, unpaired_7T_files,
                                    patch_size=(64, 64, 64), stride=32, save_dir="/notebooks/memmap"):
    """
    Loads, extracts patches, and saves them in memory-mapped format for efficient loading.
    """
    os.makedirs(save_dir, exist_ok=True)

    synthetic_memmap_path = os.path.join(save_dir, "synthetic_patches.dat")
    ground_truth_memmap_path = os.path.join(save_dir, "ground_truth_patches.dat")
    unpaired_7T_memmap_path = os.path.join(save_dir, "unpaired_7T_patches.dat")

    # Expected number of patches
    synthetic_num_patches = 5103
    ground_truth_num_patches = 5103
    unpaired_7T_num_patches = 11340

    total_synthetic_patches = 0
    total_ground_truth_patches = 0
    total_unpaired_7T_patches = 0

    # Initialize memmap files with expected sizes
    for memmap_path, num_patches in zip([synthetic_memmap_path, ground_truth_memmap_path, unpaired_7T_memmap_path],
                                        [synthetic_num_patches, ground_truth_num_patches, unpaired_7T_num_patches]):
        if not os.path.exists(memmap_path):
            # Create an empty memmap file with the expected number of patches
            memmap_array = np.memmap(memmap_path, dtype=np.float32, mode='w+', shape=(num_patches,) + patch_size)
            del memmap_array  # Close the file after creation

    # Process synthetic and ground truth pairs
    for synthetic_path, ground_truth_path in zip(synthetic_files, ground_truth_files):
        print(f"Processing: {synthetic_path} and {ground_truth_path}")

        synthetic_volume = load_nifti(synthetic_path)
        ground_truth_volume = load_nifti(ground_truth_path)

        synthetic_patches = extract_patches_corrected(synthetic_volume, patch_size, stride)
        ground_truth_patches = extract_patches_corrected(ground_truth_volume, patch_size, stride)

        # Save to memory-mapped files and count patches
        total_synthetic_patches = save_patches_to_memmap(synthetic_memmap_path, synthetic_patches, total_synthetic_patches, patch_size)
        total_ground_truth_patches = save_patches_to_memmap(ground_truth_memmap_path, ground_truth_patches, total_ground_truth_patches, patch_size)

    # Process unpaired 7T images
    for unpaired_7T_path in unpaired_7T_files:
        print(f"Processing unpaired 7T: {unpaired_7T_path}")

        unpaired_7T_volume = load_nifti(unpaired_7T_path)
        unpaired_7T_patches = extract_patches_corrected(unpaired_7T_volume, patch_size, stride)

        # Save to memory-mapped file and count patches
        total_unpaired_7T_patches = save_patches_to_memmap(unpaired_7T_memmap_path, unpaired_7T_patches, total_unpaired_7T_patches, patch_size)

    # Print total number of patches saved to each memmap file
    print(f"Total synthetic patches saved: {total_synthetic_patches}")
    print(f"Total ground truth patches saved: {total_ground_truth_patches}")
    print(f"Total unpaired 7T patches saved: {total_unpaired_7T_patches}")

    return synthetic_memmap_path, ground_truth_memmap_path, unpaired_7T_memmap_path

# File paths
synthetic_files = [
    "/notebooks/data/ALLT1w/sub-01_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-03_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-04_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-05_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-06_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-07_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-08_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-09_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-10_ses-1_T1w_defaced_registered.nii.gz"
]

ground_truth_files = [
    "/notebooks/data/ALLT1w/sub-01_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-03_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-04_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-05_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-06_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-07_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-08_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-09_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-10_ses-2_T1w_defaced_registered.nii.gz"
]

folder_path = "/notebooks/pre_training_full"
unpaired_7T_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, file))]

# Extract patches and save to memmap
patch_size = (64, 64, 64)
stride = 32

synthetic_memmap, ground_truth_memmap, unpaired_7T_memmap = load_and_extract_patches_memmap(
    synthetic_files, ground_truth_files, unpaired_7T_files, patch_size, stride
)

# Print memmap file paths
print(f"Memmap synthetic patches saved at: {synthetic_memmap}")
print(f"Memmap ground truth patches saved at: {ground_truth_memmap}")
print(f"Memmap unpaired 7T patches saved at: {unpaired_7T_memmap}")

In [None]:
# Number of patches expected
synthetic_num_patches = 5103
ground_truth_num_patches = 5103
unpaired_7T_num_patches = 5670

# Define patch size
patch_size = (64, 64, 64)

# Memory-mapped paths
synthetic_memmap_path = "/notebooks/memmap/synthetic_patches.dat"
ground_truth_memmap_path = "/notebooks/memmap/ground_truth_patches.dat"
unpaired_7T_memmap_path = "/notebooks/memmap/unpaired_7T_patches.dat"

# Load memmap arrays
synthetic_patches_memmap = np.memmap(synthetic_memmap_path, dtype=np.float32, mode='r', shape=(synthetic_num_patches,) + patch_size)
ground_truth_patches_memmap = np.memmap(ground_truth_memmap_path, dtype=np.float32, mode='r', shape=(ground_truth_num_patches,) + patch_size)
unpaired_7T_patches_memmap = np.memmap(unpaired_7T_memmap_path, dtype=np.float32, mode='r', shape=(unpaired_7T_num_patches,) + patch_size)

# Manually split the paired dataset into training and validation sets
validation_size = 0.1  # 10% for validation
validation_count = int(synthetic_num_patches * validation_size)

# Train/Validation split indices
train_indices = slice(validation_count, synthetic_num_patches)
val_indices = slice(0, validation_count)

# Create TensorFlow datasets
train_paired = tf.data.Dataset.from_tensor_slices((
    synthetic_patches_memmap[train_indices],
    ground_truth_patches_memmap[train_indices]
))

train_unpaired_7T = tf.data.Dataset.from_tensor_slices(
    unpaired_7T_patches_memmap[:synthetic_num_patches - validation_count]  # Adjust length to match paired dataset
)

val_dataset = tf.data.Dataset.from_tensor_slices((
    synthetic_patches_memmap[val_indices],
    ground_truth_patches_memmap[val_indices]
))

# Combine paired and unpaired data into one training dataset
train_dataset = tf.data.Dataset.zip((train_paired, train_unpaired_7T))

# Shuffle and batch datasets
batch_size = 8
train_dataset = train_dataset.shuffle(buffer_size=1000).batch(batch_size)
val_dataset = val_dataset.batch(batch_size)

In [None]:
def residual_block(x, filters):
    """ Residual Block to enhance fine details in feature maps. """
    res = layers.Conv3D(filters, (3, 3, 3), padding='same')(x)
    res = layers.PReLU()(res)
    res = layers.Conv3D(filters, (3, 3, 3), padding='same')(res)

    # Ensure `x` has the same number of channels as `res`
    if x.shape[-1] != filters:
        x = layers.Conv3D(filters, (1, 1, 1), padding='same', activation='linear')(x)

    return layers.Add()([x, res])  # Residual connection to refine features

def multi_scale_fusion(x, filters):
    """ Multi-scale feature extraction to retain details at different receptive fields. """
    s1 = layers.Conv3D(filters, (1, 1, 1), padding='same')(x)
    s3 = layers.Conv3D(filters, (3, 3, 3), padding='same')(x)
    s5 = layers.Conv3D(filters, (5, 5, 5), padding='same')(x)
    return layers.concatenate([s1, s3, s5], axis=-1)

def build_optimized_unet_autoencoder(input_shape=(64, 64, 64, 1)):
    inputs = layers.Input(shape=input_shape)

    # Encoder
    c1 = layers.Conv3D(32, (5, 5, 5), padding='same')(inputs)
    c1 = layers.PReLU()(c1)
    c1 = layers.Conv3D(32, (3, 3, 3), padding='same')(c1)
    c1 = layers.PReLU()(c1)
    c1 = layers.LayerNormalization()(c1)
    p1 = layers.MaxPooling3D((2, 2, 2))(c1)  # 32x32x32

    c2 = layers.Conv3D(64, (3, 3, 3), padding='same')(p1)
    c2 = layers.PReLU()(c2)
    c2 = layers.Conv3D(64, (3, 3, 3), padding='same')(c2)
    c2 = layers.PReLU()(c2)
    c2 = layers.LayerNormalization()(c2)
    p2 = layers.MaxPooling3D((2, 2, 2))(c2)  # 16x16x16

    c3 = layers.Conv3D(128, (3, 3, 3), padding='same')(p2)
    c3 = layers.PReLU()(c3)
    c3 = layers.Conv3D(128, (3, 3, 3), padding='same')(c3)
    c3 = layers.PReLU()(c3)
    c3 = layers.LayerNormalization()(c3)
    p3 = layers.MaxPooling3D((2, 2, 2))(c3)  # 8x8x8

    c4 = layers.Conv3D(256, (3, 3, 3), padding='same')(p3)
    c4 = layers.PReLU()(c4)
    c4 = layers.Conv3D(256, (3, 3, 3), padding='same')(c4)
    c4 = layers.PReLU()(c4)
    c4 = layers.LayerNormalization()(c4)

    # Decoder with Residual Blocks and Multi-Scale Fusion
    u3 = layers.Conv3DTranspose(128, (3, 3, 3), strides=(2, 2, 2), padding='same')(c4)  # 16x16x16
    u3 = layers.concatenate([u3, c3])  # Now both are (16,16,16,128)
    u3 = residual_block(u3, 128)
    u3 = multi_scale_fusion(u3, 128)
    u3 = layers.LayerNormalization()(u3)

    u2 = layers.Conv3DTranspose(64, (3, 3, 3), strides=(2, 2, 2), padding='same')(u3)  # 32x32x32
    u2 = layers.concatenate([u2, c2])  # Now both are (32,32,32,64)
    u2 = residual_block(u2, 64)
    u2 = multi_scale_fusion(u2, 64)
    u2 = layers.LayerNormalization()(u2)

    u1 = layers.Conv3DTranspose(32, (3, 3, 3), strides=(2, 2, 2), padding='same')(u2)  # 64x64x64
    u1 = layers.concatenate([u1, c1])  # Now both are (64,64,64,32)
    u1 = residual_block(u1, 32)
    u1 = multi_scale_fusion(u1, 32)
    u1 = layers.LayerNormalization()(u1)

    outputs = layers.Conv3D(1, (1, 1, 1), activation='sigmoid', padding='same')(u1)  # Sigmoid for better contrast

    model = Model(inputs, outputs)
    return model

# Hybrid Loss Function with Higher SSIM Weight
def hybrid_loss(y_true, y_pred):
    mse_loss = tf.keras.losses.MeanSquaredError()(y_true, y_pred)
    ssim_component = 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, max_val=1.0))
    return mse_loss + 0.7 * ssim_component  # Increased SSIM weight for better structure preservation


# Build Model
autoencoder = build_optimized_unet_autoencoder()
autoencoder.compile(optimizer=Adam(learning_rate=2e-5), loss=hybrid_loss, metrics=['mse'])

# Summary
autoencoder.summary()

In [None]:
# Function to display a slice from a 3D patch
def show_slice(image, slice_index=None, title="Image Slice"):
    image = image.numpy()

    if slice_index is None:
        slice_index = image.shape[2] // 2  # middle slice by default

    plt.imshow(image[:, :, slice_index], cmap='gray')
    plt.title(f"{title} (slice {slice_index})")
    plt.axis('off')
    plt.show()

# Visualize examples from training dataset
for (paired_batch, unpaired_batch) in train_dataset.take(1):
    x_3T_batch, y_7T_batch = paired_batch

    # Display first sample from paired data (synthetic and ground truth)
    show_slice(x_3T_batch[0], title="Synthetic (3T) Patch")
    show_slice(y_7T_batch[0], title="Ground Truth (7T) Patch")

    # Display first sample from unpaired data
    show_slice(unpaired_batch[0], title="Unpaired 7T Patch")

# Visualize examples from validation dataset
for (x_val, y_val) in val_dataset.take(1):
    show_slice(x_val[0], title="Validation Synthetic (3T)")
    show_slice(y_val[0], title="Validation Ground Truth (7T)")

In [None]:
# Hybrid loss function (Mean Squared Error + SSIM for paired data)
def hybrid_loss(y_true, y_pred):
    mse_loss = tf.keras.losses.MeanSquaredError()(y_true, y_pred)
    ssim_loss = 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, max_val=1.0))
    return mse_loss + 0.7 * ssim_loss  # Adjust SSIM weight for better structure preservation

# Consistency loss (used for unpaired 7T data)
def consistency_loss(y_true, y_pred):
    """
    Consistency loss for unpaired 7T MRI data (L1 + SSIM).
    """
    l1_loss = tf.reduce_mean(tf.abs(y_true - y_pred))
    ssim_loss = 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, max_val=1.0))
    return l1_loss + 0.1 * ssim_loss  # Adjust SSIM weight if needed

# Build the model
autoencoder = build_optimized_unet_autoencoder()  # Assuming this is defined earlier

# Compile the model (using hybrid loss for supervised learning)
autoencoder.compile(optimizer=Adam(learning_rate=2e-5), loss=hybrid_loss, metrics=['mse'])

class SemiSupervisedAutoencoder(tf.keras.Model):
    def __init__(self, autoencoder, lambda_consistency=0.5):
        super().__init__()
        self.autoencoder = autoencoder
        self.lambda_consistency = lambda_consistency

        # Metrics
        self.total_loss_tracker = tf.keras.metrics.Mean(name="loss")
        self.supervised_loss_tracker = tf.keras.metrics.Mean(name="supervised_loss")
        self.consistency_loss_tracker = tf.keras.metrics.Mean(name="consistency_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.supervised_loss_tracker,
            self.consistency_loss_tracker,
        ]

    def train_step(self, data):
        ((x_3T, y_7T), unpaired_7T) = data

        x_3T = tf.expand_dims(x_3T, axis=-1)
        y_7T = tf.expand_dims(y_7T, axis=-1)
        unpaired_7T = tf.expand_dims(unpaired_7T, axis=-1)

        with tf.GradientTape() as tape:
            y_pred = self.autoencoder(x_3T, training=True)
            supervised_loss = hybrid_loss(y_7T, y_pred)

            unpaired_pred = self.autoencoder(unpaired_7T, training=True)
            consistency_loss_value = consistency_loss(unpaired_7T, unpaired_pred)

            total_loss = supervised_loss + self.lambda_consistency * consistency_loss_value

        # Compute gradients explicitly
        grads = tape.gradient(total_loss, self.autoencoder.trainable_weights)

        # Verify gradients are not None or zero
        grad_norm = tf.linalg.global_norm(grads)

        # Apply gradients
        self.optimizer.apply_gradients(zip(grads, self.autoencoder.trainable_weights))

        # Update trackers
        self.total_loss_tracker.update_state(total_loss)
        self.supervised_loss_tracker.update_state(supervised_loss)
        self.consistency_loss_tracker.update_state(consistency_loss_value)

        return {
            "loss": self.total_loss_tracker.result(),
            "supervised_loss": self.supervised_loss_tracker.result(),
            "consistency_loss": self.consistency_loss_tracker.result(),
        }

    def test_step(self, data):
        (x_3T, y_7T) = data
        x_3T = tf.expand_dims(x_3T, axis=-1)
        y_7T = tf.expand_dims(y_7T, axis=-1)
        y_pred = self.autoencoder(x_3T, training=False)
        supervised_loss = hybrid_loss(y_7T, y_pred)

        self.supervised_loss_tracker.update_state(supervised_loss)
        return {
            "supervised_loss": self.supervised_loss_tracker.result()
        }

# Define Callbacks
checkpoint = ModelCheckpoint("best_7T.keras", save_best_only=True, monitor="val_supervised_loss", verbose=1)
early_stopping = EarlyStopping(monitor="val_supervised_loss", patience=10, restore_best_weights=True)

# Initialize the custom semi-supervised autoencoder model
semi_supervised_autoencoder = SemiSupervisedAutoencoder(autoencoder, lambda_consistency=0.5)

# Compile the model
semi_supervised_autoencoder.compile(optimizer=Adam(learning_rate=2e-5), loss=hybrid_loss, metrics=['mse'])

# Train the model using semi-supervised learning
semi_supervised_autoencoder.fit(
    train_dataset,
    epochs=30,
    validation_data=val_dataset,
    callbacks=[checkpoint, early_stopping]
)

In [None]:
# Predict and visualize one patch
testt = tf.expand_dims(synthetic_patches_memmap, axis=-1)
predicted_patches = semi_supervised_autoencoder.predict(testt[:150])

plt.subplot(1, 2, 1)
plt.imshow(testt[100, :, :, :, 0][32], cmap='gray')
plt.title("Synthetic Patch (Slice 32)")

plt.subplot(1, 3, 3)
plt.imshow(predicted_patches[100, :, :, :, 0][32], cmap='gray')
plt.title("Predicted Patch (Slice 32)")

In [None]:
def compute_psnr_ssim_nmse(ground_truth_volume, refined_volume):
    """
    Computes PSNR, SSIM, and NMSE for 3D MRI volumes slice-by-slice and returns their mean.
    """
    # Ensure both volumes have the same shape
    assert ground_truth_volume.shape == refined_volume.shape, "Volume shapes do not match!"

    for j in range(3):

      psnr_values = []
      ssim_values = []
      nmse_values = []

      num_slices = ground_truth_volume.shape[j]  # Assuming slices along the Z-axis

      for i in range(num_slices):

          if j == 0:
            gt_slice = ground_truth_volume[i, :, :]
            pred_slice = refined_volume[i, :, :]

          if j == 1:
            gt_slice = ground_truth_volume[:, i, :]
            pred_slice = refined_volume[:, i, :]

          if j == 2:
            gt_slice = ground_truth_volume[:, :, i]
            pred_slice = refined_volume[:, :, i]

          # Compute MSE
          mse = np.mean((gt_slice - pred_slice) ** 2)

          # Compute PSNR (handle zero-MSE case)
          psnr_value = psnr(gt_slice, pred_slice, data_range=4095)

          # Compute SSIM
          ssim_value = ssim(gt_slice, pred_slice, data_range=4095, gaussian_weights=True)

          # Compute NMSE (Normalized Mean Squared Error)
          norm_factor = np.mean(gt_slice ** 2)  # Normalize by the mean squared value of the ground truth
          nmse_value = mse / norm_factor if norm_factor > 0 else 0  # Avoid division by zero

          psnr_values.append(psnr_value)
          ssim_values.append(ssim_value)
          nmse_values.append(nmse_value)

      # Compute mean values across all slices, ignoring infinite PSNR values
      mean_psnr = np.mean([p for p in psnr_values if np.isfinite(p)]) if psnr_values else 0
      mean_ssim = np.mean(ssim_values) if ssim_values else 0
      mean_nmse = np.mean(nmse_values) if nmse_values else 0

      print(f'Axis {j}: Mean PSNR: {mean_psnr:.2f} dB, Mean SSIM: {mean_ssim:.4f}, Mean NMSE: {mean_nmse:.6f}')

    return mean_psnr, mean_ssim, mean_nmse

In [None]:
def extract_patches(volume, patch_size, stride):
    """
    Extract 3D patches from a volume, ensuring full coverage with padding at edges.
    """
    h, w, d = volume.shape
    ps_h, ps_w, ps_d = patch_size

    patches = []
    positions = []

    for i in range(0, h, stride):
        for j in range(0, w, stride):
            for k in range(0, d, stride):
                # Ensure patches at edges do not exceed image size
                i_end = min(i + ps_h, h)
                j_end = min(j + ps_w, w)
                k_end = min(k + ps_d, d)

                # Extract patch
                patch = volume[i:i_end, j:j_end, k:k_end]

                # If patch is smaller due to edge effects, pad it
                pad_h = ps_h - (i_end - i)
                pad_w = ps_w - (j_end - j)
                pad_d = ps_d - (k_end - k)

                patch = np.pad(patch, ((0, pad_h), (0, pad_w), (0, pad_d)), mode='reflect')

                patches.append(patch)
                positions.append((i, j, k))

    patches = np.array(patches)
    patches = np.expand_dims(patches, axis=-1)  # Add channel dimension
    return patches, positions

def gaussian_weight_map(shape, sigma=0.5):
    """Generate a Gaussian weighting map for smooth patch blending."""
    mask = np.ones(shape, dtype=np.float32)
    weights = gaussian_filter(mask, sigma=sigma)
    return weights / np.max(weights)  # Normalize

def reconstruct_volume(patches, positions, original_shape, patch_size):
    """
    Reconstruct a 3D volume from overlapping patches using Gaussian blending.
    """
    h, w, d = original_shape
    ps_h, ps_w, ps_d = patch_size

    reconstructed = np.zeros(original_shape, dtype=np.float32)
    counts = np.zeros(original_shape, dtype=np.float32)

    weight_map = gaussian_weight_map((ps_h, ps_w, ps_d))  # Generate smooth weight map

    for idx, (i, j, k) in enumerate(positions):
        i_end = min(i + ps_h, h)
        j_end = min(j + ps_w, w)
        k_end = min(k + ps_d, d)

        patch = patches[idx, ..., 0]  # Extract patch
        patch = patch[:i_end-i, :j_end-j, :k_end-k]  # Crop to match original size if needed
        weight = weight_map[:i_end-i, :j_end-j, :k_end-k]  # Crop weight map similarly

        reconstructed[i:i_end, j:j_end, k:k_end] += patch * weight
        counts[i:i_end, j:j_end, k:k_end] += weight

    counts[counts == 0] = 1  # Avoid division by zero
    return reconstructed / counts  # Normalize final volume

def process_nifti_with_autoencoder(input_nifti_path, ground_truth_nifti_path, output_nifti_path, autoencoder, patch_size=(64,64,64), stride=32):
    """
    Processes a single NIfTI file through the autoencoder and saves the reconstructed volume.
    """
    # Load NIfTI file
    nifti_img = nib.load(input_nifti_path)
    volume = nifti_img.get_fdata()

        # Load NIfTI file
    nifti_gt = nib.load(ground_truth_nifti_path)
    volume_gt = nifti_gt.get_fdata()


    # Normalize input to [-1,1]
    volume = (volume/4095)

    # Extract patches
    synthetic_patches, positions = extract_patches(volume, patch_size, stride)

    # Predict refined patches
    refined_patches = autoencoder.predict(synthetic_patches)

    # Reconstruct volume
    refined_volume = reconstruct_volume(refined_patches, positions, volume.shape, patch_size)

    # Denormalize input
    refined_volume = (refined_volume) * 4095

    # Example Usage:
    mean_psnr, mean_ssim, mean_nmse = compute_psnr_ssim_nmse(volume_gt, refined_volume)

    # Save as NIfTI
    refined_nifti = nib.Nifti1Image(refined_volume, nifti_img.affine)
    nib.save(refined_nifti, output_nifti_path)
    print(f"Saved refined MRI: {output_nifti_path}")

# Example Usage
input_nifti_path = "/notebooks/data/ALLT1w/sub-08_ses-1_T1w_defaced_registered.nii.gz"
ground_truth_nifti_path = "/notebooks/data/ALLT1w/sub-08_ses-2_T1w_defaced_registered.nii.gz"
output_nifti_path = "Test_sub08_SAE.nii.gz"

process_nifti_with_autoencoder(input_nifti_path, ground_truth_nifti_path, output_nifti_path, autoencoder, patch_size=(64,64,64), stride=32)