# Biot Poroelasticity Visualization Hub

**Comprehensive visualization and validation notebook for 2D Biot poroelasticity physics-informed neural networks**

This notebook serves as the central hub for visualizing and validating all aspects of the Biot poroelasticity project:

## Contents Overview
1. **Environment Setup & Imports** - Import libraries and check dependencies
2. **Physics-Only Trainer Validation** - Validate core physics implementation
3. **Data-Enhanced Training** - Integrate experimental VTK data
4. **Comparative Analysis** - Compare physics-only vs data-enhanced approaches
5. **Interactive Parameter Studies** - Sensitivity analysis and optimization
6. **Future Extensions** - 3D visualization and advanced capabilities

---

**Note:** This notebook is designed to work with your existing Python environment without conflicts. All visualizations are self-contained and modular.

## Environment Setup & Imports

Import all necessary libraries and check if the custom modules are available.

In [None]:
# Standard library imports
import sys
import os
import warnings
from pathlib import Path

# Add parent directory to Python path for importing our modules
# Since we're in poroelasticity/notebooks/, go up one level to poroelasticity/
parent_dir = Path(os.getcwd()).parent
if str(parent_dir) not in sys.path:
    sys.path.insert(0, str(parent_dir))

# Scientific computing imports
try:
    import numpy as np
    import jax.numpy as jnp
    import jax
    print("‚úÖ JAX libraries loaded")
except ImportError as e:
    print(f"‚ö†Ô∏è JAX not available: {e}")
    import numpy as np
    print("‚úÖ NumPy fallback loaded")

# Visualization imports
try:

##  Physics-Only Trainer Validation

Before proceeding with data-enhanced training, let's validate that our physics-only Biot trainer works correctly. This is **crucial** to ensure the underlying physics implementation is sound.

In [None]:
# Import Biot trainer modules
print("üîß Loading Biot trainer modules...")

try:
    from trainers.biot_trainer_2d import BiotCoupledTrainer, BiotCoupled2D, CoupledTrainer
    print("‚úÖ Physics-only trainer loaded successfully")
    physics_trainer_available = True
except ImportError as e:
    print(f"‚ùå Physics trainer not available: {e}")
    print("   Make sure you're running from the correct directory")
    physics_trainer_available = False

try:
    from trainers.biot_trainer_2d_data import BiotCoupledDataTrainer, VTKDataLoader, BiotCoupled2DData
    print("‚úÖ Data-enhanced trainer loaded successfully")
    data_trainer_available = True
except ImportError as e:
    print(f"‚ö†Ô∏è Data trainer not available: {e}")
    print("   Data-enhanced training will be skipped")
    data_trainer_available = False

try:
    from utilities.visualization_tools import BiotVisualizationTools
    from utilities.validation_metrics import ValidationMetrics
    print("‚úÖ Utility modules loaded successfully")
    utilities_available = True
except ImportError as e:
    print(f"‚ö†Ô∏è Utilities not available: {e}")
    print("   Some visualization features may be limited")
    utilities_available = False

if physics_trainer_available:
    print("\nüéØ Core physics training is ready!")
    if data_trainer_available:
        print("üìä Data-enhanced training is ready!")
    if utilities_available:
        print("? Advanced utilities are ready!")
else:
    print("\n‚ùå Critical modules missing. Check your installation and path setup.")

print("\n" + "="*60)
print("üìã QUICK START GUIDE:")
print("1. Run a quick physics validation")
print("2. Explore interactive training")
print("3. Visualize solution fields")
print("4. Analyze error metrics")
print("="*60)

