# Part 1: Data Loading and Preprocessing - Phase 1 Dataset

**Goal:** Load and inspect the BIOQIC FEM Box simulation (`four_target_phantom.mat`)

**Dataset:** Phase 1 - Simple box geometry with 4 target inclusions of different stiffness

## Objectives:
1. Load `.mat` file and understand its structure
2. Extract displacement fields (u_x, u_y, u_z)
3. Extract ground truth stiffness (complex shear modulus)
4. Visualize displacement fields and stiffness distribution
5. Build preprocessing utilities (normalization, coordinate grids)
6. Generate collocation points for physics loss

In [None]:
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
from pathlib import Path

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)

print("Libraries imported successfully!")

## 1. Load BIOQIC Box Dataset

In [None]:
# Define data path
data_path = Path('data/raw/bioqic/four_target_phantom.mat')

# Check if file exists
if not data_path.exists():
    raise FileNotFoundError(f"Dataset not found at {data_path}")

# Load .mat file
print(f"Loading dataset from: {data_path}")
data = sio.loadmat(data_path)

print("\n=== Dataset Keys ===")
for key in data.keys():
    if not key.startswith('__'):  # Skip metadata keys
        print(f"  {key}: {type(data[key])} - Shape: {data[key].shape if hasattr(data[key], 'shape') else 'N/A'}")

## 2. Inspect Data Structure

In [None]:
# Explore the main data structures
# BIOQIC datasets typically contain:
# - u: displacement field (complex)
# - Nodes: spatial coordinates
# - stiffness/modulus: ground truth

# List all non-metadata keys
data_keys = [k for k in data.keys() if not k.startswith('__')]

print("\n=== Detailed Inspection ===")
for key in data_keys:
    item = data[key]
    print(f"\n{key}:")
    print(f"  Type: {type(item)}")
    if hasattr(item, 'shape'):
        print(f"  Shape: {item.shape}")
        print(f"  Dtype: {item.dtype}")
        if np.prod(item.shape) < 100:  # Only print if small
            print(f"  Values: {item}")
        else:
            print(f"  Min: {item.min():.6f}, Max: {item.max():.6f}, Mean: {item.mean():.6f}")

## 3. Extract Key Variables

Based on the BIOQIC description, we expect:
- **Nodes**: Spatial coordinates (N x 3) for x, y, z
- **Displacements**: Complex displacement field (N x 3) for u_x, u_y, u_z
- **Stiffness**: Ground truth complex shear modulus (N x 1)

In [None]:
# Extract coordinates - try common key names
coord_keys = ['Nodes', 'nodes', 'coordinates', 'coords', 'x', 'X']
coordinates = None
for key in coord_keys:
    if key in data:
        coordinates = data[key]
        print(f"Found coordinates under key: '{key}'")
        break

if coordinates is not None:
    print(f"Coordinates shape: {coordinates.shape}")
    print(f"Coordinate range: X=[{coordinates[:, 0].min():.3f}, {coordinates[:, 0].max():.3f}], "
          f"Y=[{coordinates[:, 1].min():.3f}, {coordinates[:, 1].max():.3f}], "
          f"Z=[{coordinates[:, 2].min():.3f}, {coordinates[:, 2].max():.3f}]")
else:
    print("\nCoordinates not found with standard keys. Available keys:")
    print(data_keys)

In [None]:
# Extract displacement field - try common key names
disp_keys = ['u', 'U', 'displacement', 'Displacement', 'u_wave', 'u_field']
displacement = None
for key in disp_keys:
    if key in data:
        displacement = data[key]
        print(f"Found displacement under key: '{key}'")
        break

if displacement is not None:
    print(f"Displacement shape: {displacement.shape}")
    print(f"Displacement dtype: {displacement.dtype}")
    if np.iscomplexobj(displacement):
        print("Displacement is complex (real + imaginary)")
        print(f"  Real part - Min: {displacement.real.min():.6e}, Max: {displacement.real.max():.6e}")
        print(f"  Imag part - Min: {displacement.imag.min():.6e}, Max: {displacement.imag.max():.6e}")
    else:
        print("Displacement is real-valued")
        print(f"  Min: {displacement.min():.6e}, Max: {displacement.max():.6e}")

