
# Whole-brain Connective Field maps

Here you can select a CF parameter (e.g., eccentricity or polar angle) and plot it interactively for models referred to left or right V1. Target areas for V1 connectivity include the whole brain. Accordingly, CFs in V1 can project to the contralateral hemisphere, potentially revealing retinotopically organized connectivity in contralateral V1, as well as elsewhere in the cortex.

In [1]:
# Safe imports with error handling
import sys
import warnings
warnings.filterwarnings('ignore')

# Core imports
try:
    import os
    import glob
    import yaml
    from pathlib import Path
    import pickle
    import numpy as np
    import nibabel as nib
except Exception as e:
    print(f'‚ùå Core imports failed: {e}')
    raise

# Neuroimaging imports
try:
    import neuropythy as ny
    from neuropythy.geometry import Mesh, Tesselation
except Exception as e:
    print(f'‚ùå Neuropythy failed: {e}')
    raise

# Visualization imports
try:
    import pandas as pd
    import ipyvolume as ipv
    import matplotlib.pyplot as plt
    from matplotlib.patches import Patch, Wedge
    import matplotlib.colors as mcolors
    from matplotlib import cm
    from mpl_toolkits.axes_grid1.inset_locator import inset_axes
except Exception as e:
    print(f'‚ùå Visualization imports failed: {e}')
    raise

# Additional scientific imports
try:
    from nilearn.surface import vol_to_surf
    from nilearn import surface, plotting, signal
    from scipy.spatial.distance import pdist, squareform
    from scipy.spatial import cKDTree
    from scipy.stats import pearsonr
except Exception as e:
    pass  # Optional imports, silent fail

# Widget imports
try:
    from ipywidgets import FloatText, HBox, VBox, Textarea, Output, Dropdown, FloatSlider, interactive_output
    from traitlets import link
    from IPython.display import display
except Exception as e:
    print(f'‚ùå Widget imports failed: {e}')
    raise

import math
import gc

In [2]:
# Inline color palettes (no external cfmap dependency)
def get_eccentricity_palette():
    """
    Returns a color palette with 10 colors transitioning from red to orange to yellow to green to turquoise to cyan to blue.
    
    Returns:
        dict: Dictionary containing different formats of the color palette
    """
    from matplotlib.colors import LinearSegmentedColormap, ListedColormap
    import matplotlib.colors as mcolors
    
    # Original RGB values (0-255)
    rgb_values = [
        [255, 40, 0],    # Red
        [255, 130, 0],   # Orange-red
        [255, 210, 0],   # Orange-yellow
        [255, 255, 0],   # Yellow
        [115, 255, 0],   # Yellow-green
        [31, 255, 0],    # Green
        [0, 255, 207],   # Turquoise
        [0, 231, 255],   # Cyan
        [20, 140, 255],  # Light blue
        [40, 60, 255]    # Blue
    ]
    
    # Normalize to 0-1 range for matplotlib
    norm_values = [[r/255, g/255, b/255] for r, g, b in rgb_values]
    
    # Create hex values
    hex_values = [mcolors.rgb2hex(rgb) for rgb in norm_values]
    
    # Create named colors
    named_colors = {f"color{i+1}": hex_values[i] for i in range(len(hex_values))}
    
    return {
        "rgb_0_255": rgb_values,
        "rgb_0_1": norm_values,
        "hex": hex_values,
        "named": named_colors,
        "matplotlib_cmap": LinearSegmentedColormap.from_list("eccen_cmap", norm_values)
    }


def get_polar_palette():
    """
    Returns a color palette with 20 colors transitioning from green to red to green to blue to green.
    
    Returns:
        dict: Dictionary containing different formats of the color palette
    """
    from matplotlib.colors import LinearSegmentedColormap, ListedColormap
    import matplotlib.colors as mcolors
    
    # Original RGB values (0-255)
    rgb_values = [
        [106, 189, 69],   # Color1
        [203, 219, 42],   # Color2
        [254, 205, 8],    # Color3
        [242, 104, 34],   # Color4
        [237, 32, 36],    # Color5
        [237, 32, 36],    # Color6
        [242, 104, 34],   # Color7
        [254, 205, 8],    # Color8
        [203, 219, 42],   # Color9
        [106, 189, 69],   # Color10
        [106, 189, 69],   # Color11
        [110, 205, 221],  # Color12
        [50, 178, 219],   # Color13
        [62, 105, 179],   # Color14
        [57, 84, 165],    # Color15
        [57, 84, 165],    # Color16
        [62, 105, 179],   # Color17
        [50, 178, 219],   # Color18
        [110, 205, 221],  # Color19
        [106, 189, 69]    # Color20
    ]
    
    # Normalize to 0-1 range for matplotlib
    norm_values = [[r/255, g/255, b/255] for r, g, b in rgb_values]
    
    # Create hex values
    hex_values = [mcolors.rgb2hex(rgb) for rgb in norm_values]
    
    # Create named colors
    named_colors = {f"color{i+1}": hex_values[i] for i in range(len(hex_values))}
    
    return {
        "rgb_0_255": rgb_values,
        "rgb_0_1": norm_values,
        "hex": hex_values,
        "named": named_colors,
        "matplotlib_cmap": LinearSegmentedColormap.from_list("polar_cmap", norm_values)
    }


# Get color palettes
eccen_colors = get_eccentricity_palette()
polar_colors = get_polar_palette()


