# 4-Channel MR Latent Dataset Creation

This notebook creates a dataset using all 4 channels from the MR latent representation for richer MR→CT synthesis.

In [None]:
import numpy as np
import SimpleITK as sitk
import os
import pickle
from pathlib import Path

print("📦 Libraries imported successfully")
print(f"📍 Current directory: {os.getcwd()}")

## 📊 Load and Inspect Original Data

In [None]:
# Paths to your original data
mr_latent_path = "1ABA009_latent_mr.nii"
ct_path = "CT_normalised_latent.mha"

print("🔍 Checking file existence:")
print(f"MR latent file: {os.path.exists(mr_latent_path)} - {mr_latent_path}")
print(f"CT file: {os.path.exists(ct_path)} - {ct_path}")

if not os.path.exists(mr_latent_path):
    print("❌ MR latent file not found. Please check the path.")
if not os.path.exists(ct_path):
    print("❌ CT file not found. Please check the path.")

In [None]:
# Load the MR latent data (4D with 4 channels)
mr_sitk = sitk.ReadImage(mr_latent_path)
mr_array = sitk.GetArrayFromImage(mr_sitk)

# Load the CT data 
ct_sitk = sitk.ReadImage(ct_path)
ct_array = sitk.GetArrayFromImage(ct_sitk)

print("📊 Data shapes:")
print(f"MR latent shape: {mr_array.shape}")
print(f"CT shape: {ct_array.shape}")

# Check if MR has channels
if len(mr_array.shape) == 4:
    print(f"✅ MR data has {mr_array.shape[-1]} channels - Perfect for 4-channel processing!")
    print(f"   Spatial dimensions: {mr_array.shape[:-1]}")
elif len(mr_array.shape) == 3:
    print(f"⚠️ MR data is 3D - might need to add channel dimension")

print(f"📏 MR data range: [{np.min(mr_array):.3f}, {np.max(mr_array):.3f}]")
print(f"📏 CT data range: [{np.min(ct_array):.3f}, {np.max(ct_array):.3f}]")

## 🧩 Create 4-Channel Dataset

In [None]:
# Create output directory
output_dir = "3D/datasets/all_channels"
os.makedirs(output_dir, exist_ok=True)

print(f"📁 Created output directory: {output_dir}")

# Define crop parameters (you can adjust these)
# Using the full available data
if len(mr_array.shape) == 4:
    z_size, y_size, x_size, channels = mr_array.shape
    crop_bounds = [0, z_size, 0, y_size, 0, x_size]
else:
    z_size, y_size, x_size = mr_array.shape
    crop_bounds = [0, z_size, 0, y_size, 0, x_size]
    channels = 1

print(f"📐 Data dimensions: {z_size}×{y_size}×{x_size} with {channels} channels")
print(f"📋 Crop bounds: {crop_bounds}")

In [None]:
# Save 4-channel MR data
mr_filename = "1HNA001_mr_all_channels.npz"
mr_output_path = os.path.join(output_dir, mr_filename)

# Ensure we have 4D data (D, H, W, C)
if len(mr_array.shape) == 4:
    # Already has channels
    mr_data_to_save = mr_array
    print(f"💾 Saving 4-channel MR data: {mr_array.shape}")
elif len(mr_array.shape) == 3:
    # Add channel dimension - but this means we don't have 4 channels
    mr_data_to_save = mr_array[..., np.newaxis]
    print(f"⚠️ Only 3D data available, adding single channel: {mr_data_to_save.shape}")

np.savez_compressed(mr_output_path, data=mr_data_to_save)
print(f"✅ Saved MR data to: {mr_output_path}")

# Save CT data
ct_filename = "1HNA001_ct.npz"
ct_output_path = os.path.join(output_dir, ct_filename)

np.savez_compressed(ct_output_path, data=ct_array)
print(f"✅ Saved CT data to: {ct_output_path}")

# Verify saved data
mr_loaded = np.load(mr_output_path)['data']
ct_loaded = np.load(ct_output_path)['data']

print(f"\n🔍 Verification:")
print(f"MR saved shape: {mr_loaded.shape}")
print(f"CT saved shape: {ct_loaded.shape}")
print(f"MR channels available: {mr_loaded.shape[-1] if len(mr_loaded.shape) == 4 else 1}")

## 📋 Create Dataset Metadata

In [None]:
# Create the crops.pkl file for the dataset
sample_data = {
    'name': '1HNA001',
    'mr_path': mr_filename,
    'ct_path': ct_filename,
    'bounds': crop_bounds,
    'channels': 'all_4_channels',
    'channel_count': mr_loaded.shape[-1] if len(mr_loaded.shape) == 4 else 1
}

samples_list = [sample_data]

# Save the metadata
crops_pkl_path = os.path.join(output_dir, "crops.pkl")
with open(crops_pkl_path, 'wb') as f:
    pickle.dump(samples_list, f)

print(f"✅ Saved dataset metadata to: {crops_pkl_path}")
print(f"📊 Sample data: {sample_data}")

# Verify the pickle file
with open(crops_pkl_path, 'rb') as f:
    loaded_samples = pickle.load(f)
    
print(f"\n🔍 Verification - loaded {len(loaded_samples)} samples:")
for i, sample in enumerate(loaded_samples):
    print(f"Sample {i}: {sample}")