# Whole-brain pRF and CF maps

Here you can interactively visualize both population Receptive Field (**pRF**)and Connective Field (**CF**) maps on the cortical surface.


**pRF** maps characterize the aggregate response properties of neuronal populations in visual cortex to stimuli in visual space. pRF parameters include eccentricity (distance from fixation), polar angle (angular position in the visual field), and size (extent of the receptive field), providing a functional description of retinotopic organization within individual visual areas.


**CF** maps reveal functional connectivity patterns across the brain. You can select CF parameters (e.g., eccentricity or polar angle) for models referred to left or right V1. Target areas for V1 connectivity include the whole brain, allowing CFs in V1 to project to the contralateral hemisphere, potentially revealing retinotopically organized connectivity in contralateral V1, as well as elsewhere in the cortex.

Use the widgets below to switch between CF and pRF maps, select different parameters, and adjust visualization settings.

*Nicolas Gravel | CEA | nicolas.gravel (at) cea.fr | 28-01-2025*

In [None]:
# Cell 1: Imports and Palette Definitions (Updated)
# Safe imports with error handling
import sys
import warnings
warnings.filterwarnings('ignore')

# Core imports
try:
    import os
    import glob
    import yaml
    from pathlib import Path
    import pickle
    import numpy as np
    import nibabel as nib
    import re
except Exception as e:
    print(f'‚ùå Core imports failed: {e}')
    raise

# Neuroimaging imports
try:
    import neuropythy as ny
    from neuropythy.geometry import Mesh, Tesselation
except Exception as e:
    print(f'‚ùå Neuropythy failed: {e}')
    raise

# Visualization imports
try:
    import pandas as pd
    import ipyvolume as ipv
    import matplotlib.pyplot as plt
    from matplotlib.patches import Patch, Wedge
    import matplotlib.colors as mcolors
    from matplotlib import cm
    from mpl_toolkits.axes_grid1.inset_locator import inset_axes
except Exception as e:
    print(f'‚ùå Visualization imports failed: {e}')
    raise

# Additional scientific imports
try:
    from nilearn.surface import vol_to_surf
    from nilearn import surface, plotting, signal
    from scipy.spatial.distance import pdist, squareform
    from scipy.spatial import cKDTree
    from scipy.stats import pearsonr
except Exception as e:
    pass  # Optional imports, silent fail

# Widget imports
try:
    from ipywidgets import FloatText, HBox, VBox, Textarea, Output, Dropdown, FloatSlider, interactive_output
    from traitlets import link
    from IPython.display import display, HTML
except Exception as e:
    print(f'‚ùå Widget imports failed: {e}')
    raise

import math
import gc

# Debug flag - set to True to see detailed progress messages
DEBUG = False

# Inline color palettes (no external cfmap dependency)
def get_eccentricity_palette():
    """
    Returns a color palette with 10 colors transitioning from red to orange to yellow to green to turquoise to cyan to blue.
    
    Returns:
        dict: Dictionary containing different formats of the color palette
    """
    from matplotlib.colors import LinearSegmentedColormap, ListedColormap
    import matplotlib.colors as mcolors
    
    # Original RGB values (0-255)
    rgb_values = [
        [255, 40, 0],    # Red
        [255, 130, 0],   # Orange-red
        [255, 210, 0],   # Orange-yellow
        [255, 255, 0],   # Yellow
        [115, 255, 0],   # Yellow-green
        [31, 255, 0],    # Green
        [0, 255, 207],   # Turquoise
        [0, 231, 255],   # Cyan
        [20, 140, 255],  # Light blue
        [40, 60, 255]    # Blue
    ]
    
    # Normalize to 0-1 range for matplotlib
    norm_values = [[r/255, g/255, b/255] for r, g, b in rgb_values]
    
    # Create hex values
    hex_values = [mcolors.rgb2hex(rgb) for rgb in norm_values]
    
    # Create named colors
    named_colors = {f"color{i+1}": hex_values[i] for i in range(len(hex_values))}
    
    return {
        "rgb_0_255": rgb_values,
        "rgb_0_1": norm_values,
        "hex": hex_values,
        "named": named_colors,
        "matplotlib_cmap": LinearSegmentedColormap.from_list("eccen_cmap", norm_values)
    }


def get_polar_palette():
    """
    Returns a color palette with 20 colors transitioning from green to red to green to blue to green.
    
    Returns:
        dict: Dictionary containing different formats of the color palette
    """
    from matplotlib.colors import LinearSegmentedColormap, ListedColormap
    import matplotlib.colors as mcolors
    
    # Original RGB values (0-255)
    rgb_values = [
        [106, 189, 69],   # Color1
        [203, 219, 42],   # Color2
        [254, 205, 8],    # Color3
        [242, 104, 34],   # Color4
        [237, 32, 36],    # Color5
        [237, 32, 36],    # Color6
        [242, 104, 34],   # Color7
        [254, 205, 8],    # Color8
        [203, 219, 42],   # Color9
        [106, 189, 69],   # Color10
        [106, 189, 69],   # Color11
        [110, 205, 221],  # Color12
        [50, 178, 219],   # Color13
        [62, 105, 179],   # Color14
        [57, 84, 165],    # Color15
        [57, 84, 165],    # Color16
        [62, 105, 179],   # Color17
        [50, 178, 219],   # Color18
        [110, 205, 221],  # Color19
        [106, 189, 69]    # Color20
    ]
    
    # Normalize to 0-1 range for matplotlib
    norm_values = [[r/255, g/255, b/255] for r, g, b in rgb_values]
    
    # Create hex values
    hex_values = [mcolors.rgb2hex(rgb) for rgb in norm_values]
    
    # Create named colors
    named_colors = {f"color{i+1}": hex_values[i] for i in range(len(hex_values))}
    
    return {
        "rgb_0_255": rgb_values,
        "rgb_0_1": norm_values,
        "hex": hex_values,
        "named": named_colors,
        "matplotlib_cmap": LinearSegmentedColormap.from_list("polar_cmap", norm_values)
    }


# Get color palettes
eccen_colors = get_eccentricity_palette()
polar_colors = get_polar_palette()