# Rotate axis
def rotate_coords(coords, axis, angle_degrees):
    """
    Rotates coordinates by a given angle around the specified axis.
    
    Parameters:
        coords (np.ndarray): shape (3, N) (x, y, z as first dimension)
        axis (str): 'x', 'y', or 'z'
        angle_degrees (float): rotation angle in degrees
        
    Returns:
        np.ndarray: rotated coordinates, shape (3, N)
    """
    theta = np.deg2rad(angle_degrees)
    if axis == 'x':
        rot = np.array([
            [1, 0, 0],
            [0, np.cos(theta), -np.sin(theta)],
            [0, np.sin(theta),  np.cos(theta)]
        ])
    elif axis == 'y':
        rot = np.array([
            [ np.cos(theta), 0, np.sin(theta)],
            [ 0,             1, 0            ],
            [-np.sin(theta), 0, np.cos(theta)]
        ])
    elif axis == 'z':
        rot = np.array([
            [np.cos(theta), -np.sin(theta), 0],
            [np.sin(theta),  np.cos(theta), 0],
            [0,              0,             1]
        ])
    else:
        raise ValueError("axis must be 'x', 'y', or 'z'")
    return rot @ coords



# Plotting function (adapted from original)
def plot_and_save_brains(lh_map, rh_map, colormap, mesh_lh, mesh_rh, strips_lh, strips_rh, mask_lh, mask_rh, view, vmin=None, vmax=None, cbar_label='Value', cf_property='r2', polar_colormap='polar'):
    """
    Plot brain surfaces with given maps and colormap, set the view based on the flag, and save to PNG.
    
    Parameters:
    - lh_map: array-like, map data for left hemisphere
    - rh_map: array-like, map data for right hemisphere
    - colormap: matplotlib colormap object
    - mesh_lh: mesh object for left hemisphere
    - mesh_rh: mesh object for right hemisphere
    - strips_lh: underlay data for left hemisphere
    - strips_rh: underlay data for right hemisphere
    - mask_lh: mask for left hemisphere
    - mask_rh: mask for right hemisphere
    - view: str ('ventral' or 'dorsal') or tuple (azim, elev, dist) to set the camera view
    - vmin: float, minimum value for colormap scaling (optional)
    - vmax: float, maximum value for colormap scaling (optional)
    - cbar_label: str, label for the colorbar (optional, default: 'Value')
    - cf_property: str, CF property being plotted (for specialized colorbar insets)
    - polar_colormap: str, colormap type for polar angle ('polar' or 'hsv')
    """
    if isinstance(view, tuple) and len(view) == 3:
        azim, elev, dist = view
    elif view == 'ventral':
        azim, elev, dist = -172, -8, 180
    elif view == 'dorsal':
        azim, elev, dist = -6.13, 31.34, 46.26
    else:
        raise ValueError("view must be 'ventral', 'dorsal', or a tuple (azim, elev, dist)")
       
    # Create figure with HD 720p size
    fig = ipv.figure(width=1280, height=720)
    
    # Plot right hemisphere
    ny.cortex_plot(mesh_rh, surface='inflated', color=rh_map, cmap=colormap,
        underlay=strips_rh, underlay_cmap='gray', underlay_vmin=-5, underlay_vmax=0.0, mask=mask_rh,
        vmin=vmin, vmax=vmax,
        figure=fig)
    
    # Plot left hemisphere
    ny.cortex_plot(mesh_lh, surface='inflated', color=lh_map, cmap=colormap,
        underlay=strips_lh, underlay_cmap='gray', underlay_vmin=-5, underlay_vmax=0.0, mask=mask_lh,
        vmin=vmin, vmax=vmax,
        figure=fig)
    
    # Compute the center of the plot (mean of all mesh coordinates)
    all_coords = np.concatenate([mesh_lh.coordinates, mesh_rh.coordinates], axis=1)
    center = np.mean(all_coords, axis=1)
    fig.camera.center = center
    
    # Custom function to set view relative to center
    def set_view(fig, azimuth, elevation, distance):
        center = fig.camera.center
        elev_rad = np.radians(elevation)
        az_rad = np.radians(azimuth)
        unit = np.array([
            np.cos(elev_rad) * np.cos(az_rad),
            np.cos(elev_rad) * np.sin(az_rad),
            np.sin(elev_rad)
        ])
        fig.camera.position = tuple(center + distance * unit)
    
    # Adjust the final view
    set_view(fig, azim, elev, dist)

    ipv.show()
    
    # Add colorbar using inset - property-specific
    import matplotlib.pyplot as plt
    from matplotlib import cm
    from mpl_toolkits.axes_grid1.inset_locator import inset_axes
    from matplotlib.patches import Wedge
    
    # Normalize values
    vmin_val = vmin if vmin is not None else np.nanmin([np.nanmin(lh_map), np.nanmin(rh_map)])
    vmax_val = vmax if vmax is not None else np.nanmax([np.nanmax(lh_map), np.nanmax(rh_map)])
    
    if cf_property == 'eccentricity':
        # Create radial concentric rings colorbar for eccentricity
        fig_cb, ax_main = plt.subplots(figsize=(3, 3))
        ax_main.set_aspect('equal')
        ax_main.set_xlim(-1.5, 1.5)
        ax_main.set_ylim(-1.5, 1.5)
        ax_main.set_axis_off()
        ax_main.text(0.5, -0.05, r'$\mathit{r}\ (\mathrm{deg})$', ha='center', va='top', 
                    fontsize=14, transform=ax_main.transAxes)
        
        num_ecc_colors = len(eccen_colors["hex"])
        for i, color in enumerate(eccen_colors["hex"]):
            inner_r = i / num_ecc_colors
            outer_r = (i + 1) / num_ecc_colors
            ring = Wedge((0, 0), outer_r, 0, 360, width=outer_r - inner_r, color=color)
            ax_main.add_patch(ring)
        
        plt.tight_layout()
        plt.show()
        
    elif cf_property in ('polar', 'polar_angle'):
        # Create polar pie chart colorbar for polar angle
        fig_cb, ax_main = plt.subplots(figsize=(3, 3))
        ax_main.set_aspect('equal')
        ax_main.set_axis_off()
        
        if polar_colormap == 'hsv':
            # Use HSV colormap - create gradient pie chart
            import matplotlib.pyplot as plt
            n_segments = 20
            theta = np.linspace(0, 2*np.pi, n_segments, endpoint=False)
            colors_hsv = [plt.cm.hsv(i/n_segments) for i in range(n_segments)]
            ax_main.pie([1]*n_segments, colors=colors_hsv, 
                       startangle=180, counterclock=False)
        else:
            # Use custom polar colormap
            ax_main.pie([1]*len(polar_colors["hex"]), colors=polar_colors["hex"], 
                       startangle=180, counterclock=False)
        
        ax_main.text(0.5, -0.05, r'$\theta\ (\mathrm{rad})$', ha='center', va='top', 
                    fontsize=14, transform=ax_main.transAxes)
        
        plt.tight_layout()
        plt.show()
        
    else:
        # Create standard horizontal colorbar for other properties
        fig_cb, ax_main = plt.subplots(figsize=(8, 2))
        ax_main.set_axis_off()
        
        # Create inset for horizontal colorbar
        cbar_inset = inset_axes(ax_main, width="70%", height="30%", loc="center", borderpad=0)
        
        norm = plt.Normalize(vmin=vmin_val, vmax=vmax_val)
        
        # Create colorbar in inset
        cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=colormap),
                          cax=cbar_inset, orientation='horizontal')
        cb.set_label(cbar_label, fontsize=14)
        cb.ax.tick_params(labelsize=12)
        
        plt.tight_layout()
        plt.show()

    # Create widgets for real-time updates
    azimuth_widget = FloatText(description='Azimuth:', step=0.1, disabled=True)
    elevation_widget = FloatText(description='Elevation:', step=0.1, disabled=True)
    distance_widget = FloatText(description='Distance:', step=0.1, disabled=True)

    set_view_widget = Textarea(
        description='set_view call:',
        value='set_view(fig, 0.00, 0.00, 0.00)',
        disabled=True,
        layout={'width': '400px', 'height': '50px'}
    )

    def update_widgets(change):
        pos = fig.camera.position
        center = fig.camera.center
        v = np.array(pos) - np.array(center)
        dist = np.linalg.norm(v)
        if dist > 0:
            elevation = np.degrees(np.arcsin(v[2] / dist))
            azimuth = np.degrees(np.arctan2(v[1], v[0]))
        else:
            azimuth = 0
            elevation = 0
        distance = dist
        azimuth_widget.value = azimuth
        elevation_widget.value = elevation
        distance_widget.value = distance
        set_view_widget.value = f"set_view(fig, {azimuth:.2f}, {elevation:.2f}, {distance:.2f})"
        print(set_view_widget.value)

    fig.camera.observe(update_widgets, names=['position'])
    update_widgets(None)

    # Display the widgets
    from IPython.display import display
    display(VBox([HBox([azimuth_widget, elevation_widget, distance_widget]), set_view_widget]))


