In [2]:
#Imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm
import time
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
from threading import Lock
from collections import Counter
import networkx as nx
from scipy.spatial.distance import pdist, squareform
from scipy.spatial import ConvexHull
import ast
from itertools import chain

#Skeleton
import pcg_skel

# CloudVolume and Cave setup
from cloudvolume import CloudVolume
from caveclient import CAVEclient
sv = CloudVolume('graphene://https://minnie.microns-daf.com/segmentation/table/zheng_ca3', use_https=True, lru_bytes=int(1e8))
client = CAVEclient('zheng_ca3')
auth = client.auth

import warnings

warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)

In [3]:
def update_segids_df(df, super_voxel_col):
    # Get the current date
    current_date = datetime.now().strftime('%Y%m%d')  # Format: YYYYMMDD
    
    # Update segids
    updated_segid_list = client.chunkedgraph.get_roots(df[super_voxel_col])
    
    # Add the updated segids to the DataFrame with the date in the column name
    updated_col_name = f"updated_segids_{current_date}"
    df[updated_col_name] = updated_segid_list
    
    print(f"Number of updated segids: {len(df)}")
    return df


def compute_xyz_metrics(data):
    """
    Compute pairwise distance variance and mean for a single or nested list of 3D points.

    Parameters:
        data (list): A list or nested list of 3D point lists.
        
    Returns:
        tuple: A tuple containing the variance and mean of pairwise distances.
    """
    if not data or not isinstance(data, list):
        var = 0
        dis = 0
        return var, dis

    # Flatten the nested list if necessary
    if isinstance(data[0], list) and isinstance(data[0][0], list):  # Nested list
        flattened_data = [point for sublist in data for point in sublist]
    elif isinstance(data[0], list):  # Single list of points
        flattened_data = data
    else:
        raise ValueError("Input data must contain 3D points as lists.")

    # Check if the points have 3 coordinates
    for point in flattened_data:
        if not isinstance(point, list) or len(point) != 3:
            raise ValueError("All points must be lists of three numeric coordinates.")

    # Convert to numpy array
    points_array = np.array(flattened_data)

    # Compute pairwise distances
    distances = pdist(points_array, metric='euclidean')
    var = np.var(distances)
    dis = np.mean(distances)

    return var, dis


def bouton_detector(dataframe, synapse_threshold, voxel_threshold):
    # List to store results for creating the DataFrame
    results = []
    
    # Convert 'pre_pt_position' and 'pre_pt_root_id' to NumPy arrays for faster operations
    positions = np.array(list(dataframe['pre_pt_position']))
    pre_pt_root_ids = dataframe['pre_pt_root_id'].values
    post_pt_root_ids = dataframe['post_pt_root_id'].values
    
    # Iterate over each unique 'pre_pt_root_id'
    unique_ids = np.unique(pre_pt_root_ids)
    
    for unique_id in unique_ids:
        # Filter the dataframe for the current 'pre_pt_root_id'
        temp_df = dataframe[dataframe['pre_pt_root_id'] == unique_id]
        
        # Extract the positions and corresponding post IDs
        temp_positions = np.array(list(temp_df['pre_pt_position']))
        temp_post_ids = temp_df['post_pt_root_id'].values
        
        # Initialize a bouton counter and lists for bouton positions and bouton partners
        bouton_counter = 0
        bouton_positions = []
        bouton_post_partners = []
        visited = np.zeros(len(temp_positions), dtype=bool)
        
        for i in range(len(temp_positions)):
            if visited[i]:
                continue
            
            # Calculate distances from the current position to all others
            distances = np.sqrt(np.sum((temp_positions - temp_positions[i]) ** 2, axis=1))
            
            # Find indices of rows within the voxel threshold
            close_points = np.where((distances <= voxel_threshold) & (~visited))[0]
            
            # If 4 or more rows are within the voxel threshold, count it as a bouton
            if len(close_points) >= synapse_threshold:
                bouton_counter += 1
                visited[close_points] = True
                
                # Save the positions and unique post IDs of this bouton group
                bouton_positions.append(temp_positions[close_points].tolist())
                bouton_post_partners.append(list(set(temp_post_ids[close_points])))
        
        # Combine groups that are within 1000 of each other
        combined_positions = []
        combined_post_partners = []
        
        while bouton_positions:
            group = bouton_positions.pop(0)
            group_post_ids = bouton_post_partners.pop(0)
            group = np.array(group)
            to_merge = []
            
            for idx, other_group in enumerate(bouton_positions):
                other_group = np.array(other_group)
                other_post_ids = bouton_post_partners[idx]
                # Check if the current group is within 1000 of another group
                distances = np.sqrt(np.sum((group[:, None] - other_group[None, :]) ** 2, axis=2))
                #if np.any(distances <= 1000):
                if np.any(distances <= 300):
                    to_merge.append((other_group, other_post_ids))
            
            # Merge all nearby groups into the current group
            for merge_group, merge_post_ids in to_merge:
                group = np.vstack((group, merge_group))
                group_post_ids = list(set(group_post_ids + merge_post_ids))
                bouton_positions.remove(merge_group.tolist())
                bouton_post_partners.remove(merge_post_ids)
            
            combined_positions.append(group.tolist())
            combined_post_partners.append(group_post_ids)
        
        # Update bouton count after merging
        bouton_counter = len(combined_positions)
        
        # Find bouton volume (area contianed by a single bouton - axon will be massive)
        bouton_volume = find_volume(combined_positions)

        # Find the variance to eliminate axons, they are evenly disributed. 
        synapse_variance, synapse_mean_distance = compute_xyz_metrics(combined_positions)
        
        # Append results to the list
        for i in range(bouton_counter):
            results.append({
                "pre_pt_root_id": unique_id,
                "bouton_id": i + 1,
                "bouton_positions": combined_positions[i],
                "bouton_partners": combined_post_partners[i],
                "total_boutons": bouton_counter,  # Added total bouton count
                "bouton_volume": bouton_volume,   # Added bouton volume
                "synapse_variance": synapse_variance, # Added synapse variance
                "synapse_mean_distance": synapse_mean_distance
            })
    
    # Create a DataFrame from the results
    boutons_df = pd.DataFrame(results) 

    return boutons_df


