In [3]:
#!/usr/bin/env python3
"""
Advanced NIfTI to PNG Converter with GIF Animation, Cubic Resampling and Rotation

This script converts 3D NIfTI brain MRI images (.nii or .nii.gz) to:
1. A series of 2D PNG images in all three orientations (axial, sagittal, and coronal)
2. Composite views showing all three planes simultaneously
3. GIF animations that show slicing through each plane

The script includes:
- Resampling data to make it cubic (same dimensions in all axes) for synchronized animations
- Options to rotate slices in each orientation by 0, 90, 180, or 270 degrees

Requirements:
- nibabel (for reading NIfTI files)
- matplotlib (for image handling and saving)
- numpy (for array operations)
- pillow (for advanced image processing and GIF creation)
- scipy (for resampling with zoom)
"""

import os
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from matplotlib import cm
from PIL import Image, ImageOps
from scipy.ndimage import zoom

def normalize_data(data):
    """Normalize data to 0-1 range for better visualization."""
    data_min = np.min(data)
    data_max = np.max(data)
    
    # Check if the data has a range to avoid division by zero
    if data_max > data_min:
        return (data - data_min) / (data_max - data_min)
    else:
        return data

def rotate_slice(slice_data, rotation_angle):
    """
    Rotate a 2D slice by the specified angle.
    
    Args:
        slice_data (numpy.ndarray): 2D array representing an image slice
        rotation_angle (int): Rotation angle in degrees (0, 90, 180, or 270)
        
    Returns:
        numpy.ndarray: Rotated slice
    """
    # Convert angle to number of 90-degree rotations (0, 1, 2, or 3)
    k = (rotation_angle // 90) % 4
    
    if k == 0:
        return slice_data  # No rotation
    else:
        # Use numpy's rot90 function for efficient 90-degree rotations
        return np.rot90(slice_data, k=k)
    
def resample_to_cubic(data):
    """
    Resample data to make it cubic (same dimensions in all axes).
    Uses the maximum dimension as the target size.
    
    Args:
        data (numpy.ndarray): Input 3D volume
        
    Returns:
        numpy.ndarray: Resampled cubic volume
    """
    # Get the maximum dimension
    max_dim = max(data.shape)
    print(f"Resampling to cubic volume with dimensions: {max_dim}x{max_dim}x{max_dim}")
    
    # Calculate zoom factors for each dimension
    zoom_factors = [max_dim / dim for dim in data.shape]
    print(f"Zoom factors: {zoom_factors}")
    
    # Resample the data (order=1 for linear interpolation)
    # Use order=0 for nearest-neighbor if memory issues occur
    resampled_data = zoom(data, zoom_factors, order=1)
    
    print(f"Original shape: {data.shape}, Resampled shape: {resampled_data.shape}")
    return resampled_data
    
def create_gif_from_pngs(image_dir, gif_path, duration=100, loop=0, skip_factor=1):
    """
    Create a GIF animation from a directory of PNG images.
    
    Args:
        image_dir (str): Directory containing PNG images
        gif_path (str): Path to save the GIF
        duration (int): Duration of each frame in milliseconds
        loop (int): Number of loops (0 for infinite)
        skip_factor (int): Only use every nth image to reduce GIF size
    """
    # Make sure the output directory exists
    os.makedirs(os.path.dirname(gif_path), exist_ok=True)
    
    # Get all PNG files
    png_files = [f for f in os.listdir(image_dir) if f.endswith('.png')]
    png_files.sort()  # Sort files to ensure correct order
    
    # Use only every nth file if skip_factor > 1
    if skip_factor > 1:
        png_files = png_files[::skip_factor]
    
    if not png_files:
        print(f"No PNG files found in {image_dir}")
        return
    
    # Load images
    images = []
    for png_file in png_files:
        img_path = os.path.join(image_dir, png_file)
        img = Image.open(img_path)
        images.append(img)
    
    # Save as GIF
    if images:
        images[0].save(
            gif_path,
            save_all=True,
            append_images=images[1:],
            duration=duration,
            loop=loop,
            optimize=True  # Reduce file size
        )
        print(f"GIF created: {gif_path}")
    else:
        print(f"Failed to create GIF: {gif_path}")

def create_all_gifs(nifti_file, output_dir, views, skip_factor=2):
    """
    Create GIF animations for all views of a NIfTI file.
    
    Args:
        nifti_file (str): Path to the input NIfTI file
        output_dir (str): Directory where PNG images are saved
        views (list): List of views to generate GIFs for
        skip_factor (int): Use every nth image to reduce GIF size
    """
    gif_dir = os.path.join(output_dir, "gifs")
    os.makedirs(gif_dir, exist_ok=True)
    
    # Create a GIF for each view
    for view in views:
        view_dir = os.path.join(output_dir, view)
        if os.path.exists(view_dir) and os.listdir(view_dir):
            gif_path = os.path.join(gif_dir, f"{view}_slices.gif")
            create_gif_from_pngs(view_dir, gif_path, duration=100, skip_factor=skip_factor)
    
    # Create a GIF for composite views if they exist
    composite_dir = os.path.join(output_dir, "composite")
    if os.path.exists(composite_dir) and os.listdir(composite_dir):
        gif_path = os.path.join(gif_dir, "composite_slices.gif")
        create_gif_from_pngs(composite_dir, gif_path, duration=100, skip_factor=skip_factor)

def nifti_to_png(nifti_file, output_dir, views=None, rotations=None):
    """
    Convert a NIfTI file to a series of PNG images with rotation options.
    
    Args:
        nifti_file (str): Path to the input NIfTI file
        output_dir (str): Directory to save the output PNG images
        views (list): List of views to generate ('axial', 'sagittal', 'coronal')
        rotations (dict): Dictionary mapping view names to rotation angles (in degrees)
    """
    # Set default views if not specified
    if views is None:
        views = ['axial']
    
    # Set default rotations (no rotation) if not specified
    if rotations is None:
        rotations = {'axial': 0, 'sagittal': 0, 'coronal': 0}
    
    # Ensure all views have rotation values
    for view in views:
        if view not in rotations:
            rotations[view] = 0
            
    print(f"Using rotations: {rotations}")
    
    # Load NIfTI file
    print(f"Loading {nifti_file}...")
    img = nib.load(nifti_file)
    
    # Get data
    data = img.get_fdata()
    
    # Print some information about the data
    print(f"Image dimensions: {data.shape}")
    print(f"Data type: {data.dtype}")
    
    if len(data.shape) == 4: # Skip 4D images
        print(f"Skipping 4D image: {nifti_file}")
        return 
    
    # Resample the data to make it cubic
    data = resample_to_cubic(data)
    
    # Normalize data for better visualization
    data_normalized = normalize_data(data)

    # Process each view
    slices = {
        'axial': {'axis': 2, 'shape': (data.shape[0], data.shape[1])},
        'coronal': {'axis': 1, 'shape': (data.shape[0], data.shape[2])},
        'sagittal': {'axis': 0, 'shape': (data.shape[1], data.shape[2])}
    }
    
    # Reference slices for composite view (middle of each dimension)
    ref_slices = {
        'axial': data.shape[2] // 2,
        'coronal': data.shape[1] // 2,
        'sagittal': data.shape[0] // 2
    }
    
    # Extract and save slices for each requested view
    for view in views:
        axis = slices[view]['axis']
        n_slices = data.shape[axis]
        rotation_angle = rotations[view]
        
        for i in range(n_slices):
            # Extract slice based on the view
            if axis == 0:  # sagittal
                slice_data = data_normalized[i, :, :]
            elif axis == 1:  # coronal
                slice_data = data_normalized[:, i, :]
            else:  # axial
                slice_data = data_normalized[:, :, i]
                
            # Apply rotation if specified
            if rotation_angle != 0:
                slice_data = rotate_slice(slice_data, rotation_angle)
            
            # Create filename with padding for proper sorting
            # filepath = f"{output_dir}/{view}/slice_{i:04d}.png"
            # os.makedirs(os.path.dirname(filepath), exist_ok=True)
            
            # Save the slice as a PNG file
            plt.figure(figsize=(10, 10), dpi=100)
            plt.imshow(slice_data, cmap='gray')
            plt.axis('off')
            plt.tight_layout(pad=0)
            # plt.savefig(filepath, bbox_inches='tight', pad_inches=0)
            plt.close()
            
            # Print progress
            # if (i + 1) % 20 == 0 or i == 0 or i == n_slices - 1:
            #     print(f"Saved {view} slice {i+1}/{n_slices} (rotation: {rotation_angle}°) to {filepath}")
    
    # Create composite images
    print("Creating composite views...")
    
    # Determine how many composite images to create
    # We'll use the axial view as reference if available
    reference_view = 'axial' if 'axial' in views else views[0]
    reference_axis = slices[reference_view]['axis']
    n_composites = data.shape[reference_axis]
    
    for i in range(n_composites):
        # Get reference slice for the current position
        if reference_axis == 0:  # sagittal
            ref_slice_data = data_normalized[i, :, :]
            # Update reference positions for other views
            ref_slices['coronal'] = data.shape[1] // 2
            ref_slices['axial'] = data.shape[2] // 2
        elif reference_axis == 1:  # coronal
            ref_slice_data = data_normalized[:, i, :]
            # Update reference positions for other views
            ref_slices['sagittal'] = data.shape[0] // 2
            ref_slices['axial'] = data.shape[2] // 2
        else:  # axial
            ref_slice_data = data_normalized[:, :, i]
            # Update reference positions for other views
            ref_slices['sagittal'] = data.shape[0] // 2
            ref_slices['coronal'] = data.shape[1] // 2
        
        # Apply rotation to reference view
        if rotations[reference_view] != 0:
            ref_slice_data = rotate_slice(ref_slice_data, rotations[reference_view])
        
        # Create a composite image
        plt.figure(figsize=(15, 5))
        
        # Add the reference view
        plt.subplot(1, 3, 1)
        plt.imshow(ref_slice_data, cmap='gray')
        plt.title(f"{reference_view.capitalize()} - Slice {i}")
        plt.axis('off')
        
        # Add other views at the reference position
        plot_idx = 2
        for view in views:
            if view != reference_view:
                if view == 'axial':
                    slice_data = data_normalized[:, :, ref_slices['axial']]
                elif view == 'coronal':
                    slice_data = data_normalized[:, ref_slices['coronal'], :]
                else:  # sagittal
                    slice_data = data_normalized[ref_slices['sagittal'], :, :]
                
                # Apply rotation to this view
                if rotations[view] != 0:
                    slice_data = rotate_slice(slice_data, rotations[view])
                
                plt.subplot(1, 3, plot_idx)
                plt.imshow(slice_data, cmap='gray')
                plt.title(f"{view.capitalize()} - Reference")
                plt.axis('off')
                plot_idx += 1
        
        # Save the composite image
        # composite_filepath = f"{output_dir}/composite/slice_{i:04d}.png"
        # os.makedirs(os.path.dirname(composite_filepath), exist_ok=True)
        plt.tight_layout()
        # plt.savefig(composite_filepath, bbox_inches='tight')
        plt.close()
        
        # Print progress
        # if (i + 1) % 20 == 0 or i == 0 or i == n_composites - 1:
        #     print(f"Saved composite {i+1}/{n_composites} to {composite_filepath}")
    
    # Create GIF animations for all views
    print("Creating GIF animations...")
    create_all_gifs(nifti_file, output_dir, views, skip_factor=2)

def create_advanced_gif(nifti_file, output_dir, views, rotations=None):
    """
    Create an advanced multi-plane GIF animation.
    
    This function creates a more sophisticated GIF that shows slices
    moving through all three planes simultaneously.
    
    Args:
        nifti_file (str): Path to the input NIfTI file
        output_dir (str): Directory to save the output
        views (list): List of views to include
        rotations (dict): Dictionary mapping view names to rotation angles (in degrees)
    """
    print(f"Creating advanced multi-plane GIF for {nifti_file}...")
    
    # Set default rotations (no rotation) if not specified
    if rotations is None:
        rotations = {'axial': 0, 'sagittal': 0, 'coronal': 0}
    
    # Ensure all views have rotation values
    for view in views:
        if view not in rotations:
            rotations[view] = 0
    
    # Load NIfTI file
    img = nib.load(nifti_file)
    data = img.get_fdata()
    
    if len(data.shape) == 4:  # Skip 4D images
        print(f"Skipping 4D image: {nifti_file}")
        return
    
    # Resample the data to make it cubic
    data = resample_to_cubic(data)
    
    # Normalize data for better visualization
    data_normalized = normalize_data(data)
    
    # Create directory for frames
    frames_dir = f"{output_dir}/advanced_gif_frames"
    os.makedirs(frames_dir, exist_ok=True)
    
    # All dimensions are now the same since we resampled to a cube
    total_frames = data.shape[0]  # All dimensions should be equal now
    
    # Create frames with synchronized movement through all planes
    for i in range(total_frames):
        # Create a figure with three subplots (one for each plane)
        plt.figure(figsize=(15, 5))
        
        # For each view, get the corresponding slice (now perfectly synchronized)
        for idx, view in enumerate(views):
            plt.subplot(1, 3, idx + 1)
            
            # Extract the slice based on view - all should have the same number of slices now
            if view == 'axial':
                slice_data = data_normalized[:, :, i]
                plt.title(f"Axial - Slice {i+1}/{total_frames}")
            elif view == 'coronal':
                slice_data = data_normalized[:, i, :]
                plt.title(f"Coronal - Slice {i+1}/{total_frames}")
            elif view == 'sagittal':
                slice_data = data_normalized[i, :, :]
                plt.title(f"Sagittal - Slice {i+1}/{total_frames}")
            
            # Apply rotation if specified
            if rotations[view] != 0:
                slice_data = rotate_slice(slice_data, rotations[view])
            
            plt.imshow(slice_data, cmap='gray')
            plt.axis('off')
        
        # Save the frame
        frame_path = os.path.join(frames_dir, f"frame_{i:04d}.png")
        plt.tight_layout()
        plt.savefig(frame_path, bbox_inches='tight')
        plt.close()
        
        # Print progress
        if (i + 1) % 20 == 0 or i == 0 or i == total_frames - 1:
            print(f"Created frame {i+1}/{total_frames}")
    
    # Create the GIF from frames
    gif_path = f"{output_dir}/gifs/multiplane_animation.gif"
    os.makedirs(os.path.dirname(gif_path), exist_ok=True)
    
    # Create GIF from frames
    create_gif_from_pngs(frames_dir, gif_path, duration=80, skip_factor=1)
                
if __name__ == "__main__":
    
    # data = "BCBM-RadioGenomics_Images_Masks_Dec2024/BCBM-RadioGenomics-69-2"
    # data = ""
    # input_dir = f'./Experimental_data/image/{data}'
    # output_dir = f"./Experimental_data/image_png/{data}"
    
    data = ""
    input_dir = f'./save_img/nii/{data}'
    output_dir = f"./save_img/png/{data}"
    
    # Define views
    views = ['axial', 'sagittal', 'coronal']
    
    # Define rotations for each view (0, 90, 180, or 270 degrees)
    # Adjust these values based on your needs
    rotations = {
        'axial': 270,      # N degrees clockwise
        'sagittal': 90,  # N degrees clockwise
        'coronal': 90     # N degrees clockwise
    }
    
    # Convert NIfTI to PNG and create GIFs
    for i, f in enumerate(os.listdir(input_dir)):
        if "BCBM-RadioGenomics" in data and "mask" in f:
            continue
        
        input_file = os.path.join(input_dir, f)
        output_file = os.path.join(output_dir, f)
        if output_file.endswith('.gz'):
            output_file = output_file[:-3]
        if output_file.endswith('.nii'):
            output_file = output_file[:-4]
        print(f"\nProcessing file {i+1}: {f}")
        nifti_to_png(input_file, output_file, views=views, rotations=rotations)
        
        # Create advanced multi-plane GIF
        create_advanced_gif(input_file, output_file, views, rotations=rotations)
        
        if i>10:
            break


Processing file 1: stitched_volume.nii.gz
Using rotations: {'axial': 270, 'sagittal': 90, 'coronal': 90}
Loading ./save_img/nii/stitched_volume.nii.gz...
Image dimensions: (227, 272, 227)
Data type: float64
Resampling to cubic volume with dimensions: 272x272x272
Zoom factors: [1.198237885462555, 1.0, 1.198237885462555]
Original shape: (227, 272, 227), Resampled shape: (272, 272, 272)
Creating composite views...
Creating GIF animations...
Creating advanced multi-plane GIF for ./save_img/nii/stitched_volume.nii.gz...
Resampling to cubic volume with dimensions: 272x272x272
Zoom factors: [1.198237885462555, 1.0, 1.198237885462555]
Original shape: (227, 272, 227), Resampled shape: (272, 272, 272)
Created frame 1/272
Created frame 20/272
Created frame 40/272
Created frame 60/272
Created frame 80/272
Created frame 100/272
Created frame 120/272
Created frame 140/272
Created frame 160/272
Created frame 180/272
Created frame 200/272
Created frame 220/272
Created frame 240/272
Created frame 260/

In [2]:
import SimpleITK as sitk
import sys

# Get file path from command line or use default
file_path = "/home/brain-mri/TUMSyn/Experimental_data/image/test_HCPD_T1w.nii.gz"

# Read the image and print its size
img = sitk.ReadImage(file_path)
print(f"Image dimensions (width, height, depth): {img.GetSize()}")
print(f"Voxel spacing (mm): {img.GetSpacing()}")
print(f"Pixel type: {img.GetPixelIDTypeAsString()}")

Image dimensions (width, height, depth): (227, 272, 227)
Voxel spacing (mm): (0.800000011920929, 0.800000011920929, 0.800000011920929)
Pixel type: 32-bit float
