# Import libraries

In [None]:
import json
import math
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy import stats
from scipy.stats import ttest_ind, wilcoxon, mannwhitneyu
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
from matplotlib.patches import Patch
from itertools import combinations
from scipy.stats.mstats import winsorize
from scipy.ndimage import gaussian_filter1d
from matplotlib.gridspec import GridSpec
from sklearn.metrics import r2_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.model_selection import KFold, cross_val_score
from sklearn.linear_model import LinearRegression, Ridge

# Data loading

In [None]:
# Helper function to extract value from single-key dictionaries
def extract_value(value):
    if isinstance(value, dict) and "col_1" in value:
        return value["col_1"]
    return value

# Load JSON data
json_file_path = 'data.json'  # Replace with actual file path
with open(json_file_path, 'r') as file:
    data = json.load(file)

# Define a mapping of column names to more descriptive names
column_mapping = {
    "col_1": "Rat_Session",
    "col_2": "Tetrode_ID",
    "col_3": "Neuron_ID",
    "col_4": "Firing_Rate_Data",
    "col_5": "Speed_Vector",
    "col_6": "Neuron_Classification",
    "col_7": "Mean_Firing_Rate",
    "col_8": "Locomotion_Cell",
    "col_9": "CCG_Value"
}

# Define a mapping for the nested Firing_Rate_Data columns
firing_rate_mapping = {
    "col_1": "Firing_Rate",
    "col_2": "Position_Bins",
    "col_3": "Correct_Choices",
    "col_4": "Cue_Type",
    "col_5": "Arm_Choice",
    "col_6": "Trial_Number"
}

# Convert JSON to a structured DataFrame
rows = []
for row_key, row_data in data.items():
    row = {}
    for col_key, col_value in row_data.items():
        # Map to the descriptive name if available
        descriptive_name = column_mapping.get(col_key, col_key)
        if descriptive_name == "Firing_Rate_Data" and isinstance(col_value, dict):
            # Expand the nested dictionary
            for nested_key, nested_value in col_value.items():
                nested_name = firing_rate_mapping.get(nested_key, nested_key)
                row[nested_name] = extract_value(nested_value)
        else:
            row[descriptive_name] = extract_value(col_value)
    rows.append(row)



df = pd.DataFrame(rows)
# map neuron classification from 21 and 22 to 1 and -1
df['Neuron_Classification'] = df['Neuron_Classification'].map({21: 1, 22: -1})

# List of columns to convert (columns 3-9)
numeric_columns = [
    'Neuron_ID',
    'Firing_Rate',
    'Position_Bins', 
    'Correct_Choices',
    'Cue_Type',
    'Arm_Choice',
    'Trial_Number',
    'CCG_Value'
]

# Convert specified columns to numpy arrays
for col in numeric_columns:
    df[col] = df[col].apply(np.array)

## Functions

In [None]:
def winsorize_data(X, y, limits=[0.05, 0.05]):
    """
    Apply winsorization to both X and y data arrays to limit extreme values.

    Winsorization replaces extreme values with less extreme values by setting
    values below and above specified percentiles to those percentile values.
    This helps reduce the impact of outliers while preserving data structure.

    Parameters
    ----------
    X : array-like
        Input array of predictor variables to be winsorized
    y : array-like  
        Input array of target variables to be winsorized
    limits : tuple of float, optional (default=[0.05, 0.05])
        Tuple of (lower, upper) percentages to cut on each tail.
        E.g. (0.05, 0.05) means cut 5% on both tails.
        
    Returns
    -------
    X_win : array-like
        Winsorized version of input X array
    y_win : array-like
        Winsorized version of input y array

    Notes
    -----
    Uses scipy.stats.mstats.winsorize under the hood.
    Both arrays are winsorized independently using the same limits.
    """
    """Apply winsorization to both X and y data"""
    X_win = winsorize(X, limits=limits)
    y_win = winsorize(y, limits=limits)
    return X_win, y_win

# Data pre-processing

In [None]:
def process_dataframe(df):
    df['Speed_Vector'] = df['Speed_Vector'].apply(np.array)

    # standardize the firing rates for each neuron
    df['Firing_Rate_Standardized'] = df['Firing_Rate'].apply(stats.zscore)
    
    # standardize the speed vector for each neuron
    df['Speed_Vector_Standardized'] = df['Speed_Vector'].apply(stats.zscore)

    # apply a Gaussian filter to smooth the firing rates
    df['Firing_Rate_Smoothed'] = df['Firing_Rate_Standardized'].apply(lambda x: gaussian_filter1d(x, sigma=10))
    
    # apply the same filter to the speed vector
    df['Speed_Vector_Smoothed'] = df['Speed_Vector_Standardized'].apply(lambda x: gaussian_filter1d(x, sigma=10))

    return df

# Apply the function to the dataframe
df = process_dataframe(df)

## Single speed cell surrogates

In [None]:
def compute_chance_r2(cell_idx, df, model, kf, n_shuffles=100):
    """Computes the mean chance R2 score for a single cell using circular time shifts."""
    X_orig_circ = np.array(df.loc[cell_idx, 'Firing_Rate_Smoothed']).reshape(-1, 1)
    y_orig_circ = np.array(df.loc[cell_idx, 'Firing_Rate_Smoothed'])

    if len(X_orig_circ) == 0 or len(y_orig_circ) == 0 or len(X_orig_circ) != len(y_orig_circ) \
       or len(X_orig_circ) < kf.get_n_splits():
        return np.nan

    _, y_win = winsorize_data(y_orig_circ, y_orig_circ) # Winsorize speed once

    shuffle_scores = []
    data_len = len(X_orig_circ)
    min_shift = int(data_len * 0.20)
    max_shift = int(data_len * 0.50)
    
    if min_shift >= max_shift: # Handle short data
        min_shift = 1
        max_shift = max(1, data_len - 1)
        if min_shift >= max_shift and data_len <= 1: return np.nan # Cannot shift

    for _ in range(n_shuffles):
        shift = np.random.randint(min_shift, max_shift + 1)
        X_shifted = np.roll(X_orig_circ, shift)
        X_shifted_win, _ = winsorize_data(X_shifted, X_shifted)
        X_shifted_reshaped = X_shifted_win.reshape(-1, 1)

        current_len = min(len(X_shifted_reshaped), len(y_win))
        X_shifted_final = X_shifted_reshaped[:current_len]
        y_win_matched = y_win[:current_len]

        if len(y_win_matched) < kf.get_n_splits(): continue

        try:
            scores = cross_val_score(model, X_shifted_final, y_win_matched, cv=kf, scoring='r2')
            shuffle_scores.append(np.mean(scores))
        except ValueError: continue

    return np.mean(shuffle_scores) if shuffle_scores else np.nan