def processor_guts(idx, presyn_segids_chunk, results_list, counter_error, start_time, second_time=False):
    global api_call_counter, api_call_start_time  # Use global variables for API call tracking
    results = []

    # Check API rate limit before making the synapse query
    manage_api_rate_limit()

    try:
        # Query synapse data for the current chunk of 'presyn_segids'
        df_synapse = client.materialize.synapse_query(
            pre_ids=presyn_segids_chunk,  # Pass the chunk of presyn_segids
            post_ids=None,              
            synapse_table="synapses_ca3_v1",
            desired_resolution=[18, 18, 45]
        )
        api_call_counter += 1  # Increment API call counter

    except Exception as e:
        print(f"Error querying synapse for chunk {presyn_segids_chunk}: {e}")
        print("Retrying after 20 seconds...")
        counter_error += 1
        print(f"Number of Errors: {counter_error}")
        time.sleep(20)
        manage_api_rate_limit()  # Ensure rate limit before retrying
        try:
            df_synapse = client.materialize.synapse_query(
                pre_ids=presyn_segids_chunk, 
                post_ids=None,              
                synapse_table="synapses_ca3_v1",
                desired_resolution=[18, 18, 45]
            )
            api_call_counter += 1  # Increment API call counter

        except Exception as e2:
            print(f"Second attempt failed for chunk {presyn_segids_chunk}: {e2}")
            print("Skipping this chunk.")
            return results_list, counter_error, start_time, idx  # Return early if failed

    # Process the DataFrame to find boutons
    try:
        if not df_synapse.empty:
            unique_presyn_ids = df_synapse["pre_pt_root_id"].unique()
            for presyn_id in unique_presyn_ids:
                temp_df = df_synapse[df_synapse["pre_pt_root_id"] == presyn_id]
                postsyn_list = temp_df["post_pt_root_id"].unique()
                postsyn_num = len(postsyn_list)

                if second_time:
                    bouton_result = bouton_detector_2(temp_df, 8, 500)
                    #bouton_result = bouton_detector_2(temp_df, 10, 1000)

                else:
                    bouton_result = bouton_detector(temp_df, 8, 500)
                    #bouton_result = bouton_detector(temp_df, 10, 1000)


                # Iterate over the bouton results and save them
                for _, row in bouton_result.iterrows():
                    results.append({
                        "presyn_segid": presyn_id,
                        "postsyn_num": postsyn_num,
                        "postsyn_id_list": postsyn_list.tolist(),
                        "pre_pt_root_id": row['pre_pt_root_id'],
                        "bouton_id": row['bouton_id'],
                        "bouton_positions": row['bouton_positions'],
                        "bouton_partners": row['bouton_partners'],
                        "bouton_volume": row['bouton_volume'],
                        "synapse_variance": row['synapse_variance'],
                        "synapse_mean_distance": row['synapse_mean_distance']
                    })
        else:
            print("No synapses in dataframe, dataframe empty.")
    except Exception as bouton_error:
        print(f"Error processing boutons for chunk {presyn_segids_chunk}: {bouton_error}")

    # Append results to the results list
    results_list.extend(results)
    return results_list, counter_error, start_time, idx


def bouton_processor(df, limit_TF=False, limit=5, chunk_size=300):
    # List to store results
    results_list = []
    start_time = time.time()
    counter_error = 0

    # Process each chunk of 'updated_segids_20241227' column with tqdm progress bar
    current_date = datetime.now().strftime('%Y%m%d')  # Format: YYYYMMDD
    updated_col_name = f"updated_segids_{current_date}" 
    segids = df[updated_col_name].tolist()
    
    # Create chunks of presyn_segids
    chunks = [segids[i:i + chunk_size] for i in range(0, len(segids), chunk_size)]
    
    for idx, presyn_segids_chunk in enumerate(tqdm(chunks, desc="Processing Mossy Fiber Synapses")):
        # Limit iterations for testing purposes
        if limit_TF and idx > limit:
            print(f"Maximum index of {limit} reached. Stopping processing.")
            break

        # Process the current chunk of presyn_segids
        results_list, counter_error, start_time, error_prevent_count = processor_guts(idx, presyn_segids_chunk, \
                                                results_list, counter_error, start_time, second_time=False)

    # Combine all results into a single DataFrame if there are any results
    if results_list:
        MF_POSTSYN_DF = pd.DataFrame(results_list)
    else:
        MF_POSTSYN_DF = pd.DataFrame()
        
    # Count up the number of boutons
    MF_POSTSYN_DF['total_boutons'] = MF_POSTSYN_DF.groupby('presyn_segid')['bouton_id'].transform('count')

    return MF_POSTSYN_DF


def find_bouton_partners(POSTSYN_DATA):  
    # Create a new DataFrame with only one row per unique presyn_segid
    unique_presyn_df = POSTSYN_DATA.drop_duplicates(subset='presyn_segid', keep='first')
    # Initialize a dictionary to store all bouton partners and their corresponding presyn_segids
    all_bouton_to_presyn = {}

    # Iterate through the DataFrame
    for index, row in unique_presyn_df.iterrows():
        if isinstance(row['bouton_partners'], str):
            try:
                bouton_list = eval(row['bouton_partners'])  # Convert string to list
            except:
                continue
        elif isinstance(row['bouton_partners'], list):
            bouton_list = row['bouton_partners']
        else:
            continue

        # Map each bouton partner to its corresponding presyn_segid
        for bouton in bouton_list:
            if bouton not in all_bouton_to_presyn:
                all_bouton_to_presyn[bouton] = []
            all_bouton_to_presyn[bouton].append(row['presyn_segid'])

    # Remove duplicate presyn_segids for each bouton_partner
    all_bouton_to_presyn = {key: list(set(value)) for key, value in all_bouton_to_presyn.items()}

    # Convert the dictionary to a DataFrame
    all_bouton_df = pd.DataFrame({
        'bouton_partner': list(all_bouton_to_presyn.keys()),
        'presyn_segids': list(all_bouton_to_presyn.values())
    })

    # Add a column for the number of unique presyn_segids
    all_bouton_df['num_presyn_segids'] = all_bouton_df['presyn_segids'].apply(len)
    
    # Sort the DataFrame by 'num_presyn_segids' in descending order
    all_bouton_df_sorted = all_bouton_df.sort_values(by='num_presyn_segids', ascending=False)

    # Display the resulting DataFrame
    return all_bouton_df_sorted