In [3]:
# Interactive CF Results Plotting
# Set anterior_threshold to None for full whole brain, or a value (e.g., -30) for posterior cut
anterior_threshold = None
from ipywidgets import Dropdown, FloatSlider, interact
import glob
import re
from pathlib import Path

# Import necessary libraries
import numpy as np

# Get the current notebook's directory and define paths
notebook_dir = Path().resolve()

# Dataset configurations
DATASET_CONFIGS = {
    'LPP7T': {
        'name': 'Le Petit Prince (7T)',
        'task_pattern': r'cf_results_sub-(\d+)_(ses-\d+)_(\w+)_(lh|rh)-source_ecc([\d.]+)-([\d.]+)\.npz',
        'fs_subject_format': 'sub-{subject_id}_ses-01_iso'
    },
    'CB3T': {
        'name': 'Congenital Blindness (3T)',
        'task_pattern': r'cf_results_sub-(\d+)_(ses-\d+)_([\w-]+)_(lh|rh)-source_ecc([\d.]+)-([\d.]+)\.npz',
        'fs_subject_format': 'sub-{subject_id}_ses-01_iso'
    },
    'iCRTX7T': {
        'name': 'iCORTEX (7T)',
        'task_pattern': r'cf_results_sub-(\d+)_(ses-\d+)_([\w-]+)_(lh|rh)-source_ecc([\d.]+)-([\d.]+)\.npz',
        'fs_subject_format': 'sub-{subject_id}_ses-01_iso'
    }
}

DATASET_CONFIGS['iCRTX7T']['has_prf'] = True  # Flag for pRF availability

# Initialize with LPP7T dataset
current_dataset_key = 'LPP7T'
base_data_path = notebook_dir / 'data' / current_dataset_key

if not base_data_path.exists():
    print(f"‚ö†Ô∏è No data found for {DATASET_CONFIGS[current_dataset_key]['name']}")
    print(f"Expected path: {base_data_path}")
    print("\nPlease ensure data is organized as:")
    print("  data/")
    print("    LPP7T/")
    print("      derivatives/cf-models/")
    print("      fs_subjects/")
    print("    CB3T/")
    print("      derivatives/cf-models/")
    print("      fs_subjects/")

# Define CF models directory and FreeSurfer subjects directory
cf_models_dir = base_data_path / 'derivatives' / 'cf-models'
fs_subjects_dir = base_data_path / 'fs_subjects'


# Scan available CF model files
def scan_cf_models(cf_models_dir, task_pattern):
    """Scan CF models directory and extract available subjects, tasks, and sources."""
    cf_models_dir = Path(cf_models_dir)
    pattern = str(cf_models_dir / 'cf_results_sub-*.npz')
    files = glob.glob(pattern)
    
    available_models = []
    subjects = set()
    tasks = set()
    
    for f in files:
        basename = Path(f).name
        # Use dataset-specific pattern
        match = re.match(task_pattern, basename)
        if match:
            subject_id, session_id, task, source_hemi, min_ecc, max_ecc = match.groups()
            subjects.add(subject_id)
            tasks.add(task)
            available_models.append({
                'file': f,
                'subject': subject_id,
                'session': session_id,
                'task': task,
                'source_hemi': source_hemi,
                'min_ecc': float(min_ecc),
                'max_ecc': float(max_ecc)
            })
    
    return sorted(list(subjects)), sorted(list(tasks)), available_models