# Single-cell decoding pipeline

In [None]:
# Define model with polynomial features
model = Pipeline([
    ('poly', PolynomialFeatures(degree=3)),
    ('regressor', LinearRegression())
])

# Initialize list to store results
results = []

# Cross validation
kf = KFold(n_splits=10, shuffle=True, random_state=777)

# Train models for each cell
for idx in df.index:
    # Get cell info
    cell_info = {
        'Cell_Index': idx,
        'Rat_Session': df.loc[idx, 'Rat_Session'],
        'Tetrode_ID': df.loc[idx, 'Tetrode_ID'],
        'Neuron_ID': df.loc[idx, 'Neuron_ID']
    }
    
    # Prepare data
    X = np.array(df.loc[idx, 'Firing_Rate_Smoothed']).reshape(-1, 1)
    y = np.array(df.loc[idx, 'Speed_Vector_Smoothed'])

    # Apply winsorization
    X_win, y_win = winsorize_data(X.ravel(), y)
    X_win = np.asarray(X_win).reshape(-1, 1)

    cv_scores = cross_val_score(model, X_win, y_win, cv=kf, scoring='r2')

    # Add scores to cell info
    for fold, score in enumerate(cv_scores, 1):
        cell_info[f'Fold_{fold}_R2'] = score
    cell_info['Mean_R2'] = cv_scores.mean()
    cell_info['Std_R2'] = cv_scores.std()
    
    results.append(cell_info)

# Create DataFrame
results_df = pd.DataFrame(results)

# Display summary
print("Results DataFrame with Winsorized Data:")
print(results_df.head())
print("\nSummary Statistics:")
print(results_df[['Mean_R2', 'Std_R2']].describe())

## Compute the single speed cell surrogates

In [None]:
# Calculate Chance R2 for Speed Cells ---
speed_cell_indices = df[df['Locomotion_Cell'] == True].index
chance_r2_results = {}
print("\nCalculating chance R2 scores for speed cells...")
for cell_idx in tqdm(speed_cell_indices):
    if pd.notna(results_df.loc[cell_idx, 'Mean_R2']):
        # Pass the correct model (single_cell_model)
        chance_mean = compute_chance_r2(cell_idx, df, model, kf, n_shuffles=100)
        chance_r2_results[cell_idx] = chance_mean
    else:
        chance_r2_results[cell_idx] = np.nan

results_df['Chance_Mean_R2'] = results_df['Cell_Index'].map(chance_r2_results)

# Stats for Speed Cells vs Chance ---
speed_cell_actual_r2 = results_df.loc[results_df['Cell_Index'].isin(speed_cell_indices), 'Mean_R2'].dropna()
speed_cell_chance_r2 = results_df.loc[results_df['Cell_Index'].isin(speed_cell_indices), 'Chance_Mean_R2'].dropna()

common_indices = speed_cell_actual_r2.index.intersection(speed_cell_chance_r2.index)
speed_cell_actual_r2 = speed_cell_actual_r2.loc[common_indices]
speed_cell_chance_r2 = speed_cell_chance_r2.loc[common_indices]

chance_comparison_stat, chance_comparison_pvalue = np.nan, np.nan # Initialize
if len(speed_cell_actual_r2) > 0 and len(speed_cell_actual_r2) == len(speed_cell_chance_r2):
     diff = speed_cell_actual_r2 - speed_cell_chance_r2
     if np.all(np.isclose(diff, 0)):
         print("Actual and Chance R2 scores are identical. Wilcoxon test not applicable.")
         chance_comparison_pvalue = 1.0
     else:
         try:
              chance_comparison_stat, chance_comparison_pvalue = wilcoxon(speed_cell_actual_r2, speed_cell_chance_r2)
              print(f"\nWilcoxon test: Speed Cells Actual R2 vs. Chance R2")
              print(f"Statistic: {chance_comparison_stat:.3f}, P-value: {chance_comparison_pvalue:.3e}")
         except ValueError as e: # Fallback if Wilcoxon fails
              print(f"Wilcoxon failed ('{e}'), falling back to paired t-test.")
              try:
                  chance_comparison_stat, chance_comparison_pvalue = stats.ttest_rel(speed_cell_actual_r2, speed_cell_chance_r2)
                  print(f"Paired t-test: Stat={chance_comparison_stat:.3f}, P-value={chance_comparison_pvalue:.3e}")
              except Exception as te:
                   print(f"Paired t-test also failed: {te}")
else:
     print("Not enough valid paired data points for Speed Cell vs Chance comparison.")

# Single-cell decoding stats

In [None]:
# Filter for speed cells first, then by neuron classification
speed_cell_mask = df['Locomotion_Cell'] == True
fsi_data = results_df[(df['Neuron_Classification'] == -1) & speed_cell_mask]['Mean_R2'].dropna().values # Fast-spiking interneurons that are speed cells
msn_data = results_df[(df['Neuron_Classification'] == 1) & speed_cell_mask]['Mean_R2'].dropna().values # Medium spiny neurons that are speed cells
other_data = results_df[df['Locomotion_Cell'] == False]['Mean_R2'].dropna().values # Non-speed cells
speed_data = results_df[df['Locomotion_Cell'] == True]['Mean_R2'].dropna().values # Actual R2
speed_chance_data = results_df[df['Locomotion_Cell'] == True]['Chance_Mean_R2'].dropna().values # Chance R2