# Rotate axis
def rotate_coords(coords, axis, angle_degrees):
    """
    Rotates coordinates by a given angle around the specified axis.
    
    Parameters:
        coords (np.ndarray): shape (3, N) (x, y, z as first dimension)
        axis (str): 'x', 'y', or 'z'
        angle_degrees (float): rotation angle in degrees
        
    Returns:
        np.ndarray: rotated coordinates, shape (3, N)
        axis (str): 'x', 'y', or 'z'
    """    
    theta = np.deg2rad(angle_degrees)
    if axis == 'x':
        rot = np.array([
            [1, 0, 0],
            [0, np.cos(theta), -np.sin(theta)],
            [0, np.sin(theta),  np.cos(theta)]
        ])
    elif axis == 'y':
        rot = np.array([
            [ np.cos(theta), 0, np.sin(theta)],
            [ 0,             1, 0            ],
            [-np.sin(theta), 0, np.cos(theta)]
        ])
    elif axis == 'z':
        rot = np.array([
            [np.cos(theta), -np.sin(theta), 0],
            [np.sin(theta),  np.cos(theta), 0],
            [0,              0,             1]
        ])
    else:
        raise ValueError("axis must be 'x', 'y', or 'z'")
    return rot @ coords


# Improved state management class
class CFVisualizationState:
    """
    Manages the state of the visualization, including caches for data and intermediate results.
    """
    def __init__(self):
        self.notebook_dir = Path().resolve()
        self.loaded_results_cache = {}
        self.loaded_mesh_cache = {}
        # Cache for processed data (masked, transformed, etc.) based on parameters
        # Key: (dataset_key, subject_id, task, source_hemi, map_type, cf_property, r2_threshold)
        self.processed_data_cache = {}
        # Cache for transformed meshes (rotated, separated) based on separation
        # Key: (dataset_key, subject_id, hemi_separation)
        self.transformed_mesh_cache = {}
        # Keep track of the current plot object for potential updates (though full redraw is common)
        self.current_fig = None
        self.current_plot_data = {} # Store latest plot params and data

    def get_paths_and_configs(self, dataset_key):
        """Helper to get paths based on current dataset."""
        base_data_path = self.notebook_dir / 'data' / dataset_key
        cf_models_dir = base_data_path / 'derivatives' / 'cf-models'
        fs_subjects_dir = base_data_path / 'fs_subjects'
        return base_data_path, cf_models_dir, fs_subjects_dir

    def load_meshes_and_curv(self, fs_subject_path, cache_key):
        """Loads or retrieves cached meshes and curvature maps."""
        if cache_key in self.loaded_mesh_cache:
            if DEBUG:
                print(f"‚úì Using cached mesh data for {cache_key}")
            return self.loaded_mesh_cache[cache_key]
        
        if DEBUG:
            print(f"‚è≥ Loading mesh from disk: {fs_subject_path}")
        try:
            import nibabel.freesurfer as fs
            surf_dir = fs_subject_path / 'surf'
            
            # Load geometry
            lh_coords, lh_faces = fs.read_geometry(str(surf_dir / 'lh.inflated'))
            rh_coords, rh_faces = fs.read_geometry(str(surf_dir / 'rh.inflated'))
            
            # Create neuropythy Mesh objects
            lh_mesh = Mesh(Tesselation(lh_faces.T), lh_coords.T)
            rh_mesh = Mesh(Tesselation(rh_faces.T), rh_coords.T)
            
            # Load curvature maps
            lh_curv_map = fs.read_morph_data(str(surf_dir / 'lh.curv'))
            rh_curv_map = fs.read_morph_data(str(surf_dir / 'rh.curv'))

            mesh_data = {
                'lh_mesh': lh_mesh,
                'rh_mesh': rh_mesh,
                'lh_curv_map': lh_curv_map,
                'rh_curv_map': rh_curv_map
            }
            self.loaded_mesh_cache[cache_key] = mesh_data
            if DEBUG:
                print(f"‚úì Loaded and cached mesh: LH={lh_coords.shape[0]}, RH={rh_coords.shape[0]} vertices")
            return mesh_data
            
        except Exception as e:
            print(f"‚ùå Error loading mesh: {e}")
            return None

    def load_cf_results(self, subject_id, task, source_hemi, dataset_key):
        """Loads or retrieves cached CF results."""
        cache_key = f"{dataset_key}_sub-{subject_id}_{task}_{source_hemi}"
        
        if cache_key in self.loaded_results_cache:
            if DEBUG:
                print(f"Using cached CF results for {cache_key}")
            return self.loaded_results_cache[cache_key]
        
        # Find matching file (requires cf_models_info to be accessible)
        # Assume cf_models_info is available in the scope where this is called or passed
        matching_files = [m for m in cf_models_info 
                         if m['subject'] == subject_id 
                         and m['task'] == task 
                         and m['source_hemi'] == source_hemi]
        
        if not matching_files:
            print(f"‚ùå No CF model found for sub-{subject_id}, task={task}, source={source_hemi}")
            return None, None
        
        model_file = matching_files[0]['file']
        if DEBUG:
            print(f"Loading CF results from: {model_file}")
            
        # Load results
        results = np.load(model_file, allow_pickle=True)
        results_lh = results['lh_results'].item()
        results_rh = results['rh_results'].item()
        
        # Cache results
        self.loaded_results_cache[cache_key] = (results_lh, results_rh)
        #print(f"‚úì Loaded and cached CF results for {cache_key}")
        
        return results_lh, results_rh

   
    def load_prf_results(self, subject_id, dataset_key):
        """Loads or retrieves cached pRF results."""
        cache_key = f"prf_{dataset_key}_sub-{subject_id}"

        #if cache_key in self.loaded_results_cache:
        #    print(f"Using cached pRF results for {cache_key}")
        #    return self.loaded_results_cache[cache_key]

        dataset_root = self.notebook_dir / 'data' / dataset_key
        prf_file = dataset_root / 'derivatives' / 'pRF-maps' / f'pRF_parameters_both_hemispheres_averaged_sub-{subject_id}_ses-01.csv'

        empty_lh = {
            'polar_angle': [], 'eccentricity': [], 'size': [], 'baseline': [],
            'amplitude': [], 'hrf_delay': [], 'hrf_dispersion': [], 'r2': []
        }
        empty_rh = { k: [] for k in empty_lh.keys() }

        if not prf_file.exists():
            print(f"‚ö†Ô∏è pRF file not found: {prf_file}")
            self.loaded_results_cache[cache_key] = (empty_lh, empty_rh)
            return empty_lh, empty_rh

        #print(f"Loading pRF results from: {prf_file}")
        df = pd.read_csv(prf_file, sep=',', engine='python')
        
        cols_lower = {c.lower(): c for c in df.columns}
        hemi_col = cols_lower.get('hemisphere') or cols_lower.get('hemi')
        vertex_col = cols_lower.get('vertex_index') or cols_lower.get('vertex')
        pa_col = cols_lower.get('polar_angle') or cols_lower.get('polar')
        ecc_col = cols_lower.get('eccentricity')
        size_col = cols_lower.get('sd')
        r2_col = cols_lower.get('r2')
        baseline_col = cols_lower.get('baseline')
        amplitude_col = cols_lower.get('amplitude')
        hrf_delay_col = cols_lower.get('hrf_delay')
        hrf_dispersion_col = cols_lower.get('hrf_dispersion') # Fixed typo

        if hemi_col is None:
            print(f"‚ö†Ô∏è pRF CSV missing 'hemisphere' column: {prf_file.name}")
            self.loaded_results_cache[cache_key] = (empty_lh, empty_rh)
            return empty_lh, empty_rh

        lh_results = { k: [] for k in empty_lh.keys() }
        rh_results = { k: [] for k in empty_lh.keys() }

        for _, row in df.iterrows():
            hemi = str(row[hemi_col]).lower()
            try:
                vidx = int(row[vertex_col])
            except Exception:
                continue
            hemi = str(row[hemi_col]).lower()
            target_dict = lh_results if hemi.startswith('lh') else rh_results if hemi.startswith('rh') else None
            if target_dict is None:
                continue

            if pa_col and pa_col in row and not pd.isna(row[pa_col]):
                target_dict['polar_angle'].append((vidx, row[pa_col]))
            if ecc_col and ecc_col in row and not pd.isna(row[ecc_col]):
                target_dict['eccentricity'].append((vidx, row[ecc_col]))
            if size_col and size_col in row and not pd.isna(row[size_col]):
                target_dict['size'].append((vidx, row[size_col]))
            if r2_col and r2_col in row and not pd.isna(row[r2_col]):
                target_dict['r2'].append((vidx, row[r2_col]))
            if baseline_col and baseline_col in row and not pd.isna(row[baseline_col]):
                target_dict['baseline'].append((vidx, row[baseline_col]))
            if amplitude_col and amplitude_col in row and not pd.isna(row[amplitude_col]):
                target_dict['amplitude'].append((vidx, row[amplitude_col]))
            if hrf_delay_col and hrf_delay_col in row and not pd.isna(row[hrf_delay_col]):
                target_dict['hrf_delay'].append((vidx, row[hrf_delay_col]))
            if hrf_dispersion_col and hrf_dispersion_col in row and not pd.isna(row[hrf_dispersion_col]): # Use corrected name
                target_dict['hrf_dispersion'].append((vidx, row[hrf_dispersion_col]))

        self.loaded_results_cache[cache_key] = (lh_results, rh_results)
        #print(f"‚úì Loaded and cached pRF results for {cache_key}")
        return lh_results, rh_results

    def prepare_prf_data_arrays(self, lh_results_dict, rh_results_dict, lh_mesh_shape, rh_mesh_shape, cf_property):
        """Converts pRF dictionary results to full numpy arrays for plotting."""
        # Initialize arrays with NaNs
        lh_data = np.full(lh_mesh_shape[1], np.nan) # Shape is (3, n_vertices)
        rh_data = np.full(rh_mesh_shape[1], np.nan)
        lh_r2 = np.full(lh_mesh_shape[1], np.nan)
        rh_r2 = np.full(rh_mesh_shape[1], np.nan)

        # Fill data arrays from dictionaries
        # Left Hemisphere
        for vidx, val in lh_results_dict[cf_property]:
            if 0 <= vidx < lh_data.shape[0]:
                if cf_property == 'size':
                    lh_data[vidx] = abs(val) * 2.355
                else:
                    lh_data[vidx] = val
        for vidx, val in lh_results_dict['r2']:
            if 0 <= vidx < lh_r2.shape[0]:
                lh_r2[vidx] = val

        # Right Hemisphere
        for vidx, val in rh_results_dict[cf_property]:
            if 0 <= vidx < rh_data.shape[0]:
                if cf_property == 'size':
                    rh_data[vidx] = abs(val) * 2.355
                else:
                    rh_data[vidx] = val
        for vidx, val in rh_results_dict['r2']:
            if 0 <= vidx < rh_r2.shape[0]:
                rh_r2[vidx] = val
                
        return lh_data, rh_data, lh_r2, rh_r2

    def apply_r2_mask(self, lh_data, rh_data, lh_r2, rh_r2, r2_threshold):
        """Applies R2 threshold mask to data arrays."""
        lh_mask = np.isnan(lh_r2) | (lh_r2 < r2_threshold)
        rh_mask = np.isnan(rh_r2) | (rh_r2 < r2_threshold)
        
        # Apply mask by setting values to NaN
        lh_data_masked = lh_data.copy()
        rh_data_masked = rh_data.copy()
        lh_data_masked[lh_mask] = np.nan
        rh_data_masked[rh_mask] = np.nan
        
        return lh_data_masked, rh_data_masked, lh_mask, rh_mask

    def get_property_specific_config(self, cf_property, map_type):
        """Gets default configuration for a given property."""
        property_config = {
            'r2': {'adaptive': True, 'vmin': 0, 'vmax': 1},
            'polar': {'adaptive': True, 'vmin': -np.pi, 'vmax': np.pi},
            'eccentricity': {'adaptive': False, 'vmin': 0.5, 'vmax': 6.5},
            'cf_size': {'adaptive': False, 'vmin': 0.5, 'vmax': 5.0},
            'polar_angle': {'adaptive': True, 'vmin': -np.pi, 'vmax': np.pi},
            'size': {'adaptive': False, 'vmin': 0, 'vmax': 5.0},
            'baseline': {'adaptive': True, 'vmin': 0, 'vmax': 1},
            'amplitude': {'adaptive': True, 'vmin': 0, 'vmax': 1},
            'hrf_delay': {'adaptive': True, 'vmin': 0, 'vmax': 1},
            'hrf_dispersion': {'adaptive': True, 'vmin': 0, 'vmax': 1} # Fixed typo
        }
        
        # Special case for pRF eccentricity
        if map_type == 'pRF' and cf_property == 'eccentricity':
            return {'adaptive': False, 'vmin': 0, 'vmax': 5}
        
        return property_config.get(cf_property, {'adaptive': True, 'vmin': 0, 'vmax': 1})

    # Simplified colormap selection - always uses polar_colors for polar angles
    def get_colormap_and_label(self, cf_property):
        """Selects the appropriate colormap and colorbar label."""
        if cf_property == 'eccentricity':
            grad_cmap = eccen_colors['matplotlib_cmap']
            cbar_label = r'Eccentricity $r$ (deg)'
        elif cf_property in ('polar', 'polar_angle'):
            grad_cmap = polar_colors['matplotlib_cmap'] # Always use polar_colors
            cbar_label = r'Polar angle $\theta$ (rad)'
        elif cf_property == 'r2':
            grad_cmap = plt.colormaps['viridis']
            cbar_label = r'$R^2$'
        elif cf_property == 'cf_size':
            grad_cmap = plt.colormaps['viridis']
            cbar_label = r'CF size $\sigma$ (mm)'
        elif cf_property == 'size':
            grad_cmap = plt.colormaps['viridis']
            cbar_label = r'pRF size $\sigma$ (deg)'
        elif cf_property in ['baseline', 'amplitude', 'hrf_delay', 'hrf_dispersion']: # Fixed typo
            grad_cmap = plt.colormaps['viridis'] # Or specific ones if desired
            labels = {
                'baseline': 'Baseline (a.u.)',
                'amplitude': 'Amplitude (a.u.)',
                'hrf_delay': r'HRF delay ($s$)',
                'hrf_dispersion': 'HRF dispersion' # Fixed typo
            }
            cbar_label = labels.get(cf_property, 'Value')
        else:
            grad_cmap = plt.colormaps['viridis']
            cbar_label = 'Value'
        return grad_cmap, cbar_label

    def apply_coordinate_transformations(self, lh_mesh, rh_mesh, hemi_separation, angle=0):
        """Applies rotation and separation to mesh coordinates."""
        lh_coords_plot = lh_mesh.coordinates
        rh_coords_plot = rh_mesh.coordinates
        lh_faces_plot = lh_mesh.tess.faces
        rh_faces_plot = rh_mesh.tess.faces
        
        # Rotate and shift
        lh_coords_medial = rotate_coords(lh_coords_plot, axis='z', angle_degrees=-angle)
        rh_coords_medial = rotate_coords(rh_coords_plot, axis='z', angle_degrees=angle * 2)
        rh_coords_medial[0, :] += hemi_separation
        
        # Create shifted meshes
        lh_mesh_shifted = Mesh(Tesselation(lh_faces_plot), lh_coords_medial)
        rh_mesh_shifted = Mesh(Tesselation(rh_faces_plot), rh_coords_medial)
        
        return lh_mesh_shifted, rh_mesh_shifted

    def prepare_underlay_and_masks(self, lh_curv_map, rh_curv_map, lh_data, rh_data):
        """Prepares underlay data (curvature) and masks for plotting."""
        # Prepare underlay (strips) - negative curvature values
        lh_strips_plot = lh_curv_map.astype(float)
        lh_strips_plot[lh_strips_plot > 0] = np.nan
        rh_strips_plot = rh_curv_map.astype(float)
        rh_strips_plot[rh_strips_plot > 0] = np.nan
        # Prepare underlay (strips) - negative curvature values
        # Prepare masks for valid data points
        lh_mask_plot = ~np.isnan(lh_data)
        rh_mask_plot = ~np.isnan(rh_data)
        
        return lh_strips_plot, rh_strips_plot, lh_mask_plot, rh_mask_plot

    def get_or_create_transformed_meshes(self, lh_mesh, rh_mesh, hemi_separation):
        """Checks cache for transformed meshes, otherwise creates and caches them."""
        # For simplicity in this refactor, we recompute the transformation every time,
        # but the *loading* of the base meshes (lh_mesh, rh_mesh) is cached.
        # The transformation itself (rotate_coords, shift X) is relatively fast.
        return self.apply_coordinate_transformations(lh_mesh, rh_mesh, hemi_separation, angle=0)