In [None]:
# Extract ground truth stiffness - try common key names
stiff_keys = ['G', 'g', 'mu', 'stiffness', 'Stiffness', 'shear_modulus', 'modulus']
stiffness = None
for key in stiff_keys:
    if key in data:
        stiffness = data[key]
        print(f"Found stiffness under key: '{key}'")
        break

if stiffness is not None:
    print(f"Stiffness shape: {stiffness.shape}")
    print(f"Stiffness dtype: {stiffness.dtype}")
    if np.iscomplexobj(stiffness):
        print("Stiffness is complex (storage + loss modulus)")
        print(f"  Storage modulus (real) - Min: {stiffness.real.min():.3f} Pa, Max: {stiffness.real.max():.3f} Pa")
        print(f"  Loss modulus (imag) - Min: {stiffness.imag.min():.3f} Pa, Max: {stiffness.imag.max():.3f} Pa")
    else:
        print("Stiffness is real-valued")
        print(f"  Min: {stiffness.min():.3f} Pa, Max: {stiffness.max():.3f} Pa")

## 4. Visualize Ground Truth Stiffness Distribution

The box phantom should show 4 distinct target regions with different stiffness values.

In [None]:
if stiffness is not None and coordinates is not None:
    # For complex stiffness, visualize magnitude
    if np.iscomplexobj(stiffness):
        stiffness_mag = np.abs(stiffness)
        stiffness_plot = stiffness.real  # Storage modulus for visualization
        label = 'Storage Modulus (Pa)'
    else:
        stiffness_mag = stiffness
        stiffness_plot = stiffness
        label = 'Stiffness (Pa)'
    
    # 3D scatter plot
    fig = plt.figure(figsize=(14, 6))
    
    # 3D view
    ax1 = fig.add_subplot(121, projection='3d')
    scatter = ax1.scatter(coordinates[:, 0], coordinates[:, 1], coordinates[:, 2],
                         c=stiffness_plot.flatten(), cmap='viridis', s=2)
    ax1.set_xlabel('X (m)')
    ax1.set_ylabel('Y (m)')
    ax1.set_zlabel('Z (m)')
    ax1.set_title('Ground Truth Stiffness Distribution (3D)')
    plt.colorbar(scatter, ax=ax1, label=label)
    
    # Histogram
    ax2 = fig.add_subplot(122)
    ax2.hist(stiffness_plot.flatten(), bins=50, edgecolor='black')
    ax2.set_xlabel(label)
    ax2.set_ylabel('Frequency')
    ax2.set_title('Stiffness Distribution')
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    # Print unique stiffness values to identify the 4 targets
    unique_stiff = np.unique(np.round(stiffness_plot.flatten(), 1))
    print(f"\nNumber of distinct stiffness regions: {len(unique_stiff)}")
    print(f"Stiffness values: {unique_stiff} Pa")

## 5. Visualize Displacement Field