neuron_ttest = stats.ttest_ind(fsi_data, msn_data, nan_policy='omit') # Fast-spiking vs Medium spiny
loco_ttest = stats.ttest_ind(other_data, speed_data, nan_policy='omit') # Non-speed vs Speed cells
# chance_comparison stat/pvalue already calculated

In [None]:
# Comparison of Positive vs Negative CCG Cells

mean_r2_by_cell = results_df.set_index('Cell_Index')['Mean_R2']

# Use CCG_Value instead of Optimal_Lag to define positive/negative cells
pos_idx = df[(df['Locomotion_Cell'] == True) & (df['CCG_Value'] >= 0)].index
neg_idx = df[(df['Locomotion_Cell'] == True) & (df['CCG_Value'] < 0)].index

pos_r2 = mean_r2_by_cell.reindex(pos_idx).dropna().values
neg_r2 = mean_r2_by_cell.reindex(neg_idx).dropna().values

# Statistical test: Mann-Whitney U (nonparametric, two independent groups)
pos_neg_p = np.nan
pos_neg_stat = np.nan
if len(pos_r2) > 0 and len(neg_r2) > 0:
    pos_neg_stat, pos_neg_p = mannwhitneyu(pos_r2, neg_r2, alternative='two-sided')
    print(f"Positive vs Negative CCG Mann-Whitney U p-value: {pos_neg_p:.3e}, U={pos_neg_stat:.3f}")
    print(f"Positive CCG cells: {len(pos_r2)}, Negative CCG cells: {len(neg_r2)}")
else:
    print("Positive vs Negative CCG comparison: not enough data (pos:", len(pos_r2), "neg:", len(neg_r2), ")")

# Multi-cell decoding pipeline

## Helper functions

In [None]:
def validate_combination(cell_indices, df):
    """
    Validates if selected cells from the DataFrame have compatible data lengths.
    This function checks if the firing rate vectors for all selected cells have the same length,
    and if the speed vector matches this length. This validation is crucial for ensuring 
    data consistency before further analysis.
    Parameters
    ----------
    cell_indices : list or array-like
        List of cell indices to validate from the DataFrame
    df : pandas.DataFrame
        DataFrame containing the neural recording data with columns:
        - 'CircLagged_Firing_Rate': firing rate vectors for each cell
        - 'CircLagged_Speed': speed vector for each cell
    Returns
    -------
    bool
        True if all vectors have matching lengths
        False if any length mismatch is found
    Examples
    --------
    >>> cell_ids = [0, 1, 2]
    >>> result = validate_combination(cell_ids, neural_data_df)
    >>> if result:
    ...     print("Data lengths are compatible")
    ... else:
    ...     print("Data lengths mismatch")
    Notes
    -----
    The function performs two checks:
    1. Verifies that all firing rate vectors have the same length
    2. Confirms that the speed vector matches the firing rate vector length
    """
    # Check if all firing rate vectors have the same length
    firing_rate_lengths = [len(df.loc[cell_idx, 'Firing_Rate_Smoothed']) for cell_idx in cell_indices]
    if len(set(firing_rate_lengths)) != 1:
        print(f"Error: Firing rate vectors have different lengths for cells {cell_indices}")
        return False
    
    # Check if speed vector has the same length as firing rate vectors
    speed_length = len(df.loc[cell_indices[0], 'Speed_Vector_Smoothed'])
    if speed_length != firing_rate_lengths[0]:
        print(f"Error: Speed vector has different length for cells {cell_indices}")
        return False
    
    return True

In [None]:
def compute_r2_scores(cell_indices, df, cv=kf):
    """
    Compute cross-validated R² for a combination of cells WITHOUT time-lag features.
    Each cell contributes a single column (winsorized firing rate). Uses cross_val_score.
    Returns array of CV fold R² scores or array([np.nan]) if not computable.
    """
    # collect winsorized firing-rate columns
    X_cols = []
    for cell_idx in cell_indices:
        X_temp = np.array(df.loc[cell_idx, 'Firing_Rate_Smoothed']).ravel()
        X_win, _ = winsorize_data(X_temp, X_temp)
        X_cols.append(np.asarray(X_win).ravel())

    # stack columns -> shape (n_samples, n_cells)
    X_final = np.column_stack(X_cols)

    # target (use first cell's speed vector; validate_combination2 ensures lengths match)
    y = np.array(df.loc[cell_indices[0], 'Speed_Vector_Smoothed']).ravel()
    _, y_win = winsorize_data(y, y)
    y_final = y_win[:len(X_final)]  # safety trim if needed

    # require enough samples for CV
    n_splits = getattr(cv, "get_n_splits", lambda *a, **k: 1)()
    if len(y_final) < n_splits or X_final.shape[0] < n_splits:
        return np.array([np.nan])

    try:
        scores = cross_val_score(model, X_final, y_final, cv=cv, scoring='r2', n_jobs=None)
        return scores
    except Exception:
        return np.array([np.nan])

## Compute the decoding results

In [None]:
linear_pipeline = Pipeline([
    ('poly', PolynomialFeatures(degree=3)),
    ('regressor', Ridge(alpha=1.0))
])

r2_data = {n: [] for n in range(1, 6)}

mean_r2_series = results_df.set_index('Cell_Index')['Mean_R2']
df['SingleCell_R2'] = mean_r2_series
session_groups = df.groupby('Rat_Session').groups
# compute total_combinations
total_combinations = 0
for session, cell_indices in session_groups.items():
    session_cells = list(cell_indices)
    filtered_speed = [
        idx for idx in session_cells
        if df.loc[idx, 'Locomotion_Cell']
        and pd.notna(df.loc[idx, 'SingleCell_R2'])
        and df.loc[idx, 'SingleCell_R2'] >= 0.1
    ]
    for n_cells in range(1, min(6, len(filtered_speed) + 1)):
        total_combinations += math.comb(len(filtered_speed), n_cells)