def plot_and_save_brains(lh_map, rh_map, colormap, mesh_lh, mesh_rh, strips_lh, strips_rh, mask_lh, mask_rh, view, vmin=None, vmax=None, cbar_label='Value', cf_property='r2'):
    """
    Plot brain surfaces with given maps and colormap, set the view based on the flag, and save to PNG.
    NOTE: polar_colormap parameter removed.
    
    Parameters:
    - lh_map: array-like, map data for left hemisphere
    - rh_map: array-like, map data for right hemisphere
    - colormap: matplotlib colormap object
    - mesh_lh: mesh object for left hemisphere
    - mesh_rh: mesh object for right hemisphere
    - strips_lh: underlay data for left hemisphere
    - strips_rh: underlay data for right hemisphere
    - mask_lh: mask for left hemisphere
    - mask_rh: mask for right hemisphere
    - view: str ('ventral' or 'dorsal') or tuple (azim, elev, dist) to set the camera view
    - vmin: float, minimum value for colormap scaling (optional)
    - vmax: float, maximum value for colormap scaling (optional)
    - cbar_label: str, label for the colorbar (optional, default: 'Value')
    - cf_property: str, CF property being plotted (for specialized colorbar insets)
    """
    
    if isinstance(view, tuple) and len(view) == 3:
        azim, elev, dist = view
    elif view == 'ventral':
        azim, elev, dist = -172, -8, 180
    elif view == 'dorsal':
        azim, elev, dist = -6.13, 31.34, 46.26
    else:
        raise ValueError("view must be 'ventral', 'dorsal', or a tuple (azim, elev, dist)")
    
    # Create figure
    fig = ipv.figure(width=640, height=480)
    
    # Plot right hemisphere
    ny.cortex_plot(mesh_rh, surface='inflated', color=rh_map, cmap=colormap,
        underlay=strips_rh, underlay_cmap='gray', underlay_vmin=-5, underlay_vmax=0.0, mask=mask_rh,
        vmin=vmin, vmax=vmax,
        figure=fig)
    
    # Plot left hemisphere
    ny.cortex_plot(mesh_lh, surface='inflated', color=lh_map, cmap=colormap,
        underlay=strips_lh, underlay_cmap='gray', underlay_vmin=-5, underlay_vmax=0.0, mask=mask_lh,
        vmin=vmin, vmax=vmax,
        figure=fig)
    
    # Compute the center of the plot (mean of all mesh coordinates)
    all_coords = np.concatenate([mesh_lh.coordinates, mesh_rh.coordinates], axis=1)
    center = np.mean(all_coords, axis=1)
    fig.camera.center = center
    
    # Custom function to set view relative to center
    def set_view(fig, azimuth, elevation, distance):
        center = fig.camera.center
        elev_rad = np.radians(elevation)
        az_rad = np.radians(azimuth)
        unit = np.array([
            np.cos(elev_rad) * np.cos(az_rad),
            np.cos(elev_rad) * np.sin(az_rad),
            np.sin(elev_rad)
        ])
        fig.camera.position = tuple(center + distance * unit)
    
    # Adjust the final view
    set_view(fig, azim, elev, dist)
    ipv.show()
    
    # Add colorbar using inset - property-specific
    import matplotlib.pyplot as plt
    from matplotlib import cm
    from mpl_toolkits.axes_grid1.inset_locator import inset_axes
    from matplotlib.patches import Wedge
    
    # Normalize values
    vmin_val = vmin if vmin is not None else np.nanmin([np.nanmin(lh_map), np.nanmin(rh_map)])
    vmax_val = vmax if vmax is not None else np.nanmax([np.nanmax(lh_map), np.nanmax(rh_map)])
    
    if cf_property == 'eccentricity':
        # Create radial concentric rings colorbar for eccentricity
        fig_cb, ax_main = plt.subplots(figsize=(3, 3))
        ax_main.set_aspect('equal')
        ax_main.set_xlim(-1.5, 1.5)
        ax_main.set_ylim(-1.5, 1.5)
        ax_main.set_axis_off()
        ax_main.text(0.5, -0.05, r'$\mathit{r}\ (\mathrm{deg})$', ha='center', va='top', 
                    fontsize=14, transform=ax_main.transAxes)
        
        num_ecc_colors = len(eccen_colors["hex"])
        for i, color in enumerate(eccen_colors["hex"]):
            inner_r = i / num_ecc_colors
            outer_r = (i + 1) / num_ecc_colors
            ring = Wedge((0, 0), outer_r, 0, 360, width=outer_r - inner_r, color=color)
            ax_main.add_patch(ring)
        
        plt.tight_layout()
        plt.show()
        
    elif cf_property in ('polar', 'polar_angle'):
        # Create polar pie chart colorbar for polar angle - Always uses polar_colors now
        fig_cb, ax_main = plt.subplots(figsize=(3, 3))
        ax_main.set_aspect('equal')
        ax_main.set_axis_off()
        # Use custom polar colormap (previously 'polar' option)
        ax_main.pie([1] * len(polar_colors["hex"]), colors=polar_colors["hex"], 
                   startangle=180, counterclock=False)
    
        ax_main.text(0.5, -0.05, r'$\theta\ (\mathrm{rad})$', ha='center', va='top', 
                    fontsize=14, transform=ax_main.transAxes)
        
        plt.tight_layout()
        plt.show()
        
    else:
        # Create standard horizontal colorbar for other properties
        fig_cb, ax_main = plt.subplots(figsize=(8, 2))
        ax_main.set_axis_off()
        
        # Create inset for horizontal colorbar
        cbar_inset = inset_axes(ax_main, width="70%", height="30%", loc="center", borderpad=0)
        
        norm = plt.Normalize(vmin=vmin_val, vmax=vmax_val)
        
        # Create colorbar in inset
        cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=colormap),
                          cax=cbar_inset, orientation='horizontal')
        cb.set_label(cbar_label, fontsize=14)
        cb.ax.tick_params(labelsize=12)
        
        plt.tight_layout()
        plt.show()

    # Create widgets for real-time updates
    # azimuth_widget = FloatText(description='Azimuth:', step=0.1, disabled=True)
    # elevation_widget = FloatText(description='Elevation:', step=0.1, disabled=True)
    # distance_widget = FloatText(description='Distance:', step=0.1, disabled=True)

    # set_view_widget = Textarea(
    #     description='set_view call:',
    #     value='set_view(fig, 0.00, 0.00, 0.00)',
    #     disabled=True,
    #     layout={'width': '400px', 'height': '50px'}
    # )

    # def update_widgets(change):
    #     pos = fig.camera.position
    #     center = fig.camera.center
    #     v = np.array(pos) - np.array(center)
    #     dist = np.linalg.norm(v)
    #     if dist > 0:
    #         elevation = np.degrees(np.arcsin(v[2] / dist))
    #         azimuth = np.degrees(np.arctan2(v[1], v[0]))
    #     else:
    #         azimuth = 0
    #         elevation = 0
    #     distance = dist
    #     azimuth_widget.value = azimuth
    #     elevation_widget.value = elevation
    #     distance_widget.value = distance
    #     set_view_widget.value = f"set_view(fig, {azimuth:.2f}, {elevation:.2f}, {distance:.2f})"
    #     print(set_view_widget.value)

    # fig.camera.observe(update_widgets, names=['position'])
    # update_widgets(None)

    # # Display the widgets
    # display(VBox([HBox([azimuth_widget, elevation_widget, distance_widget]), set_view_widget]))
    # # Display the widgets