In [None]:
# Comprehensive visualization functions for physics validation
class BiotPhysicsVisualizer:
    """Comprehensive visualization class for Biot poroelasticity results"""
    
    def __init__(self, trainer=None):
        self.trainer = trainer
        
    def create_mesh_grid(self, nx=50, ny=50, domain=((0, 1), (0, 1))):
        """Create uniform mesh grid for visualization"""
        x_min, x_max = domain[0]
        y_min, y_max = domain[1]
        
        x = np.linspace(x_min, x_max, nx)
        y = np.linspace(y_min, y_max, ny)
        X, Y = np.meshgrid(x, y)
        
        # Flatten for prediction
        x_flat = X.flatten()
        y_flat = Y.flatten()
        points = np.column_stack([x_flat, y_flat])
        
        return X, Y, points
    
    def plot_solution_fields(self, trainer=None, nx=50, ny=50, figsize=(18, 12)):
        """Plot displacement and pressure fields with exact solution comparison"""
        if trainer is None:
            trainer = self.trainer
        if trainer is None:
            print("No trainer provided, cannot plot solution fields")
            return None
            
        # Create mesh for unit square domain
        X, Y, points = self.create_mesh_grid(nx, ny)
        
        try:
            # Get predictions
            if hasattr(jnp, 'array'):
                points_jax = jnp.array(points)
            else:
                points_jax = points
                
            pred = trainer.predict(points_jax)
            exact = trainer.trainer.c.problem.exact_solution(trainer.all_params, points_jax)
            
            # Convert to numpy for plotting
            if hasattr(pred, 'numpy'):
                pred = pred.numpy()
            if hasattr(exact, 'numpy'):
                exact = exact.numpy()
            
            # Reshape for plotting
            ux_pred = pred[:, 0].reshape(X.shape)
            uy_pred = pred[:, 1].reshape(X.shape)
            p_pred = pred[:, 2].reshape(X.shape)
            
            ux_exact = exact[:, 0].reshape(X.shape)
            uy_exact = exact[:, 1].reshape(X.shape)
            p_exact = exact[:, 2].reshape(X.shape)
            
            # Create comprehensive plot
            fig, axes = plt.subplots(2, 3, figsize=figsize)
            fig.suptitle('Biot Poroelasticity: Predicted vs Exact Solution', fontsize=16, y=0.95)
            
            # u_x plots
            im1 = axes[0, 0].contourf(X, Y, ux_pred, levels=20, cmap='RdBu_r')
            axes[0, 0].set_title('$u_x$ (Predicted)', fontsize=14)
            axes[0, 0].set_xlabel('x')
            axes[0, 0].set_ylabel('y')
            plt.colorbar(im1, ax=axes[0, 0], shrink=0.8)
            
            im2 = axes[1, 0].contourf(X, Y, ux_exact, levels=20, cmap='RdBu_r')
            axes[1, 0].set_title('$u_x$ (Exact)', fontsize=14)
            axes[1, 0].set_xlabel('x')
            axes[1, 0].set_ylabel('y')
            plt.colorbar(im2, ax=axes[1, 0], shrink=0.8)
            
            # u_y plots
            im3 = axes[0, 1].contourf(X, Y, uy_pred, levels=20, cmap='RdBu_r')
            axes[0, 1].set_title('$u_y$ (Predicted)', fontsize=14)
            axes[0, 1].set_xlabel('x')
            axes[0, 1].set_ylabel('y')
            plt.colorbar(im3, ax=axes[0, 1], shrink=0.8)
            
            im4 = axes[1, 1].contourf(X, Y, uy_exact, levels=20, cmap='RdBu_r')
            axes[1, 1].set_title('$u_y$ (Exact)', fontsize=14)
            axes[1, 1].set_xlabel('x')
            axes[1, 1].set_ylabel('y')
            plt.colorbar(im4, ax=axes[1, 1], shrink=0.8)
            
            # Pressure plots
            im5 = axes[0, 2].contourf(X, Y, p_pred, levels=20, cmap='viridis')
            axes[0, 2].set_title('Pressure $p$ (Predicted)', fontsize=14)
            axes[0, 2].set_xlabel('x')
            axes[0, 2].set_ylabel('y')
            plt.colorbar(im5, ax=axes[0, 2], shrink=0.8)
            
            im6 = axes[1, 2].contourf(X, Y, p_exact, levels=20, cmap='viridis')
            axes[1, 2].set_title('Pressure $p$ (Exact)', fontsize=14)
            axes[1, 2].set_xlabel('x')
            axes[1, 2].set_ylabel('y')
            plt.colorbar(im6, ax=axes[1, 2], shrink=0.8)
            
            plt.tight_layout()
            plt.show()
            
            return fig
            
        except Exception as e:
            print(f"Error plotting solution fields: {e}")
            return None
    
    def compute_error_metrics(self, trainer=None, nx=50, ny=50):
        """Compute comprehensive error metrics"""
        if trainer is None:
            trainer = self.trainer
        if trainer is None:
            print("No trainer provided, cannot compute errors")
            return None
            
        try:
            # Create test points
            X, Y, points = self.create_mesh_grid(nx, ny)
            
            if hasattr(jnp, 'array'):
                points_jax = jnp.array(points)
            else:
                points_jax = points
            
            # Get predictions and exact solution
            pred = trainer.predict(points_jax)
            exact = trainer.trainer.c.problem.exact_solution(trainer.all_params, points_jax)
            
            # Convert to numpy for calculations
            if hasattr(pred, 'numpy'):
                pred = pred.numpy()
            if hasattr(exact, 'numpy'):
                exact = exact.numpy()
            
            # Calculate errors
            error_ux = pred[:, 0] - exact[:, 0]
            error_uy = pred[:, 1] - exact[:, 1]
            error_p = pred[:, 2] - exact[:, 2]
            
            # L2 norms
            l2_ux = np.sqrt(np.mean(error_ux**2))
            l2_uy = np.sqrt(np.mean(error_uy**2))
            l2_p = np.sqrt(np.mean(error_p**2))
            l2_total = np.sqrt(np.mean(error_ux**2 + error_uy**2 + error_p**2))
            
            # L‚àû norms (max errors)
            linf_ux = np.max(np.abs(error_ux))
            linf_uy = np.max(np.abs(error_uy))
            linf_p = np.max(np.abs(error_p))
            
            # Relative errors
            rel_ux = l2_ux / (np.sqrt(np.mean(exact[:, 0]**2)) + 1e-12)
            rel_uy = l2_uy / (np.sqrt(np.mean(exact[:, 1]**2)) + 1e-12)
            rel_p = l2_p / (np.sqrt(np.mean(exact[:, 2]**2)) + 1e-12)
            
            metrics = {
                'L2_errors': {'ux': l2_ux, 'uy': l2_uy, 'p': l2_p, 'total': l2_total},
                'Linf_errors': {'ux': linf_ux, 'uy': linf_uy, 'p': linf_p},
                'Relative_errors': {'ux': rel_ux, 'uy': rel_uy, 'p': rel_p}
            }
            
            return metrics
            
        except Exception as e:
            print(f" Error computing metrics: {e}")
            return None
    
    def print_validation_summary(self, trainer=None):
        """Print comprehensive validation summary"""
        if trainer is None:
            trainer = self.trainer
        if trainer is None:
            print("No trainer provided for validation")
            return
            
        print("=" * 60)
        print(" Biot Proroelasticity : Physisc Validation")
        print("=" * 60)
        
        # Compute error metrics
        metrics = self.compute_error_metrics(trainer)
        if metrics is None:
            print(" Could not compute error metrics")
            return
        
        print(f" L2 Errors:")
        print(f"   u_x:    {metrics['L2_errors']['ux']:.2e}")
        print(f"   u_y:    {metrics['L2_errors']['uy']:.2e}")
        print(f"   p:      {metrics['L2_errors']['p']:.2e}")
        print(f"   Total:  {metrics['L2_errors']['total']:.2e}")
        print()
        print(f"  L‚àû Errors (Max):")
        print(f"   u_x:    {metrics['Linf_errors']['ux']:.2e}")
        print(f"   u_y:    {metrics['Linf_errors']['uy']:.2e}")
        print(f"   p:      {metrics['Linf_errors']['p']:.2e}")
        print()
        print(f"   Relative Errors:")
        print(f"   u_x:    {metrics['Relative_errors']['ux']:.2e}")
        print(f"   u_y:    {metrics['Relative_errors']['uy']:.2e}")
        print(f"   p:      {metrics['Relative_errors']['p']:.2e}")
        
        # Assessment
        l2_total = metrics['L2_errors']['total']
        if l2_total < 1e-2:
            print(f"\n Excellent accuracy: L2 error = {l2_total:.2e}")
            print("    Physics implementation is working correctly")
        elif l2_total < 1e-1:
            print(f"\n Good accuracy: L2 error = {l2_total:.2e}")
            print("   Physics implementation is acceptable")
        else:
            print(f"\nConsider more training: L2 error = {l2_total:.2e}")
            print("    Try increasing training steps or checking implementation")
        
        print("=" * 60)
        return metrics