In [4]:
# Global variable to track API calls within a minute
api_call_counter = 0
api_call_start_time = time.time()

def nonMossy_partner_analysis(df, chunk_size=100):
    global api_call_counter, api_call_start_time

    # Reset index and initialize columns
    df.reset_index(inplace=True, drop=True)
    df["Potential_Mossy_Partners"] = None  # Initialize the column with None

    # Initialize additional columns for bouton attributes
    bouton_columns = [
        "pre_pt_root_id", "bouton_id", "bouton_positions", 
        "bouton_partners", "bouton_volume", "synapse_variance", 
        "synapse_mean_distance"
    ]
    for col in bouton_columns:
        df[col] = None

    error_prevent_count = 0

    for idx, bouton_partner in enumerate(tqdm(df['bouton_partner'], desc="Processing Mossy Fiber Synapses")):
        while True:  # Retry logic
            try:
                # Handle rate limiting
                manage_api_rate_limit()

                # Query the synapse data for the current bouton partner
                df_synapse_temp = client.materialize.synapse_query(
                    pre_ids=None,
                    post_ids=[bouton_partner],
                    synapse_table="synapses_ca3_v1",
                    desired_resolution=[18, 18, 45]
                )
                api_call_counter += 1

                if not df_synapse_temp.empty:
                    postsyn_list = df_synapse_temp["post_pt_root_id"].unique()
                    postsyn_num = len(postsyn_list)

                    # Filter and deduplicate results
                    filtered_df = df_synapse_temp[
                        df_synapse_temp['pre_pt_root_id'].map(df_synapse_temp['pre_pt_root_id'].value_counts()) > 5
                    ]
                    filtered_df = filtered_df.drop_duplicates(subset='pre_pt_root_id')

                    # Run the bouton detector function with chunking
                    bouton_result, error_prevent_count, _ = bouton_processor_2(
                        filtered_df, chunk_size=chunk_size, idx=idx, 
                        error_prevent_count=error_prevent_count, start_time=api_call_start_time
                    )

                    # If bouton results are found, assign them to the respective columns
                    if not bouton_result.empty:
                        for col in bouton_columns:
                            df.at[idx, col] = bouton_result[col].tolist()

                    # Save Potential Mossy Partners
                    df.at[idx, 'Potential_Mossy_Partners'] = filtered_df['pre_pt_root_id'].tolist()

                break  # Exit the retry loop if successful

            except Exception as e:
                print(f"Error processing bouton partner {bouton_partner}: {e}")
                print("Retrying after 20 seconds...")
                time.sleep(20)

    # Remove rows where 'Potential_Mossy_Partners' is empty
    df = df[df['Potential_Mossy_Partners'].apply(lambda x: x is not None and len(x) > 0)]
    
    return df


def bouton_processor_2(df, chunk_size=300, limit_TF=False, limit=5, idx=0, error_prevent_count=0, start_time=0):
    global api_call_counter, api_call_start_time

    # List to store results
    results_list = []
    counter_error = 0

    # Get unique presyn_segids and create chunks
    presyn_segids = df['pre_pt_root_id'].unique()
    chunks = [presyn_segids[i:i + chunk_size] for i in range(0, len(presyn_segids), chunk_size)]

    for idx_2, presyn_segids_chunk in enumerate(chunks):
        while True:  # Retry logic
            try:
                # Handle rate limiting
                manage_api_rate_limit()

                # Process the current chunk of presyn_segids
                results_list, counter_error, start_time, error_prevent_count = processor_guts(
                    error_prevent_count, presyn_segids_chunk, results_list, 
                    counter_error, start_time, second_time=True
                )
                api_call_counter += 1

                break  # Exit the retry loop if successful

            except Exception as e:
                print(f"Error processing chunk {presyn_segids_chunk}: {e}")
                print("Retrying after 20 seconds...")
                time.sleep(20)

    # Combine all results into a single DataFrame if there are any results
    if results_list:
        MF_POSTSYN_DF = pd.DataFrame(results_list)
    else:
        MF_POSTSYN_DF = pd.DataFrame()

    return MF_POSTSYN_DF, error_prevent_count, start_time


def manage_api_rate_limit():
    """Ensures that the number of API calls does not exceed the limit."""
    global api_call_counter, api_call_start_time

    elapsed_time = time.time() - api_call_start_time
    if elapsed_time < 60 and api_call_counter >= 300:
        time_to_sleep = 60 - elapsed_time
        print(f"Rate limit reached. Sleeping for {time_to_sleep:.2f} seconds...")
        time.sleep(time_to_sleep)
        api_call_start_time = time.time()  # Reset the timer
        api_call_counter = 0
    elif elapsed_time >= 60:
        # Reset counter and timer after a minute
        api_call_start_time = time.time()
        api_call_counter = 0