# Main update function (reconstructed with state management)
def update_plot_refactored_with_state(state_manager, dataset_key, subject_id, task, source_hemi, map_type,
                           cf_property, r2_threshold, use_adaptive_range, vmin, vmax,
                           hemi_separation): # Removed polar_colormap parameter
    """Main update function using refactored components and state manager."""
    
    # --- 1. Determine Paths and Configuration ---
    base_data_path, cf_models_dir, fs_subjects_dir = state_manager.get_paths_and_configs(dataset_key)
    current_sub_id_formatted = f"{int(subject_id):02d}"
    fs_subject_name = DATASET_CONFIGS[dataset_key]['fs_subject_format'].format(subject_id=current_sub_id_formatted)
    fs_subject_path = fs_subjects_dir / fs_subject_name
    mesh_cache_key = f"{dataset_key}_{current_sub_id_formatted}"
    
    # --- 2. Load Data (CF or pRF) ---
    if map_type == 'CF':
        # Load CF results
        available_tasks = [info['task'] for info in cf_models_info if info['subject'] == subject_id and info['source_hemi'] == source_hemi]
        available_tasks = sorted(list(set(available_tasks)))
        if task not in available_tasks:
             print(f"‚ö†Ô∏è Task '{task}' not available for sub-{subject_id}, source-{source_hemi}. Available: {available_tasks}")
             return # Exit early if invalid task
        cf_results_lh, cf_results_rh = state_manager.load_cf_results(subject_id, task, source_hemi, dataset_key)
        if cf_results_lh is None or cf_results_rh is None:
            return # Error already printed in load function

        # Map property to CF result key
        prop_map = {
            'eccentricity': 'inherited_eccen', 'polar': 'inherited_polar',
            'cf_size': 'cf_size', 'r2': 'r2'
        }
        if cf_property not in prop_map:
            print(f"Unknown CF property: {cf_property}")
            return

        # Extract data and R2
        lh_data_raw = cf_results_lh[prop_map[cf_property]]
        rh_data_raw = cf_results_rh[prop_map[cf_property]]
        lh_r2_raw = cf_results_lh.get('r2', np.full_like(lh_data_raw, np.nan))
        rh_r2_raw = cf_results_rh.get('r2', np.full_like(rh_data_raw, np.nan))

    elif map_type == 'pRF':
        # Load pRF results
        prf_results_lh, prf_results_rh = state_manager.load_prf_results(subject_id, dataset_key)
        if prf_results_lh is None or prf_results_rh is None:
            return # Error already printed in load function

        # Prepare full data arrays from dictionary
        # This requires the mesh shape, so we load the mesh first if needed for pRF
        mesh_data = state_manager.load_meshes_and_curv(fs_subject_path, mesh_cache_key)
        if mesh_data is None:
             return # Error already printed in load function
        lh_mesh_shape = mesh_data['lh_mesh'].coordinates.shape
        rh_mesh_shape = mesh_data['rh_mesh'].coordinates.shape
        
        lh_data_raw, rh_data_raw, lh_r2_raw, rh_r2_raw = state_manager.prepare_prf_data_arrays(
            prf_results_lh, prf_results_rh, lh_mesh_shape, rh_mesh_shape, cf_property
        )
    else:
        print(f"Unknown map_type: {map_type}")
        return

    # --- 3. Apply R2 Mask (Could potentially cache this step too) ---
    lh_data_masked, rh_data_masked, lh_mask, rh_mask = state_manager.apply_r2_mask(
        lh_data_raw, rh_data_raw, lh_r2_raw, rh_r2_raw, r2_threshold
    )

    # --- 4. Load Meshes and Curvature (Cached) ---
    mesh_data = state_manager.load_meshes_and_curv(fs_subject_path, mesh_cache_key)
    if mesh_data is None:
         return # Error already printed in load function
    lh_mesh, rh_mesh = mesh_data['lh_mesh'], mesh_data['rh_mesh']
    lh_curv_map, rh_curv_map = mesh_data['lh_curv_map'], mesh_data['rh_curv_map']

    # --- 5. Determine Value Range (Adaptive or Fixed) ---
    if use_adaptive_range:
        valid_data = np.concatenate([lh_data_masked[~np.isnan(lh_data_masked)], rh_data_masked[~np.isnan(rh_data_masked)]])
        if valid_data.size > 0:
            vmin_calc = np.percentile(valid_data, 2)
            vmax_calc = np.percentile(valid_data, 98)
        else:
            # Fallback if no valid data passes R2 threshold
            fallback_cfg = state_manager.get_property_specific_config(cf_property, map_type)
            vmin_calc = fallback_cfg['vmin']
            vmax_calc = fallback_cfg['vmax']
    else:
        vmin_calc = vmin
        vmax_calc = vmax

    # --- 6. Prepare Underlay and Masks for Plotting ---
    lh_strips_plot, rh_strips_plot, lh_mask_plot, rh_mask_plot = state_manager.prepare_underlay_and_masks(
        lh_curv_map, rh_curv_map, lh_data_masked, rh_data_masked
    )

    # --- 7. Apply Coordinate Transformations (Cached or Recomputed) ---
    # As discussed, caching the transformed mesh based on separation might be complex.
    # We recompute the transformation here, relying on the cached base meshes (lh_mesh, rh_mesh).
    lh_mesh_final, rh_mesh_final = state_manager.get_or_create_transformed_meshes(
        lh_mesh, rh_mesh, hemi_separation
    )

    # --- 8. Select Colormap and Label (Simplified) ---
    grad_cmap, cbar_label = state_manager.get_colormap_and_label(cf_property)
    
    # --- 9. Finalize Plot ---
    try:
        plot_and_save_brains(
            lh_data_masked, rh_data_masked, grad_cmap,
            lh_mesh_final, rh_mesh_final,
            lh_strips_plot, rh_strips_plot,
            lh_mask_plot, rh_mask_plot,
            (-122, -27, 80), vmin=vmin_calc, vmax=vmax_calc,
            cbar_label=cbar_label, cf_property=cf_property
        )
        # Update current plot state
        state_manager.current_plot_data = {
            'lh_data': lh_data_masked, 'rh_data': rh_data_masked,
            'lh_mesh': lh_mesh_final, 'rh_mesh': rh_mesh_final,
            'colormap': grad_cmap, 'vmin': vmin_calc, 'vmax': vmax_calc,
            'cbar_label': cbar_label, 'subject_id': subject_id, 'task': task,
            'source_hemi': source_hemi, 'cf_property': cf_property,
            'r2_threshold': r2_threshold, 'map_type': map_type
        }
    except Exception as e:
        print(f"‚ùå Error plotting brains: {e}")
        return