# Scan models with current dataset pattern
available_subjects, available_tasks, cf_models_info = scan_cf_models(
    cf_models_dir, 
    DATASET_CONFIGS[current_dataset_key]['task_pattern']
)


# Define property-specific defaults
property_config = {
    'r2': {
        'adaptive': True,
        'vmin': 0,
        'vmax': 1
    },
    'polar': {
        'adaptive': True,
        'vmin': -np.pi,
        'vmax': np.pi
    },
    'eccentricity': {
        'adaptive': False,
        'vmin': 0.5,
        'vmax': 6.5
    },
    'cf_size': {
        'adaptive': False,
        'vmin': 0.5,
        'vmax': 5.0
    },
    'polar_angle': {
        'adaptive': True,
        'vmin': -np.pi,
        'vmax': np.pi
    },
    'size': {
        'adaptive': False,
        'vmin': 0,
        'vmax': 5.0
    },
    'baseline': {
        'adaptive': True,
        'vmin': 0,
        'vmax': 1
    },
    'amplitude': {
        'adaptive': True,
        'vmin': 0,
        'vmax': 1
    },
    'hrf_delay': {
        'adaptive': True,
        'vmin': 0,
        'vmax': 1
    },
    'hrf_dispersion': {
        'adaptive': True,
        'vmin': 0,
        'vmax': 1
    }
}

# Cache for loaded results (includes dataset key to avoid conflicts)
loaded_results_cache = {}

def load_cf_results(subject_id, task, source_hemi, dataset_key='LPP7T'):
    """Load CF results for given subject, task, and source hemisphere."""
    # Create cache key with dataset prefix
    cache_key = f"{dataset_key}_sub-{subject_id}_{task}_{source_hemi}"
    
    # Check cache
    if cache_key in loaded_results_cache:
        return loaded_results_cache[cache_key]
    
    # Find matching file
    matching_files = [m for m in cf_models_info 
                     if m['subject'] == subject_id 
                     and m['task'] == task 
                     and m['source_hemi'] == source_hemi]
    
    if not matching_files:
        print(f"‚ùå No model found for sub-{subject_id}, task={task}, source={source_hemi}")
        return None, None
    
    model_file = matching_files[0]['file']
        
    # Load results
    results = np.load(model_file, allow_pickle=True)
    results_lh = results['lh_results'].item()
    results_rh = results['rh_results'].item()
    
    # Cache results
    loaded_results_cache[cache_key] = (results_lh, results_rh)
    
    return results_lh, results_rh