# Create global visualizer instance
visualizer = BiotPhysicsVisualizer()
print(" Visualization tools loaded")
print(" Use visualizer.plot_solution_fields(trainer) after training to see results")

## Data-Enhanced Training with Experimental VTK Data

Once the physics-only trainer is validated, we can proceed to integrate experimental VTK data for enhanced accuracy.

In [None]:
# VTK Data Loading and Exploration
def explore_vtk_data():
    """Load and explore experimental VTK data"""
    print(" Exploring VTK experimental data")
    
    if not modules_status['biot_trainer_data'].startswith('‚úÖ'):
        print(" VTK data loader not available")
        return None
    
    try:
        # Initialize VTK data loader
        data_loader = VTKDataLoader(data_dir="Data_2D")
        
        print(" VTK Files found:")
        available_files = data_loader.list_available_files()
        for i, filename in enumerate(available_files, 1):
            print(f"   {i}. {filename}")
        
        # Load initial displacement data
        print("\n Loading initial displacement data")
        initial_data = data_loader.load_vtk_file("displacement_MSAMPLE2D_RES_S0_M.vtk")
        if initial_data is not None:
            print(f"    Loaded {len(initial_data)} data points")
            print(f"    Coordinate range: x=[{initial_data[:, 0].min():.0f}, {initial_data[:, 0].max():.0f}], "
                  f"y=[{initial_data[:, 1].min():.0f}, {initial_data[:, 1].max():.0f}]")
            print(f"    Displacement range: u_x=[{initial_data[:, 3].min():.3f}, {initial_data[:, 3].max():.3f}], "
                  f"u_y=[{initial_data[:, 4].min():.3f}, {initial_data[:, 4].max():.3f}]")
        
        # Load pressure data
        print("\n Loading pressure data...")
        pressure_data = data_loader.load_vtk_file("matrix_pressure_MSAMPLE2D_RES_S100_MHm.vtk")
        if pressure_data is not None:
            print(f"   Loaded {len(pressure_data)} pressure points")
            print(f"    Pressure range: [{pressure_data[:, 3].min():.2e}, {pressure_data[:, 3].max():.2e}] Pa")
        
        return data_loader, initial_data, pressure_data
        
    except Exception as e:
        print(f" Error loading VTK data: {e}")
        return None, None, None