# Widget setup and main loop (cell 3 equivalent)
# Dataset configurations (Updated order and removed polar_colormap widget)
DATASET_CONFIGS = {
    'iCRTX7T': { # Moved to top
        'name': 'iCORTEX (7T)',
        'task_pattern': r'cf_results_sub-(\d+)_(ses-\d+)_([\w-]+)_(lh|rh)-source_ecc([\d.]+)-([\d.]+)\.npz',
        'fs_subject_format': 'sub-{subject_id}_ses-01_iso'
    },
    'LPP7T': {
        'name': 'Le Petit Prince (7T)',
        'task_pattern': r'cf_results_sub-(\d+)_(ses-\d+)_(\w+)_(lh|rh)-source_ecc([\d.]+)-([\d.]+)\.npz',
        'fs_subject_format': 'sub-{subject_id}_ses-01_iso'
    },
    'CB3T': {
        'name': 'Congenital Blindness (3T)',
        'task_pattern': r'cf_results_sub-(\d+)_(ses-\d+)_(\w+)_(lh|rh)-source_ecc([\d.]+)-([\d.]+)\.npz',
        'fs_subject_format': 'sub-{subject_id}_ses-01_iso'
    },
}
# Flag for pRF availability (already present in iCRTX7T key)
DATASET_CONFIGS['iCRTX7T']['has_prf'] = True
DATASET_CONFIGS['LPP7T']['has_prf'] = False
DATASET_CONFIGS['CB3T']['has_prf'] = False
DATASET_CONFIGS['LPP7T']['has_prf'] = False