def bouton_detector_2(dataframe, synapse_threshold, voxel_threshold):
    # Dictionary to store results, keyed by `pre_pt_root_id`
    results_by_bouton = {}

    # Convert 'pre_pt_position' and 'pre_pt_root_id' to NumPy arrays for faster operations
    positions = np.array(list(dataframe['pre_pt_position']))
    pre_pt_root_ids = dataframe['pre_pt_root_id'].values
    post_pt_root_ids = dataframe['post_pt_root_id'].values

    # Iterate over each unique 'pre_pt_root_id'
    unique_ids = np.unique(pre_pt_root_ids)

    for unique_id in unique_ids:
        # Filter the dataframe for the current 'pre_pt_root_id'
        temp_df = dataframe[dataframe['pre_pt_root_id'] == unique_id]

        # Extract the positions and corresponding post IDs
        temp_positions = np.array(list(temp_df['pre_pt_position']))
        temp_post_ids = temp_df['post_pt_root_id'].values

        # Initialize a bouton counter and lists for bouton positions and bouton partners
        bouton_counter = 0
        bouton_positions = []
        bouton_post_partners = []
        visited = np.zeros(len(temp_positions), dtype=bool)

        for i in range(len(temp_positions)):
            if visited[i]:
                continue

            # Calculate distances from the current position to all others
            distances = np.sqrt(np.sum((temp_positions - temp_positions[i]) ** 2, axis=1))

            # Find indices of rows within the voxel threshold
            close_points = np.where((distances <= voxel_threshold) & (~visited))[0]

            # If 4 or more rows are within the voxel threshold, count it as a bouton
            if len(close_points) >= synapse_threshold:
                bouton_counter += 1
                visited[close_points] = True

                # Save the positions and unique post IDs of this bouton group
                bouton_positions.append(temp_positions[close_points].tolist())
                bouton_post_partners.append(list(set(temp_post_ids[close_points])))

        # Combine groups that are within 1000 of each other
        combined_positions = []
        combined_post_partners = []

        while bouton_positions:
            group = bouton_positions.pop(0)
            group_post_ids = bouton_post_partners.pop(0)
            group = np.array(group)
            to_merge = []

            for idx, other_group in enumerate(bouton_positions):
                other_group = np.array(other_group)
                other_post_ids = bouton_post_partners[idx]
                # Check if the current group is within 1000 of another group
                distances = np.sqrt(np.sum((group[:, None] - other_group[None, :]) ** 2, axis=2))
                if np.any(distances <= 300):
                    to_merge.append((other_group, other_post_ids))

            # Merge all nearby groups into the current group
            for merge_group, merge_post_ids in to_merge:
                group = np.vstack((group, merge_group))
                group_post_ids = list(set(group_post_ids + merge_post_ids))
                bouton_positions.remove(merge_group.tolist())
                bouton_post_partners.remove(merge_post_ids)

            combined_positions.append(group.tolist())
            combined_post_partners.append(group_post_ids)

        # Update bouton count after merging
        bouton_counter = len(combined_positions)

        # Initialize the results list for this `pre_pt_root_id`
        results_by_bouton[unique_id] = []

        # Create separate lists for each unique bouton's results
        for i in range(bouton_counter):
            synapse_variance, synapse_mean_distance = compute_xyz_metrics(combined_positions[i])
            volume_ = find_volume(combined_positions[i])

            results_by_bouton[unique_id].append({
                "pre_pt_root_id": unique_id,
                "bouton_id": i + 1,
                "bouton_positions": combined_positions[i],
                "bouton_partners": combined_post_partners[i],
                "total_boutons": bouton_counter,  # Added total bouton count
                "bouton_volume": volume_,   # Added bouton volume
                "synapse_variance": synapse_variance, # Added synapse variance
                "synapse_mean_distance": synapse_mean_distance
            })

    # Flatten results into a single list of dictionaries for DataFrame creation
    flattened_results = [item for bouton_list in results_by_bouton.values() for item in bouton_list]
    boutons_df = pd.DataFrame(flattened_results)

    return boutons_df


def find_volume(bouton_positions):
    # Ensure bouton_positions is a NumPy array
    bouton_positions = np.array(bouton_positions)
    
    # Check if bouton_positions has enough points for a ConvexHull
    if bouton_positions.shape[0] < 4:
        #print("Not enough points for ConvexHull. Returning volume=0.")
        return 0
    
    try:
        hull = ConvexHull(bouton_positions)
        volume = hull.volume
    except Exception as e:
        #print(f"ConvexHull failed with error: {e}. Returning volume=0.")
        volume = 0
    
    return volume


def split_lists_by_bouton_id(df):
    """
    Splits the values in the specified columns of the DataFrame based on occurrences of "1" in the "bouton_id" column,
    including preceding non-1 values and creating a new sublist for each "1".

    Args:
        df (pd.DataFrame): The DataFrame containing columns with values to be split.

    Returns:
        pd.DataFrame: The transformed DataFrame with nested lists for specified columns.
    """
    # Remove all none type columns
    df = df.dropna(subset=['bouton_id'])
    
    # Columns to transform
    columns_to_split = ["bouton_id", "bouton_positions", "bouton_partners", "bouton_volume", "synapse_variance", \
                        "synapse_mean_distance", "pre_pt_root_id"]

    def split_row(row, columns_to_split):
        # Lists to store the split values
        new_column_values = {col: [] for col in columns_to_split}
        
        # Temporary storage for the current sublist
        temp_values = {col: [] for col in columns_to_split}

        for i, val in enumerate(row["bouton_id"]):
            if val == 1:
                # Add the current sublist before the 1 (if any)
                if temp_values["bouton_id"]:  # Only add if temp list is not empty
                    for col in columns_to_split:
                        new_column_values[col].append(temp_values[col])
                
                # Start a new sublist for the 1 and include it
                for col in columns_to_split:
                    temp_values[col] = [row[col][i]]

            else:
                # Add the current value to the ongoing sublist
                for col in columns_to_split:
                    temp_values[col].append(row[col][i])

        # Add any remaining values in temp_values to the final lists
        if temp_values["bouton_id"]:
            for col in columns_to_split:
                new_column_values[col].append(temp_values[col])

        # Update the row for the specified columns
        for col in columns_to_split:
            row[col] = new_column_values[col]

        return row

    # Apply the splitting function to each row in the DataFrame
    df = df.apply(lambda row: split_row(row, columns_to_split), axis=1)

    return df