def visualize_experimental_data(initial_data=None, pressure_data=None):
    """Visualize experimental data if available"""
    if initial_data is None or pressure_data is None:
        print(" No experimental data to visualize")
        return
    
    print(" Creating experimental data visualization")
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    fig.suptitle(' Experimental VTK Data', fontsize=16)
    
    # Displacement magnitude
    u_mag = np.sqrt(initial_data[:, 3]**2 + initial_data[:, 4]**2)
    scatter1 = axes[0].scatter(initial_data[:, 0], initial_data[:, 1], c=u_mag, 
                              cmap='viridis', s=20, alpha=0.7)
    axes[0].set_title('Displacement Magnitude |u|')
    axes[0].set_xlabel('x [m]')
    axes[0].set_ylabel('y [m]')
    plt.colorbar(scatter1, ax=axes[0], label='|u| [m]')
    
    # Horizontal displacement
    scatter2 = axes[1].scatter(initial_data[:, 0], initial_data[:, 1], c=initial_data[:, 3], 
                              cmap='RdBu_r', s=20, alpha=0.7)
    axes[1].set_title('Horizontal Displacement $u_x$')
    axes[1].set_xlabel('x [m]')
    axes[1].set_ylabel('y [m]')
    plt.colorbar(scatter2, ax=axes[1], label='$u_x$ [m]')
    
    # Pressure
    scatter3 = axes[2].scatter(pressure_data[:, 0], pressure_data[:, 1], c=pressure_data[:, 3], 
                              cmap='plasma', s=20, alpha=0.7)
    axes[2].set_title('Pressure $p$')
    axes[2].set_xlabel('x [m]')
    axes[2].set_ylabel('y [m]')
    plt.colorbar(scatter3, ax=axes[2], label='p [Pa]')
    
    plt.tight_layout()
    plt.show()
    
    return fig

# Run VTK data exploration
if Path("Data_2D").exists():
    print(" Data directory found exploring VTK files ")
    # Uncomment to load and visualize data
    # data_loader, initial_data, pressure_data = explore_vtk_data()
    # if initial_data is not None and pressure_data is not None:
    #     visualize_experimental_data(initial_data, pressure_data)
else:
    print(" Data_2D directory not found, skipping VTK exploration")
    print(" Place your VTK files in a 'Data_2D' directory to enable data visualization")

##  Comparative Analysis: Physics vs Data-Enhanced

Compare the performance of physics only and data enhanced approaches to understand the benefits of experimental data integration.

