In [None]:
import tensorflow as tf
import csv
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import itertools
from tensorflow.python.framework.errors_impl import DataLossError
from collections import defaultdict
import os, csv
import time

Preprocessing: Convert tensorboard event files to csv

In [None]:
def tensorboard_to_csv(event_file, csv_file):
    """
    Convert TensorBoard event file data to a CSV format.

    Args:
        event_file (str): Path to the TensorBoard event file (e.g., events.out.tfevents.xxx).
        csv_file (str): Path where the output CSV file should be saved.
    """
    # Create a list to store the data rows
    data_rows = []
    
    # Use tf.compat.v1 to access the summary_iterator in TensorFlow 2.x
    for e in tf.compat.v1.train.summary_iterator(event_file):
        for v in e.summary.value:
            # Only consider scalar summaries
            if v.HasField('simple_value'):
                tag = v.tag
                value = v.simple_value
                step = e.step
                data_rows.append([step, tag, value])
    
    # Write the extracted data into a CSV file
    with open(csv_file, 'w', newline='') as file:
        writer = csv.writer(file)
        # Write the header
        writer.writerow(['Step', 'Tag', 'Value'])
        # Write the data rows
        writer.writerows(data_rows)
    
    print(f"Data from {event_file} has been written to {csv_file}")


In [None]:
def tensorboard_to_separate_csv(event_file, output_dir):
    """
    Convert TensorBoard event file data to separate CSV files for each tag.
    """
    tag_data = defaultdict(list)

    try:
        for e in tf.compat.v1.train.summary_iterator(event_file):
            try:
                # Process individual records
                for v in e.summary.value:
                    if v.HasField('simple_value'):
                        tag = v.tag
                        value = v.simple_value
                        step = e.step
                        tag_data[tag].append([step, value])
            except Exception as record_error:
                # Log a warning and skip problematic records
                print(f"Skipped a corrupt record in file: {event_file}")
    except DataLossError:
        # Log and continue for files with partial writes
        print(f"Encountered DataLossError. Possibly due to incomplete writes in file: {event_file}")

    # Save tag data to CSV files
    for tag, data_rows in tag_data.items():
        filename = f"{output_dir}/{tag.replace('/', '_')}_data.csv"
        with open(filename, 'w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(['Step', 'Value'])
            writer.writerows(data_rows)
        print(f"Data for tag '{tag}' has been written to {filename}")


In [None]:
def is_file_stable(file_path, wait_time=1.0):
    initial_size = os.path.getsize(file_path)
    time.sleep(wait_time)
    final_size = os.path.getsize(file_path)
    return initial_size == final_size

def process_tensorboard_results(parent_dir, output_parent_dir):
    """
    Loop through all subdirectories of a parent directory, process TensorBoard files,
    and save results to a corresponding directory structure.

    Args:
        parent_dir (str): Parent directory containing TensorBoard event files.
        output_parent_dir (str): Parent directory where CSV files should be saved.
    """
    for root, dirs, files in os.walk(parent_dir):
        for file in files:
            if "tfevents" in file:  # Check if the file is a TensorBoard event file
                event_file = os.path.join(root, file)
                relative_path = os.path.relpath(root, parent_dir)
                output_dir = os.path.join(output_parent_dir, relative_path)
                os.makedirs(output_dir, exist_ok=True)

                # Skip processing if output file already exists
                if os.path.exists(output_dir):
                    print(f"Skipping already processed file: {event_file}")
                    continue

                # Check if the file is stable
                if not is_file_stable(event_file):
                    print(f"File is still being written: {event_file}, skipping.")
                    continue

                try:
                    tensorboard_to_separate_csv(event_file, output_dir)
                    print(f"Processed {event_file} -> {output_dir}")
                except Exception as e:
                    print(f"Failed to process {event_file}: {e}")

Utils: Plotting and Analysis

In [None]:
def exponential_moving_average(data, alpha):
    """
    Calculate the exponential moving average (EMA) of a 1D array.

    Args:
        data (array-like): The input data.
        alpha (float): The smoothing factor (0 < alpha <= 1).

    Returns:
        numpy.ndarray: The EMA values.
    """
    if not (0 < alpha <= 1):
        raise ValueError("Alpha must be between 0 and 1.")

    ema = [data[0]]  # Initialize EMA with the first data point
    for i in range(1, len(data)):
        ema.append(alpha * data[i] + (1 - alpha) * ema[-1])
    return np.array(ema)


def rolling_average(data, window_size):
    """
    Calculate the rolling average of a 1D array.

    Args:
        data (array-like): The input data.
        window_size (int): The size of the rolling window.

    Returns:
        numpy.ndarray: The rolling average values.
    """
    if window_size < 1:
        raise ValueError("Window size must be at least 1.")
    if len(data) < window_size:
        raise ValueError("Data length must be at least equal to the window size.")
    
    # Use np.convolve for efficient computation
    weights = np.ones(window_size) / window_size
    return np.convolve(data, weights, mode='valid')

def plot_average_trajectory(time_series, error_type='std', time_points=None, xlabel='Time', ylabel='Value', title='Average Trajectory'):
    """
    Plots the average trajectory of a set of time series with error bars.

    Parameters:
    - time_series (2D array-like): A set of time series, shape (n_series, n_time_points).
    - error_type (str): Either 'std' for standard deviation or 'sem' for standard error.
    - time_points (1D array-like, optional): Time points corresponding to the time series. Defaults to indices.
    - xlabel (str): Label for the x-axis.
    - ylabel (str): Label for the y-axis.
    - title (str): Title of the plot.
    """
    time_series = np.array(time_series)
    if time_points is None:
        time_points = np.arange(time_series.shape[1])
    else:
        time_points = np.array(time_points)
    
    if time_series.shape[1] != len(time_points):
        raise ValueError("Length of time_points must match the number of columns in time_series.")
    
    # Compute average and error
    mean_trajectory = np.mean(time_series, axis=0)
    if error_type == 'std':
        error = np.std(time_series, axis=0)
    elif error_type == 'sem':
        error = np.std(time_series, axis=0) / np.sqrt(time_series.shape[0])
    else:
        raise ValueError("error_type must be 'std' or 'sem'.")
    
    # Plot
    plt.figure(figsize=(10, 6))
    plt.plot(time_points, mean_trajectory, label='Mean Trajectory', color='blue')
    plt.fill_between(time_points, mean_trajectory - error, mean_trajectory + error, alpha=0.3, color='blue', label='Error')
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.grid(False)
    plt.show()

def trim_and_calculate_mean(array_list):
    # Find the minimum length among all arrays
    min_length = min(len(arr) for arr in array_list)
    
    # Trim each array to the minimum length
    trimmed_arrays = [arr[:min_length] for arr in array_list]
    
    # Convert to a NumPy array for vectorized mean calculation
    trimmed_arrays = np.array(trimmed_arrays)
    
    return trimmed_arrays


In [None]:
import numpy as np
from scipy.stats import ttest_ind, mannwhitneyu

def compare_binned_time_series(data_group1, data_group2, num_bins=10, test_func=ttest_ind, **test_kwargs):
    """
    Compare two groups' time-series data by dividing the time course into bins,
    averaging values in each bin, and performing a specified statistical test.

    Args:
        data_group1 (list of lists or np.array): Time-series data for group 1 (each row is an instance).
        data_group2 (list of lists or np.array): Time-series data for group 2 (each row is an instance).
        num_bins (int): Number of bins to divide the time course.
        test_func (callable): Statistical test function (default: t-test). 
                              Must accept two arrays and return a statistic and p-value.
        **test_kwargs: Additional keyword arguments to pass to the statistical test.

    Returns:
        dict: Contains test statistics, p-values, and Bonferroni-corrected significance for each bin.
    """
    # Convert to NumPy arrays
    data_group1 = np.array(data_group1)
    data_group2 = np.array(data_group2)

    # Ensure both groups have the same number of time steps
    assert data_group1.shape[1] == data_group2.shape[1], "Mismatched time steps"

    # Define bin edges
    num_timesteps = data_group1.shape[1]
    bin_edges = np.linspace(0, num_timesteps, num_bins + 1, dtype=int)

    results = {}
    alpha_corrected = 0.05 / num_bins  # Bonferroni correction

    for i in range(num_bins):
        start, end = bin_edges[i], bin_edges[i+1]

        # Compute mean values in the bin for each instance
        mean_group1 = data_group1[:, start:end].mean(axis=1)
        mean_group2 = data_group2[:, start:end].mean(axis=1)

        # Perform the specified test
        test_stat, p_val = test_func(mean_group1, mean_group2, **test_kwargs)

        # Store results
        results[f"Bin {i+1}"] = {
            "test_statistic": test_stat,
            "p-value": p_val,
            "threshold after Bonferroni correction": alpha_corrected,
            "significant (Bonferroni)": p_val < alpha_corrected
        }

    return results


In [None]:
def read_group_data_with_specified_tag(folder, tag_keyword, has_child_dir=False):
    """
    Read group data from CSV files with a specified tag in the given folder.

    Args:
        folder (str): The folder containing the CSV files.
        tag (str): The tag to search for in the CSV file names.
        has_child_dir (bool): Whether the CSV files are in subdirectories.

    Returns:
        dict: A dictionary containing the data read from the CSV files.
    """
    parent_dir = os.path.join('res', folder)
    files = os.listdir(parent_dir)

    group_data = []

    data = None 

    if has_child_dir:
        for file in files:
            if tag_keyword in file:
                sub_dir = os.path.join(parent_dir, file)            
                file = os.listdir(sub_dir)[0]
                data = pd.read_csv(os.path.join(sub_dir, file))
                group_data.append(data['Value'].to_numpy())
    else:
        for file in files:
            if tag_keyword in file:
                data = pd.read_csv(os.path.join(parent_dir, file))
                group_data.append(data['Value'].to_numpy())

    assert data is not None, f"No file with tag '{tag_keyword}' found in folder '{folder}'."
    
    return group_data

In [None]:
def plot_data_per_condition(folders, wsize, tag_keyword, labels, has_child_dir):
    """
    Plot the data for each condition.
    """
    collective_d = [[] for _ in range(len(folders))]
    for ixs, folder in enumerate(folders):
        collective_d[ixs] = read_group_data_with_specified_tag(folder, tag_keyword, has_child_dir)
    collective_d_avg = [np.mean(trim_and_calculate_mean(collective_d[i]), axis=0) for i in range(len(collective_d))]

    # plot
    # alpha = 0.99
    wsize = wsize
    for ixs, d in enumerate(collective_d_avg):
        # plt.plot(exponential_moving_average(d, alpha), label=labels[ixs])
        plt.plot(rolling_average(d, wsize), label=labels[ixs])
    plt.legend()
    plt.title(f'Mean {tag_keyword}')

    fig, axes = plt.subplots(1, len(labels), figsize=(4*len(labels),4))
    wsize = wsize
    for i in range(len(labels)):
        for j in range(len(collective_d[i])):
            axes[i].plot(rolling_average(collective_d[i][j], wsize))
        axes[i].set_title(labels[i])
    plt.suptitle(f'Individual {tag_keyword}')