def filter_bouton_data(df, columns_to_filter, threshold_upper=100000000, threshold_lower=900):
    df = df.copy()
    for index in df.index:
        bouton_volumes = df.at[index, 'bouton_volume']
        flat_bouton_volumes = [
            item for sublist in bouton_volumes 
            for item in (sublist if isinstance(sublist, list) else [sublist])
        ]
        flat_bouton_volumes = [
            v for v in flat_bouton_volumes if isinstance(v, (int, float)) and not isinstance(v, bool)
        ]
        if not flat_bouton_volumes or all(v == 0 for v in flat_bouton_volumes):
            continue
        if any(v > threshold_upper for v in flat_bouton_volumes) or all(v < threshold_lower for v in flat_bouton_volumes):
            for col in columns_to_filter:
                if col in df.columns:
                    column_values = df.at[index, col]
                    #if len(flat_bouton_volumes) != len(column_values):
                    #    print(f"Skipping row {index} due to length mismatch: bouton_volume={len(flat_bouton_volumes)}, {col}={len(column_values)}")
                    #   continue
                    df.at[index, col] = [
                        val for idx, val in enumerate(column_values) 
                        if threshold_lower <= flat_bouton_volumes[idx] <= threshold_upper
                    ]
        pre_pt_root_id_list = df.at[index, 'pre_pt_root_id']
        # Flatten the list
        pre_pt_root_id_flat = list(chain.from_iterable(pre_pt_root_id_list))

        # Remove duplicates while preserving order
        seen = set()
        pre_pt_root_id_unique = [x for x in pre_pt_root_id_flat if not (x in seen or seen.add(x))]
        df.at[index, 'Potential_Mossy_Partners'] = list(pre_pt_root_id_unique)                 
    
    return df


def filter_non_mossy_fibers_advanced(df):
    """
    Advanced filtering for non-mossy fibers using trait-specific thresholds and combined conditions.
    
    Parameters:
    df (pd.DataFrame): Dataframe containing mossy fiber data.
    
    Returns:
    pd.DataFrame: Filtered dataframe with non-mossy fibers removed.
    """
    df = df.copy()
    
    # Define trait-specific thresholds (customize these based on domain knowledge)
    thresholds = {               
        'bouton_volume': (50000, 20000000),  # Example range
        'skeleton_length': (1000, 1000000),
        'skeleton_branch_num': (0, 10),
        'synapse_variance': (0, 10000),
        'synapse_mean_distance': (50, 200)
    }
    
    # Parse stringified lists
    trait_columns = list(thresholds.keys())
    for col in trait_columns:
        if col in df.columns:
            df[col] = df[col].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    
    # Apply filtering logic
    for col, (lower, upper) in thresholds.items():
        if col in df.columns:
            # Flatten nested lists
            df[f"flat_{col}"] = df[col].apply(
                lambda x: [item for sublist in x for item in (sublist if isinstance(sublist, list) else [sublist])]
                if isinstance(x, list) else [x]
            )
            
            # Remove entries with values outside the defined range
            def filter_row(trait_values, mossy_fibers):
                return [
                    mossy_fiber
                    for idx, mossy_fiber in enumerate(mossy_fibers)
                    if idx < len(trait_values) and all(lower <= val <= upper for val in (trait_values[idx] if isinstance(trait_values[idx], list) else [trait_values[idx]]))
                ]

            df[f"filtered_{col}"] = df.apply(
                lambda row: filter_row(row[f"flat_{col}"], row["MOSSY_FIBER"])
                if len(row["MOSSY_FIBER"]) == len(row[f"flat_{col}"])
                else row["MOSSY_FIBER"], axis=1
            )

            # Update `MOSSY_FIBER` column
            df["MOSSY_FIBER"] = df[f"filtered_{col}"]
    
    # Drop intermediate columns
    
    intermediate_cols = [f"flat_{col}" for col in trait_columns] + [f"filtered_{col}" for col in trait_columns]
    df = df.drop(columns=[col for col in intermediate_cols if col in df.columns], errors='ignore')
    
    return df


def reorder_dataframe_columns(df):
    """
    Reorders the columns of the DataFrame to match a specified order.
    
    Parameters:
        df (pd.DataFrame): The input DataFrame.
    
    Returns:
        pd.DataFrame: A DataFrame with reordered columns.
    """
    # Desired column order
    new_order = [
        "bouton_partner", "Potential_Mossy_Partners", "pre_pt_root_id", 
        "bouton_id", "bouton_positions", "bouton_partners", "bouton_volume",
        "synapse_variance", "synapse_mean_distance", 
        "presyn_segids", "num_presyn_segids"
    ]
    
    # Reorder columns based on the new order
    return df[new_order]

