In [None]:
import os
import io
import glob
import datetime
import zipfile
import random
import requests
import numpy as np
import nibabel as nib
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torchio as tio
import torch
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

In [None]:
def load_nifti(file_path):
    """
    Load a NIfTI file and normalize it to [0, 1].
    """
    img = nib.load(file_path)
    data = img.get_fdata()
    data = data / 4095.0  # Normalization
    return data.astype(np.float32)

def extract_all_slices(volume, resize_to=(256, 256)):
    """
    Extracts all 2D slices from axial, coronal, and sagittal planes.
    Skips slices with only zero values. Resizes to (256, 256).
    """
    slices = []

    # Axial: slices along the Z-axis
    for i in range(volume.shape[2]):
        s = volume[:, :, i]
        if np.max(s) != 0:
            s = tf.image.resize(s[..., np.newaxis], resize_to).numpy()
            slices.append(s)

    # Coronal: slices along the Y-axis
    for i in range(volume.shape[1]):
        s = volume[:, i, :]
        if np.max(s) != 0:
            s = tf.image.resize(s[..., np.newaxis], resize_to).numpy()
            slices.append(s)

    # Sagittal: slices along the X-axis
    for i in range(volume.shape[0]):
        s = volume[i, :, :]
        if np.max(s) != 0:
            s = tf.image.resize(s[..., np.newaxis], resize_to).numpy()
            slices.append(s)

    return np.array(slices)

def load_and_extract_slices(synthetic_files, ground_truth_files, resize_to=(256, 256)):
    synthetic_slices = []
    ground_truth_slices = []

    for synthetic_path, gt_path in zip(synthetic_files, ground_truth_files):
        print(f"Processing: {synthetic_path} and {gt_path}")
        syn_vol = load_nifti(synthetic_path)
        gt_vol = load_nifti(gt_path)

        syn_slices = extract_all_slices(syn_vol, resize_to)
        gt_slices = extract_all_slices(gt_vol, resize_to)

        # Ensure both have the same number of slices
        min_len = min(len(syn_slices), len(gt_slices))
        synthetic_slices.append(syn_slices[:min_len])
        ground_truth_slices.append(gt_slices[:min_len])

    synthetic_slices = np.concatenate(synthetic_slices, axis=0)
    ground_truth_slices = np.concatenate(ground_truth_slices, axis=0)

    return synthetic_slices, ground_truth_slices

# === File Paths ===
synthetic_files = [
    "/notebooks/data/ALLT1w/sub-01_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-02_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-03_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-04_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-05_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-06_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-07_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-08_ses-1_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-09_ses-1_T1w_defaced_registered.nii.gz"
]

ground_truth_files = [
    "/notebooks/data/ALLT1w/sub-01_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-02_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-03_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-04_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-05_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-06_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-07_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-08_ses-2_T1w_defaced_registered.nii.gz",
    "/notebooks/data/ALLT1w/sub-09_ses-2_T1w_defaced_registered.nii.gz"
]

# === Load and Extract ===
input_slices, target_slices = load_and_extract_slices(
    synthetic_files, ground_truth_files, resize_to=(256, 256)
)

print("Input shape:", input_slices.shape)
print("Target shape:", target_slices.shape)

# === Split into Train/Test ===
train_input, test_input, train_target, test_target = train_test_split(
    input_slices, target_slices, test_size=0.2, random_state=42
)

# === Convert to TensorFlow Datasets ===
train_dataset = tf.data.Dataset.from_tensor_slices((train_input, train_target)).batch(1).shuffle(100)
test_dataset = tf.data.Dataset.from_tensor_slices((test_input, test_target)).batch(1)

# === Visual Check ===
for input_img, target_img in train_dataset.take(1):
    print("Input slice shape:", input_img.shape)
    print("Target slice shape:", target_img.shape)

In [None]:
OUTPUT_CHANNELS = 1

# Define a ResNet block
def resnet_block(filters, kernel_size=3, strides=1):
    initializer = tf.random_normal_initializer(0., 0.02)

    def block(x):
        skip = x
        x = tf.keras.layers.Conv2D(filters, kernel_size, strides=strides, padding='same',
                                   kernel_initializer=initializer, use_bias=False)(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)

        x = tf.keras.layers.Conv2D(filters, kernel_size, strides=strides, padding='same',
                                   kernel_initializer=initializer, use_bias=False)(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Add()([x, skip])  # Residual connection
        x = tf.keras.layers.ReLU()(x)

        return x

    return block

# Increase model depth with additional downsampling layers
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                                      kernel_initializer=initializer, use_bias=False))
    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())
    result.add(tf.keras.layers.LeakyReLU())
    return result

def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same',
                                               kernel_initializer=initializer, use_bias=False))
    result.add(tf.keras.layers.BatchNormalization())
    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))
    result.add(tf.keras.layers.ReLU())
    return result