def load_prf_results(subject_id, dataset_key='iCRTX7T'):
    """Load pRF results from CSV files for given subject.

    Returns (lh_results, rh_results) as dicts where values are lists of (vertex_index, value) tuples.
    Arrays will be created later with correct mesh sizes.
    """
    cache_key = f"prf_{dataset_key}_sub-{subject_id}"

    if cache_key in loaded_results_cache:
        return loaded_results_cache[cache_key]

    dataset_root = notebook_dir / 'data' / dataset_key

    # Direct path for pRF file
    prf_file = dataset_root / 'derivatives' / 'pRF-maps' / f'pRF_parameters_both_hemispheres_averaged_sub-{subject_id}_ses-01.csv'

    prf_file_lh = dataset_root / 'derivatives' / 'pRF-maps' / f'pRF_parameters_lh_sub-{subject_id}_ses-01.csv'
    prf_file_rh = dataset_root / 'derivatives' / 'pRF-maps' / f'pRF_parameters_rh_sub-{subject_id}_ses-01.csv'

    # Empty results as dicts of empty lists
    empty_lh = {
        'polar_angle': [],
        'eccentricity': [],
        'size': [],
        'baseline': [],
        'amplitude': [],
        'hrf_delay': [],
        'hrf_dispersion': [],
        'r2': []
    }

    empty_rh = {
        'polar_angle': [],
        'eccentricity': [],
        'size': [],
        'baseline': [],
        'amplitude': [],
        'hrf_delay': [],
        'hrf_dispersion': [],
        'r2': []
    }

    if not prf_file.exists():
        print(f"‚ö†Ô∏è pRF file not found: {prf_file}")
        loaded_results_cache[cache_key] = (empty_lh, empty_rh)
        return empty_lh, empty_rh

    # Read CSV
    df = pd.read_csv(prf_file, sep=',', engine='python')
    
    cols_lower = {c.lower(): c for c in df.columns}
    hemi_col = cols_lower.get('hemisphere') or cols_lower.get('hemi')
    vertex_col = cols_lower.get('vertex_index') or cols_lower.get('vertex')
    pa_col = cols_lower.get('polar_angle') or cols_lower.get('polar')
    ecc_col = cols_lower.get('eccentricity')
    size_col = cols_lower.get('sd')
    r2_col = cols_lower.get('r2')
    baseline_col = cols_lower.get('baseline')
    amplitude_col = cols_lower.get('amplitude')
    hrf_delay_col = cols_lower.get('hrf_delay')
    hrf_dispersion_col = cols_lower.get('hrf_dispersion')

    # If hemisphere column missing, try to infer from filename
    hemi_default = None
    if hemi_col is None:
        fname = prf_file.name.lower()
        if '_lh' in fname or '.lh' in fname or 'lh_' in fname:
            hemi_default = 'lh'
        elif '_rh' in fname or '.rh' in fname or 'rh_' in fname:
            hemi_default = 'rh'
        elif 'both' in fname:
            print(f"‚ö†Ô∏è pRF CSV appears to be 'both' but lacks a hemisphere column: {prf_file.name}")
            loaded_results_cache[cache_key] = (empty_lh, empty_rh)
            return empty_lh, empty_rh
        else:
            print(f"‚ö†Ô∏è pRF CSV missing 'hemisphere' column and cannot infer hemisphere from filename: {prf_file.name}")
            loaded_results_cache[cache_key] = (empty_lh, empty_rh)
            return empty_lh, empty_rh

    lh_results = {
        'polar_angle': [],
        'eccentricity': [],
        'size': [],
        'baseline': [],
        'amplitude': [],
        'hrf_delay': [],
        'hrf_dispersion': [],
        'r2': []
    }

    rh_results = {
        'polar_angle': [],
        'eccentricity': [],
        'size': [],
        'baseline': [],
        'amplitude': [],
        'hrf_delay': [],
        'hrf_dispersion': [],
        'r2': []
    }

    # Iterate rows and collect data
    for _, row in df.iterrows():
        if hemi_col is not None:
            hemi = str(row[hemi_col]).lower()
        else:
            hemi = hemi_default

        try:
            vidx = int(row[vertex_col])
        except Exception:
            continue

        if hemi.startswith('lh'):
            if pa_col and pa_col in row and not pd.isna(row[pa_col]):
                lh_results['polar_angle'].append((vidx, row[pa_col]))
            if ecc_col and ecc_col in row and not pd.isna(row[ecc_col]):
                lh_results['eccentricity'].append((vidx, row[ecc_col]))
            if size_col and size_col in row and not pd.isna(row[size_col]):
                lh_results['size'].append((vidx, row[size_col]))
            if r2_col and r2_col in row and not pd.isna(row[r2_col]):
                lh_results['r2'].append((vidx, row[r2_col]))
            if baseline_col and baseline_col in row and not pd.isna(row[baseline_col]):
                lh_results['baseline'].append((vidx, row[baseline_col]))
            if amplitude_col and amplitude_col in row and not pd.isna(row[amplitude_col]):
                lh_results['amplitude'].append((vidx, row[amplitude_col]))
            if hrf_delay_col and hrf_delay_col in row and not pd.isna(row[hrf_delay_col]):
                lh_results['hrf_delay'].append((vidx, row[hrf_delay_col]))
            if hrf_dispersion_col and hrf_dispersion_col in row and not pd.isna(row[hrf_dispersion_col]):
                lh_results['hrf_dispersion'].append((vidx, row[hrf_dispersion_col]))
        elif hemi.startswith('rh'):
            if pa_col and pa_col in row and not pd.isna(row[pa_col]):
                rh_results['polar_angle'].append((vidx, row[pa_col]))
            if ecc_col and ecc_col in row and not pd.isna(row[ecc_col]):
                rh_results['eccentricity'].append((vidx, row[ecc_col]))
            if size_col and size_col in row and not pd.isna(row[size_col]):
                rh_results['size'].append((vidx, row[size_col]))
            if r2_col and r2_col in row and not pd.isna(row[r2_col]):
                rh_results['r2'].append((vidx, row[r2_col]))
            if baseline_col and baseline_col in row and not pd.isna(row[baseline_col]):
                rh_results['baseline'].append((vidx, row[baseline_col]))
            if amplitude_col and amplitude_col in row and not pd.isna(row[amplitude_col]):
                rh_results['amplitude'].append((vidx, row[amplitude_col]))
            if hrf_delay_col and hrf_delay_col in row and not pd.isna(row[hrf_delay_col]):
                rh_results['hrf_delay'].append((vidx, row[hrf_delay_col]))
            if hrf_dispersion_col and hrf_dispersion_col in row and not pd.isna(row[hrf_dispersion_col]):
                rh_results['hrf_dispersion'].append((vidx, row[hrf_dispersion_col]))

    loaded_results_cache[cache_key] = (lh_results, rh_results)
    return lh_results, rh_results