with tqdm(total=total_combinations, desc="Computing R² scores") as pbar:
    for session, cell_indices in session_groups.items():
        session_cells = list(cell_indices)
        filtered_speed = [
            idx for idx in session_cells
            if df.loc[idx, 'Locomotion_Cell']
            and pd.notna(df.loc[idx, 'SingleCell_R2'])
            and df.loc[idx, 'SingleCell_R2'] >= 0.1
        ]
        for n_cells in range(1, min(6, len(filtered_speed) + 1)):
            for cell_combination in combinations(filtered_speed, n_cells):
                if validate_combination(cell_combination, df):
                    scores = compute_r2_scores(cell_combination, df, cv=kf)
                    if np.all(np.isnan(scores)):
                        r2_data[n_cells].append(np.nan)
                    else:
                        r2_data[n_cells].append(np.median(scores))
                pbar.update(1)

print("\nNumber of valid combinations by group size:")
print("-" * 40)
for n_cells, scores in r2_data.items():
    print(f"Number of combinations for {n_cells} cells: {len(scores)}")

## Multi-cell decoding surrogates

In [None]:
def compute_multi_cell_chance_r2(cell_indices, df, model, kf, n_shuffles=100):
    """
    Computes the mean chance R2 score for a combination of cells using circular time shifts.
    All firing rate vectors are shifted by the same amount to preserve inter-neuronal relationships
    while destroying the relationship with speed.
    
    Parameters
    ----------
    cell_indices : list
        List of cell indices to include in the multi-cell analysis
    df : pandas.DataFrame
        DataFrame containing neural data
    model : sklearn.pipeline.Pipeline
        Machine learning model for decoding
    kf : sklearn.model_selection.KFold
        Cross-validation object
    n_shuffles : int, default=100
        Number of surrogate iterations
        
    Returns
    -------
    float
        Mean chance R² score across shuffles, or np.nan if computation fails
    """
    # Validate combination first
    if not validate_combination(cell_indices, df):
        return np.nan
    
    # Collect original data
    X_cols = []
    for cell_idx in cell_indices:
        X_temp = np.array(df.loc[cell_idx, 'Firing_Rate_Smoothed']).ravel()
        X_win, _ = winsorize_data(X_temp, X_temp)
        X_cols.append(np.asarray(X_win).ravel())
    
    X_orig = np.column_stack(X_cols)  # Shape: (n_samples, n_cells)
    
    # Get target speed vector
    y_orig = np.array(df.loc[cell_indices[0], 'Speed_Vector_Smoothed']).ravel()
    _, y_win = winsorize_data(y_orig, y_orig)
    y_final = y_win[:len(X_orig)]
    
    # Check if we have enough data
    if len(X_orig) == 0 or len(y_final) == 0 or len(X_orig) != len(y_final) or len(X_orig) < kf.get_n_splits():
        return np.nan
    
    shuffle_scores = []
    data_len = len(X_orig)
    min_shift = int(data_len * 0.20)
    max_shift = int(data_len * 0.50)
    
    # Handle short data
    if min_shift >= max_shift:
        min_shift = 1
        max_shift = max(1, data_len - 1)
        if min_shift >= max_shift and data_len <= 1:
            return np.nan
    
    for _ in range(n_shuffles):
        # Apply same shift to all firing rate vectors
        shift = np.random.randint(min_shift, max_shift + 1)
        X_shifted = np.roll(X_orig, shift, axis=0)  # Shift along time axis
        
        # Ensure same length
        current_len = min(len(X_shifted), len(y_final))
        X_shifted_final = X_shifted[:current_len]
        y_final_matched = y_final[:current_len]
        
        if len(y_final_matched) < kf.get_n_splits():
            continue
            
        try:
            scores = cross_val_score(model, X_shifted_final, y_final_matched, cv=kf, scoring='r2')
            shuffle_scores.append(np.mean(scores))
        except (ValueError, Exception):
            continue
    
    return np.mean(shuffle_scores) if shuffle_scores else np.nan

In [None]:
# Compute surrogate R² for multi-cell combinations
print("Computing surrogate R² scores for multi-cell combinations...")

multi_cell_chance_r2 = {n: [] for n in range(1, 6)}
session_groups = df.groupby('Rat_Session').groups

# Calculate total combinations for progress bar
total_combinations_chance = 0
for session, cell_indices in session_groups.items():
    session_cells = list(cell_indices)
    filtered_speed = [
        idx for idx in session_cells
        if df.loc[idx, 'Locomotion_Cell']
        and pd.notna(df.loc[idx, 'SingleCell_R2'])
        and df.loc[idx, 'SingleCell_R2'] >= 0.1
    ]
    for n_cells in range(1, min(6, len(filtered_speed) + 1)):
        total_combinations_chance += math.comb(len(filtered_speed), n_cells)

with tqdm(total=total_combinations_chance, desc="Computing multi-cell surrogate R² scores") as pbar:
    for session, cell_indices in session_groups.items():
        session_cells = list(cell_indices)
        filtered_speed = [
            idx for idx in session_cells
            if df.loc[idx, 'Locomotion_Cell']
            and pd.notna(df.loc[idx, 'SingleCell_R2'])
            and df.loc[idx, 'SingleCell_R2'] >= 0.1
        ]
        
        for n_cells in range(1, min(6, len(filtered_speed) + 1)):
            for cell_combination in combinations(filtered_speed, n_cells):
                if validate_combination(cell_combination, df):
                    if n_cells == 1:
                        # For single cells, use existing single-cell surrogate if available
                        cell_idx = cell_combination[0]
                        if cell_idx in chance_r2_results and pd.notna(chance_r2_results[cell_idx]):
                            chance_score = chance_r2_results[cell_idx]
                        else:
                            chance_score = compute_chance_r2(cell_idx, df, model, kf, n_shuffles=100)
                    else:
                        # For multi-cell combinations, compute new surrogate
                        chance_score = compute_multi_cell_chance_r2(cell_combination, df, model, kf, n_shuffles=100)
                    
                    multi_cell_chance_r2[n_cells].append(chance_score)
                pbar.update(1)

print("\nSurrogate R² computation completed!")
print("Number of surrogate combinations by group size:")
print("-" * 50)
for n_cells, scores in multi_cell_chance_r2.items():
    valid_scores = [s for s in scores if not np.isnan(s)]
    print(f"{n_cells} cells: {len(valid_scores)} valid surrogates out of {len(scores)} total")