In [8]:
def process_row(index, row, progress_bar):
    """Process a single row to calculate skeleton properties."""
    results = []  # To store results for all skeletons in this row
    filtered_sids = []  # To store valid sids with branch numbers <= 10

    for sid in row['Potential_Mossy_Partners']:
        if sid > 0:
            retries = 0
            max_retries = 5  # Maximum number of retries
            wait_time = 120  # Wait time in seconds between retries

            while retries < max_retries:
                try:
                    # Fetch the skeleton
                    skel = pcg_skel.pcg_skeleton(root_id=sid, client=client, root_point_resolution=[1, 1, 1])
                    vertices = skel.vertices  # Skeleton vertices
                    edges = skel.edges  # Skeleton edges

                    # Calculate total skeleton length
                    edge_lengths = np.linalg.norm(vertices[edges[:, 0]] - vertices[edges[:, 1]], axis=1)
                    total_length = edge_lengths.sum()

                    # Build a graph from edges to represent the skeleton
                    graph = nx.Graph()
                    for i, edge in enumerate(edges):
                        graph.add_edge(edge[0], edge[1], length=edge_lengths[i])

                    # Identify branch points (nodes with degree > 2)
                    branch_points = [node for node in graph.nodes if graph.degree[node] > 2]

                    # Traverse all branches to measure their lengths
                    branch_lengths = []
                    visited_edges = set()

                    for branch_point in branch_points:
                        for neighbor in graph.neighbors(branch_point):
                            edge = tuple(sorted((branch_point, neighbor)))
                            if edge not in visited_edges:
                                visited_edges.add(edge)
                                # Perform DFS to measure the branch length
                                branch_length = 0
                                current_node = neighbor
                                previous_node = branch_point
                                while True:
                                    branch_length += graph[previous_node][current_node]['length']
                                    visited_edges.add(tuple(sorted((previous_node, current_node))))
                                    neighbors = list(graph.neighbors(current_node))
                                    neighbors.remove(previous_node)  # Remove the previous node from consideration

                                    if len(neighbors) == 1:  # Continue to the next node
                                        previous_node = current_node
                                        current_node = neighbors[0]
                                    else:  # Reached a branch point or end of branch
                                        break
                                branch_lengths.append(branch_length)

                    # Only keep skeletons with branch numbers <= 10 and valid branch lengths
                    num_branches = len(branch_lengths)
                    if num_branches <= 15:
                        # Check if more than 6 branch lengths exceed 100,000
                        large_branch_count = sum(1 for length in branch_lengths if length > 1000)
                        if large_branch_count <= 6:
                            results.append((total_length, num_branches, branch_lengths))
                            filtered_sids.append(sid)
                    break  # Successfully processed this skeleton, exit retry loop

                except Exception as e:
                    retries += 1
                    if retries < max_retries:
                        print(f"Error processing skeleton for ID {sid}: {e}. Retrying in {wait_time} seconds...")
                        time.sleep(wait_time)
                    else:
                        print(f"Error processing skeleton for ID {sid} after {max_retries} retries: {e}. Skipping...")

        else:
            # Handle invalid skeleton IDs
            print(f"Invalid skeleton ID {sid}. Skipping...")

        # Update progress bar for each skeleton ID
        progress_bar.update(1)

    # Return all results for this row and the filtered sids
    return index, results, filtered_sids


# Main function
def mossy_fiber_skeleton_sorter(df):
    df['skeleton_length'] = None  # List of total lengths for each SID
    df['skeleton_branch_num'] = None  # List of branch numbers for each SID
    df['skeleton_branch_length'] = None  # List of lists of branch lengths for each SID
    df['mossy_fibers'] = None  # Filtered list of valid sids

    results = []
    total_tasks = sum(len(row['Potential_Mossy_Partners']) for _, row in df.iterrows())
    progress_bar = tqdm(total=total_tasks, desc="Processing skeletons sequentially", position=0)

    for index, row in df.iterrows():
        results.append(process_row(index, row, progress_bar))

    progress_bar.close()

    # Update the DataFrame with results
    for index, result_list, filtered_sids in results:
        skeleton_length = [result[0] for result in result_list]  # List of total lengths
        skeleton_branch_num = [result[1] for result in result_list]  # List of branch numbers
        skeleton_branch_length = [result[2] for result in result_list]  # List of lists of branch lengths

        # Update the DataFrame
        df.at[index, 'skeleton_length'] = skeleton_length
        df.at[index, 'skeleton_branch_num'] = skeleton_branch_num
        df.at[index, 'skeleton_branch_length'] = skeleton_branch_length
        df.at[index, 'mossy_fibers'] = filtered_sids
    
    # Remove the Potential_Mossy_Partners column
    df = df.drop(columns=['Potential_Mossy_Partners'])
    df.rename(columns={'bouton_partner': 'POSTSYNAPTIC_CELL'}, inplace=True)
    df.rename(columns={'mossy_fibers': 'MOSSY_FIBER'}, inplace=True)

    return df

def add_row_value_counts(df, column='MOSSY_FIBER', new_column_name='total_mossy_fiber_num'):
    """
    Adds a new column to the DataFrame with the count of values in the list within a specified column for each row.

    Parameters:
        df (pd.DataFrame): The input DataFrame.
        column (str): The column containing lists (or strings of lists) to count values from.
        new_column_name (str): Name of the new column to add with the counts.

    Returns:
        pd.DataFrame: DataFrame with the new column added.
    """
    if column not in df.columns:
        raise ValueError(f"The specified column '{column}' is not in the DataFrame.")
    
    df[new_column_name] = df[column].apply(len)
    
    # Sort the DataFrame by 'total_mossy_fiber_num' in descending order
    df_sorted = df.sort_values(by=new_column_name, ascending=False)

    # Reset the index if needed (optional)
    df_sorted = df_sorted.reset_index(drop=True)
    
    # Filter rows where 'total_mossy_fiber_num' is not 0
    df_filtered = df_sorted.loc[df_sorted['total_mossy_fiber_num'] != 0]

    return df_filtered

def reorder_dataframe_columns_2(df):
    """
    Reorders the columns of the DataFrame to match a specified order.
    
    Parameters:
        df (pd.DataFrame): The input DataFrame.
    
    Returns:
        pd.DataFrame: A DataFrame with reordered columns.
    """
    # Desired column order
    new_order = [
        "POSTSYNAPTIC_CELL", "MOSSY_FIBER", "total_mossy_fiber_num", 
        "bouton_id", "bouton_positions", "bouton_partners", "bouton_volume",
        "skeleton_length", "skeleton_branch_num", "skeleton_branch_length", 
        "synapse_variance", "synapse_mean_distance", 
        "presyn_segids", "num_presyn_segids", "pre_pt_root_id"
    ]
    
    # Reorder columns based on the new order
    return df[new_order]