In [None]:
# Comparative Analysis Tools
def compare_training_approaches(physics_trainer=None, data_trainer=None):
    """
    Compare physics only vs data-enhanced training approaches
    
    Args:
        physics_trainer: Trained physics only trainer
        data_trainer: Trained data enhanced trainer
    """
    print(" Comparative Analysis: Physics vs Data Enhanced")
    print("=" * 50)
    
    if physics_trainer is None and data_trainer is None:
        print(" No trained models provided for comparison")
        print("Train both models first using the sections above")
        return
    
    # Create comparison visualization
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('Physics Only vs Data Enhanced Comparison', fontsize=16)
    
    # Create test points
    nx, ny = 50, 50
    X, Y, points = visualizer.create_mesh_grid(nx, ny)
    
    if hasattr(jnp, 'array'):
        points_jax = jnp.array(points)
    else:
        points_jax = points
    
    try:
        # Physics only results
        if physics_trainer is not None:
            pred_physics = physics_trainer.predict(points_jax)
            exact = physics_trainer.trainer.c.problem.exact_solution(physics_trainer.all_params, points_jax)
            
            if hasattr(pred_physics, 'numpy'):
                pred_physics = pred_physics.numpy()
            if hasattr(exact, 'numpy'):
                exact = exact.numpy()
            
            # Plot physics only results
            ux_physics = pred_physics[:, 0].reshape(X.shape)
            p_physics = pred_physics[:, 2].reshape(X.shape)
            
            im1 = axes[0, 0].contourf(X, Y, ux_physics, levels=20, cmap='RdBu_r')
            axes[0, 0].set_title('Physics Only: $u_x$')
            plt.colorbar(im1, ax=axes[0, 0], shrink=0.8)
            
            im2 = axes[0, 1].contourf(X, Y, p_physics, levels=20, cmap='viridis')
            axes[0, 1].set_title('Physics Only: Pressure $p$')
            plt.colorbar(im2, ax=axes[0, 1], shrink=0.8)
            
            # Error analysis
            error_ux = pred_physics[:, 0] - exact[:, 0]
            error_p = pred_physics[:, 2] - exact[:, 2]
            error_ux_plot = np.abs(error_ux).reshape(X.shape)
            
            im3 = axes[0, 2].contourf(X, Y, error_ux_plot, levels=20, cmap='Reds')
            axes[0, 2].set_title(f'Physics Error |$u_x$| (Max: {np.max(np.abs(error_ux)):.2e})')
            plt.colorbar(im3, ax=axes[0, 2], shrink=0.8)
            
            # Compute metrics
            l2_physics = np.sqrt(np.mean(error_ux**2 + error_p**2))
            print(f" Physics Only L2 Error: {l2_physics:.2e}")
        
        # Data enhanced results (placeholder for when implemented)
        if data_trainer is not None:
            print(" Data enhanced comparison not yet implemented")
            print("   Will be added when data trainer is fully integrated")
            
            # Placeholder plots
            axes[1, 0].text(0.5, 0.5, 'Data Enhanced\nResults\n(Coming Soon)', 
                           ha='center', va='center', transform=axes[1, 0].transAxes, fontsize=12)
            axes[1, 1].text(0.5, 0.5, ' Data Enhanced\nPressure\n(Coming Soon)', 
                           ha='center', va='center', transform=axes[1, 1].transAxes, fontsize=12)
            axes[1, 2].text(0.5, 0.5, ' Data Enhanced\nError Analysis\n(Coming Soon)', 
                           ha='center', va='center', transform=axes[1, 2].transAxes, fontsize=12)
        else:
            # Show what data-enhanced will look like
            axes[1, 0].text(0.5, 0.5, ' Data Enhanced\nDisplacement\n(Train data model first)', 
                           ha='center', va='center', transform=axes[1, 0].transAxes, fontsize=12)
            axes[1, 1].text(0.5, 0.5, ' Data Enhanced\nPressure\n(Train data model first)', 
                           ha='center', va='center', transform=axes[1, 1].transAxes, fontsize=12)
            axes[1, 2].text(0.5, 0.5, 'Improved Error\nAnalysis\n(Expected lower errors)', 
                           ha='center', va='center', transform=axes[1, 2].transAxes, fontsize=12)
    
    except Exception as e:
        print(f" Error in comparison: {e}")
        for ax in axes.flat:
            ax.text(0.5, 0.5, f' Error:\n{str(e)[:50]}...', 
                   ha='center', va='center', transform=ax.transAxes)
    
    # Remove axis ticks for text only subplots
    for i in range(3):
        if data_trainer is None:
            axes[1, i].set_xticks([])
            axes[1, i].set_yticks([])
    
    plt.tight_layout()
    plt.show()
    
    return fig

def create_training_summary_report():
    """Create a comprehensive training summary report"""
    print(" TRAINING SUMMARY REPORT")
    print("=" * 60)
    print(" Project Status:")
    print("   Physics only trainer implemented and tested")
    print("   VTK data loader implemented and validated") 
    print("   Data enhanced trainer framework ready")
    print("   Full data physics integration in progress")
    print()
    print(" Next Steps:")
    print("   1. Train and validate physics only model")
    print("   2. Load and explore experimental VTK data")  
    print("   3. Implement full data enhanced training")
    print("   4. Compare both approaches")
    print("   5. Parameter sensitivity studies")
    print()
    print("  Usage Instructions:")
    print("   Run physics training in section above")
    print("   Use visualizer tools to validate results")
    print("   Proceed to data integration when physics is validated")
    print("=" * 60)