# Statistical comparison between actual and surrogate multi-cell decoding
print("\nStatistical comparisons - Actual vs Surrogate:")
print("=" * 60)

for n_cells in range(1, 6):
    actual_scores = [s for s in r2_data[n_cells] if not np.isnan(s)]
    surrogate_scores = [s for s in multi_cell_chance_r2[n_cells] if not np.isnan(s)]
    
    if len(actual_scores) > 0 and len(surrogate_scores) > 0:
        # Mann-Whitney U test (independent samples)
        stat, p_val = mannwhitneyu(actual_scores, surrogate_scores, alternative='two-sided')
        effect_size = (np.mean(actual_scores) - np.mean(surrogate_scores)) / np.std(np.concatenate([actual_scores, surrogate_scores]))
        
        print(f"{n_cells} cells:")
        print(f"  Actual mean ± std: {np.mean(actual_scores):.4f} ± {np.std(actual_scores):.4f}")
        print(f"  Surrogate mean ± std: {np.mean(surrogate_scores):.4f} ± {np.std(surrogate_scores):.4f}")
        print(f"  Mann-Whitney U p-value: {p_val:.3e}")
        print(f"  Effect size (Cohen's d): {effect_size:.3f}")
        print(f"  Significant: {'Yes' if p_val < 0.05 else 'No'}")
        print()
    else:
        print(f"{n_cells} cells: Insufficient data for comparison")
        print()

# Representative cell decoding

In [None]:
cell_idx = 74
start_idx = 7000
end_idx = 8000

# Get original data and extract the segment
original_firing = np.array(df.loc[cell_idx, 'Firing_Rate'])
original_speed = np.array(df.loc[cell_idx, 'Speed_Vector'])

# Smooth original data
original_firing = gaussian_filter1d(original_firing, sigma=10)
original_speed = gaussian_filter1d(original_speed, sigma=10)

# Extract the segment
firing_segment = original_firing[start_idx:end_idx]
speed_segment = original_speed[start_idx:end_idx]

# Apply winsorization to the segment
firing_win, speed_win = winsorize_data(firing_segment, speed_segment)

# Fit model and predict
linear_pipeline.fit(firing_win.reshape(-1, 1), speed_win)
speed_decoded = linear_pipeline.predict(firing_win.reshape(-1, 1))

# Convert to regular numpy arrays
y_win = np.asarray(speed_win)
predictions = np.asarray(speed_decoded)


# Multi-cell decoding stats

In [None]:
mean_r2_by_cell = results_df.set_index('Cell_Index')['Mean_R2']

# Use CCG_Value instead of Optimal_Lag to define positive/negative cells
pos_idx = df[(df['Locomotion_Cell'] == True) & (df['CCG_Value'] >= 0)].index
neg_idx = df[(df['Locomotion_Cell'] == True) & (df['CCG_Value'] < 0)].index

pos_r2 = mean_r2_by_cell.reindex(pos_idx).dropna().values
neg_r2 = mean_r2_by_cell.reindex(neg_idx).dropna().values

# Statistical test: Mann-Whitney U (nonparametric, two independent groups)
pos_neg_p = np.nan
pos_neg_stat = np.nan
if len(pos_r2) > 0 and len(neg_r2) > 0:
    pos_neg_stat, pos_neg_p = mannwhitneyu(pos_r2, neg_r2, alternative='two-sided')
    print(f"Positive vs Negative CCG Mann-Whitney U p-value: {pos_neg_p:.3e}, U={pos_neg_stat:.3f}")
    print(f"Positive CCG cells: {len(pos_r2)}, Negative CCG cells: {len(neg_r2)}")
else:
    print("Positive vs Negative CCG comparison: not enough data (pos:", len(pos_r2), "neg:", len(neg_r2), ")")


In [None]:
cell_numbers = np.arange(1, 6)
actual_box_data = [r2_data[i] for i in cell_numbers]

# Statistical comparisons between consecutive actual data points
actual_p_values = []
for i in range(len(actual_box_data) - 1):
    stat, p_val = mannwhitneyu(actual_box_data[i], actual_box_data[i + 1])
    actual_p_values.append(p_val)
    print(f"Actual multi-cell {i+1} vs {i+2} cells p-value: {p_val:.3e}")

# Add comparison significance bars between actual and surrogate for each cell number
multi_cell_comparisons_p = []
for n_cells in cell_numbers:
    actual_scores = [s for s in r2_data[n_cells] if not np.isnan(s)]
    surrogate_scores = [s for s in multi_cell_chance_r2[n_cells] if not np.isnan(s)]
    
    if len(actual_scores) > 0 and len(surrogate_scores) > 0:
        stat, p_val = mannwhitneyu(actual_scores, surrogate_scores, alternative='two-sided')
        multi_cell_comparisons_p.append(p_val)
        print(f"Multi-cell {n_cells} cells - Actual vs Surrogate p-value: {p_val:.3e}")
    else:
        multi_cell_comparisons_p.append(np.nan)
        

# Export the data for figure 7

In [None]:
# Export all the necessary data for plotting
figure_data = {
    # Example decoding data
    'y_win': y_win,
    'predictions': predictions,
    
    # Cell type comparison data
    'fsi_data': fsi_data,
    'msn_data': msn_data,
    'neuron_ttest_pvalue': neuron_ttest.pvalue,
    
    # Locomotion cell comparison data
    'other_data': other_data,
    'speed_data': speed_data,
    'loco_ttest_pvalue': loco_ttest.pvalue,
    
    # Speed cells vs surrogate data
    'speed_r2': speed_data,
    'surrogate_data': speed_cell_chance_r2.values,
    'chance_comparison_pvalue': chance_comparison_pvalue,
    
    # Positive vs negative lag data
    'pos_r2': pos_r2,
    'neg_r2': neg_r2,
    'pos_neg_p': pos_neg_p,
    
    # Multi-cell decoding data
    'r2_data': r2_data,
    'multi_cell_chance_r2': multi_cell_chance_r2,
    'actual_p_values': actual_p_values,
    'multi_cell_comparisons_p': multi_cell_comparisons_p,
    
    # Additional metadata
    'cell_numbers': cell_numbers.tolist(),
    'positions_actual': (np.arange(1, 11, 2) - 0.3).tolist(),
    'positions_surrogate': (np.arange(1, 11, 2) + 0.3).tolist()
}