In [None]:
if displacement is not None and coordinates is not None:
    # Extract magnitude of displacement (total displacement)
    if displacement.shape[1] == 3:  # 3 components
        if np.iscomplexobj(displacement):
            # Magnitude of complex vector: sqrt(|ux|^2 + |uy|^2 + |uz|^2)
            disp_mag = np.sqrt(np.abs(displacement[:, 0])**2 + 
                              np.abs(displacement[:, 1])**2 + 
                              np.abs(displacement[:, 2])**2)
        else:
            disp_mag = np.sqrt(displacement[:, 0]**2 + 
                              displacement[:, 1]**2 + 
                              displacement[:, 2]**2)
    else:
        disp_mag = np.abs(displacement).flatten()
    
    # 3D visualization
    fig = plt.figure(figsize=(14, 6))
    
    ax1 = fig.add_subplot(121, projection='3d')
    scatter = ax1.scatter(coordinates[:, 0], coordinates[:, 1], coordinates[:, 2],
                         c=disp_mag, cmap='plasma', s=2)
    ax1.set_xlabel('X (m)')
    ax1.set_ylabel('Y (m)')
    ax1.set_zlabel('Z (m)')
    ax1.set_title('Displacement Field Magnitude')
    plt.colorbar(scatter, ax=ax1, label='|u| (m)')
    
    # Histogram
    ax2 = fig.add_subplot(122)
    ax2.hist(disp_mag, bins=50, edgecolor='black', alpha=0.7)
    ax2.set_xlabel('Displacement Magnitude (m)')
    ax2.set_ylabel('Frequency')
    ax2.set_title('Displacement Distribution')
    ax2.set_yscale('log')
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

## 6. Data Preprocessing Utilities

In [None]:
class MREDataPreprocessor:
    """Preprocessing utilities for MRE data."""
    
    def __init__(self):
        self.coord_mean = None
        self.coord_std = None
        self.disp_scale = None
        self.stiff_scale = None
    
    def normalize_coordinates(self, coords, fit=True):
        """Normalize coordinates to [-1, 1] range."""
        if fit:
            self.coord_mean = coords.mean(axis=0)
            self.coord_std = coords.std(axis=0)
        
        coords_normalized = (coords - self.coord_mean) / (self.coord_std + 1e-10)
        return coords_normalized
    
    def normalize_displacement(self, disp, fit=True):
        """Normalize displacement field."""
        if fit:
            if np.iscomplexobj(disp):
                self.disp_scale = np.max(np.abs(disp))
            else:
                self.disp_scale = np.max(np.abs(disp))
        
        disp_normalized = disp / (self.disp_scale + 1e-10)
        return disp_normalized
    
    def normalize_stiffness(self, stiff, fit=True):
        """Normalize stiffness values."""
        if fit:
            if np.iscomplexobj(stiff):
                self.stiff_scale = np.max(np.abs(stiff))
            else:
                self.stiff_scale = np.max(np.abs(stiff))
        
        stiff_normalized = stiff / (self.stiff_scale + 1e-10)
        return stiff_normalized
    
    def denormalize_stiffness(self, stiff_normalized):
        """Convert normalized stiffness back to physical units."""
        return stiff_normalized * self.stiff_scale
    
    def generate_collocation_points(self, coords, n_points=1000, method='random'):
        """Generate collocation points for physics loss.
        
        Args:
            coords: Original coordinate array (N x 3)
            n_points: Number of collocation points to generate
            method: 'random' for random sampling, 'grid' for uniform grid
        """
        if method == 'random':
            # Random sampling from existing points
            indices = np.random.choice(coords.shape[0], size=n_points, replace=False)
            collocation_pts = coords[indices]
        elif method == 'grid':
            # Generate uniform grid within bounds
            x_min, x_max = coords[:, 0].min(), coords[:, 0].max()
            y_min, y_max = coords[:, 1].min(), coords[:, 1].max()
            z_min, z_max = coords[:, 2].min(), coords[:, 2].max()
            
            n_per_dim = int(np.ceil(n_points**(1/3)))
            x = np.linspace(x_min, x_max, n_per_dim)
            y = np.linspace(y_min, y_max, n_per_dim)
            z = np.linspace(z_min, z_max, n_per_dim)
            
            xx, yy, zz = np.meshgrid(x, y, z)
            collocation_pts = np.column_stack([xx.ravel(), yy.ravel(), zz.ravel()])
            
            # Randomly sample if we generated too many
            if collocation_pts.shape[0] > n_points:
                indices = np.random.choice(collocation_pts.shape[0], size=n_points, replace=False)
                collocation_pts = collocation_pts[indices]
        
        return collocation_pts

# Test the preprocessor
preprocessor = MREDataPreprocessor()