def scale_xyz_column(dataframe, column_name, new_col_name):
    """
    Scales nested lists of coordinates in the specified column of a DataFrame.
    Each X and Y value is divided by 18, and each Z value is divided by 45.
    The nested structure is preserved, and the scaled values are saved in a new column.

    Parameters:
    dataframe (pd.DataFrame): The input DataFrame.
    column_name (str): The name of the column containing the nested XYZ lists.
    new_col_name (str): The name of the new column to store the scaled coordinates.

    Returns:
    pd.DataFrame: A new DataFrame with the scaled coordinates added as a new column.
    """
    def scale_coordinates(item):
        if isinstance(item, list) and all(isinstance(i, (int, float)) for i in item) and len(item) == 3:
            # If it's a list of 3 numbers, scale them
            x, y, z = item
            return [x / 18, y / 18, z / 45]
        elif isinstance(item, list):
            # If it's a list, process each element recursively
            return [scale_coordinates(sub_item) for sub_item in item]
        else:
            # If it's neither a list nor an XYZ coordinate, return it as is
            return item

    # Apply the recursive scaling function to the specified column
    dataframe[new_col_name] = dataframe[column_name].apply(scale_coordinates)
    return dataframe


def scale_bouton_volume(df, bouton_volume_col, new_bouton_col_name, scale_factor):
    """
    Scales bouton volume values in a specific column of a DataFrame.

    Parameters:
    - df (pd.DataFrame): The DataFrame to process.
    - bouton_volume_col (str): Name of the column containing bouton volumes to scale.
    - new_bouton_col_name (str): Name of the new column to store scaled bouton volumes.
    - scale_factor (float): Factor by which to scale the bouton volumes.

    Returns:
    - pd.DataFrame: The updated DataFrame with the scaled bouton volume column.
    """
    def scale_volume(nested_list):
        return [[val * scale_factor for val in sublist] for sublist in nested_list]
    
    df[new_bouton_col_name] = df[bouton_volume_col].apply(scale_volume)
    return df


def reorder_dataframe_columns_3(df):
    """
    Reorders the columns of the DataFrame to match a specified order.
    
    Parameters:
        df (pd.DataFrame): The input DataFrame.
    
    Returns:
        pd.DataFrame: A DataFrame with reordered columns.
    """
    # Desired column order
    new_order = [
        "POSTSYNAPTIC_CELL", "MOSSY_FIBER", "total_mossy_fiber_num", 
        "bouton_id", "bouton_positions", "bouton_positions_um", "bouton_partners", 
        "bouton_volume", "bouton_volume_um", "skeleton_length", "skeleton_branch_num", 
        "skeleton_branch_length", "synapse_variance", "synapse_mean_distance", 
        "presyn_segids", "num_presyn_segids", "pre_pt_root_id"
    ]
    
    # Reorder columns based on the new order
    return df[new_order]


def filter_leftover_bouton_data(df):
    # Define the columns that need to have corresponding lists removed
    columns_to_filter = [
        "bouton_id", "bouton_partners", "synapse_mean_distance",
        "synapse_variance", "bouton_volume", "bouton_volume_um",
        "bouton_positions_um", "bouton_positions"
    ]
    
    # Iterate through each row of the DataFrame with a progress bar
    for index, row in tqdm(df.iterrows(), total=len(df), desc="Filtering DataFrame"):
        # Extract the lists of lists for the current row
        pre_pt_root_id = row["pre_pt_root_id"]
        mossy_fiber = row["MOSSY_FIBER"]
        
        # Convert to Python objects if they are strings
        if isinstance(pre_pt_root_id, str):
            pre_pt_root_id = eval(pre_pt_root_id)
        if isinstance(mossy_fiber, str):
            mossy_fiber = eval(mossy_fiber)
        
        # Find the indices of the lists in pre_pt_root_id that match mossy_fiber
        valid_indices = [i for i, lst in enumerate(pre_pt_root_id) if lst[0] in mossy_fiber]
        
        # Filter the lists in pre_pt_root_id and the other columns
        df.at[index, "pre_pt_root_id"] = [pre_pt_root_id[i] for i in valid_indices]
        for col in columns_to_filter:
            col_data = row[col]
            if isinstance(col_data, str):
                col_data = eval(col_data)
            df.at[index, col] = [col_data[i] for i in valid_indices]

    return df

In [9]:
def only_known_pyr_cell(df):

    df_pyr = pd.read_csv('all_pyramidal_cells - Copy of MF-pyr.csv')

    updated_segid_list = client.chunkedgraph.get_roots(df_pyr['supervoxel'])
    
    
    print(len(df["bouton_partner"]))

    df = df[df["bouton_partner"].apply(lambda x: isinstance(x, (list, tuple, set)) and all(int(i) in updated_segid_list for i in x))]

    return df



In [10]:
"""
The basics of how this works.

This code takes in the csv file of already found mossy fibers. Then it updates all of their segids 
to what they currently are. Then it looks for boutons in all of these "mossy fibers" by finding how 
many synapses (minimum 8) are within the thresholded distance (500 voxels). Each time this is found 
this is considered a bouton and a number is added to the bouton_id as well as other information like 
the volume etc. So if a row has a bouton_id of [1,2,3,4] there are four boutons. And each bouton 
corresponds to the pre_pt_root_id in the list (would look something like this: [648518346449706895, 
648518346449706895, 648518346449706895, 648518346449706895]). We then look for all of the synaptic 
partners of these mossy fibers (using bouton_parters) and then find all cells postsynaptic to them. 
These are assumed to be CA3 cells, either inhibitory, glial, or pyramidal. Now we look at these CA3
cells' presynaptic partner and do the same bouton_detection to try to find all mossy fibers connecting
to these CA3 cells. Finally we do some sorting so the lists in each row make sense and correspond with 
the pre_pt_root_id, etc. and use the skeleton of each mossy fiber to help weed out any non mossy fibers. 
Still needs more work to weed out axons, but pretty good overall. Takes about 30 hours. 

"""
current_date = datetime.now()
formatted_date = current_date.strftime("%Y-%m-%d")

df = pd.read_csv("CA3 proofreading 2 - 16.MF_identification_EH_1stPriority_ONLY.csv")
df = df.iloc[0:10]
print("----------------------------------------------")
print(f"Read Starting CSV, DF Length: {len(df)}")
print("----------------------------------------------")
#df = update_segids_df(df, super_voxel_col="pre_pt_root_id")
print("----------------------------------------------")
print(f"Segids Updated, DF Length: {len(df)}")
print("----------------------------------------------")