# Save as pickle file for easy loading
with open('figure_data_final_acabapeloamordedeus.pkl', 'wb') as f:
    pickle.dump(figure_data, f)

# Figure (before processing in image editor)

In [None]:
# Load the data
with open('figure_data_final.pkl', 'rb') as f:
    data = pickle.load(f)

# Extract all variables
y_win = np.array(data['y_win'])
predictions = np.array(data['predictions'])
fsi_data = np.array(data['fsi_data'])
msn_data = np.array(data['msn_data'])
neuron_ttest_pvalue = data['neuron_ttest_pvalue']
other_data = np.array(data['other_data'])
speed_data = np.array(data['speed_data'])
loco_ttest_pvalue = data['loco_ttest_pvalue']
speed_r2 = np.array(data['speed_r2'])
surrogate_data = np.array(data['surrogate_data'])
chance_comparison_pvalue = data['chance_comparison_pvalue']
pos_r2 = np.array(data['pos_r2'])
neg_r2 = np.array(data['neg_r2'])
pos_neg_p = data['pos_neg_p']
r2_data = data['r2_data']
multi_cell_chance_r2 = data['multi_cell_chance_r2']
actual_p_values = data['actual_p_values']
multi_cell_comparisons_p = data['multi_cell_comparisons_p']
cell_numbers = np.array(data['cell_numbers'])
positions_actual = np.array(data['positions_actual'])
positions_surrogate = np.array(data['positions_surrogate'])

print("Data loaded successfully!")
print(f"FSI vs MSN p-value: {neuron_ttest_pvalue}")
print(f"Speed vs Other p-value: {loco_ttest_pvalue}")
print(f"Speed vs Surrogate p-value: {chance_comparison_pvalue}")
print("Data loaded successfully!")
print(f"FSI vs MSN p-value: {neuron_ttest_pvalue}")
print(f"Speed vs Other p-value: {loco_ttest_pvalue}")
print(f"Speed vs Surrogate p-value: {chance_comparison_pvalue}")

# Create figure with the new layout
fig = plt.figure(figsize=(12, 8))  # Reduced height

# Outer GridSpec has 2 rows: top and bottom with increased spacing
outer_gs = GridSpec(2, 1, height_ratios=[1, 1], figure=fig, hspace=0.3)  # Increased hspace for more whitespace


###############################################################################
# First row: 2 columns (75% left, 25% right)
###############################################################################
gs_top = GridSpecFromSubplotSpec(1, 2, subplot_spec=outer_gs[0], width_ratios=[3, 1])

# --- Speed Decoding (Row 1, Column 1) ---
ax_speed = fig.add_subplot(gs_top[0, 0])
ax_speed.plot(y_win, label='Actual', alpha=0.7, color='black', linewidth=2.5)
ax_speed.plot(predictions, label='Decoded', alpha=0.7, color='red', linewidth=2.5)
ax_speed.set_xlabel('Time (s)', fontname='Arial', fontsize=12)
ax_speed.set_ylabel('Speed (cm/s)', fontname='Arial', fontsize=12)
ax_speed.legend(frameon=False, prop={'family': 'Arial', 'size': 12})
ax_speed.spines['top'].set_visible(False)
ax_speed.spines['right'].set_visible(False)
ax_speed.set_ylim(-0, 50)
xticks = np.linspace(0, 1000, 6)
xticklabels = [f'{x:g}' for x in np.linspace(0, 100, 6)]
ax_speed.set_xticks(xticks)
ax_speed.set_xticklabels(xticklabels, fontsize=12)
ax_speed.text(-0.1, 1.05, 'A', fontsize=14, fontweight='bold', fontname='Arial', transform=ax_speed.transAxes)

# --- Cell Type Analysis (Row 1, Column 2) ---
ax_cell = fig.add_subplot(gs_top[0, 1])
bp_cell = ax_cell.boxplot([fsi_data, msn_data], showfliers=False, notch=False, patch_artist=True)
bp_cell['boxes'][0].set_facecolor('none')
bp_cell['boxes'][0].set_edgecolor('#FC4366')
bp_cell['boxes'][1].set_facecolor('none')
bp_cell['boxes'][1].set_edgecolor('#AEB2FF')
ax_cell.scatter([1.2] * len(fsi_data), fsi_data, alpha=0.5, color='#FC4366')
ax_cell.scatter([1.8] * len(msn_data), msn_data, alpha=0.5, color='#AEB2FF')
ax_cell.set_ylabel('Decoding Accuracy', fontname='Arial', fontsize=12)
ax_cell.set_xticks([1, 2])
ax_cell.set_xticklabels(['FSI', 'MSN'], fontname='Arial', fontsize=12)
ax_cell.spines['top'].set_visible(False)
ax_cell.spines['right'].set_visible(False)
for median in bp_cell['medians']:
    median.set_color('black')
for whisker in bp_cell['whiskers']:
    whisker.set_color(whisker.get_color())
for cap in bp_cell['caps']:
    cap.set_color(cap.get_color())
y_max_cell = max(np.max(fsi_data), np.max(msn_data))
bar_height_cell = y_max_cell + 0.02
ax_cell.plot([1, 2], [bar_height_cell, bar_height_cell], '-k', linewidth=1)
ax_cell.plot([1, 1], [bar_height_cell - 0.01, bar_height_cell], '-k', linewidth=1)
ax_cell.plot([2, 2], [bar_height_cell - 0.01, bar_height_cell], '-k', linewidth=1)
if neuron_ttest_pvalue < 0.05:
    ax_cell.text(1.5, bar_height_cell + 0.01, '*', ha='center', va='bottom', fontsize=14)
else:
    ax_cell.text(1.5, bar_height_cell + 0.01, 'n.s.', ha='center', va='bottom', fontsize=14)

ax_cell.text(-0.4, 1.05, 'B', fontsize=14, fontweight='bold', fontname='Arial', transform=ax_cell.transAxes)