# Initialize with iCRTX7T dataset (Updated default)
current_dataset_key = 'iCRTX7T'
notebook_dir = Path().resolve() # Access notebook dir
base_data_path = notebook_dir / 'data' / current_dataset_key
cf_models_dir = base_data_path / 'derivatives' / 'cf-models'
fs_subjects_dir = base_data_path / 'fs_subjects'

if not base_data_path.exists():
    print(f"‚ö†Ô∏è No data found for {DATASET_CONFIGS[current_dataset_key]['name']}")
    print(f"Expected path: {base_data_path}")
    print("\nPlease ensure data is organized as:")
    print("  data/")
    print("    iCRTX7T/") # Updated example
    print("      derivatives/cf-models/")
    print("      fs_subjects/")
    print("    LPP7T/")
    print("      derivatives/cf-models/")
    print("      fs_subjects/")
    print("    CB3T/")
    print("      derivatives/cf-models/")
    print("      fs_subjects/")

# Scan available CF model files
def scan_cf_models(cf_models_dir, task_pattern):
    """Scan CF models directory and extract available subjects, tasks, and sources."""
    cf_models_dir = Path(cf_models_dir)
    pattern = str(cf_models_dir / '*.npz')
    files = glob.glob(pattern)
    
    available_models = []
    subjects = set()
    tasks = set()
    
    for f in files:
        basename = Path(f).name
        # Use dataset-specific pattern
        match = re.match(task_pattern, basename)
        if match:
            subject_id, session_id, task, source_hemi, min_ecc, max_ecc = match.groups()
            subjects.add(subject_id)
            tasks.add(task)
            available_models.append({
                'file': f,
                'subject': subject_id,
                'session': session_id,
                'task': task,
                'source_hemi': source_hemi,
                'min_ecc': float(min_ecc),
                'max_ecc': float(max_ecc)
            })
    
    return sorted(list(subjects)), sorted(list(tasks)), available_models