# Updated Generator with deeper layers and ResNet blocks
def Generator():
    inputs = tf.keras.layers.Input(shape=[256, 256, 1])

    # Downsampling with added depth
    down_stack = [
        downsample(64, 4, apply_batchnorm=False),  # (256 -> 128)
        downsample(128, 4),  # (128 -> 64)
        downsample(256, 4),  # (64 -> 32)
        downsample(512, 4),  # (32 -> 16)
        downsample(512, 4),  # (16 -> 8)
        downsample(512, 4),  # (8 -> 4)
        downsample(512, 4),  # (4 -> 2)
        downsample(512, 4),  # (2 -> 1)
    ]

    # Adding ResNet blocks for deeper feature learning
    resnet_blocks = [
        resnet_block(512),
        resnet_block(512),
        resnet_block(512),
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True),  # (1 -> 2)
        upsample(512, 4, apply_dropout=True),  # (2 -> 4)
        upsample(512, 4, apply_dropout=True),  # (4 -> 8)
        upsample(512, 4),  # (8 -> 16)
        upsample(256, 4),  # (16 -> 32)
        upsample(128, 4),  # (32 -> 64)
        upsample(64, 4),   # (64 -> 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4, strides=2, padding='same',
                                           kernel_initializer=initializer, activation='tanh')
    x = inputs

    # Apply downsampling layers
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    # Apply ResNet blocks
    for res_block in resnet_blocks:
        x = res_block(x)

    skips = reversed(skips[:-1])

    # Apply upsampling with skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    inp = tf.keras.layers.Input(shape=[256, 256, 1], name='input_image')
    tar = tf.keras.layers.Input(shape=[256, 256, 1], name='target_image')
    x = tf.keras.layers.concatenate([inp, tar])  # Concatenate the input and target
    down1 = downsample(64, 4, False)(x)  # (batch_size, 128, 128, 64)
    down2 = downsample(128, 4)(down1)  # (batch_size, 64, 64, 128)
    down3 = downsample(256, 4)(down2)  # (batch_size, 32, 32, 256)
    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (batch_size, 34, 34, 256)
    conv = tf.keras.layers.Conv2D(512, 4, strides=1, kernel_initializer=initializer,
                                  use_bias=False)(zero_pad1)  # (batch_size, 31, 31, 512)
    batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
    leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (batch_size, 33, 33, 512)
    last = tf.keras.layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(zero_pad2)  # (batch_size, 30, 30, 1)
    return tf.keras.Model(inputs=[inp, tar], outputs=last)

# Define losses and optimizers
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
LAMBDA = 100

def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    total_gen_loss = gan_loss + (LAMBDA * l1_loss)
    return total_gen_loss, gan_loss, l1_loss

def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
    total_disc_loss = real_loss + generated_loss
    return total_disc_loss

generator = Generator()
discriminator = Discriminator()

generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

# Training step function (no print statements here)
@tf.function
def train_step(input_image, target):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)
        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)

        # Calculate losses
        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    # Apply gradients
    generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))

    return gen_total_loss, disc_loss

# Training loop
def fit(train_ds, steps):
    for step, (input_image, target) in train_ds.repeat().take(steps).enumerate():
        gen_total_loss, disc_loss = train_step(input_image, target)

        # Print progress every 100 steps
        if step % 100 == 0:
            print(f"Step {step}, Generator Loss: {gen_total_loss:.4f}, Discriminator Loss: {disc_loss:.4f}")

        # Save checkpoint and print progress every 1,000 steps
        if (step + 1) % 10000 == 0:
            print(f"Checkpoint saved at step {step + 1}")
            checkpoint.save(file_prefix=checkpoint_prefix)

# Generate some images after training
def generate_images(model, test_input, tar):
    prediction = model(test_input, training=True)
    plt.figure(figsize=(15, 15))
    display_list = [test_input[0], tar[0], prediction[0]]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()

In [None]:
# Function to display a single image with matplotlib
def display_image(image, title=None):
    plt.figure()
    plt.imshow((image + 1) / 2)  # Rescale from [-1, 1] to [0, 1] if normalized
    if title:
        plt.title(title)
    plt.axis('off')
    plt.show()

# Display images from train_dataset
def display_samples_from_dataset(train_dataset, num_samples=20):
    for i, (input_image, target_image) in enumerate(train_dataset.take(num_samples)):
        print(f"Displaying sample {i + 1}")
        # Display input (T1w) and target (T2w) images side by side
        plt.figure(figsize=(10, 5))

        # Rescale the images back to [0, 1] if they were normalized to [-1, 1]
        input_image = (input_image[0] + 1) / 2
        target_image = (target_image[0] + 1) / 2

        # Display input image
        plt.subplot(1, 2, 1)
        plt.imshow(input_image)
        plt.title("Input Image (T1w)")
        plt.axis('off')

        # Display target image
        plt.subplot(1, 2, 2)
        plt.imshow(target_image)
        plt.title("Target Image (T2w)")
        plt.axis('off')

        plt.show()