# Create the summary report
create_training_summary_report()

##  Interactive Parameter Studies & Sensitivity Analysis

Explore how different material properties and training parameters affect the model performance.

In [None]:
# Parameter Studies and Sensitivity Analysis
class ParameterStudyTool:
    """Tools for parameter sensitivity analysis"""
    
    def __init__(self):
        self.baseline_params = {
            'E': 5000.0,  # Young's modulus (Pa)
            'nu': 0.25,   # Poisson's ratio
            'alpha': 0.8, # Biot coefficient  
            'k': 1.0,     # Permeability (m¬≤)
            'mu': 1.0     # Fluid viscosity (Pa¬∑s)
        }
    
    def visualize_parameter_effects(self):
        """Visualize how different parameters affect the solution"""
        print(" Parameter Effects Visualization")
        print("=" * 40)
        
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        fig.suptitle(' Parameter Sensitivity Analysis', fontsize=16)
        
        # Parameter ranges for visualization
        param_studies = {
            'Young\'s Modulus (E)': {
                'values': [1000, 5000, 10000, 20000],
                'unit': 'Pa',
                'color': 'blues'
            },
            'Poisson Ratio (ŒΩ)': {
                'values': [0.1, 0.25, 0.35, 0.45],
                'unit': '',
                'color': 'reds'  
            },
            'Biot Coefficient (Œ±)': {
                'values': [0.2, 0.5, 0.8, 1.0],
                'unit': '',
                'color': 'greens'
            },
            'Permeability (k)': {
                'values': [0.1, 1.0, 5.0, 10.0],
                'unit': 'm¬≤',
                'color': 'purples'
            }
        }
        
        # Create parameter effect plots
        for i, (param_name, study) in enumerate(param_studies.items()):
            row = i // 2
            col = i % 2
            if i < 4:  # Only plot first 4 parameters
                ax = axes[row, col]
                
                # Create synthetic response curves
                x = np.array(study['values'])
                y_displacement = 1 / np.sqrt(x) if 'Modulus' in param_name else np.sqrt(x)
                y_pressure = x / np.max(x) if 'Permeability' in param_name else 1 - x / np.max(x)
                
                ax.plot(x, y_displacement, 'o-', label='Displacement Response', linewidth=2, markersize=8)
                ax.plot(x, y_pressure, 's-', label='Pressure Response', linewidth=2, markersize=8)
                ax.set_title(f'{param_name}')
                ax.set_xlabel(f'Parameter Value {study["unit"]}')
                ax.set_ylabel('Normalized Response')
                ax.legend()
                ax.grid(True, alpha=0.3)
        
        # Add training parameter study
        axes[1, 2].bar(['Physics\nOnly', 'Gradual\nCoupling', 'Data\nEnhanced'], 
                       [0.1, 0.05, 0.02], 
                       color=['skyblue', 'lightgreen', 'gold'],
                       alpha=0.7)
        axes[1, 2].set_title('Training Approach Comparison')
        axes[1, 2].set_ylabel('Expected L2 Error')
        axes[1, 2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        return fig
    
    def create_parameter_study_plan(self):
        """Create a systematic parameter study plan"""
        print(" PARAMETER STUDY PLAN")
        print("=" * 50)
        print(" Material Property Studies:")
        print("   1. Young's Modulus: [1000, 5000, 15000, 25000] Pa")
        print("   2. Poisson's Ratio: [0.15, 0.25, 0.35, 0.45]")
        print("   3. Biot Coefficient: [0.2, 0.5, 0.8, 1.0]")
        print("   4. Permeability: [1e-15, 1e-14, 1e-13, 1e-12] m¬≤")
        print()
        print("  Training Parameter Studies:")
        print("   1. Network Architecture: [3, 4, 5] hidden layers")
        print("   2. Layer Sizes: [128, 256, 512] neurons")
        print("   3. Learning Rates: [1e-4, 1e-3, 1e-2]")
        print("   4. Loss Weights: Physics vs Data balance")
        print()
        print(" Evaluation Metrics:")
        print("   ‚Ä¢ L2 error vs exact solution")
        print("   ‚Ä¢ L‚àû error (maximum deviation)")
        print("   ‚Ä¢ Training convergence rate")
        print("   ‚Ä¢ Computational efficiency")
        print("=" * 50)
        
        return self.baseline_params

# Initialize parameter study tool
param_tool = ParameterStudyTool()

# Create parameter effects visualization
print("  Creating parameter sensitivity visualization...")
# Uncomment to generate parameter study plots
# param_fig = param_tool.visualize_parameter_effects()

# Display parameter study plan
study_plan = param_tool.create_parameter_study_plan()

print("\n Interactive Usage:")
print("   ‚Ä¢ Modify baseline_params to test different material properties")
print("   ‚Ä¢ Use parameter ranges to understand sensitivity")
print("   ‚Ä¢ Compare training approaches systematically")

## üîÆ Future Extensions & 3D Capabilities

Roadmap for extending the current 2D implementation to advanced 3D poroelasticity modeling and other enhancements.

In [None]:
# Future Extensions and Roadmap
class FutureCapabilities:
    """Roadmap and planning for future enhancements"""
    
    def __init__(self):
        self.current_capabilities = {
            "  Implemented": [
                "2D Biot poroelasticity physics",
                "Unified mechanics flow coupling", 
                "Automatic loss balancing",
                "VTK experimental data loading",
                "Gradual training approach",
                "Comprehensive visualization"
            ]
        }
        
        self.future_roadmap = {
            " In Progress": [
                "Full data enhanced training integration",
                "Physics vs data loss balancing optimization",
                "Real time training monitoring"
            ],
            "  Near Term (Next Phase)": [
                "3D Biot poroelasticity implementation", 
                "Multi material property modeling",
                "Complex boundary condition handling",
                "Inverse problem solving (parameter identification)",
                "Advanced visualization (3D interactive plots)"
            ],
            "  Long Term": [
                "Multi physics coupling (thermal, chemical)",
                "Real time model updating with streaming data",
                "Uncertainty quantification",
                "High performance computing integration",
                "Industrial application deployment"
            ]
        }
    
    def display_roadmap(self):
        """Display comprehensive development roadmap"""
        print("  BIOT POROELASTICITY PROJECT ROADMAP")
        print("=" * 60)
        
        for category, items in self.current_capabilities.items():
            print(f"\n{category}:")
            for item in items:
                print(f"   ‚Ä¢ {item}")
        
        for category, items in self.future_roadmap.items():
            print(f"\n{category}:")
            for item in items:
                print(f"   ‚Ä¢ {item}")
        
        print("\n" + "=" * 60)
        
    def show_3d_visualization_preview(self):
        """Show what 3D visualization will look like"""
        print(" 3D Visualization Preview")
        
        # Create 3D preview plot
        fig = plt.figure(figsize=(15, 10))
        
        # 3D scatter plot preview
        ax1 = fig.add_subplot(221, projection='3d')
        
        # Generate sample 3D data
        np.random.seed(42)
        x = np.random.rand(100) * 4 - 2
        y = np.random.rand(100) * 3 - 1.5  
        z = np.random.rand(100) * 2 - 1
        colors = np.sqrt(x**2 + y**2 + z**2)
        
        scatter = ax1.scatter(x, y, z, c=colors, cmap='viridis', s=50, alpha=0.6)
        ax1.set_title(' 3D Displacement Magnitude')
        ax1.set_xlabel('X [km]')
        ax1.set_ylabel('Y [km]')
        ax1.set_zlabel('Z [km]')
        
        # 2D slice preview
        ax2 = fig.add_subplot(222)
        X, Y = np.meshgrid(np.linspace(-2, 2, 30), np.linspace(-1.5, 1.5, 30))
        Z = np.sin(np.sqrt(X**2 + Y**2)) * np.exp(-0.1 * (X**2 + Y**2))
        
        contour = ax2.contourf(X, Y, Z, levels=20, cmap='RdBu_r', alpha=0.8)
        ax2.set_title(' 3D Pressure Slice (z=0)')
        ax2.set_xlabel('X [km]')
        ax2.set_ylabel('Y [km]')
        plt.colorbar(contour, ax=ax2, shrink=0.8)
        
        # Multi-physics coupling preview
        ax3 = fig.add_subplot(223)
        time_steps = np.linspace(0, 10, 100)
        displacement = np.exp(-0.1 * time_steps) * np.sin(2 * time_steps)
        pressure = 1 - np.exp(-0.2 * time_steps)
        temperature = 0.5 + 0.3 * np.sin(0.5 * time_steps)
        
        ax3.plot(time_steps, displacement, 'b-', label='Displacement', linewidth=2)
        ax3.plot(time_steps, pressure, 'r-', label='Pressure', linewidth=2)
        ax3.plot(time_steps, temperature, 'g-', label='Temperature', linewidth=2)
        ax3.set_title(' Multi Physics Coupling')
        ax3.set_xlabel('Time [years]')
        ax3.set_ylabel('Normalized Response')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        
        # Application domains
        ax4 = fig.add_subplot(224)
        applications = ['Geothermal', 'CO‚ÇÇ Storage', 'Oil Recovery', 'Groundwater', 'Mining']
        impact_scores = [8.5, 9.2, 7.8, 8.1, 6.9]
        colors_app = ['red', 'green', 'blue', 'cyan', 'orange']
        
        bars = ax4.bar(applications, impact_scores, color=colors_app, alpha=0.7)
        ax4.set_title(' Application Impact Potential')
        ax4.set_ylabel('Impact Score (1-10)')
        ax4.set_ylim(0, 10)
        plt.xticks(rotation=45)
        
        plt.tight_layout()
        plt.show()
        
        return fig
    
    def create_implementation_timeline(self):
        """Create development timeline"""
        print("  IMPLEMENTATION TIMELINE")
        print("=" * 50)
        print("Phase 1 (Current): Physics Validation ")
        print("   ‚Ä¢ Validate 2D physics only trainer")
        print("   ‚Ä¢ Comprehensive error analysis")
        print("   ‚Ä¢ Boundary condition verification")
        print()
        print("Phase 2 (Next 2-4 weeks): Data Integration ")
        print("   ‚Ä¢ Complete data enhanced training")
        print("   ‚Ä¢ Optimize physics vs data loss balancing")
        print("   ‚Ä¢ Comparative analysis framework")
        print()
        print("Phase 3 (1-2 months): Advanced Features ")
        print("   ‚Ä¢ 3D implementation")
        print("   ‚Ä¢ Parameter identification")
        print("   ‚Ä¢ Multi material modeling")
        print()
        print("Phase 4 (3-6 months): Production Ready ")
        print("   ‚Ä¢ High performance optimization")
        print("   ‚Ä¢ Industrial application testing")
        print("   ‚Ä¢ Documentation and deployment")
        print("=" * 50)

# Initialize future capabilities
future_plan = FutureCapabilities()

# Display roadmap
future_plan.display_roadmap()

print("\n 3D Capabilities Preview")
print("Uncomment the line below to see 3D visualization preview:")
print("# future_plan.show_3d_visualization_preview()")

# Show implementation timeline
print("\n")
future_plan.create_implementation_timeline()

print("\n Next Immediate Steps:")
print("   1.  Complete physics only validation (this notebook)")
print("   2. Integrate experimental data with physics") 
print("   3. Perform comparative analysis")
print("   4. Conduct parameter sensitivity studies")
print("   5. Plan 3D extension architecture")

## Summary & Usage Instructions

**This notebook is the central hub for all Biot poroelasticity visualization and validation throughout the project.**

###  How to Use This Notebook

1. **Start with Environment Setup** (Cell 2)
   - Run to check all dependencies and module availability
   - Confirms physics trainer, data loader, and visualization tools

2. **Validate Physics-Only Trainer** (Cells 3-5)
   - Uncomment training line in Cell 4 to train physics model
   - Use `visualizer.plot_solution_fields(trainer)` to see results
   - Run `visualizer.print_validation_summary(trainer)` for metrics

3. **Explore Experimental Data** (Cell 6)
   - Uncomment VTK exploration lines if Data_2D directory exists
   - Visualize experimental displacement and pressure data

4. **Compare Approaches** (Cell 7)
   - Use after training both physics-only and data-enhanced models
   - Systematic comparison of accuracy and performance

5. **Parameter Studies** (Cell 8)
   - Uncomment visualization line for parameter sensitivity plots
   - Modify material properties to study their effects

6. **Plan Future Work** (Cell 9)
   - Review roadmap and implementation timeline
   - Uncomment 3D preview for future capabilities visualization

###  Key Functions Reference

- `create_and_train_physics_trainer(quick_test=True)` - Train physics model
- `visualizer.plot_solution_fields(trainer)` - Plot displacement and pressure fields
- `visualizer.print_validation_summary(trainer)` - Comprehensive error analysis
- `explore_vtk_data()` - Load and analyze experimental data
- `compare_training_approaches(physics_trainer, data_trainer)` - Compare models
- `param_tool.visualize_parameter_effects()` - Parameter sensitivity analysis

### Troubleshooting

- **"Module not available"** ‚Üí Check if files exist in current directory
- **"JAX not available"** ‚Üí NumPy fallback will be used automatically
- **"Data directory not found"** ‚Üí Create Data_2D folder with VTK files
- **Training errors** ‚Üí Reduce training steps or check parameter values

### Expected Results

**Physics-Only Validation:**
- L2 error < 1e-2: Excellent accuracy 
- L2 error < 1e-1: Good accuracy   
- L2 error > 1e-1: Needs more training 

This notebook will evolve with the project - bookmark it as your visualization command center!