def update_plot(dataset_key, subject_id, task, source_hemi, map_type, cf_property,
                r2_threshold, use_adaptive_range, vmin, vmax, polar_colormap='polar',
                hemi_separation=80):
    """Update surface plot based on widget selections (supports CF and pRF)."""
    # Declare globals at the start
    global current_loaded_subject, current_loaded_dataset, lh_mesh, rh_mesh, lh_curv_map, rh_curv_map
    
    try:
        if map_type == 'CF':
            results_lh, results_rh = load_cf_results(subject_id, task, source_hemi, dataset_key)
            prop_map = {
                'eccentricity': 'inherited_eccen',
                'polar': 'inherited_polar',
                'cf_size': 'cf_size',
                'r2': 'r2'
            }
            if results_lh is None or results_rh is None:
                return
            if cf_property not in prop_map:
                print(f"Unknown CF property: {cf_property}")
                return
            lh_data = results_lh[prop_map[cf_property]].copy()
            rh_data = results_rh[prop_map[cf_property]].copy()
            lh_r2 = results_lh.get('r2', np.full_like(lh_data, np.nan))
            rh_r2 = results_rh.get('r2', np.full_like(rh_data, np.nan))
        elif map_type == 'pRF':
            results_lh, results_rh = load_prf_results(subject_id, dataset_key)
            prop_map_prf = {
                'eccentricity': 'eccentricity',
                'polar_angle': 'polar_angle',
                'size': 'size',
                'baseline': 'baseline',
                'amplitude': 'amplitude',
                'hrf_delay': 'hrf_delay',
                'hrf_dispersion': 'hrf_dispersion',
                'r2': 'r2'
            }
            if results_lh is None or results_rh is None:
                return
            if cf_property not in prop_map_prf:
                print(f"Unknown pRF property: {cf_property}")
                return
            
            # For pRF, load mesh if needed to get vertex counts for arrays
            current_sub_id = f"{int(subject_id):02d}"
            need_load = (
                'current_loaded_subject' not in globals() or
                'current_loaded_dataset' not in globals() or
                current_loaded_subject != current_sub_id or
                current_loaded_dataset != dataset_key
            )
            if need_load:
                fs_subject_format = DATASET_CONFIGS[dataset_key]['fs_subject_format']
                fs_subject_name = fs_subject_format.format(subject_id=current_sub_id)
                fs_subject_path = fs_subjects_dir / fs_subject_name
                print(f"üß† Loading mesh for {DATASET_CONFIGS[dataset_key]['name']}: {fs_subject_name}")
                if not fs_subject_path.exists():
                    print(f"‚ùå FreeSurfer subject not found: {fs_subject_path}")
                    return
                try:
                    import nibabel.freesurfer as fs
                    surf_dir = fs_subject_path / 'surf'
                    lh_coords, lh_faces = fs.read_geometry(str(surf_dir / 'lh.inflated'))
                    rh_coords, rh_faces = fs.read_geometry(str(surf_dir / 'rh.inflated'))
                    lh_mesh = Mesh(Tesselation(lh_faces.T), lh_coords.T)
                    rh_mesh = Mesh(Tesselation(rh_faces.T), rh_coords.T)
                    lh_curv_map = fs.read_morph_data(str(surf_dir / 'lh.curv'))
                    rh_curv_map = fs.read_morph_data(str(surf_dir / 'rh.curv'))
                    current_loaded_subject = current_sub_id
                    current_loaded_dataset = dataset_key
                    print(f"‚úì Loaded mesh: LH={lh_coords.shape[0]} vertices, RH={rh_coords.shape[0]} vertices")
                except Exception as e:
                    print(f"‚ùå Error loading mesh: {e}")
                    return
            
            # Create arrays for pRF data with correct mesh sizes (use global mesh shape)
            lh_data = np.full(lh_mesh.coordinates.shape[1], np.nan)
            rh_data = np.full(rh_mesh.coordinates.shape[1], np.nan)
            lh_r2 = np.full(lh_mesh.coordinates.shape[1], np.nan)
            rh_r2 = np.full(rh_mesh.coordinates.shape[1], np.nan)
            
            # Fill LH data (apply transformation for size)
            for vidx, val in results_lh[prop_map_prf[cf_property]]:
                if 0 <= vidx < lh_mesh.coordinates.shape[1]:
                    if cf_property == 'size':
                        lh_data[vidx] = abs(val) * 2.355
                    else:
                        lh_data[vidx] = val
            for vidx, val in results_lh['r2']:
                if 0 <= vidx < lh_mesh.coordinates.shape[1]:
                    lh_r2[vidx] = val
            
            # Fill RH data (apply transformation for size)
            for vidx, val in results_rh[prop_map_prf[cf_property]]:
                if 0 <= vidx < rh_mesh.coordinates.shape[1]:
                    if cf_property == 'size':
                        rh_data[vidx] = abs(val) * 2.355
                    else:
                        rh_data[vidx] = val
            for vidx, val in results_rh['r2']:
                if 0 <= vidx < rh_mesh.coordinates.shape[1]:
                    rh_r2[vidx] = val
            
            # pRF doesn't need task/source
            task = 'averaged'
            source_hemi = 'both'

        else:
            print(f"Unknown map_type: {map_type}")
            return
    except Exception as e:
        print(f"‚ùå Error loading results: {e}")
        return

    # Apply R¬≤ mask (now works for both CF and pRF since arrays are ready)
    lh_mask = np.isnan(lh_r2) | (lh_r2 < r2_threshold)
    rh_mask = np.isnan(rh_r2) | (rh_r2 < r2_threshold)
    lh_data = lh_data.copy()
    rh_data = rh_data.copy()
    lh_data[lh_mask] = np.nan
    rh_data[rh_mask] = np.nan

    # Load meshes if needed (unified for both CF and pRF; skips if already loaded)
    current_sub_id = f"{int(subject_id):02d}"
    need_load = (
        'current_loaded_subject' not in globals() or
        'current_loaded_dataset' not in globals() or
        current_loaded_subject != current_sub_id or
        current_loaded_dataset != dataset_key
    )
    if need_load:
        fs_subject_format = DATASET_CONFIGS[dataset_key]['fs_subject_format']
        fs_subject_name = fs_subject_format.format(subject_id=current_sub_id)
        fs_subject_path = fs_subjects_dir / fs_subject_name
        print(f"üß† Loading mesh for {DATASET_CONFIGS[dataset_key]['name']}: {fs_subject_name}")
        if not fs_subject_path.exists():
            print(f"‚ùå FreeSurfer subject not found: {fs_subject_path}")
            return
        try:
            import nibabel.freesurfer as fs
            surf_dir = fs_subject_path / 'surf'
            lh_coords, lh_faces = fs.read_geometry(str(surf_dir / 'lh.inflated'))
            rh_coords, rh_faces = fs.read_geometry(str(surf_dir / 'rh.inflated'))
            lh_mesh = Mesh(Tesselation(lh_faces.T), lh_coords.T)
            rh_mesh = Mesh(Tesselation(rh_faces.T), rh_coords.T)
            lh_curv_map = fs.read_morph_data(str(surf_dir / 'lh.curv'))
            rh_curv_map = fs.read_morph_data(str(surf_dir / 'rh.curv'))
            current_loaded_subject = current_sub_id
            current_loaded_dataset = dataset_key
            print(f"‚úì Loaded mesh: LH={lh_coords.shape[0]} vertices, RH={rh_coords.shape[0]} vertices")
        except Exception as e:
            print(f"‚ùå Error loading mesh: {e}")
            return

    # Adaptive vmin/vmax
    if use_adaptive_range:
        valid = np.concatenate([lh_data[~np.isnan(lh_data)], rh_data[~np.isnan(rh_data)]])
        if valid.size > 0:
            vmin = np.percentile(valid, 2)
            vmax = np.percentile(valid, 98)
        else:
            cfg = property_config.get(cf_property, {'vmin': 0, 'vmax': 1})
            vmin, vmax = cfg['vmin'], cfg['vmax']

    # Prepare strips and masks
    lh_strips_plot = lh_curv_map.astype(float)
    lh_strips_plot[lh_strips_plot > 0] = np.nan
    rh_strips_plot = rh_curv_map.astype(float)
    rh_strips_plot[rh_strips_plot > 0] = np.nan
    lh_mask_plot = ~np.isnan(lh_data)
    rh_mask_plot = ~np.isnan(rh_data)

    # Apply rotations and separation
    angle = 0
    lh_coords_plot = lh_mesh.coordinates
    rh_coords_plot = rh_mesh.coordinates
    lh_faces_plot = lh_mesh.tess.faces
    rh_faces_plot = rh_mesh.tess.faces
    lh_coords_medial = rotate_coords(lh_coords_plot, axis='z', angle_degrees=-angle)
    rh_coords_medial = rotate_coords(rh_coords_plot, axis='z', angle_degrees=angle * 2)
    rh_coords_medial[0, :] += hemi_separation
    lh_mesh_shifted = Mesh(Tesselation(lh_faces_plot), lh_coords_medial)
    rh_mesh_shifted = Mesh(Tesselation(rh_faces_plot), rh_coords_medial)

    # Select colormap and label
    if cf_property == 'eccentricity':
        grad_cmap = eccen_colors['matplotlib_cmap']
        cbar_label = r'Eccentricity $r$ (deg)'
    elif cf_property in ('polar', 'polar_angle'):
        grad_cmap = polar_colors['matplotlib_cmap'] if polar_colormap == 'polar' else plt.colormaps['hsv']
        cbar_label = r'Polar angle $\theta$ (rad)'
    elif cf_property == 'r2':
        grad_cmap = plt.colormaps['viridis']
        cbar_label = r'$R^2$'
    elif cf_property == 'cf_size':
        grad_cmap = plt.colormaps['viridis']
        cbar_label = r'CF size $\sigma$ (mm)'
    elif cf_property == 'size':
        grad_cmap = plt.colormaps['viridis']
        cbar_label = r'pRF size $\sigma$ (deg)'
    elif cf_property == 'baseline':
        grad_cmap = plt.colormaps['hot']
        cbar_label = 'Baseline (a.u.)'
    elif cf_property == 'amplitude':
        grad_cmap = plt.colormaps['viridis']
        cbar_label = 'Amplitude (a.u.)'
    elif cf_property == 'hrf_delay':
        grad_cmap = plt.colormaps['viridis']
        cbar_label = r'HRF delay ($s$)'
    elif cf_property == 'hrf_dispersion':
        grad_cmap = plt.colormaps['viridis']
        cbar_label = 'HRF dispersion'
    else:
        grad_cmap = plt.colormaps['viridis']
        cbar_label = 'Value'

    # Plot
    try:
        plot_and_save_brains(lh_data, rh_data, grad_cmap,
                             lh_mesh_shifted, rh_mesh_shifted,
                             lh_strips_plot, rh_strips_plot,
                             lh_mask_plot, rh_mask_plot,
                             (-122, -27, 80), vmin=vmin, vmax=vmax,
                             cbar_label=cbar_label, cf_property=cf_property,
                             polar_colormap=polar_colormap)
    except Exception as e:
        print(f"‚ùå Error plotting brains: {e}")
        return

    # Save current plot state
    global current_plot_data
    current_plot_data = {
        'lh_data': lh_data,
        'rh_data': rh_data,
        'lh_mesh': lh_mesh_shifted,
        'rh_mesh': rh_mesh_shifted,
        'colormap': grad_cmap,
        'vmin': vmin,
        'vmax': vmax,
        'cbar_label': cbar_label,
        'subject_id': subject_id,
        'task': task,
        'source_hemi': source_hemi,
        'cf_property': cf_property,
        'r2_threshold': r2_threshold,
        'map_type': map_type
    }