# Display 5 random samples from the dataset
display_samples_from_dataset(train_dataset, num_samples=10)

In [None]:
# Start training
fit(train_dataset, steps=100000)

In [None]:
# Shuffle the dataset and select 5 random samples
shuffled_dataset = test_dataset.shuffle(buffer_size=len(test_dataset))

# Take 5 random samples from the shuffled dataset
for i, (inp, tar) in enumerate(shuffled_dataset.take(5)):
    generate_images(generator, inp, tar)

In [None]:
import numpy as np
import nibabel as nib
import tensorflow as tf
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

def compute_metrics_all_axes(gt_volume, pred_volume):
    """
    Compute PSNR, SSIM, and NMSE along all 3 slice orientations.
    """
    assert gt_volume.shape == pred_volume.shape, "Volumes must have the same shape"

    results = {}

    for axis in range(3):  # 0 = sagittal, 1 = coronal, 2 = axial
        psnr_list, ssim_list, nmse_list = [], [], []

        num_slices = gt_volume.shape[axis]

        for i in range(num_slices):
            if axis == 0:
                gt_slice = gt_volume[i, :, :]
                pred_slice = pred_volume[i, :, :]
            elif axis == 1:
                gt_slice = gt_volume[:, i, :]
                pred_slice = pred_volume[:, i, :]
            else:
                gt_slice = gt_volume[:, :, i]
                pred_slice = pred_volume[:, :, i]

            if np.max(gt_slice) == 0:
                continue

            mse = np.mean((gt_slice - pred_slice) ** 2)
            norm_factor = np.mean(gt_slice ** 2)
            nmse_val = mse / norm_factor if norm_factor > 0 else 0
            psnr_val = psnr(gt_slice, pred_slice, data_range=4095)
            ssim_val = ssim(gt_slice, pred_slice, data_range=4095, gaussian_weights=True)

            psnr_list.append(psnr_val)
            ssim_list.append(ssim_val)
            nmse_list.append(nmse_val)

        results[f"axis_{axis}"] = {
            "psnr": np.mean(psnr_list),
            "ssim": np.mean(ssim_list),
            "nmse": np.mean(nmse_list)
        }

    return results

def process_nifti_with_pix2pix(input_nifti_path, ground_truth_nifti_path, output_nifti_path, generator):
    """
    Process a NIfTI volume using a Pix2Pix generator slice-by-slice (axial).
    """
    nifti_img = nib.load(input_nifti_path)
    volume = nifti_img.get_fdata()
    volume = (volume / 4095.0).astype(np.float32)

    gt_nifti = nib.load(ground_truth_nifti_path)
    gt_volume = gt_nifti.get_fdata()

    reconstructed = np.zeros_like(volume)

    for i in range(volume.shape[2]):  # Axial slices
        slice_2d = volume[:, :, i]

        if np.max(slice_2d) == 0:
            continue

        original_shape = slice_2d.shape
        slice_norm = tf.convert_to_tensor(slice_2d[..., np.newaxis])  # (H, W, 1)
        slice_resized = tf.image.resize(slice_norm, (256, 256))
        slice_input = tf.expand_dims(slice_resized, axis=0)  # (1, 256, 256, 1)

        # Predict (no rgb_to_grayscale needed)
        refined_slice = generator(slice_input, training=True)[0]  # (256, 256, 1)
        refined_slice = tf.clip_by_value(refined_slice, 0.0, 1.0)

        # Resize back to original shape
        refined_slice_resized = tf.image.resize(refined_slice, original_shape)
        reconstructed[:, :, i] = refined_slice_resized[..., 0].numpy() * 4095  # Denormalize

    # Compute metrics
    metrics = compute_metrics_all_axes(gt_volume, reconstructed)
    for axis, vals in metrics.items():
        print(f"{axis.upper()} — PSNR: {vals['psnr']:.2f} dB | SSIM: {vals['ssim']:.4f} | NMSE: {vals['nmse']:.6f}")


    # Save reconstructed volume as NIfTI
    out_nifti = nib.Nifti1Image(reconstructed, nifti_img.affine)
    nib.save(out_nifti, output_nifti_path)
    print(f"Saved Pix2Pix output: {output_nifti_path}")

# Example usage
input_nifti_path = "/notebooks/data/ALLT1w/sub-10_ses-1_T1w_defaced_registered.nii.gz"
ground_truth_nifti_path = "/notebooks/data/ALLT1w/sub-10_ses-2_T1w_defaced_registered.nii.gz"
output_nifti_path = "pix2pix_output_sub10.nii.gz"

process_nifti_with_pix2pix(input_nifti_path, ground_truth_nifti_path, output_nifti_path, generator)