###############################################################################
# Second row: 4 columns (20%, 20%, 20%, 40%)
###############################################################################
gs_bottom = GridSpecFromSubplotSpec(1, 4, subplot_spec=outer_gs[1], width_ratios=[1, 1, 1, 2])

# --- Locomotion Cell Analysis (Row 2, Column 1) ---
ax_loco = fig.add_subplot(gs_bottom[0, 0])
bp_loco = ax_loco.boxplot([other_data, speed_data], showfliers=False, notch=False, patch_artist=True)
bp_loco['boxes'][0].set_facecolor('none')
bp_loco['boxes'][0].set_edgecolor('gray')
bp_loco['boxes'][1].set_facecolor('none')
bp_loco['boxes'][1].set_edgecolor('#37a259')
ax_loco.scatter([1.2] * len(other_data), other_data, alpha=0.5, color='gray')
ax_loco.scatter([1.8] * len(speed_data), speed_data, alpha=0.5, color='#37a259')
ax_loco.set_ylabel('Decoding Accuracy', fontname='Arial', fontsize=12)
ax_loco.set_xticks([1, 2])
ax_loco.set_xticklabels(['Other', 'Speed'], fontname='Arial', fontsize=12)
ax_loco.spines['top'].set_visible(False)
ax_loco.spines['right'].set_visible(False)
for median in bp_loco['medians']:
    median.set_color('black')
for whisker in bp_loco['whiskers']:
    whisker.set_color(whisker.get_color())
for cap in bp_loco['caps']:
    cap.set_color(cap.get_color())
if loco_ttest_pvalue < 0.05:
    y_max_loco = max(np.max(other_data), np.max(speed_data))
    bar_height_loco = y_max_loco + 0.02
    ax_loco.plot([1, 2], [bar_height_loco, bar_height_loco], '-k', linewidth=1)
    ax_loco.plot([1, 1], [bar_height_loco - 0.01, bar_height_loco], '-k', linewidth=1)
    ax_loco.plot([2, 2], [bar_height_loco - 0.01, bar_height_loco], '-k', linewidth=1)
    ax_loco.text(1.5, bar_height_loco + 0.01, '*', ha='center', va='bottom', fontsize=12)
    
ax_loco.text(-0.1, 1.05, 'C', fontsize=14, fontweight='bold', fontname='Arial', transform=ax_loco.transAxes)

# --- Speed Cells vs Surrogates (Row 2, Column 2) ---
ax_new = fig.add_subplot(gs_bottom[0, 1])
bp_new = ax_new.boxplot([speed_r2, surrogate_data], showfliers=False, notch=False, patch_artist=True)
bp_new['boxes'][0].set_facecolor('none')
bp_new['boxes'][0].set_edgecolor('#37a259')
bp_new['boxes'][1].set_facecolor('none')
bp_new['boxes'][1].set_edgecolor('gray')
ax_new.scatter([1.2] * len(speed_r2), speed_r2, alpha=0.5, color='#37a259')
ax_new.scatter([1.8] * len(surrogate_data), surrogate_data, alpha=0.5, color='gray')
ax_new.set_ylabel('Decoding Accuracy', fontname='Arial', fontsize=12)
ax_new.set_xticks([1, 2])
ax_new.set_xticklabels(['Speed Cells', 'Surrogate'], fontname='Arial', fontsize=12)
ax_new.spines['top'].set_visible(False)
ax_new.spines['right'].set_visible(False)
for median in bp_new['medians']:
    median.set_color('black')
for whisker in bp_new['whiskers']:
    whisker.set_color(whisker.get_color())
for cap in bp_new['caps']:
    cap.set_color(cap.get_color())
y_max_new = max(np.max(speed_r2), np.max(surrogate_data))
bar_height_new = y_max_new + 0.02
ax_new.plot([1, 2], [bar_height_new, bar_height_new], '-k', linewidth=1)
ax_new.plot([1, 1], [bar_height_new - 0.01, bar_height_new], '-k', linewidth=1)
ax_new.plot([2, 2], [bar_height_new - 0.01, bar_height_new], '-k', linewidth=1)
if chance_comparison_pvalue < 0.05:
    ax_new.text(1.5, bar_height_new + 0.01, '*', ha='center', va='bottom', fontsize=12)
else:
    ax_new.text(1.5, bar_height_new + 0.01, 'n.s.', ha='center', va='bottom', fontsize=12)

ax_new.text(-0.1, 1.05, 'D', fontsize=14, fontweight='bold', fontname='Arial', transform=ax_new.transAxes)

# --- Positive vs Negative Optimal_Lag comparison (Row 2, Column 3) ---
ax_pos_neg = fig.add_subplot(gs_bottom[0, 2])
if len(pos_r2) + len(neg_r2) > 0:
    bp_pn = ax_pos_neg.boxplot([pos_r2, neg_r2], showfliers=False, patch_artist=True)
    colors_pn = ['#2ca02c', '#d62728']  # green for positive, red for negative
    for patch, color in zip(bp_pn['boxes'], colors_pn):
        patch.set_facecolor('none')
        patch.set_edgecolor(color)
    # scatter individual points with slight jitter
    ax_pos_neg.scatter(np.random.normal(1, 0.04, size=len(pos_r2)), pos_r2, color=colors_pn[0], alpha=0.6, s=12)
    ax_pos_neg.scatter(np.random.normal(2, 0.04, size=len(neg_r2)), neg_r2, color=colors_pn[1], alpha=0.6, s=12)

    ax_pos_neg.set_xticks([1, 2])
    ax_pos_neg.set_xticklabels(['Lag > 0', 'Lag < 0'], fontname='Arial', fontsize=12)
    ax_pos_neg.set_ylabel('Decoding Accuracy', fontsize=12)
    ax_pos_neg.spines['top'].set_visible(False)
    ax_pos_neg.spines['right'].set_visible(False)

    # annotate significance
    combined = np.concatenate([pos_r2, neg_r2]) if (len(pos_r2)+len(neg_r2))>0 else np.array([0.0])
    y_max_pn = np.nanmax(combined) if combined.size>0 else 0.0
    bar_h = y_max_pn + 0.02
    ax_pos_neg.plot([1, 2], [bar_h, bar_h], '-k', linewidth=1)
    ax_pos_neg.plot([1, 1], [bar_h - 0.01, bar_h], '-k', linewidth=1)
    ax_pos_neg.plot([2, 2], [bar_h - 0.01, bar_h], '-k', linewidth=1)
    if not np.isnan(pos_neg_p):
        sig_text = '*' if pos_neg_p < 0.05 else 'n.s.'
        ax_pos_neg.text(1.5, bar_h + 0.01, sig_text, ha='center', va='bottom', fontsize=12)