# Create widgets
dataset_widget = Dropdown(
    options=[(config['name'], key) for key, config in DATASET_CONFIGS.items()],
    value='LPP7T',
    description='Dataset:'
)
subject_widget = Dropdown(options=available_subjects, 
                         value=available_subjects[0] if available_subjects else '01', 
                         description='Subject:')
task_widget = Dropdown(options=available_tasks, 
                      value=available_tasks[0] if available_tasks else 'LPP1', 
                      description='Task:')
source_hemi_widget = Dropdown(options=['lh', 'rh'], value='lh', description='Source:')
map_type_widget = Dropdown(options=['CF', 'pRF'], value='CF', description='Map type:')
cf_property_widget = Dropdown(options=['eccentricity', 'polar', 'cf_size', 'r2'], 
                              value='eccentricity', description='Property:')
r2_threshold_widget = FloatSlider(value=0.1, min=0.0, max=1.0, step=0.01, description='R¬≤ threshold:')

# Adaptive range widget - default depends on CF property
adaptive_range_widget = Dropdown(options=[True, False], value=False, 
                                description='Range:')

# Min/max widgets - defaults depend on CF property - initialized for eccentricity
vmin_widget = FloatSlider(value=0.5, min=-10, max=10, step=0.01, description='min:')
vmax_widget = FloatSlider(value=6.5, min=-10, max=10, step=0.01, description='max:')

# Polar colormap widget - only visible when plotting polar angle
polar_colormap_widget = Dropdown(options=['polar', 'hsv'], value='polar', 
                                 description='Polar cmap:')

# Hemisphere separation slider
hemi_separation_widget = FloatSlider(
    value=80, 
    min=0, 
    max=200, 
    step=1, 
    description='Hemi gap:',
    tooltip='Distance between left and right hemispheres'
)

