In [11]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import Lasso, LassoCV, ElasticNet
from sklearn.feature_selection import SelectKBest, f_regression
from sklearn.metrics import mean_squared_error, r2_score
from scipy.optimize import minimize
import time

import ipywidgets as widgets
from IPython.display import display, clear_output

# Add project root to path
current_dir = os.getcwd()
if 'notebooks' in current_dir:
    project_root = os.path.abspath(os.path.join(current_dir, '..'))
else:
    project_root = current_dir

if project_root not in sys.path:
    sys.path.append(project_root)

from reward_func.evo_devo import somitogenesis_sol_func, weights_to_matrix

print(f"Project root: {project_root}")


Project root: /Users/dannyhuang/Developer/gflownet2/discrete-gflownet


In [None]:
# Interactive VISUALIZATION FUNCTIONS FOR ALL ROWS
def plot_all_rows_lasso(cell_idx=0, alpha=0.01, max_rows_display=7):
    """Plot LASSO results for all rows of the matrix, recomputing if needed, and print a detailed summary."""
    # Always recompute results for the given alpha using the defined function
    results = run_lasso_all_rows_7x7(alpha, verbose=False)
    
    n_timepoints = len(t_sim_7x7)
    start = cell_idx * n_timepoints
    end = (cell_idx + 1) * n_timepoints
    
    # Create subplots - adjust grid based on number of rows
    n_rows_plot = min(max_rows_display, n_nodes_7x7)
    cols = 3 if n_rows_plot > 6 else 2
    rows = int(np.ceil(n_rows_plot / cols))
    
    fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 4*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)
    if cols == 1:
        axes = axes.reshape(-1, 1)
    
    for row_idx in range(n_rows_plot):
        ax = axes[row_idx // cols, row_idx % cols]
        
        # Get data for this row
        y_sample = y_all_by_row[row_idx][start:end]
        y_pred_sample = results['row_results'][row_idx]['y_pred'][start:end]
        
        ax.plot(t_sim_7x7, y_sample, 'b-', label='Original', linewidth=2)
        ax.plot(t_sim_7x7, y_pred_sample, 'r--', label='LASSO', linewidth=2)
        
        row_res = results['row_results'][row_idx]
        ax.set_title(
            f'Row {row_idx}: {row_res["n_zeros"]}/{n_nodes_7x7} zeros\n'
            f'R²: {row_res["r2"]:.3f}'
        )
        ax.set_xlabel('Time')
        ax.set_ylabel('y(t)')
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for i in range(n_rows_plot, rows * cols):
        axes[i // cols, i % cols].set_visible(False)
    
    plt.suptitle(f'Cell {cell_idx}, α={alpha:.4g} - All Matrix Rows\nOverall Sparsity: {results["sparsity_percent"]:.1f}%', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    # Print comprehensive summary statistics 
    print(f"{'-'*60}")
    print(f"\nALPHA = {alpha}")
    print(f"\nStatistics for α={alpha}, Cell {cell_idx}:")
    print(f"Overall sparsity: {results['total_zeros']}/{n_nodes_7x7*n_nodes_7x7} ({results['sparsity_percent']:.1f}%)")
    print(f"Average MSE: {results['avg_mse']:.2e}")
    print(f"Average R²: {results['avg_r2']:.4f}")
    print(f"\nOriginal 7x7 Weight Matrix:")
    print(_fmt_arr(W_original_7x7, 2))
    print(f"\nSparse Weight Matrix:")
    print(_fmt_arr(results['W_sparse'], 2))
    print("\nOriginal test_state_7x7:")
    print(_fmt_arr(test_state_7x7, 2))
    print("\nLASSO test_state after all-rows LASSO:")
    print(_fmt_arr(results['lasso_test_state'], 2))


def plot_all_rows_lasso_per_row_alpha(cell_idx=0, alpha_0=0.1, alpha_1=0.1, alpha_2=0.1, alpha_3=0.1, alpha_4=0.1, alpha_5=0.1, alpha_6=0.1, max_rows_display=7):
    """Plot LASSO results with per-row alpha control."""
    # Collect alpha values for each row
    alphas = [alpha_0, alpha_1, alpha_2, alpha_3, alpha_4, alpha_5, alpha_6]
    
    # Compute results with per-row alphas
    results = run_lasso_all_rows_7x7(alphas, verbose=False)
    
    n_timepoints = len(t_sim_7x7)
    start = cell_idx * n_timepoints
    end = (cell_idx + 1) * n_timepoints
    
    # Create subplots - adjust grid based on number of rows
    n_rows_plot = min(max_rows_display, n_nodes_7x7)
    cols = 3 if n_rows_plot > 6 else 2
    rows = int(np.ceil(n_rows_plot / cols))
    
    fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 4*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)
    if cols == 1:
        axes = axes.reshape(-1, 1)
    
    for row_idx in range(n_rows_plot):
        ax = axes[row_idx // cols, row_idx % cols]
        
        # Get data for this row
        y_sample = y_all_by_row[row_idx][start:end]
        y_pred_sample = results['row_results'][row_idx]['y_pred'][start:end]
        
        ax.plot(t_sim_7x7, y_sample, 'b-', label='Original', linewidth=2)
        ax.plot(t_sim_7x7, y_pred_sample, 'r--', label='LASSO', linewidth=2)
        
        row_res = results['row_results'][row_idx]
        ax.set_title(
            f'Row {row_idx}: α={alphas[row_idx]:.3g}\n'
            f'{row_res["n_zeros"]}/{n_nodes_7x7} zeros, R²: {row_res["r2"]:.3f}'
        )
        ax.set_xlabel('Time')
        ax.set_ylabel('y(t)')
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for i in range(n_rows_plot, rows * cols):
        axes[i // cols, i % cols].set_visible(False)
    
    plt.suptitle(f'Cell {cell_idx} - Per-Row Alpha Control\nOverall Sparsity: {results["sparsity_percent"]:.1f}%', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    # Print comprehensive summary statistics 
    print(f"{'-'*80}")
    print(f"\nPER-ROW ALPHA CONTROL")
    print(f"\nStatistics for Cell {cell_idx}:")
    print(f"Alpha values: {[f'{a:.3g}' for a in alphas]}")
    print(f"Overall sparsity: {results['total_zeros']}/{n_nodes_7x7*n_nodes_7x7} ({results['sparsity_percent']:.1f}%)")
    print(f"Average MSE: {results['avg_mse']:.2e}")
    print(f"Average R²: {results['avg_r2']:.4f}")
    print(f"\nPer-row statistics:")
    for row_idx in range(n_nodes_7x7):
        row_res = results['row_results'][row_idx]
        print(f"  Row {row_idx}: α={alphas[row_idx]:.3g}, {row_res['n_zeros']:2d}/{n_nodes_7x7} zeros, MSE={row_res['mse']:.2e}, R²={row_res['r2']:.4f}")
    
    print(f"\nOriginal 7x7 Weight Matrix:")
    print(_fmt_arr(W_original_7x7, 2))
    print(f"\nSparse Weight Matrix (per-row α):")
    print(_fmt_arr(results['W_sparse'], 2))
    print("\nOriginal test_state_7x7:")
    print(_fmt_arr(test_state_7x7, 2))
    print("\nLASSO test_state after per-row LASSO:")
    print(_fmt_arr(results['lasso_test_state'], 2))










# # Interactive widget for ALL-ROW alpha control
# cell_slider_all = widgets.IntSlider(value=6, min=0, max=N_CELLS-1, step=1, description='Cell:')
# alpha_slider_all = widgets.FloatLogSlider(value=1.0, base=10, min=-3, max=1, step=0.01, description='Alpha:', continuous_update=False, readout_format='.4g')
# print("\nInteractive plot for ALL matrix rows (single alpha):")
# out_all_rows = widgets.interactive_output(
#     plot_all_rows_lasso, 
#     {'cell_idx': cell_slider_all, 'alpha': alpha_slider_all}
# )
# display(widgets.HBox([cell_slider_all, alpha_slider_all]), out_all_rows)









# Interactive widget for PER-ROW alpha control
print("\n" + "="*80)
print("🎯 PER-ROW ALPHA CONTROL - Each row gets its own alpha parameter!")
print("="*80)

# Create cell slider for per-row control
cell_slider_per_row = widgets.IntSlider(value=6, min=0, max=N_CELLS-1, step=1, description='Cell:')

# Set default alpha values for each row (customize as needed)
default_alpha_values = [1.0, 0.0, 0.05, 0.0, 0.0, 0.0873, 0.9]  # Example: 7 different defaults

# Create individual alpha sliders for each row, each with its own default value
alpha_sliders = []
for i in range(n_nodes_7x7):
    slider = widgets.FloatLogSlider(
        value=default_alpha_values[i] if i < len(default_alpha_values) else 0.0,  # Use per-row default, fallback to 0.01
        base=10,
        min=-7,  # 0.001
        max=1,   # 10
        step=0.001,
        description=f'α{i}:',
        continuous_update=False,
        readout_format='.3g',
        style={'description_width': '30px'},
        layout=widgets.Layout(width='200px')
    )
    alpha_sliders.append(slider)

# Create the interactive output
out_per_row = widgets.interactive_output(
    plot_all_rows_lasso_per_row_alpha,
    {
        'cell_idx': cell_slider_per_row,
        'alpha_0': alpha_sliders[0],
        'alpha_1': alpha_sliders[1], 
        'alpha_2': alpha_sliders[2],
        'alpha_3': alpha_sliders[3],
        'alpha_4': alpha_sliders[4],
        'alpha_5': alpha_sliders[5],
        'alpha_6': alpha_sliders[6]
    }
)

# Display the widgets in a nice layout
print("\nPer-row alpha control interface:")
control_box = widgets.VBox([
    widgets.HBox([cell_slider_per_row]),
    widgets.HTML("<b>Alpha values for each row:</b>"),
    widgets.HBox(alpha_sliders[:4]),  # First 4 sliders 
    widgets.HBox(alpha_sliders[4:])   # Last 3 sliders
])
display(control_box, out_per_row)



🎯 PER-ROW ALPHA CONTROL - Each row gets its own alpha parameter!



Per-row alpha control interface:


VBox(children=(HBox(children=(IntSlider(value=6, description='Cell:', max=99),)), HTML(value='<b>Alpha values …

Output()

In [None]:
# Interactive VISUALIZATION FUNCTIONS FOR ALL ROWS
def plot_all_rows_lasso(cell_idx=0, alpha=0.01, max_rows_display=7):
    """Plot LASSO results for all rows of the matrix, recomputing if needed, and print a detailed summary."""
    # Always recompute results for the given alpha using the defined function
    results = run_lasso_all_rows_7x7(alpha, verbose=False)
    
    n_timepoints = len(t_sim_7x7)
    start = cell_idx * n_timepoints
    end = (cell_idx + 1) * n_timepoints
    
    # Create subplots - adjust grid based on number of rows
    n_rows_plot = min(max_rows_display, n_nodes_7x7)
    cols = 3 if n_rows_plot > 6 else 2
    rows = int(np.ceil(n_rows_plot / cols))
    
    fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 4*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)
    if cols == 1:
        axes = axes.reshape(-1, 1)
    
    for row_idx in range(n_rows_plot):
        ax = axes[row_idx // cols, row_idx % cols]
        
        # Get data for this row
        y_sample = y_all_by_row[row_idx][start:end]
        y_pred_sample = results['row_results'][row_idx]['y_pred'][start:end]
        
        ax.plot(t_sim_7x7, y_sample, 'b-', label='Original', linewidth=2)
        ax.plot(t_sim_7x7, y_pred_sample, 'r--', label='LASSO', linewidth=2)
        
        row_res = results['row_results'][row_idx]
        ax.set_title(
            f'Row {row_idx}: {row_res["n_zeros"]}/{n_nodes_7x7} zeros\n'
            f'R²: {row_res["r2"]:.3f}'
        )
        ax.set_xlabel('Time')
        ax.set_ylabel('y(t)')
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for i in range(n_rows_plot, rows * cols):
        axes[i // cols, i % cols].set_visible(False)
    
    plt.suptitle(f'Cell {cell_idx}, α={alpha:.4g} - All Matrix Rows\nOverall Sparsity: {results["sparsity_percent"]:.1f}%', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    # Print comprehensive summary statistics 
    print(f"{'-'*60}")
    print(f"\nALPHA = {alpha}")
    print(f"\nStatistics for α={alpha}, Cell {cell_idx}:")
    print(f"Overall sparsity: {results['total_zeros']}/{n_nodes_7x7*n_nodes_7x7} ({results['sparsity_percent']:.1f}%)")
    print(f"Average MSE: {results['avg_mse']:.2e}")
    print(f"Average R²: {results['avg_r2']:.4f}")
    print(f"\nOriginal 7x7 Weight Matrix:")
    print(_fmt_arr(W_original_7x7, 2))
    print(f"\nSparse Weight Matrix:")
    print(_fmt_arr(results['W_sparse'], 2))
    print("\nOriginal test_state_7x7:")
    print(_fmt_arr(test_state_7x7, 2))
    print("\nLASSO test_state after all-rows LASSO:")
    print(_fmt_arr(results['lasso_test_state'], 2))


def plot_all_rows_lasso_per_row_alpha(cell_idx=0, alpha_0=0.1, alpha_1=0.1, alpha_2=0.1, alpha_3=0.1, alpha_4=0.1, alpha_5=0.1, alpha_6=0.1, max_rows_display=7):
    """Plot LASSO results with per-row alpha control."""
    # Collect alpha values for each row
    alphas = [alpha_0, alpha_1, alpha_2, alpha_3, alpha_4, alpha_5, alpha_6]
    
    # Compute results with per-row alphas
    results = run_lasso_all_rows_7x7(alphas, verbose=False)
    
    n_timepoints = len(t_sim_7x7)
    start = cell_idx * n_timepoints
    end = (cell_idx + 1) * n_timepoints
    
    # Create subplots - adjust grid based on number of rows
    n_rows_plot = min(max_rows_display, n_nodes_7x7)
    cols = 3 if n_rows_plot > 6 else 2
    rows = int(np.ceil(n_rows_plot / cols))
    
    fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 4*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)
    if cols == 1:
        axes = axes.reshape(-1, 1)
    
    for row_idx in range(n_rows_plot):
        ax = axes[row_idx // cols, row_idx % cols]
        
        # Get data for this row
        y_sample = y_all_by_row[row_idx][start:end]
        y_pred_sample = results['row_results'][row_idx]['y_pred'][start:end]
        
        ax.plot(t_sim_7x7, y_sample, 'b-', label='Original', linewidth=2)
        ax.plot(t_sim_7x7, y_pred_sample, 'r--', label='LASSO', linewidth=2)
        
        row_res = results['row_results'][row_idx]
        ax.set_title(
            f'Row {row_idx}: α={alphas[row_idx]:.3g}\n'
            f'{row_res["n_zeros"]}/{n_nodes_7x7} zeros, R²: {row_res["r2"]:.3f}'
        )
        ax.set_xlabel('Time')
        ax.set_ylabel('y(t)')
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for i in range(n_rows_plot, rows * cols):
        axes[i // cols, i % cols].set_visible(False)
    
    plt.suptitle(f'Cell {cell_idx} - Per-Row Alpha Control\nOverall Sparsity: {results["sparsity_percent"]:.1f}%', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    # Print comprehensive summary statistics 
    print(f"{'-'*80}")
    print(f"\nPER-ROW ALPHA CONTROL")
    print(f"\nStatistics for Cell {cell_idx}:")
    print(f"Alpha values: {[f'{a:.3g}' for a in alphas]}")
    print(f"Overall sparsity: {results['total_zeros']}/{n_nodes_7x7*n_nodes_7x7} ({results['sparsity_percent']:.1f}%)")
    print(f"Average MSE: {results['avg_mse']:.2e}")
    print(f"Average R²: {results['avg_r2']:.4f}")
    print(f"\nPer-row statistics:")
    for row_idx in range(n_nodes_7x7):
        row_res = results['row_results'][row_idx]
        print(f"  Row {row_idx}: α={alphas[row_idx]:.3g}, {row_res['n_zeros']:2d}/{n_nodes_7x7} zeros, MSE={row_res['mse']:.2e}, R²={row_res['r2']:.4f}")
    
    print(f"\nOriginal 7x7 Weight Matrix:")
    print(_fmt_arr(W_original_7x7, 2))
    print(f"\nSparse Weight Matrix (per-row α):")
    print(_fmt_arr(results['W_sparse'], 2))
    print("\nOriginal test_state_7x7:")
    print(_fmt_arr(test_state_7x7, 2))
    print("\nLASSO test_state after per-row LASSO:")
    print(_fmt_arr(results['lasso_test_state'], 2))










# # Interactive widget for ALL-ROW alpha control
# cell_slider_all = widgets.IntSlider(value=6, min=0, max=N_CELLS-1, step=1, description='Cell:')
# alpha_slider_all = widgets.FloatLogSlider(value=1.0, base=10, min=-3, max=1, step=0.01, description='Alpha:', continuous_update=False, readout_format='.4g')
# print("\nInteractive plot for ALL matrix rows (single alpha):")
# out_all_rows = widgets.interactive_output(
#     plot_all_rows_lasso, 
#     {'cell_idx': cell_slider_all, 'alpha': alpha_slider_all}
# )
# display(widgets.HBox([cell_slider_all, alpha_slider_all]), out_all_rows)









# Interactive widget for PER-ROW alpha control
print("\n" + "="*80)
print("🎯 PER-ROW ALPHA CONTROL - Each row gets its own alpha parameter!")
print("="*80)

# Create cell slider for per-row control
cell_slider_per_row = widgets.IntSlider(value=6, min=0, max=N_CELLS-1, step=1, description='Cell:')

# Set default alpha values for each row (customize as needed)
default_alpha_values = [1.0, 0.0, 0.05, 0.0, 0.0, 0.0873, 0.9]  # Example: 7 different defaults

# Create individual alpha sliders for each row, each with its own default value
alpha_sliders = []
for i in range(n_nodes_7x7):
    slider = widgets.FloatLogSlider(
        value=default_alpha_values[i] if i < len(default_alpha_values) else 0.0,  # Use per-row default, fallback to 0.01
        base=10,
        min=-7,  # 0.001
        max=1,   # 10
        step=0.001,
        description=f'α{i}:',
        continuous_update=False,
        readout_format='.3g',
        style={'description_width': '30px'},
        layout=widgets.Layout(width='200px')
    )
    alpha_sliders.append(slider)

# Create the interactive output
out_per_row = widgets.interactive_output(
    plot_all_rows_lasso_per_row_alpha,
    {
        'cell_idx': cell_slider_per_row,
        'alpha_0': alpha_sliders[0],
        'alpha_1': alpha_sliders[1], 
        'alpha_2': alpha_sliders[2],
        'alpha_3': alpha_sliders[3],
        'alpha_4': alpha_sliders[4],
        'alpha_5': alpha_sliders[5],
        'alpha_6': alpha_sliders[6]
    }
)

# Display the widgets in a nice layout
print("\nPer-row alpha control interface:")
control_box = widgets.VBox([
    widgets.HBox([cell_slider_per_row]),
    widgets.HTML("<b>Alpha values for each row:</b>"),
    widgets.HBox(alpha_sliders[:4]),  # First 4 sliders 
    widgets.HBox(alpha_sliders[4:])   # Last 3 sliders
])
display(control_box, out_per_row)



🎯 PER-ROW ALPHA CONTROL - Each row gets its own alpha parameter!



Per-row alpha control interface:


VBox(children=(HBox(children=(IntSlider(value=6, description='Cell:', max=99),)), HTML(value='<b>Alpha values …

Output()

In [None]:
# Interactive VISUALIZATION FUNCTIONS FOR ALL ROWS
def plot_all_rows_lasso(cell_idx=0, alpha=0.01, max_rows_display=7):
    """Plot LASSO results for all rows of the matrix, recomputing if needed, and print a detailed summary."""
    # Always recompute results for the given alpha using the defined function
    results = run_lasso_all_rows_7x7(alpha, verbose=False)
    
    n_timepoints = len(t_sim_7x7)
    start = cell_idx * n_timepoints
    end = (cell_idx + 1) * n_timepoints
    
    # Create subplots - adjust grid based on number of rows
    n_rows_plot = min(max_rows_display, n_nodes_7x7)
    cols = 3 if n_rows_plot > 6 else 2
    rows = int(np.ceil(n_rows_plot / cols))
    
    fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 4*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)
    if cols == 1:
        axes = axes.reshape(-1, 1)
    
    for row_idx in range(n_rows_plot):
        ax = axes[row_idx // cols, row_idx % cols]
        
        # Get data for this row
        y_sample = y_all_by_row[row_idx][start:end]
        y_pred_sample = results['row_results'][row_idx]['y_pred'][start:end]
        
        ax.plot(t_sim_7x7, y_sample, 'b-', label='Original', linewidth=2)
        ax.plot(t_sim_7x7, y_pred_sample, 'r--', label='LASSO', linewidth=2)
        
        row_res = results['row_results'][row_idx]
        ax.set_title(
            f'Row {row_idx}: {row_res["n_zeros"]}/{n_nodes_7x7} zeros\n'
            f'R²: {row_res["r2"]:.3f}'
        )
        ax.set_xlabel('Time')
        ax.set_ylabel('y(t)')
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for i in range(n_rows_plot, rows * cols):
        axes[i // cols, i % cols].set_visible(False)
    
    plt.suptitle(f'Cell {cell_idx}, α={alpha:.4g} - All Matrix Rows\nOverall Sparsity: {results["sparsity_percent"]:.1f}%', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    # Print comprehensive summary statistics 
    print(f"{'-'*60}")
    print(f"\nALPHA = {alpha}")
    print(f"\nStatistics for α={alpha}, Cell {cell_idx}:")
    print(f"Overall sparsity: {results['total_zeros']}/{n_nodes_7x7*n_nodes_7x7} ({results['sparsity_percent']:.1f}%)")
    print(f"Average MSE: {results['avg_mse']:.2e}")
    print(f"Average R²: {results['avg_r2']:.4f}")
    print(f"\nOriginal 7x7 Weight Matrix:")
    print(_fmt_arr(W_original_7x7, 2))
    print(f"\nSparse Weight Matrix:")
    print(_fmt_arr(results['W_sparse'], 2))
    print("\nOriginal test_state_7x7:")
    print(_fmt_arr(test_state_7x7, 2))
    print("\nLASSO test_state after all-rows LASSO:")
    print(_fmt_arr(results['lasso_test_state'], 2))


def plot_all_rows_lasso_per_row_alpha(cell_idx=0, alpha_0=0.1, alpha_1=0.1, alpha_2=0.1, alpha_3=0.1, alpha_4=0.1, alpha_5=0.1, alpha_6=0.1, max_rows_display=7):
    """Plot LASSO results with per-row alpha control."""
    # Collect alpha values for each row
    alphas = [alpha_0, alpha_1, alpha_2, alpha_3, alpha_4, alpha_5, alpha_6]
    
    # Compute results with per-row alphas
    results = run_lasso_all_rows_7x7(alphas, verbose=False)
    
    n_timepoints = len(t_sim_7x7)
    start = cell_idx * n_timepoints
    end = (cell_idx + 1) * n_timepoints
    
    # Create subplots - adjust grid based on number of rows
    n_rows_plot = min(max_rows_display, n_nodes_7x7)
    cols = 3 if n_rows_plot > 6 else 2
    rows = int(np.ceil(n_rows_plot / cols))
    
    fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 4*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)
    if cols == 1:
        axes = axes.reshape(-1, 1)
    
    for row_idx in range(n_rows_plot):
        ax = axes[row_idx // cols, row_idx % cols]
        
        # Get data for this row
        y_sample = y_all_by_row[row_idx][start:end]
        y_pred_sample = results['row_results'][row_idx]['y_pred'][start:end]
        
        ax.plot(t_sim_7x7, y_sample, 'b-', label='Original', linewidth=2)
        ax.plot(t_sim_7x7, y_pred_sample, 'r--', label='LASSO', linewidth=2)
        
        row_res = results['row_results'][row_idx]
        ax.set_title(
            f'Row {row_idx}: α={alphas[row_idx]:.3g}\n'
            f'{row_res["n_zeros"]}/{n_nodes_7x7} zeros, R²: {row_res["r2"]:.3f}'
        )
        ax.set_xlabel('Time')
        ax.set_ylabel('y(t)')
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for i in range(n_rows_plot, rows * cols):
        axes[i // cols, i % cols].set_visible(False)
    
    plt.suptitle(f'Cell {cell_idx} - Per-Row Alpha Control\nOverall Sparsity: {results["sparsity_percent"]:.1f}%', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    # Print comprehensive summary statistics 
    print(f"{'-'*80}")
    print(f"\nPER-ROW ALPHA CONTROL")
    print(f"\nStatistics for Cell {cell_idx}:")
    print(f"Alpha values: {[f'{a:.3g}' for a in alphas]}")
    print(f"Overall sparsity: {results['total_zeros']}/{n_nodes_7x7*n_nodes_7x7} ({results['sparsity_percent']:.1f}%)")
    print(f"Average MSE: {results['avg_mse']:.2e}")
    print(f"Average R²: {results['avg_r2']:.4f}")
    print(f"\nPer-row statistics:")
    for row_idx in range(n_nodes_7x7):
        row_res = results['row_results'][row_idx]
        print(f"  Row {row_idx}: α={alphas[row_idx]:.3g}, {row_res['n_zeros']:2d}/{n_nodes_7x7} zeros, MSE={row_res['mse']:.2e}, R²={row_res['r2']:.4f}")
    
    print(f"\nOriginal 7x7 Weight Matrix:")
    print(_fmt_arr(W_original_7x7, 2))
    print(f"\nSparse Weight Matrix (per-row α):")
    print(_fmt_arr(results['W_sparse'], 2))
    print("\nOriginal test_state_7x7:")
    print(_fmt_arr(test_state_7x7, 2))
    print("\nLASSO test_state after per-row LASSO:")
    print(_fmt_arr(results['lasso_test_state'], 2))










# # Interactive widget for ALL-ROW alpha control
# cell_slider_all = widgets.IntSlider(value=6, min=0, max=N_CELLS-1, step=1, description='Cell:')
# alpha_slider_all = widgets.FloatLogSlider(value=1.0, base=10, min=-3, max=1, step=0.01, description='Alpha:', continuous_update=False, readout_format='.4g')
# print("\nInteractive plot for ALL matrix rows (single alpha):")
# out_all_rows = widgets.interactive_output(
#     plot_all_rows_lasso, 
#     {'cell_idx': cell_slider_all, 'alpha': alpha_slider_all}
# )
# display(widgets.HBox([cell_slider_all, alpha_slider_all]), out_all_rows)









# Interactive widget for PER-ROW alpha control
print("\n" + "="*80)
print("🎯 PER-ROW ALPHA CONTROL - Each row gets its own alpha parameter!")
print("="*80)

# Create cell slider for per-row control
cell_slider_per_row = widgets.IntSlider(value=6, min=0, max=N_CELLS-1, step=1, description='Cell:')

# Set default alpha values for each row (customize as needed)
default_alpha_values = [1.0, 0.0, 0.05, 0.0, 0.0, 0.0873, 0.9]  # Example: 7 different defaults

# Create individual alpha sliders for each row, each with its own default value
alpha_sliders = []
for i in range(n_nodes_7x7):
    slider = widgets.FloatLogSlider(
        value=default_alpha_values[i] if i < len(default_alpha_values) else 0.0,  # Use per-row default, fallback to 0.01
        base=10,
        min=-7,  # 0.001
        max=1,   # 10
        step=0.001,
        description=f'α{i}:',
        continuous_update=False,
        readout_format='.3g',
        style={'description_width': '30px'},
        layout=widgets.Layout(width='200px')
    )
    alpha_sliders.append(slider)

# Create the interactive output
out_per_row = widgets.interactive_output(
    plot_all_rows_lasso_per_row_alpha,
    {
        'cell_idx': cell_slider_per_row,
        'alpha_0': alpha_sliders[0],
        'alpha_1': alpha_sliders[1], 
        'alpha_2': alpha_sliders[2],
        'alpha_3': alpha_sliders[3],
        'alpha_4': alpha_sliders[4],
        'alpha_5': alpha_sliders[5],
        'alpha_6': alpha_sliders[6]
    }
)

# Display the widgets in a nice layout
print("\nPer-row alpha control interface:")
control_box = widgets.VBox([
    widgets.HBox([cell_slider_per_row]),
    widgets.HTML("<b>Alpha values for each row:</b>"),
    widgets.HBox(alpha_sliders[:4]),  # First 4 sliders 
    widgets.HBox(alpha_sliders[4:])   # Last 3 sliders
])
display(control_box, out_per_row)



🎯 PER-ROW ALPHA CONTROL - Each row gets its own alpha parameter!



Per-row alpha control interface:


VBox(children=(HBox(children=(IntSlider(value=6, description='Cell:', max=99),)), HTML(value='<b>Alpha values …

Output()

In [12]:
# 7x7 system from test_grid.ipynb
# test_state_7x7 = [165, -120, -75, 175, 155, -185, 200, -165, 120, -110, 20, -105, -15, -55, 200, 160, 5, -15, -10, 160, 105, 55, 100, -150, 155, -150, -155, 55, 55, 5, -5, 10, -100, 0, 10, 50, -50, 50, 5, -5, -5, 50, 10, 50, 50, 0, 0, -50, 5, -200, 175, 125, -130, -50, 50, -5]
test_state_7x7 = [165, -120, -75, 175, 155, -185, 200, -180, 120, -110, 20, -105, -15, -55, 200, 160, 5, -15, -10, 160, 105, 55, 100, -145, 155, -150, -155, 55, 55, 5, -5, 10, -100, 0, 10, 50, -50, 50, 5, -5, -5, 50, 10, 50, 50, 0, 0, -50, 5, -200, 175, 125, -130, -50, 50, -5]

# Calculate system parameters for 7x7
n_nodes_7x7 = int((-1 + (1 + 4*len(test_state_7x7))**0.5) / 2)
n_weights_7x7 = n_nodes_7x7 * n_nodes_7x7

print(f"7x7 System: {n_nodes_7x7} nodes, {n_weights_7x7} weights")
print(f"State length: {len(test_state_7x7)}")

# Get original weight matrix and first row
W_original_7x7 = weights_to_matrix(test_state_7x7[:n_weights_7x7])
w_original_7x7 = W_original_7x7[0, :]  # First row

print(f"Original 7x7 weight matrix:")
print(W_original_7x7)
print(f"\nFirst row weights: {w_original_7x7}")
print(f"Number of zeros in original: {np.sum(w_original_7x7 == 0)}")


7x7 System: 7 nodes, 49 weights
State length: 56
Original 7x7 weight matrix:
[[ 165  175  200 -105  -10   55    5]
 [ -75 -120  120  -55  105    5   -5]
 [-185 -180  155  160  100   10   10]
 [  20  -15  200 -110  155    0   50]
 [ -15  160   55 -145    5   50    0]
 [-155   55   -5 -100   10 -150    5]
 [  50   -5   50   50    0  -50  -50]]

First row weights: [ 165  175  200 -105  -10   55    5]
Number of zeros in original: 0


In [13]:
# Generate trajectories for ALL 100 cell positions and ALL matrix rows
print("Generating trajectories for all 100 cell positions and all matrix rows...")
N_CELLS = 100

# Storage for all data - now we need data for each row
all_X_data = []
all_y_data_by_row = [[] for _ in range(n_nodes_7x7)]  # List of lists, one per row
cell_positions = []

start_time = time.time()

for cell_pos in range(N_CELLS):
    # Get trajectories for this cell position
    t_sim_7x7, cell_trajectory_7x7, _ = somitogenesis_sol_func(test_state_7x7, cell_position=cell_pos)
    
    # Calculate weighted sum for ALL rows of this cell
    for row_idx in range(n_nodes_7x7):
        w_row = W_original_7x7[row_idx, :]  # Current row
        y_cell_row = cell_trajectory_7x7 @ w_row
        all_y_data_by_row[row_idx].append(y_cell_row)
    
    # Store X data (same for all rows)
    all_X_data.append(cell_trajectory_7x7)
    cell_positions.append(cell_pos)
    
    if (cell_pos + 1) % 20 == 0:
        print(f"Processed {cell_pos + 1}/{N_CELLS} cells...")

end_time = time.time()
print(f"Data generation completed in {end_time - start_time:.2f} seconds")

# Stack all data for regression
X_all = np.vstack(all_X_data)  # [N_CELLS * time_points, n_genes]
y_all_by_row = [np.concatenate(y_row_data) for y_row_data in all_y_data_by_row]  # List of [N_CELLS * time_points] arrays

print(f"\nStacked data shape: X_all={X_all.shape}")
for row_idx in range(n_nodes_7x7):
    print(f"Row {row_idx}: y_shape={len(y_all_by_row[row_idx])}")
    original_mse = mean_squared_error(y_all_by_row[row_idx], X_all @ W_original_7x7[row_idx, :])
    print(f"Row {row_idx} original reconstruction MSE: {original_mse:.2e}")

# Keep original y_all for backward compatibility (first row)
y_all = y_all_by_row[0]
w_original_7x7 = W_original_7x7[0, :]  # First row for compatibility


Generating trajectories for all 100 cell positions and all matrix rows...
Processed 20/100 cells...
Processed 40/100 cells...
Processed 60/100 cells...
Processed 80/100 cells...
Processed 100/100 cells...
Data generation completed in 8.49 seconds

Stacked data shape: X_all=(20000, 7)
Row 0: y_shape=20000
Row 0 original reconstruction MSE: 0.00e+00
Row 1: y_shape=20000
Row 1 original reconstruction MSE: 0.00e+00
Row 2: y_shape=20000
Row 2 original reconstruction MSE: 0.00e+00
Row 3: y_shape=20000
Row 3 original reconstruction MSE: 0.00e+00
Row 4: y_shape=20000
Row 4 original reconstruction MSE: 0.00e+00
Row 5: y_shape=20000
Row 5 original reconstruction MSE: 0.00e+00
Row 6: y_shape=20000
Row 6 original reconstruction MSE: 0.00e+00


In [14]:
# Helpers

def reconstruct_test_state_with_all_lasso_weights(W_lasso_all_rows, original_test_state, n_nodes):
    """
    Reconstruct test_state by replacing ALL rows of weight matrix with LASSO weights
    and keeping everything else the same.
    """
    # Get original weight matrix
    n_weights = n_nodes * n_nodes
    original_weights = original_test_state[:n_weights]
    W_original = weights_to_matrix(original_weights)
    
    # Replace ALL rows with LASSO weights
    W_new = W_lasso_all_rows.copy()  # Use the entire sparse matrix
    
    # Convert back to flattened weight vector format
    new_weights = matrix_to_weights(W_new)
    
    # Reconstruct full test_state (weights + d_values)
    new_test_state = original_test_state.copy()
    new_test_state[:n_weights] = new_weights
    
    return new_test_state


def reconstruct_test_state_with_lasso_weights(w_lasso, original_test_state, n_nodes):
    """
    Reconstruct test_state by replacing the first row of weight matrix with LASSO weights
    and keeping everything else the same.
    """
    # Get original weight matrix
    n_weights = n_nodes * n_nodes
    original_weights = original_test_state[:n_weights]
    W_original = weights_to_matrix(original_weights)
    
    # Replace first row with LASSO weights
    W_new = W_original.copy()
    W_new[0, :] = w_lasso  # Replace first row
    
    # Convert back to flattened weight vector format
    # We need to reverse the weights_to_matrix transformation
    new_weights = matrix_to_weights(W_new)
    
    # Reconstruct full test_state (weights + d_values)
    new_test_state = original_test_state.copy()
    new_test_state[:n_weights] = new_weights
    
    return new_test_state

def matrix_to_weights(W_matrix):
    """
    Convert weight matrix back to flattened weights vector (reverse of weights_to_matrix)
    """
    n_nodes = W_matrix.shape[0]
    
    if n_nodes == 1:
        return [W_matrix[0, 0]]
    
    if n_nodes == 2:
        # For 2x2: [w1,w2,w3,w4] -> [[w1,w4],[w3,w2]]
        return [W_matrix[0,0], W_matrix[1,1], W_matrix[1,0], W_matrix[0,1]]
    
    # For larger matrices, we need to extract in the same order as weights_to_matrix builds them
    weights = []
    
    # First, add weights for the (n-1)x(n-1) submatrix recursively
    if n_nodes > 2:
        sub_matrix = W_matrix[:n_nodes-1, :n_nodes-1]
        weights.extend(matrix_to_weights(sub_matrix))
    
    # Add diagonal element for last node
    weights.append(W_matrix[n_nodes-1, n_nodes-1])
    
    # Add off-diagonal elements for last row and column
    for i in range(n_nodes-1):
        weights.append(W_matrix[n_nodes-1, i])  # Last row
        weights.append(W_matrix[i, n_nodes-1])  # Last column
    
    return weights


def _fmt_arr(arr, decimals=2):
    """Format a numpy array or list to string with limited decimals."""
    arr = np.asarray(arr)
    fmt = f"{{:.{decimals}f}}"
    if arr.ndim == 1:
        return "[" + ", ".join(fmt.format(x) for x in arr) + "]"
    elif arr.ndim == 2:
        return "[" + "\n ".join("[" + ", ".join(fmt.format(x) for x in row) + "]" for row in arr) + "]"
    else:
        # fallback for higher dims
        return np.array2string(arr, formatter={'float_kind':lambda x: fmt.format(x)})



In [15]:
### Sparse Regression on Combined Dataset by LASSO  (first ROW)

In [16]:
# LASSO analysis on 7x7 system (first row) with all cell data, with interactive cell selection and alpha control
print("Running LASSO analysis on 7x7 system (first row) ...")


# Print original weights before LASSO
print("Original weights (w_original_7x7):")
print(_fmt_arr(w_original_7x7, 2))
print("Original test_state_7x7:")
print(_fmt_arr(test_state_7x7, 2))


# Try different alpha values for 7x7 system
alphas_7x7 = [0.01, 0.1, 1.0, 5.0, 10.0] 
lasso_results_7x7 = {}

# Fit LASSO models and store results (do this once for efficiency)
for alpha in alphas_7x7:
    print(f"Fitting LASSO with α={alpha}...")
    
    # Fit LASSO
    lasso = Lasso(alpha=alpha, max_iter=10000)
    lasso.fit(X_all, y_all)
    
    w_lasso = lasso.coef_
    y_pred = X_all @ w_lasso

    # Reconstruct test_state after LASSO using provided function
    lasso_test_state = reconstruct_test_state_with_lasso_weights(
        w_lasso, test_state_7x7, n_nodes_7x7
    )

    # Print LASSO weights and reconstructed test_state after fitting, rounded to 2 decimals
    print(f"LASSO weights (alpha={alpha}):")
    print(_fmt_arr(w_lasso, 2))
    # print(f"test_state_7x7 after LASSO (alpha={alpha}):")
    # print(_fmt_arr(lasso_test_state, 2))
    
    # Store results
    lasso_results_7x7[alpha] = {
        'weights': w_lasso,
        'n_zeros': np.sum(np.abs(w_lasso) < 1e-6),
        'mse': mean_squared_error(y_all, y_pred),
        'r2': r2_score(y_all, y_pred),
        'y_pred': y_pred,
        'lasso_test_state': lasso_test_state
    }



# Interactive plotting function for fixed alphas
def plot_lasso_cell(cell_idx=0):
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()
    n_timepoints = len(t_sim_7x7)
    # Compute the slice for the selected cell
    start = cell_idx * n_timepoints
    end = (cell_idx + 1) * n_timepoints
    y_sample = y_all[start:end]
    for i, alpha in enumerate(alphas_7x7):
        y_pred = lasso_results_7x7[alpha]['y_pred']
        y_pred_sample = y_pred[start:end]
        axes[i].plot(t_sim_7x7, y_sample, 'b-', label='Original', linewidth=2)
        axes[i].plot(t_sim_7x7, y_pred_sample, 'r--', label='LASSO', linewidth=2)
        axes[i].set_title(
            f'α={alpha}, Zeros: {lasso_results_7x7[alpha]["n_zeros"]}/{n_nodes_7x7}\n'
            f'R²: {lasso_results_7x7[alpha]["r2"]:.4f}'
        )
        axes[i].set_xlabel('Time')
        axes[i].set_ylabel('y(t)')
        axes[i].legend()
        axes[i].grid(True, alpha=0.3)
    # Remove extra subplot if needed
    if len(axes) > len(alphas_7x7):
        fig.delaxes(axes[-1])
    plt.tight_layout()
    plt.show()

# Interactive plotting function for controlled alpha
def plot_lasso_cell_alpha(cell_idx=0, alpha=0.01):
    # Fit LASSO for the given alpha (if not already fit)
    lasso = Lasso(alpha=alpha, max_iter=10000)
    lasso.fit(X_all, y_all)
    w_lasso = lasso.coef_
    y_pred = X_all @ w_lasso
    n_zeros = np.sum(np.abs(w_lasso) < 1e-6)
    mse = mean_squared_error(y_all, y_pred)
    r2 = r2_score(y_all, y_pred)
    n_timepoints = len(t_sim_7x7)
    start = cell_idx * n_timepoints
    end = (cell_idx + 1) * n_timepoints
    y_sample = y_all[start:end]
    y_pred_sample = y_pred[start:end]
    # Reconstruct test_state after LASSO using provided function
    lasso_test_state = reconstruct_test_state_with_lasso_weights(
        w_lasso, test_state_7x7, n_nodes_7x7
    )
    # Print weights and reconstructed test_state for this alpha, rounded to 2 decimals
    print(f"\nOriginal weights (w_original_7x7):")
    print(_fmt_arr(w_original_7x7, 2))
    print(f"LASSO weights (alpha={alpha}):")
    print(_fmt_arr(w_lasso, 2))
    print("Original test_state_7x7:")
    print(_fmt_arr(test_state_7x7, 2))
    print(f"LASSO test_state_7x7 (alpha={alpha}):")
    print(_fmt_arr(lasso_test_state, 2))
    fig, ax = plt.subplots(figsize=(7, 5))
    ax.plot(t_sim_7x7, y_sample, 'b-', label='Original', linewidth=2)
    ax.plot(t_sim_7x7, y_pred_sample, 'r--', label='LASSO', linewidth=2)
    ax.set_title(
        f'α={alpha:.4g}, Zeros: {n_zeros}/{n_nodes_7x7}\n'
        f'R²: {r2:.4f}, MSE: {mse:.2e}'
    )
    ax.set_xlabel('Time')
    ax.set_ylabel('y(t)')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

# Create slider for cell index
cell_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=N_CELLS-1,
    step=1,
    description='Cell idx:',
    continuous_update=False
)

# Create slider for alpha (log scale)
alpha_slider = widgets.FloatLogSlider(
    value=0.01,
    base=10,
    min=-5,  # 1e-5
    max=2,   # 10^2 = 100
    step=0.01,
    description='Alpha:',
    continuous_update=False,
    readout_format='.4g'
)







# Display interactive widget for fixed alphas
print("\n\nInteractive plot for fixed alpha values:")
out_fixed = widgets.interactive_output(plot_lasso_cell, {'cell_idx': cell_slider})
display(widgets.HBox([cell_slider]), out_fixed)

# Display interactive widget for controlled alpha
print("\nInteractive plot for user-controlled alpha:")
out_alpha = widgets.interactive_output(plot_lasso_cell_alpha, {'cell_idx': cell_slider, 'alpha': alpha_slider})
display(widgets.HBox([cell_slider, alpha_slider]), out_alpha)

# Print results summary, including after-LASSO test_state for each alpha
print("\n7x7 LASSO Results Summary:")
print("Alpha\t\tZeros\t\tMSE\t\tR²\t\tSparsity")
print("-" * 70)
for alpha in alphas_7x7:
    res = lasso_results_7x7[alpha]
    sparsity = res['n_zeros'] / n_nodes_7x7 * 100
    print(f"{alpha:g}\t\t{res['n_zeros']}/{n_nodes_7x7}\t\t{res['mse']:.2e}\t{res['r2']:.4f}\t{sparsity:.1f}%")
    print(f"test_state_7x7 after LASSO (alpha={alpha}):")
    print(_fmt_arr(res['lasso_test_state'], 2))


Running LASSO analysis on 7x7 system (first row) ...
Original weights (w_original_7x7):
[165.00, 175.00, 200.00, -105.00, -10.00, 55.00, 5.00]
Original test_state_7x7:
[165.00, -120.00, -75.00, 175.00, 155.00, -185.00, 200.00, -180.00, 120.00, -110.00, 20.00, -105.00, -15.00, -55.00, 200.00, 160.00, 5.00, -15.00, -10.00, 160.00, 105.00, 55.00, 100.00, -145.00, 155.00, -150.00, -155.00, 55.00, 55.00, 5.00, -5.00, 10.00, -100.00, 0.00, 10.00, 50.00, -50.00, 50.00, 5.00, -5.00, -5.00, 50.00, 10.00, 50.00, 50.00, 0.00, 0.00, -50.00, 5.00, -200.00, 175.00, 125.00, -130.00, -50.00, 50.00, -5.00]
Fitting LASSO with α=0.01...
LASSO weights (alpha=0.01):
[164.80, 175.63, 197.01, -100.17, -9.87, 50.54, 3.16]
Fitting LASSO with α=0.1...
LASSO weights (alpha=0.1):
[160.14, 176.89, 178.66, -75.64, -0.00, 0.00, 0.00]
Fitting LASSO with α=1.0...
LASSO weights (alpha=1.0):
[150.68, 177.39, 110.91, 0.00, -0.00, 0.00, 0.00]
Fitting LASSO with α=5.0...
LASSO weights (alpha=5.0):
[138.60, 150.60, 101.32, 

HBox(children=(IntSlider(value=0, continuous_update=False, description='Cell idx:', max=99),))

Output()


Interactive plot for user-controlled alpha:


HBox(children=(IntSlider(value=0, continuous_update=False, description='Cell idx:', max=99), FloatLogSlider(va…

Output()


7x7 LASSO Results Summary:
Alpha		Zeros		MSE		R²		Sparsity
----------------------------------------------------------------------
0.01		0/7		1.82e-01	1.0000	0.0%
test_state_7x7 after LASSO (alpha=0.01):
[164.00, -120.00, -75.00, 175.00, 155.00, -185.00, 197.00, -180.00, 120.00, -110.00, 20.00, -100.00, -15.00, -55.00, 200.00, 160.00, 5.00, -15.00, -9.00, 160.00, 105.00, 55.00, 100.00, -145.00, 155.00, -150.00, -155.00, 50.00, 55.00, 5.00, -5.00, 10.00, -100.00, 0.00, 10.00, 50.00, -50.00, 50.00, 3.00, -5.00, -5.00, 50.00, 10.00, 50.00, 50.00, 0.00, 0.00, -50.00, 5.00, -200.00, 175.00, 125.00, -130.00, -50.00, 50.00, -5.00]
0.1		3/7		1.59e+01	0.9991	42.9%
test_state_7x7 after LASSO (alpha=0.1):
[160.00, -120.00, -75.00, 176.00, 155.00, -185.00, 178.00, -180.00, 120.00, -110.00, 20.00, -75.00, -15.00, -55.00, 200.00, 160.00, 5.00, -15.00, 0.00, 160.00, 105.00, 55.00, 100.00, -145.00, 155.00, -150.00, -155.00, 0.00, 55.00, 5.00, -5.00, 10.00, -100.00, 0.00, 10.00, 50.00, -50.00, 50.00, 0

In [17]:
# COMPREHENSIVE SPARSE MATRIX ANALYSIS (LASSO) - ALL ROWS

In [21]:

print("="*80)
print("COMPREHENSIVE SPARSE MATRIX ANALYSIS (LASSO)  - ALL ROWS")
print("="*80)

def run_lasso_all_rows_7x7(alphas, verbose=True):
    """Fit LASSO for all rows of the 7x7 system with per-row alpha values.
    
    Args:
        alphas: Either a single alpha value (applied to all rows) or a list/array of alpha values (one per row)
        verbose: Whether to print progress information
    """
    from sklearn.linear_model import LinearRegression  # Import here for the fix
    
    # Handle both single alpha and per-row alpha cases
    if np.isscalar(alphas):
        alphas = [alphas] * n_nodes_7x7
    else:
        alphas = list(alphas)
        if len(alphas) != n_nodes_7x7:
            raise ValueError(f"Number of alphas ({len(alphas)}) must match number of rows ({n_nodes_7x7})")
    
    W_lasso_all_rows = np.zeros((n_nodes_7x7, n_nodes_7x7))
    row_results = {}

    for row_idx in range(n_nodes_7x7):
        if verbose and (row_idx == 0 or row_idx % 2 == 0):
            print(f"  Processing row {row_idx} with α={alphas[row_idx]:.4g}...")

        # 🔧 FIX: Use LinearRegression for alpha=0 or very small alpha
        if alphas[row_idx] <= 1e-6:  # Threshold for "essentially zero"
            if verbose and alphas[row_idx] == 0:
                print(f"    Using LinearRegression (OLS) for α={alphas[row_idx]}")
            model = LinearRegression()
            model.fit(X_all, y_all_by_row[row_idx])
            w_lasso_row = model.coef_
            method_used = 'OLS'
        else:
            # Use LASSO for non-zero alpha
            model = Lasso(alpha=alphas[row_idx], max_iter=10000)
            model.fit(X_all, y_all_by_row[row_idx])
            w_lasso_row = model.coef_
            method_used = 'LASSO'

        y_pred_row = X_all @ w_lasso_row

        W_lasso_all_rows[row_idx, :] = w_lasso_row
        row_results[row_idx] = {
            'weights': w_lasso_row,
            'n_zeros': np.sum(np.abs(w_lasso_row) < 1e-6),
            'mse': mean_squared_error(y_all_by_row[row_idx], y_pred_row),
            'r2': r2_score(y_all_by_row[row_idx], y_pred_row),
            'y_pred': y_pred_row,
            'alpha_used': alphas[row_idx],
            'method_used': method_used
        }

    lasso_test_state_all = reconstruct_test_state_with_all_lasso_weights(
        W_lasso_all_rows, test_state_7x7, n_nodes_7x7
    )

    total_weights = n_nodes_7x7 * n_nodes_7x7
    total_zeros = np.sum(np.abs(W_lasso_all_rows) < 1e-6)
    avg_mse = np.mean([row_results[i]['mse'] for i in range(n_nodes_7x7)])
    avg_r2 = np.mean([row_results[i]['r2'] for i in range(n_nodes_7x7)])

    results = {
        'W_sparse': W_lasso_all_rows,
        'row_results': row_results,
        'total_zeros': total_zeros,
        'sparsity_percent': (total_zeros / total_weights) * 100,
        'avg_mse': avg_mse,
        'avg_r2': avg_r2,
        'lasso_test_state': lasso_test_state_all,
        'alphas_used': alphas
    }

    if verbose:
        alpha_str = f"α=[{', '.join(f'{a:.3g}' for a in alphas)}]"
        print(f"  {alpha_str}: {total_zeros}/{total_weights} zeros ({total_zeros/total_weights*100:.1f}% sparse)")
        print(f"  Avg MSE: {avg_mse:.2e}, Avg R²: {avg_r2:.4f}")

    return results

# LASSO analysis on 7x7 system - ALL ROWS
print("Running LASSO analysis on 7x7 system for ALL ROWS...")

alphas_all_rows = [0.01, 0.1, 1.0, 5.0, 10.0] 
lasso_results_all_rows_7x7 = {}

for alpha in alphas_all_rows:
    print(f"\nFitting LASSO with α={alpha} for all {n_nodes_7x7} rows...")
    lasso_results_all_rows_7x7[alpha] = run_lasso_all_rows_7x7(alpha, verbose=True)

print("\nComprehensive LASSO analysis completed!")


COMPREHENSIVE SPARSE MATRIX ANALYSIS (LASSO)  - ALL ROWS
Running LASSO analysis on 7x7 system for ALL ROWS...

Fitting LASSO with α=0.01 for all 7 rows...
  Processing row 0 with α=0.01...
  Processing row 2 with α=0.01...
  Processing row 4 with α=0.01...
  Processing row 6 with α=0.01...
  α=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01]: 1/49 zeros (2.0% sparse)
  Avg MSE: 2.34e-01, Avg R²: 0.9999

Fitting LASSO with α=0.1 for all 7 rows...
  Processing row 0 with α=0.1...
  Processing row 2 with α=0.1...
  Processing row 4 with α=0.1...
  Processing row 6 with α=0.1...
  α=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]: 10/49 zeros (20.4% sparse)
  Avg MSE: 1.06e+01, Avg R²: 0.9971

Fitting LASSO with α=1.0 for all 7 rows...
  Processing row 0 with α=1...
  Processing row 2 with α=1...
  Processing row 4 with α=1...
  Processing row 6 with α=1...
  α=[1, 1, 1, 1, 1, 1, 1]: 22/49 zeros (44.9% sparse)
  Avg MSE: 9.81e+01, Avg R²: 0.9749

Fitting LASSO with α=5.0 for all 7 rows...
  Processing row 0

In [35]:
# Interactive VISUALIZATION FUNCTIONS FOR ALL ROWS
def plot_all_rows_lasso(cell_idx=0, alpha=0.01, max_rows_display=7):
    """Plot LASSO results for all rows of the matrix, recomputing if needed, and print a detailed summary."""
    # Always recompute results for the given alpha using the defined function
    results = run_lasso_all_rows_7x7(alpha, verbose=False)
    
    n_timepoints = len(t_sim_7x7)
    start = cell_idx * n_timepoints
    end = (cell_idx + 1) * n_timepoints
    
    # Create subplots - adjust grid based on number of rows
    n_rows_plot = min(max_rows_display, n_nodes_7x7)
    cols = 3 if n_rows_plot > 6 else 2
    rows = int(np.ceil(n_rows_plot / cols))
    
    fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 4*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)
    if cols == 1:
        axes = axes.reshape(-1, 1)
    
    for row_idx in range(n_rows_plot):
        ax = axes[row_idx // cols, row_idx % cols]
        
        # Get data for this row
        y_sample = y_all_by_row[row_idx][start:end]
        y_pred_sample = results['row_results'][row_idx]['y_pred'][start:end]
        
        ax.plot(t_sim_7x7, y_sample, 'b-', label='Original', linewidth=2)
        ax.plot(t_sim_7x7, y_pred_sample, 'r--', label='LASSO', linewidth=2)
        
        row_res = results['row_results'][row_idx]
        ax.set_title(
            f'Row {row_idx}: {row_res["n_zeros"]}/{n_nodes_7x7} zeros\n'
            f'R²: {row_res["r2"]:.3f}'
        )
        ax.set_xlabel('Time')
        ax.set_ylabel('y(t)')
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for i in range(n_rows_plot, rows * cols):
        axes[i // cols, i % cols].set_visible(False)
    
    plt.suptitle(f'Cell {cell_idx}, α={alpha:.4g} - All Matrix Rows\nOverall Sparsity: {results["sparsity_percent"]:.1f}%', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    # Print comprehensive summary statistics 
    print(f"{'-'*60}")
    print(f"\nALPHA = {alpha}")
    print(f"\nStatistics for α={alpha}, Cell {cell_idx}:")
    print(f"Overall sparsity: {results['total_zeros']}/{n_nodes_7x7*n_nodes_7x7} ({results['sparsity_percent']:.1f}%)")
    print(f"Average MSE: {results['avg_mse']:.2e}")
    print(f"Average R²: {results['avg_r2']:.4f}")
    print(f"\nOriginal 7x7 Weight Matrix:")
    print(_fmt_arr(W_original_7x7, 2))
    print(f"\nSparse Weight Matrix:")
    print(_fmt_arr(results['W_sparse'], 2))
    print("\nOriginal test_state_7x7:")
    print(_fmt_arr(test_state_7x7, 2))
    print("\nLASSO test_state after all-rows LASSO:")
    print(_fmt_arr(results['lasso_test_state'], 2))


def plot_all_rows_lasso_per_row_alpha(cell_idx=0, alpha_0=0.1, alpha_1=0.1, alpha_2=0.1, alpha_3=0.1, alpha_4=0.1, alpha_5=0.1, alpha_6=0.1, max_rows_display=7):
    """Plot LASSO results with per-row alpha control."""
    # Collect alpha values for each row
    alphas = [alpha_0, alpha_1, alpha_2, alpha_3, alpha_4, alpha_5, alpha_6]
    
    # Compute results with per-row alphas
    results = run_lasso_all_rows_7x7(alphas, verbose=False)
    
    n_timepoints = len(t_sim_7x7)
    start = cell_idx * n_timepoints
    end = (cell_idx + 1) * n_timepoints
    
    # Create subplots - adjust grid based on number of rows
    n_rows_plot = min(max_rows_display, n_nodes_7x7)
    cols = 3 if n_rows_plot > 6 else 2
    rows = int(np.ceil(n_rows_plot / cols))
    
    fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 4*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)
    if cols == 1:
        axes = axes.reshape(-1, 1)
    
    for row_idx in range(n_rows_plot):
        ax = axes[row_idx // cols, row_idx % cols]
        
        # Get data for this row
        y_sample = y_all_by_row[row_idx][start:end]
        y_pred_sample = results['row_results'][row_idx]['y_pred'][start:end]
        
        ax.plot(t_sim_7x7, y_sample, 'b-', label='Original', linewidth=2)
        ax.plot(t_sim_7x7, y_pred_sample, 'r--', label='LASSO', linewidth=2)
        
        row_res = results['row_results'][row_idx]
        ax.set_title(
            f'Row {row_idx}: α={alphas[row_idx]:.3g}\n'
            f'{row_res["n_zeros"]}/{n_nodes_7x7} zeros, R²: {row_res["r2"]:.3f}'
        )
        ax.set_xlabel('Time')
        ax.set_ylabel('y(t)')
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for i in range(n_rows_plot, rows * cols):
        axes[i // cols, i % cols].set_visible(False)
    
    plt.suptitle(f'Cell {cell_idx} - Per-Row Alpha Control\nOverall Sparsity: {results["sparsity_percent"]:.1f}%', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    # Print comprehensive summary statistics 
    print(f"{'-'*80}")
    print(f"\nPER-ROW ALPHA CONTROL")
    print(f"\nStatistics for Cell {cell_idx}:")
    print(f"Alpha values: {[f'{a:.3g}' for a in alphas]}")
    print(f"Overall sparsity: {results['total_zeros']}/{n_nodes_7x7*n_nodes_7x7} ({results['sparsity_percent']:.1f}%)")
    print(f"Average MSE: {results['avg_mse']:.2e}")
    print(f"Average R²: {results['avg_r2']:.4f}")
    print(f"\nPer-row statistics:")
    for row_idx in range(n_nodes_7x7):
        row_res = results['row_results'][row_idx]
        print(f"  Row {row_idx}: α={alphas[row_idx]:.3g}, {row_res['n_zeros']:2d}/{n_nodes_7x7} zeros, MSE={row_res['mse']:.2e}, R²={row_res['r2']:.4f}")
    
    print(f"\nOriginal 7x7 Weight Matrix:")
    print(_fmt_arr(W_original_7x7, 2))
    print(f"\nSparse Weight Matrix (per-row α):")
    print(_fmt_arr(results['W_sparse'], 2))
    print("\nOriginal test_state_7x7:")
    print(_fmt_arr(test_state_7x7, 2))
    print("\nLASSO test_state after per-row LASSO:")
    print(_fmt_arr(results['lasso_test_state'], 2))










# # Interactive widget for ALL-ROW alpha control
# cell_slider_all = widgets.IntSlider(value=6, min=0, max=N_CELLS-1, step=1, description='Cell:')
# alpha_slider_all = widgets.FloatLogSlider(value=1.0, base=10, min=-3, max=1, step=0.01, description='Alpha:', continuous_update=False, readout_format='.4g')
# print("\nInteractive plot for ALL matrix rows (single alpha):")
# out_all_rows = widgets.interactive_output(
#     plot_all_rows_lasso, 
#     {'cell_idx': cell_slider_all, 'alpha': alpha_slider_all}
# )
# display(widgets.HBox([cell_slider_all, alpha_slider_all]), out_all_rows)









# Interactive widget for PER-ROW alpha control
print("\n" + "="*80)
print("🎯 PER-ROW ALPHA CONTROL - Each row gets its own alpha parameter!")
print("="*80)

# Create cell slider for per-row control
cell_slider_per_row = widgets.IntSlider(value=6, min=0, max=N_CELLS-1, step=1, description='Cell:')

# Set default alpha values for each row (customize as needed)
default_alpha_values = [1.0, 0.0, 0.05, 0.0, 0.0, 0.0873, 0.9]  # Example: 7 different defaults

# Create individual alpha sliders for each row, each with its own default value
alpha_sliders = []
for i in range(n_nodes_7x7):
    slider = widgets.FloatLogSlider(
        value=default_alpha_values[i] if i < len(default_alpha_values) else 0.0,  # Use per-row default, fallback to 0.01
        base=10,
        min=-7,  # 0.001
        max=1,   # 10
        step=0.001,
        description=f'α{i}:',
        continuous_update=False,
        readout_format='.3g',
        style={'description_width': '30px'},
        layout=widgets.Layout(width='200px')
    )
    alpha_sliders.append(slider)

# Create the interactive output
out_per_row = widgets.interactive_output(
    plot_all_rows_lasso_per_row_alpha,
    {
        'cell_idx': cell_slider_per_row,
        'alpha_0': alpha_sliders[0],
        'alpha_1': alpha_sliders[1], 
        'alpha_2': alpha_sliders[2],
        'alpha_3': alpha_sliders[3],
        'alpha_4': alpha_sliders[4],
        'alpha_5': alpha_sliders[5],
        'alpha_6': alpha_sliders[6]
    }
)

# Display the widgets in a nice layout
print("\nPer-row alpha control interface:")
control_box = widgets.VBox([
    widgets.HBox([cell_slider_per_row]),
    widgets.HTML("<b>Alpha values for each row:</b>"),
    widgets.HBox(alpha_sliders[:4]),  # First 4 sliders 
    widgets.HBox(alpha_sliders[4:])   # Last 3 sliders
])
display(control_box, out_per_row)



🎯 PER-ROW ALPHA CONTROL - Each row gets its own alpha parameter!



Per-row alpha control interface:


VBox(children=(HBox(children=(IntSlider(value=6, description='Cell:', max=99),)), HTML(value='<b>Alpha values …

Output()

In [None]:
## Summary and Applications

### Key Findings from 7x7 System Analysis

🔬 **Biological Insights:**
- **Gene Redundancy**: Many genes contribute minimally to the weighted sum pattern
- **Critical Interactions**: Only 2-4 genes are typically needed to reconstruct complex dynamics
- **Spatial Consistency**: Sparse patterns work across all cell positions

🎯 **Method Performance:**
- **LASSO**: Automatically finds sparse solutions, good for exploration
- **OMP**: Precise sparsity control, cleaner solutions
- **Trade-off**: ~3-4 genes can achieve >95% reconstruction quality

### Applications in Gene Network Design

**1. Network Compression 🗜️**
- Reduce 7-gene system to 3-4 essential genes
- Maintain pattern formation capability
- Lower complexity for biological implementation

**2. Drug Target Identification 🎯**
- Identify which genes are critical vs. redundant
- Focus therapeutic interventions on key regulators
- Predict system robustness to perturbations

**3. Evolutionary Biology 🧬**
- Understand which genetic interactions are evolutionary conserved
- Predict which mutations would be most disruptive
- Design minimal viable gene regulatory circuits

### Mathematical Insights

The sparse regression problem reveals the **effective dimensionality** of the gene regulatory system:

$$\text{Effective Genes} \ll \text{Total Genes}$$

This suggests that:
- Complex patterns can emerge from simple rules
- Gene networks have inherent redundancy 
- Evolutionary pressure favors robust, sparse architectures

### Next Steps 🚀

1. **Extend to other rows** of the weight matrix
2. **Time-varying sparsity** analysis  
3. **Nonlinear sparse methods** (sparse neural networks)
4. **Multi-objective optimization** (sparsity + biological constraints)

The notebook provides a foundation for understanding which gene interactions are truly necessary for pattern formation! 🧬✨


In [None]:
# Test with LinearRegression instead of Lasso(alpha=0.0)
from sklearn.linear_model import LinearRegression

print("Testing with LinearRegression (true OLS):")
lr_results = []
for row_idx in range(n_nodes_7x7):
    lr = LinearRegression()
    lr.fit(X_all, y_all_by_row[row_idx])
    lr_results.append(lr.coef_)

W_lr = np.array(lr_results)
print("LinearRegression weights:")
print(_fmt_arr(W_lr, 2))

print("Difference (LinearRegression - Original):")
print(_fmt_arr(W_lr - W_original_7x7, 4))

Testing with LinearRegression (true OLS):
LinearRegression weights:
[[165.00, 175.00, 200.00, -105.00, -10.00, 55.00, 5.00]
 [-75.00, -120.00, 120.00, -55.00, 105.00, 5.00, -5.00]
 [-185.00, -180.00, 155.00, 160.00, 100.00, 10.00, 10.00]
 [20.00, -15.00, 200.00, -110.00, 155.00, -0.00, 50.00]
 [-15.00, 160.00, 55.00, -145.00, 5.00, 50.00, -0.00]
 [-155.00, 55.00, -5.00, -100.00, 10.00, -150.00, 5.00]
 [50.00, -5.00, 50.00, 50.00, -0.00, -50.00, -50.00]]
Difference (LinearRegression - Original):
[[0.0000, 0.0000, -0.0000, 0.0000, 0.0000, -0.0000, 0.0000]
 [-0.0000, -0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -0.0000]
 [0.0000, 0.0000, -0.0000, 0.0000, 0.0000, 0.0000, -0.0000]
 [0.0000, 0.0000, -0.0000, 0.0000, -0.0000, -0.0000, 0.0000]
 [0.0000, -0.0000, 0.0000, -0.0000, 0.0000, 0.0000, -0.0000]
 [0.0000, 0.0000, -0.0000, 0.0000, -0.0000, -0.0000, -0.0000]
 [-0.0000, 0.0000, 0.0000, 0.0000, -0.0000, 0.0000, -0.0000]]


In [None]:
# Example: Demonstrating per-row alpha control with different strategies

print("🧪 DEMONSTRATION: Per-Row Alpha Control Strategies")
print("="*60)

# Strategy 1: Uniform alpha (equivalent to old behavior)
print("\n1. Uniform Alpha Strategy (α=0.1 for all rows)")
uniform_alphas = [0.1] * n_nodes_7x7
results_uniform = run_lasso_all_rows_7x7(uniform_alphas, verbose=False)

# Strategy 2: Gradient alpha (decreasing sparsity)
print("\n2. Gradient Alpha Strategy (decreasing from 1.0 to 0.01)")
gradient_alphas = [1.0, 0.5, 0.2, 0.1, 0.05, 0.02, 0.01]
results_gradient = run_lasso_all_rows_7x7(gradient_alphas, verbose=False)

# Strategy 3: Selective sparsity (odd rows sparse, even rows dense)
print("\n3. Selective Sparsity Strategy (odd rows sparse, even rows dense)")
selective_alphas = [0.01, 1.0, 0.01, 1.0, 0.01, 1.0, 0.01]  # Even indices=dense, odd indices=sparse
results_selective = run_lasso_all_rows_7x7(selective_alphas, verbose=False)

# Compare strategies
strategies = [
    ("Uniform", uniform_alphas, results_uniform),
    ("Gradient", gradient_alphas, results_gradient), 
    ("Selective", selective_alphas, results_selective)
]

print(f"\n{'Strategy':<12} {'Alphas':<35} {'Sparsity':<10} {'Avg R²':<8} {'Avg MSE':<10}")
print("-" * 80)
for name, alphas, results in strategies:
    alpha_str = f"[{', '.join(f'{a:.2g}' for a in alphas)}]"
    sparsity = f"{results['sparsity_percent']:.1f}%"
    avg_r2 = f"{results['avg_r2']:.4f}"
    avg_mse = f"{results['avg_mse']:.2e}"
    print(f"{name:<12} {alpha_str:<35} {sparsity:<10} {avg_r2:<8} {avg_mse:<10}")

print(f"\nExample sparse matrices (showing non-zero structure):")
for name, alphas, results in strategies:
    print(f"\n{name} Strategy:")
    sparse_matrix = results['W_sparse']
    # Show pattern of zeros and non-zeros
    pattern = np.where(np.abs(sparse_matrix) > 1e-6, '●', '○')
    for row in pattern:
        print('  ' + ' '.join(row))
    print(f"  Total zeros: {results['total_zeros']}/{sparse_matrix.size}")

print(f"\n💡 Tip: Use the interactive widgets above to experiment with your own alpha combinations!")


In [None]:
## Summary and Applications

### Key Findings from 7x7 System Analysis

🔬 **Biological Insights:**
- **Gene Redundancy**: Many genes contribute minimally to the weighted sum pattern
- **Critical Interactions**: Only 2-4 genes are typically needed to reconstruct complex dynamics
- **Spatial Consistency**: Sparse patterns work across all cell positions

🎯 **Method Performance:**
- **LASSO**: Automatically finds sparse solutions, good for exploration
- **OMP**: Precise sparsity control, cleaner solutions
- **Trade-off**: ~3-4 genes can achieve >95% reconstruction quality

### Applications in Gene Network Design

**1. Network Compression 🗜️**
- Reduce 7-gene system to 3-4 essential genes
- Maintain pattern formation capability
- Lower complexity for biological implementation

**2. Drug Target Identification 🎯**
- Identify which genes are critical vs. redundant
- Focus therapeutic interventions on key regulators
- Predict system robustness to perturbations

**3. Evolutionary Biology 🧬**
- Understand which genetic interactions are evolutionary conserved
- Predict which mutations would be most disruptive
- Design minimal viable gene regulatory circuits

### Mathematical Insights

The sparse regression problem reveals the **effective dimensionality** of the gene regulatory system:

$$\text{Effective Genes} \ll \text{Total Genes}$$

This suggests that:
- Complex patterns can emerge from simple rules
- Gene networks have inherent redundancy 
- Evolutionary pressure favors robust, sparse architectures

### Next Steps 🚀

1. **Extend to other rows** of the weight matrix
2. **Time-varying sparsity** analysis  
3. **Nonlinear sparse methods** (sparse neural networks)
4. **Multi-objective optimization** (sparsity + biological constraints)

The notebook provides a foundation for understanding which gene interactions are truly necessary for pattern formation! 🧬✨