if coordinates is not None:
    coords_norm = preprocessor.normalize_coordinates(coordinates)
    print("Normalized coordinates:")
    print(f"  Range: X=[{coords_norm[:, 0].min():.3f}, {coords_norm[:, 0].max():.3f}], "
          f"Y=[{coords_norm[:, 1].min():.3f}, {coords_norm[:, 1].max():.3f}], "
          f"Z=[{coords_norm[:, 2].min():.3f}, {coords_norm[:, 2].max():.3f}]")
    
    # Generate collocation points
    colloc_pts = preprocessor.generate_collocation_points(coordinates, n_points=500, method='random')
    print(f"\nGenerated {colloc_pts.shape[0]} collocation points")

## 7. Save Preprocessed Data

In [None]:
# Create output directory
output_dir = Path('data/processed/phase1_box')
output_dir.mkdir(parents=True, exist_ok=True)

# Save preprocessed data
if coordinates is not None and displacement is not None and stiffness is not None:
    np.save(output_dir / 'coordinates.npy', coordinates)
    np.save(output_dir / 'coordinates_normalized.npy', coords_norm)
    np.save(output_dir / 'displacement.npy', displacement)
    np.save(output_dir / 'stiffness_ground_truth.npy', stiffness)
    np.save(output_dir / 'collocation_points.npy', colloc_pts)
    
    # Save preprocessing parameters
    preproc_params = {
        'coord_mean': preprocessor.coord_mean,
        'coord_std': preprocessor.coord_std,
        'disp_scale': preprocessor.disp_scale,
        'stiff_scale': preprocessor.stiff_scale
    }
    np.save(output_dir / 'preprocessing_params.npy', preproc_params)
    
    print(f"Preprocessed data saved to {output_dir}")
    print(f"\nSaved files:")
    for file in output_dir.glob('*.npy'):
        print(f"  - {file.name}")

## 8. Summary Statistics

In [None]:
print("="*60)
print("PHASE 1 DATASET SUMMARY")
print("="*60)

if coordinates is not None:
    print(f"\nüìç Spatial Domain:")
    print(f"   Number of nodes: {coordinates.shape[0]:,}")
    print(f"   X range: [{coordinates[:, 0].min():.4f}, {coordinates[:, 0].max():.4f}] m")
    print(f"   Y range: [{coordinates[:, 1].min():.4f}, {coordinates[:, 1].max():.4f}] m")
    print(f"   Z range: [{coordinates[:, 2].min():.4f}, {coordinates[:, 2].max():.4f}] m")

if displacement is not None:
    print(f"\nüåä Displacement Field:")
    print(f"   Shape: {displacement.shape}")
    print(f"   Type: {'Complex' if np.iscomplexobj(displacement) else 'Real'}")
    if np.iscomplexobj(displacement):
        print(f"   Magnitude range: [{np.abs(displacement).min():.6e}, {np.abs(displacement).max():.6e}] m")
    else:
        print(f"   Range: [{displacement.min():.6e}, {displacement.max():.6e}] m")

if stiffness is not None:
    print(f"\nüéØ Ground Truth Stiffness:")
    print(f"   Shape: {stiffness.shape}")
    print(f"   Type: {'Complex (Œº' + iŒº")' if np.iscomplexobj(stiffness) else 'Real'}")
    if np.iscomplexobj(stiffness):
        print(f"   Storage modulus (Œº'): [{stiffness.real.min():.1f}, {stiffness.real.max():.1f}] Pa")
        print(f"   Loss modulus (Œº"): [{stiffness.imag.min():.1f}, {stiffness.imag.max():.1f}] Pa")
    else:
        print(f"   Range: [{stiffness.min():.1f}, {stiffness.max():.1f}] Pa")

print("\n" + "="*60)
print("‚úÖ Data loading and preprocessing complete!")
print("üìå Next step: Part 2 - Physics Module (Helmholtz equation)")
print("="*60)