# Function to update widget defaults when CF property changes
def update_widget_defaults(cf_property):
    """Update adaptive range and min/max defaults based on CF property and map type."""
    map_type = map_type_widget.value  # Get current map type
    
    if map_type == 'pRF' and cf_property == 'eccentricity':
        # Special case for pRF eccentricity: set vmax to 4.5
        adaptive_range_widget.value = False
        vmin_widget.value = 0
        vmax_widget.value = 5
    else:
        # Use default config for other properties
        config = property_config.get(cf_property, {'adaptive': True, 'vmin': 0, 'vmax': 1})
        adaptive_range_widget.value = config['adaptive']
        vmin_widget.value = config['vmin']
        vmax_widget.value = config['vmax']

# Function to update property options and widget visibility based on map type
def update_map_type_options(change):
    """Update property options and widget visibility when map type changes."""
    map_type = change['new']
    if map_type == 'CF':
        cf_property_widget.options = ['eccentricity', 'polar', 'cf_size', 'r2']
        cf_property_widget.value = 'eccentricity'
        task_widget.layout.display = 'flex'
        source_hemi_widget.layout.display = 'flex'
        cf_property_widget.description = 'CF parameter:'
    elif map_type == 'pRF':
        cf_property_widget.options = ['eccentricity', 'polar_angle', 'size', 'baseline', 'amplitude', 'hrf_delay', 'hrf_dispersion', 'r2']
        cf_property_widget.value = 'eccentricity'
        task_widget.layout.display = 'none'
        source_hemi_widget.layout.display = 'none'
        cf_property_widget.description = 'pRF parameter:'
    # Update defaults for the new property
    update_widget_defaults(cf_property_widget.value)

# Function to update available subjects and tasks when dataset changes
def update_dataset_options(change):
    """Update subject and task options when dataset changes."""
    global available_subjects, available_tasks, cf_models_info, base_data_path, cf_models_dir, fs_subjects_dir, current_dataset_key
    
    new_dataset_key = change['new']
    current_dataset_key = new_dataset_key
    
    # Update base path to new dataset folder
    base_data_path = notebook_dir / 'data' / new_dataset_key
    
    if not base_data_path.exists():
        print(f"‚ö†Ô∏è No data found for {DATASET_CONFIGS[new_dataset_key]['name']}")
        print(f"Expected path: {base_data_path}")
        return
    
    # Update paths
    cf_models_dir = base_data_path / 'derivatives' / 'cf-models'
    fs_subjects_dir = base_data_path / 'fs_subjects'
    
    # Rescan with new pattern
    task_pattern = DATASET_CONFIGS[new_dataset_key]['task_pattern']
    available_subjects, available_tasks, cf_models_info = scan_cf_models(cf_models_dir, task_pattern)
    
    # Update widget options
    subject_widget.options = available_subjects
    subject_widget.value = available_subjects[0] if available_subjects else '01'
    task_widget.options = available_tasks
    task_widget.value = available_tasks[0] if available_tasks else 'LPP1'
    
    # Update map type options based on pRF availability
    has_prf = DATASET_CONFIGS[new_dataset_key].get('has_prf', False)
    if has_prf:
        map_type_widget.options = ['CF', 'pRF']
        map_type_widget.value = 'CF'  # Default to CF
    else:
        map_type_widget.options = ['CF']
        map_type_widget.value = 'CF'
    
    print(f"‚úì Switched to {DATASET_CONFIGS[new_dataset_key]['name']}")
    print(f"  Found {len(available_subjects)} subjects, {len(available_tasks)} tasks")
    if has_prf:
        print("  pRF data available")

# Link dataset widget to update function
dataset_widget.observe(update_dataset_options, names='value')

# Link map type widget to update function
map_type_widget.observe(update_map_type_options, names='value')

# Link CF property widget to update defaults
cf_property_widget.observe(lambda change: update_widget_defaults(change['new']), names='value')

# Interactive plot with conditional polar colormap widget
from ipywidgets import interactive_output, VBox, HBox
from IPython.display import display

# Create interactive output
ui_controls = {
    'dataset_key': dataset_widget,
    'subject_id': subject_widget,
    'task': task_widget, 
    'source_hemi': source_hemi_widget,
    'map_type': map_type_widget,
    'cf_property': cf_property_widget,
    'r2_threshold': r2_threshold_widget,
    'use_adaptive_range': adaptive_range_widget,
    'vmin': vmin_widget,
    'vmax': vmax_widget,
    'polar_colormap': polar_colormap_widget,
    'hemi_separation': hemi_separation_widget
}

out = interactive_output(update_plot, ui_controls)

# Function to update widget visibility based on CF property
def update_widget_visibility(change):
    """Show polar colormap widget only when polar angle is selected."""
    if change['new'] == 'polar':
        polar_colormap_widget.layout.display = 'flex'
    else:
        polar_colormap_widget.layout.display = 'none'

# Initially hide polar colormap widget if not plotting polar
if cf_property_widget.value != 'polar':
    polar_colormap_widget.layout.display = 'none'

# Initially set task/source visibility for CF
task_widget.layout.display = 'flex'
source_hemi_widget.layout.display = 'flex'

# Link visibility to CF property changes
cf_property_widget.observe(update_widget_visibility, names='value')

# Display widgets and output in compact 2-column layout
display(VBox([
    HBox([
        VBox([dataset_widget, subject_widget, map_type_widget, task_widget, source_hemi_widget]),
        VBox([cf_property_widget, polar_colormap_widget, r2_threshold_widget, hemi_separation_widget])
    ]),
    HBox([
        adaptive_range_widget,
        vmin_widget,
        vmax_widget
    ]),
    out
]))



VBox(children=(HBox(children=(VBox(children=(Dropdown(description='Dataset:', options=(('Le Petit Prince (7T)'‚Ä¶