else:
    ax_pos_neg.text(0.5, 0.5, 'No data for positive vs negative lag comparison', ha='center', va='center', fontsize=10)

ax_pos_neg.text(-0.1, 1.05, 'E', fontsize=14, fontweight='bold', fontname='Arial', transform=ax_pos_neg.transAxes)

# --- Decoding with Multiple Cells INCLUDING SURROGATES (Row 2, Column 4) ---
ax_multi = fig.add_subplot(gs_bottom[0, 3])

# Prepare actual and surrogate data for boxplots
actual_box_data = [r2_data[i] for i in cell_numbers]
surrogate_box_data = [multi_cell_chance_r2[i] for i in cell_numbers]

# Plot actual data (green)
bp_actual = ax_multi.boxplot(actual_box_data, positions=positions_actual, widths=0.4,
                            showfliers=False, patch_artist=True, 
                            boxprops=dict(facecolor='none', edgecolor='#37a259'),
                            whiskerprops=dict(color='#37a259'),
                            capprops=dict(color='#37a259'))

# Plot surrogate data (gray)
bp_surrogate = ax_multi.boxplot(surrogate_box_data, positions=positions_surrogate, widths=0.4,
                               showfliers=False, patch_artist=True,
                               boxprops=dict(facecolor='none', edgecolor='gray'),
                               whiskerprops=dict(color='gray'),
                               capprops=dict(color='gray'))

# Set medians to black for both
for median in bp_actual['medians']:
    median.set_color('black')
for median in bp_surrogate['medians']:
    median.set_color('black')

# Add median line and dots for actual data (green) - positioned slightly to the left
median_offset = -0.45  # Offset to move circles to the left of boxplots
medians_actual = [median.get_ydata()[0] for median in bp_actual['medians']]
ax_multi.plot(positions_actual + median_offset, medians_actual, color='#37a259', alpha=0.7, linewidth=2)
ax_multi.scatter(positions_actual + median_offset, medians_actual, color='#37a259', zorder=5, s=30)

# Add significance bars for actual vs actual comparison
for i, p_val in enumerate(actual_p_values):
    x1, x2 = positions_actual[i], positions_actual[i + 1]
    # Find the highest point among the two groups being compared
    y_max_comp = max(np.max(actual_box_data[i]), np.max(actual_box_data[i + 1]))
    bar_height_comp = y_max_comp + 0.15  # Reduced from 0.25 to bring bars down
    
    ax_multi.plot([x1, x2], [bar_height_comp, bar_height_comp], '-k', linewidth=0.8)
    ax_multi.plot([x1, x1], [bar_height_comp - 0.01, bar_height_comp], '-k', linewidth=0.8)
    ax_multi.plot([x2, x2], [bar_height_comp - 0.01, bar_height_comp], '-k', linewidth=0.8)
    
    sig_text = '*' if p_val < 0.05 else 'n.s.'
    ax_multi.text((x1 + x2) / 2, bar_height_comp + 0.01, sig_text, 
                 ha='center', va='bottom', fontsize=10)

# Add significance indicators between actual and surrogate for each cell number
for i, (pos_act, pos_sur, p_val) in enumerate(zip(positions_actual, positions_surrogate, multi_cell_comparisons_p)):
    if not np.isnan(p_val):
        # Get max height for this comparison
        actual_data = actual_box_data[i]
        surrogate_points = surrogate_box_data[i]
        
        if len(actual_data) > 0 and len(surrogate_points) > 0:
            y_max_comp = max(np.max(actual_data), np.max(surrogate_points))
            bar_height_comp = y_max_comp + 0.1  # Lower position for actual vs surrogate
            
            # Draw comparison bar
            ax_multi.plot([pos_act, pos_sur], [bar_height_comp, bar_height_comp], '-k', linewidth=0.8)
            ax_multi.plot([pos_act, pos_act], [bar_height_comp - 0.01, bar_height_comp], '-k', linewidth=0.8)
            ax_multi.plot([pos_sur, pos_sur], [bar_height_comp - 0.01, bar_height_comp], '-k', linewidth=0.8)
            
            # Add significance text (single asterisk regardless of p-value)
            sig_text = '*' if p_val < 0.05 else 'n.s.'
            ax_multi.text((pos_act + pos_sur) / 2, bar_height_comp + 0.01, sig_text, 
                         ha='center', va='bottom', fontsize=10)

# Add legend - moved to center right
legend_elements = [Patch(facecolor='none', edgecolor='#37a259', label='Actual'),
                   Patch(facecolor='none', edgecolor='gray', label='Surrogate')]
ax_multi.legend(handles=legend_elements, loc='center right', bbox_to_anchor=(1, 0.4), frameon=False, fontsize=10)


ax_multi.set_xlabel('Number of Cells', fontname='Arial', fontsize=12)
ax_multi.set_ylabel('Decoding Accuracy', fontname='Arial', fontsize=12)
ax_multi.set_xticks(np.arange(1, 11, 2))
ax_multi.set_xticklabels(cell_numbers, fontsize=12)
ax_multi.spines['top'].set_visible(False)
ax_multi.spines['right'].set_visible(False)
ax_multi.set_xlim(0, 10)

ax_multi.text(-0.1, 1.05, 'F', fontsize=14, fontweight='bold', fontname='Arial', transform=ax_multi.transAxes)

# Save and show the figure
plt.tight_layout()
plt.subplots_adjust(hspace=0.1)  # Reduced hspace for less whitespace
plt.savefig('figure7.png', dpi=300, bbox_inches='tight')
plt.show()

print("Figure saved as 'figure7.png'")