current_date = datetime.now().strftime('%Y%m%d') 
new_name = f"updated_segids_{current_date}"
df.rename(columns={"Potential mossy fibers": new_name}, inplace=True)

# Takes about 2.5 hours. 
df = bouton_processor(df, limit_TF=False, limit=5)
df.to_csv(f'V2_Step_1_BoutonProcessor_{formatted_date}.csv', index=False)
print("----------------------------------------------")
print(f"Boutons Processed, DF Length: {len(df)}")
print("----------------------------------------------")
df = find_bouton_partners(df)
df.to_csv(f'V2_Step_2_BoutonPartners_{formatted_date}.csv', index=False)


#print("----------------------------------------------")
#print("NEW NEW NEW - - - - - ONLY KNOWN PYR CELLS")
#print("----------------------------------------------")
#df = only_known_pyr_cell(df)
#print("----------------------------------------------")
#print(f"Pyr Only Partners Left, DF Length: {len(df)}")
#print("----------------------------------------------")


print("----------------------------------------------")
print(f"Bouton Partners Found, DF Length: {len(df)}")
print("----------------------------------------------")
df = nonMossy_partner_analysis(df)
# Seems to take about 6-8 hours 
df.to_csv(f'V2_Step_3_NonMossyPartners_{formatted_date}.csv', index=False)
df_1 = df.head(1)  
print("----------------------------------------------")
print(f"Partner Cell Analysis, DF Length: {len(df)}")
print("----------------------------------------------")
df = split_lists_by_bouton_id(df)
df_2 = df.head(1)
df.to_csv(f'V2_Step_4_CellsSplit_{formatted_date}.csv', index=False)
print("----------------------------------------------")
print(f"Splitting boutons in dataframe, DF Length: {len(df)}")
print("----------------------------------------------")
df = reorder_dataframe_columns(df)
df_3 = df.head(1)
df = filter_bouton_data(df, columns_to_filter=['bouton_id', 'pre_pt_root_id', \
                        'bouton_positions', 'bouton_partners', 'synapse_variance', \
                        'synapse_mean_distance', 'bouton_volume'])
df.to_csv(f'V2_Step_5_BoutonFilteredMossyFibers_{formatted_date}.csv', index=False)
df_post = df
print("----------------------------------------------")
print(f"Bouton Volumes Processed, DF Length: {len(df)}")
print("----------------------------------------------")
df = mossy_fiber_skeleton_sorter(df)
print("----------------------------------------------")
print(f"Partner Skeleton Analysis, DF Length: {len(df)}")
print("----------------------------------------------")
df = add_row_value_counts(df, column='MOSSY_FIBER', new_column_name='total_mossy_fiber_num')
df = reorder_dataframe_columns_2(df)
df.to_csv(f'V2_Step_6_SkeletonFilteredMossyFibers_{formatted_date}.csv', index=False)
print("----------------------------------------------")
print(f"Redorder and Count total MF, DF Length: {len(df)}")
print("----------------------------------------------")
df = filter_non_mossy_fibers_advanced(df)
df = add_row_value_counts(df, column='MOSSY_FIBER', new_column_name='total_mossy_fiber_num')
df = scale_xyz_column(df, column_name='bouton_positions', new_col_name='bouton_positions_um')
df = scale_bouton_volume(df, bouton_volume_col='bouton_volume', new_bouton_col_name='bouton_volume_um', scale_factor=14.58)
df = reorder_dataframe_columns_3(df)
df.to_csv(f'V2_Step_7_FULL_MOSSYFIBER_DATASET_{formatted_date}.csv', index=False)
print("----------------------------------------------")
print(f"Filtered by Z score, DF Length: {len(df)}")
print("----------------------------------------------")
print("----------------------------------------------")
df = filter_leftover_bouton_data(df)
df.to_csv(f'V2_Step_8_FULL_MOSSYFIBER_DATASET_{formatted_date}.csv', index=False)
print("----------------------------------------------")
print(f"Fix Left Over Bouton Data, DF Length: {len(df)}")
print("----------------------------------------------")
print("PROCESS FINISHED")
print("----------------------------------------------")

----------------------------------------------
Read Starting CSV, DF Length: 10
----------------------------------------------
----------------------------------------------
Segids Updated, DF Length: 10
----------------------------------------------


Processing Mossy Fiber Synapses: 100%|██████████| 1/1 [00:00<00:00,  3.20it/s]


Error processing boutons for chunk [648518346439917587, 648518346429693973, 648518346440572966, 648518346442407984, 648518346446209074, 648518346454335564, 648518346439524444, 648518346436247690, 648518346451583119, 648518346438869147]: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (3,) + inhomogeneous part.
----------------------------------------------
Boutons Processed, DF Length: 1
----------------------------------------------
----------------------------------------------
Bouton Partners Found, DF Length: 9
----------------------------------------------


Processing Mossy Fiber Synapses: 100%|██████████| 9/9 [00:02<00:00,  3.68it/s]


----------------------------------------------
Partner Cell Analysis, DF Length: 2
----------------------------------------------
----------------------------------------------
Splitting boutons in dataframe, DF Length: 2
----------------------------------------------
----------------------------------------------
Bouton Volumes Processed, DF Length: 2
----------------------------------------------


Processing skeletons sequentially: 100%|██████| 95/95 [01:58<00:00,  1.25s/it]


----------------------------------------------
Partner Skeleton Analysis, DF Length: 2
----------------------------------------------
----------------------------------------------
Redorder and Count total MF, DF Length: 2
----------------------------------------------
----------------------------------------------
Filtered by Z score, DF Length: 1
----------------------------------------------
----------------------------------------------


Filtering DataFrame: 100%|████████████████████| 1/1 [00:00<00:00, 1499.57it/s]

----------------------------------------------
Fix Left Over Bouton Data, DF Length: 1
----------------------------------------------
PROCESS FINISHED
----------------------------------------------



