# Axon IA: Data Preparation

This notebook demonstrates how to prepare data for use with Axon IA. We'll cover:

1. Loading medical images (NIfTI format)
2. Preprocessing data
3. Creating the expected directory structure
4. Visualizing the data

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import shutil

# Add parent directory to path for imports
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
from axon_ia.utils.nifti_utils import load_nifti, save_nifti
from axon_ia.utils.visualization import plot_slices
from axon_ia.data.preprocessing import (
    resample_to_spacing,
    normalize_intensity,
    crop_foreground,
    standardize_orientation
)

## 1. Load Sample Data

First, let's load some sample medical images. For this example, we'll create synthetic data.

In [None]:
# Create a synthetic volume with a spherical "lesion"
def create_synthetic_volume(size=(128, 128, 64)):
    # Create background
    volume = np.random.normal(100, 20, size).astype(np.float32)
    
    # Create a spherical region with higher intensity
    x, y, z = np.ogrid[:size[0], :size[1], :size[2]]
    center_x, center_y, center_z = size[0]//2, size[1]//2, size[2]//2
    sphere = ((x - center_x)**2 + (y - center_y)**2 + ((z - center_z))**2) <= (size[0]//8)**2
    volume[sphere] += 50
    
    # Create mask
    mask = sphere.astype(np.float32)
    
    return volume, mask

# Create a volume and mask
volume, mask = create_synthetic_volume()

# Visualize
fig = plot_slices(volume, mask=mask, n_slices=3)
plt.show()

## 2. Preprocess Data

Next, let's apply common preprocessing steps:

In [None]:
# 1. Resample to isotropic spacing
original_spacing = (1.0, 1.0, 2.0)  # Example: Z has lower resolution
target_spacing = (1.0, 1.0, 1.0)  # Target: isotropic 1mm

print(f"Original volume shape: {volume.shape}")

resampled_volume = resample_to_spacing(
    volume, 
    original_spacing=original_spacing,
    target_spacing=target_spacing,
    interpolation='linear'
)

resampled_mask = resample_to_spacing(
    mask, 
    original_spacing=original_spacing,
    target_spacing=target_spacing,
    interpolation='nearest',
    is_mask=True
)

print(f"Resampled volume shape: {resampled_volume.shape}")

# Visualize resampled volume
fig = plot_slices(resampled_volume, mask=resampled_mask, n_slices=3)
plt.show()

In [None]:
# 2. Normalize intensity
normalized_volume = normalize_intensity(
    resampled_volume,
    mode='z_score',
    mask=resampled_mask
)

print(f"Normalized volume stats - Mean: {normalized_volume.mean():.4f}, Std: {normalized_volume.std():.4f}")

# Visualize normalized volume
fig = plot_slices(normalized_volume, mask=resampled_mask, n_slices=3)
plt.show()

In [None]:
# 3. Crop foreground to focus on the region of interest
cropped_volume, cropped_mask, crop_indices = crop_foreground(
    normalized_volume,
    resampled_mask,
    margin=10
)

print(f"Cropped volume shape: {cropped_volume.shape}")

# Visualize cropped volume
fig = plot_slices(cropped_volume, mask=cropped_mask, n_slices=3)
plt.show()

## 3. Create Directory Structure

Now, let's create the expected directory structure for Axon IA:

In [None]:
# Create directory structure
data_dir = Path("../data_example")

# Create splits
for split in ["train", "val", "test"]:
    (data_dir / split).mkdir(parents=True, exist_ok=True)

# We'll create one sample patient for the train split
patient_dir = data_dir / "train" / "patient_001"
patient_dir.mkdir(exist_ok=True)

# Save preprocessed volume as multiple modalities (simulated)
# In a real scenario, you would have actual multi-modal data

# FLAIR (main contrast)
save_nifti(cropped_volume, patient_dir / "flair.nii.gz")

# T1 (slightly different contrast - simulated)
t1_volume = cropped_volume * 0.8 + np.random.normal(0, 0.1, cropped_volume.shape)
save_nifti(t1_volume, patient_dir / "t1.nii.gz")

# T2 (slightly different contrast - simulated)
t2_volume = cropped_volume * 1.2 + np.random.normal(0, 0.1, cropped_volume.shape)
save_nifti(t2_volume, patient_dir / "t2.nii.gz")

# DWI (slightly different contrast - simulated)
dwi_volume = cropped_volume * 0.7 + np.random.normal(0, 0.15, cropped_volume.shape)
save_nifti(dwi_volume, patient_dir / "dwi.nii.gz")

# Save mask
save_nifti(cropped_mask, patient_dir / "mask.nii.gz")

print(f"Data saved to {data_dir}")

## 4. Load and Visualize the Prepared Data

Finally, let's load and visualize the data from the created directory structure:

In [None]:
# Load all modalities
modalities = ["flair", "t1", "t2", "dwi"]
loaded_volumes = {}

for modality in modalities:
    filepath = patient_dir / f"{modality}.nii.gz"
    loaded_volumes[modality] = load_nifti(filepath)

# Load mask
loaded_mask = load_nifti(patient_dir / "mask.nii.gz")

# Visualize each modality
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

for i, (modality, volume) in enumerate(loaded_volumes.items()):
    # Get middle slice
    slice_idx = volume.shape[2] // 2
    img_slice = volume[:, :, slice_idx]
    mask_slice = loaded_mask[:, :, slice_idx]
    
    # Plot
    axes[i].imshow(img_slice, cmap='gray')
    axes[i].imshow(mask_slice, alpha=0.3, cmap='red')
    axes[i].set_title(modality.upper())
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## 5. Use the AxonDataset Class

Now let's use the `AxonDataset` class to load our prepared data:

In [None]:
from axon_ia.data.dataset import AxonDataset
from axon_ia.data.transforms import get_train_transform

# Create dataset
dataset = AxonDataset(
    root_dir=data_dir,
    split="train",
    modalities=modalities,
    target="mask",
    transform=get_train_transform()
)

print(f"Dataset length: {len(dataset)}")

# Get a sample
sample = dataset[0]

# Print info
print(f"Sample ID: {sample['sample_id']}")
print(f"Image shape: {sample['image'].shape}")
print(f"Mask shape: {sample['mask'].shape}")

# Visualize sample
from axon_ia.utils.visualization import plot_multiple_slices

# Convert PyTorch tensor to numpy
image = sample['image'].numpy()  # Shape: (C, D, H, W)
mask = sample['mask'].numpy()    # Shape: (1, D, H, W)

# Plot all channels
fig, axes = plt.subplots(len(modalities), 3, figsize=(15, 12))

for i, modality in enumerate(modalities):
    for j, slice_idx in enumerate([image.shape[1]//4, image.shape[1]//2, 3*image.shape[1]//4]):
        axes[i, j].imshow(image[i, slice_idx], cmap='gray')
        axes[i, j].imshow(mask[0, slice_idx], alpha=0.3, cmap='red')
        axes[i, j].set_title(f"{modality.upper()} - Slice {slice_idx}")
        axes[i, j].axis('off')

plt.tight_layout()
plt.show()