# Scan models with current dataset pattern
available_subjects, available_tasks, cf_models_info = scan_cf_models(
    cf_models_dir, 
    DATASET_CONFIGS[current_dataset_key]['task_pattern']
)

# Instantiate state manager
state_manager = CFVisualizationState()

# Flag to prevent recursive updates
updating = False

# Widget functions (callbacks)
def update_dataset_options(change):
    """Update subject and task options when dataset changes."""
    global available_subjects, available_tasks, cf_models_info, base_data_path, cf_models_dir, fs_subjects_dir, current_dataset_key
    
    new_dataset_key = change['new']
    current_dataset_key = new_dataset_key
    
    # Update base path to new dataset folder
    base_data_path = notebook_dir / 'data' / new_dataset_key
    
    if not base_data_path.exists():
        print(f"‚ö†Ô∏è No data found for {DATASET_CONFIGS[new_dataset_key]['name']}")
        print(f"Expected path: {base_data_path}")
        return
    
    # Update paths
    cf_models_dir = base_data_path / 'derivatives' / 'cf-models'
    fs_subjects_dir = base_data_path / 'fs_subjects'
    
    # Rescan with new pattern
    task_pattern = DATASET_CONFIGS[new_dataset_key]['task_pattern']
    available_subjects_new, available_tasks_new, cf_models_info_new = scan_cf_models(cf_models_dir, task_pattern)
    
    # Update global variables and widget options
    available_subjects[:] = available_subjects_new
    available_tasks[:] = available_tasks_new
    cf_models_info[:] = cf_models_new
    
    subject_widget.options = available_subjects
    subject_widget.value = available_subjects[0] if available_subjects else '01'
    task_widget.options = available_tasks
    task_widget.value = available_tasks[0] if available_tasks else 'LPP1'
    
    # Update map type options based on pRF availability
    #has_prf = DATASET_CONFIGS[new_dataset_key].get('has_prf', False)
    has_prf = DATASET_CONFIGS[new_dataset_key].get('has_prf', True)
    if has_prf:
        map_type_widget.options = ['CF', 'pRF']
        map_type_widget.value = 'pRF'  # Default to pRF
    else:
        map_type_widget.options = ['CF']
        map_type_widget.value = 'CF'
    if DEBUG:
        print(f"‚úì Switched to {DATASET_CONFIGS[new_dataset_key]['name']}")
        print(f"  Found {len(available_subjects)} subjects, {len(available_tasks)} tasks")
        if has_prf:
            print("  pRF data available")

def update_map_type_options(change):
    """Update property options and widget visibility when map type changes."""
    global updating
    if updating:
        return
    updating = True
    
    try:
        map_type = change['new']
        if map_type == 'CF':
            cf_property_widget.options = ['eccentricity', 'polar', 'cf_size', 'r2']
            cf_property_widget.value = 'eccentricity'
            task_widget.layout.display = 'flex'  # Show task
            source_hemi_widget.layout.display = 'flex' # Show source
            cf_property_widget.description = 'CF parameter:'
        elif map_type == 'pRF':
            cf_property_widget.options = ['eccentricity', 'polar_angle', 'size', 'baseline', 'amplitude', 'hrf_delay', 'hrf_dispersion', 'r2']
            cf_property_widget.value = 'eccentricity' # Or whatever makes sense as default for pRF
            task_widget.layout.display = 'none'  # Hide task
            source_hemi_widget.layout.display = 'none' # Hide source
            cf_property_widget.description = 'pRF parameter:'
        # Update defaults for the new property
        update_widget_defaults(cf_property_widget.value) # Call this to set vmin/vmax/range defaults
    finally:
        updating = False

def update_widget_defaults(cf_property):
    """Update adaptive range and min/max defaults based on CF property and map type.
    Note: This function is always called from within a protected context (where updating is already True),
    so it doesn't need to check or modify the updating flag."""
    map_type = map_type_widget.value  # Get current map type
    
    config = state_manager.get_property_specific_config(cf_property, map_type)
    adaptive_range_widget.value = config['adaptive']
    vmin_widget.value = config['vmin']
    vmax_widget.value = config['vmax']

# Removed update_widget_visibility as polar_colormap_widget is gone

# Widget creation
# Updated order of dataset options to prioritize iCRTX7T
dataset_widget = Dropdown(
    options=[(config['name'], key) for key, config in DATASET_CONFIGS.items()],
    value='iCRTX7T', # Changed default
    description='Dataset:'
)
subject_widget = Dropdown(options=available_subjects, 
                         value=available_subjects[0] if available_subjects else '01', 
                         description='Subject:')
task_widget = Dropdown(options=available_tasks, 
                      value=available_tasks[0] if available_tasks else 'LPP1', 
                      description='Task:')
source_hemi_widget = Dropdown(options=['lh', 'rh'], value='lh', description='Source:')

map_type_widget = Dropdown(options=['CF', 'pRF'], value='pRF', description='Map type:') 

cf_property_widget = Dropdown(options=['eccentricity', 'polar', 'cf_size', 'r2'], 
                              value='eccentricity', description='Property:')
r2_threshold_widget = FloatSlider(value=0.1, min=0.0, max=1.0, step=0.01, description='R¬≤ threshold:')

# Adaptive range widget - default depends on CF property
adaptive_range_widget = Dropdown(options=[True, False], value=False, 
                                description='Range:')
r2_threshold_widget = FloatSlider(value=0.1, min=0.0, max=1.0, step=0.01, description='R¬≤ threshold:')
# Min/max widgets - defaults depend on CF property - initialized for eccentricity
vmin_widget = FloatSlider(value=0.5, min=-10, max=10, step=0.01, description='min:')
vmax_widget = FloatSlider(value=6.5, min=-10, max=10, step=0.01, description='max:')

# REMOVED: polar_colormap_widget
# Min/max widgets - defaults depend on CF property - initialized for eccentricity
# Hemisphere separation slider
hemi_separation_widget = FloatSlider(
    value=80, 
    min=0, 
    max=200, 
    step=1, 
    description='Hemi gap:',
    tooltip='Distance between left and right hemispheres'
)
# Trigger initial layout update for pRF map type
update_map_type_options({'new': map_type_widget.value})

# Link widget callbacks
dataset_widget.observe(update_dataset_options, names='value')
map_type_widget.observe(update_map_type_options, names='value')

# Protected callback for cf_property changes
def on_cf_property_change(change):
    global updating
    if updating:
        return
    
    updating = True  # Block all other callbacks during entire sequence
    
    # Temporarily unobserve the widgets we're about to change
    adaptive_range_widget.unobserve(update_visualization, names='value')
    vmin_widget.unobserve(update_visualization, names='value')
    vmax_widget.unobserve(update_visualization, names='value')
    
    try:
        # Update the widgets
        update_widget_defaults(change['new'])
    finally:
        # Re-attach observers
        adaptive_range_widget.observe(update_visualization, names='value')
        vmin_widget.observe(update_visualization, names='value')
        vmax_widget.observe(update_visualization, names='value')
    
    # Now trigger single visualization update while updating is still True
    # This prevents any other callbacks from interfering
    update_visualization()
    
    # Only reset flag after everything is done
    updating = False

cf_property_widget.observe(on_cf_property_change, names='value')
# Removed observe for update_widget_visibility

# =============================
# CLEAN RENDERING WITH MANUAL CALLBACKS
# =============================

from ipywidgets import Output
from IPython.display import HTML
# CLEAN RENDERING WITH MANUAL CALLBACKS
# Output areas
loading_output = Output()
plot_output = Output()

def show_loading():
    with loading_output:
        loading_output.clear_output(wait=True)
        display(HTML("""
            <div style="text-align: center; margin: 20px;">
                <div style="width: 40px; height: 40px; border: 4px solid rgba(0,0,0,0.1);
                            border-left-color: #4A90E2; border-radius: 50%;
                            animation: spin 1s linear infinite;"></div>
                <p style="margin-top: 10px; color: #555;">Loading...</p>
            </div>
            <style>@keyframes spin { to { transform: rotate(360deg); } }</style>
        """))

def hide_loading():
    with loading_output:
        loading_output.clear_output()

# Unified update function
def update_visualization(*args, from_protected_callback=False):
    import time
    t_start = time.time()
    
    global updating
    # Skip updating check if called from protected callback
    if not from_protected_callback and updating:
        return
    if DEBUG:
        print("\n" + "="*60)
        print("üîÑ UPDATE VISUALIZATION TRIGGERED")
        print("="*60)
    
    # Get current values
    dataset_key = dataset_widget.value
    subject_id = subject_widget.value
    task = task_widget.value
    source_hemi = source_hemi_widget.value
    map_type = map_type_widget.value
    cf_property = cf_property_widget.value
    r2_threshold = r2_threshold_widget.value
    use_adaptive_range = adaptive_range_widget.value
    vmin = vmin_widget.value
    vmax = vmax_widget.value
    hemi_separation = hemi_separation_widget.value

    # Clear old plot
    with plot_output:
        plot_output.clear_output()
    show_loading()

    try:
        from time import sleep
        sleep(0.05)
        t_plot_start = time.time()
        with plot_output:
            update_plot_refactored_with_state(
                state_manager=state_manager,
                dataset_key=dataset_key,
                subject_id=subject_id,
                task=task,
                source_hemi=source_hemi,
                map_type=map_type,
                cf_property=cf_property,
                r2_threshold=r2_threshold,
                use_adaptive_range=use_adaptive_range,
                vmin=vmin,
                vmax=vmax,
                hemi_separation=hemi_separation
            )
        if DEBUG:
            print(f"\n‚è±Ô∏è Plot function took: {time.time()-t_plot_start:.2f}s")
    except Exception as e:
        with plot_output:
            print(f"‚ùå Error: {e}")
    finally:
        hide_loading()
        if DEBUG:
            print(f"‚úÖ Total update time: {time.time()-t_start:.2f}s")
            print("="*60 + "\n")

# Attach callback to all relevant widgets
# NOTE: cf_property_widget is NOT in this list because it has its own
# protected callback (on_cf_property_change) that handles widget updates
for widget in [
    dataset_widget, subject_widget, task_widget, source_hemi_widget,
    map_type_widget, r2_threshold_widget,
    adaptive_range_widget, vmin_widget, vmax_widget, hemi_separation_widget
]:
    widget.observe(update_visualization, names='value')
    dataset_widget, subject_widget, task_widget, source_hemi_widget,
# Initial render
#update_visualization()
update_visualization(from_protected_callback=True)

# Final layout
display(VBox([
    HBox([
        VBox([dataset_widget, subject_widget, map_type_widget, task_widget,
              source_hemi_widget, cf_property_widget, adaptive_range_widget]),
        VBox([hemi_separation_widget, r2_threshold_widget, vmin_widget, vmax_widget])
    ]),
    loading_output,

    plot_output
]))
