In [None]:
import h5py
from collections import Counter
import os
import matplotlib.pyplot as plt
from collections import defaultdict
import pickle
import numpy as np
import imageio
import re
from matplotlib.colors import ListedColormap
import pandas as pd
from skimage import measure, morphology

### Load and Analyze Simulation Data

In [None]:
def load_stats_from_jld2(filename):
    """
    Load simulation statistics from a JLD2 (HDF5) file.

    Args:
        filename (str): Path to the .jld2 file.

    Returns:
        dict: A dictionary containing simulation statistics such as cell counts and volumes.
    """
    with h5py.File(filename, 'r') as f:
        step = f["step"][()]
        cell_voxels = {
            int(k): f["cell_voxels"][k][()]
            for k in f["cell_voxels"]
        }
        cell_states = {
            int(k): f["cell_states"][k][()].decode("utf-8")
            for k in f["cell_states"]
        }

    # Compute stats
    volumes = [len(voxels) for voxels in cell_voxels.values()]
    avg_vol = np.mean(volumes)
    max_vol = np.max(volumes)
    num_cells = len(volumes)
    cum_states = list(cell_states.values())
    state_counts = Counter(cum_states)

    return {
        'step': step,
        'num_cells': num_cells,
        'avg_volume': avg_vol,
        'max_volume': max_vol,
        'state_counts': state_counts
    }

In [None]:

# --- Function to load and analyze data ---
def analyze_sim_folder(folder_path, max_step=None):
    """
    Analyze all simulation results in a folder.

    Args:
        folder_path (str): Path to the folder containing simulation .jld2 files.
        max_files (int): Optional limit on the number of files to analyze.

    Returns:
        list: A list of aggregated statistics dictionaries for each simulation file.
    """
    all_stats = []

    for fname in sorted(os.listdir(folder_path)):
        if fname.endswith(".jld2") and "step_" in fname:
            try:
                step_str = fname.split("_")[-1].replace(".jld2", "")
                step_num = int(step_str)
                if max_step is not None and step_num > max_step:
                    continue
                fpath = os.path.join(folder_path, fname)
                stats = load_stats_from_jld2(fpath)
                all_stats.append(stats)
            except (KeyError, ValueError) as e:
                print(f"Skipping file {fname} due to error: {e}")

    all_stats.sort(key=lambda d: d["step"])

    steps = [s["step"] for s in all_stats]
    num_cells = [s["num_cells"] for s in all_stats]
    avg_volumes = [s["avg_volume"] for s in all_stats]
    max_volumes = [s["max_volume"] for s in all_stats]

    # Get all unique states across steps
    all_states = sorted({state for s in all_stats for state in s["state_counts"].keys()})
    state_trends = {state: [] for state in all_states}
    for s in all_stats:
        for state in all_states:
            state_trends[state].append(s["state_counts"].get(state, 0))

    return {
        "steps": steps,
        "num_cells": num_cells,
        "avg_volumes": avg_volumes,
        "max_volumes": max_volumes,
        "state_trends": state_trends,
        "all_states": all_states
    }

### Plot Simulation Statistics

In [None]:
def plot_simulation_stats(stats_list, states_to_include=None, labels=None, max_step=None):
    """
    Plot cell state counts over time for one or multiple simulations.

    Args:
        stats_list (list): List of simulation stats dictionaries.
        states_to_include (list): Optional list of states to include in the plot.
        labels (list): Optional labels for the simulations.
        max_step (int): Maximum step to show in the plot.
    """
    if not isinstance(stats_list, list):
        stats_list = [stats_list]
    if labels is None:
        labels = [f"Run {i+1}" for i in range(len(stats_list))]

    def filter_by_max_step(stats):
        if max_step is None:
            return stats
        idxs = [i for i, s in enumerate(stats["steps"]) if s <= max_step]
        return {
            "steps": [stats["steps"][i] for i in idxs],
            "num_cells": [stats["num_cells"][i] for i in idxs],
            "avg_volumes": [stats["avg_volumes"][i] for i in idxs],
            "max_volumes": [stats["max_volumes"][i] for i in idxs],
            "state_trends": {
                state: [stats["state_trends"][state][i] for i in idxs]
                for state in stats["all_states"]
            },
            "all_states": stats["all_states"]
        }

    stats_list = [filter_by_max_step(stats) for stats in stats_list]

    # --- Plot 1: Total Cell Count ---
    plt.figure()
    for stats, label in zip(stats_list, labels):
        plt.plot(stats["steps"], stats["num_cells"], label=label)
    plt.xlabel("Step")
    plt.ylabel("Number of Cells")
    plt.title("Total Cell Count Over Time")
    plt.legend()
    plt.grid(False)
    plt.tight_layout()

    # --- Plot 2: Avg and Max Volume ---
    plt.figure()
    for stats, label in zip(stats_list, labels):
        plt.plot(stats["steps"], stats["avg_volumes"], label=f"{label} - Avg")
        plt.plot(stats["steps"], stats["max_volumes"], '--', label=f"{label} - Max")
    plt.xlabel("Step")
    plt.ylabel("Volume (voxels)")
    plt.title("Cell Volume Over Time")
    plt.legend()
    plt.grid(False)
    plt.tight_layout()

    # --- Plot 3: Cell States ---
    plt.figure()
    for stats, label in zip(stats_list, labels):
        states = stats["all_states"]
        trends = stats["state_trends"]
        steps = stats["steps"]
        states_to_plot = states_to_include if states_to_include else states
        for state in states_to_plot:
            if state in trends:
                plt.plot(steps, trends[state], label=f"{label}")
    plt.xlabel("Step")
    plt.ylabel("Cells in State")
    plt.title("Necrotic State Counts Over Time")
    plt.legend()
    plt.grid(False)
    plt.tight_layout()

    plt.show()

### Save and Reload Stats

In [None]:

def save_stats(stats, filename):
    """
    Save computed statistics to a pickle file.

    Args:
        stats (dict): Computed statistics.
        filename (str): Output file name for the pickle file.
    """
    with open(filename, "wb") as f:
        pickle.dump(stats, f)
def load_stats(filename):
    with open(filename, "rb") as f:
        return pickle.load(f)

### Growth Rate Computation and Visualization

In [None]:
def compute_growth_rate_over_time(stats, smooth_window=1):
    """
    Compute the instantaneous growth rate from the area or volume data.

    Args:
        stats (dict): Statistics dictionary from a simulation.
        smooth_window (int): Size of smoothing window for growth rate.

    Returns:
        tuple: (steps, growth_rates) arrays.
    """
    steps = np.array(stats["steps"])
    counts = np.array(stats["num_cells"])
    
    # Raw growth rate per step
    delta_counts = np.diff(counts)
    delta_steps = np.diff(steps)
    growth_rates = delta_counts / delta_steps
    mid_steps = (steps[:-1] + steps[1:]) / 2  # Use midpoint for plotting

    # Optional smoothing
    if smooth_window > 1:
        kernel = np.ones(smooth_window) / smooth_window
        growth_rates = np.convolve(growth_rates, kernel, mode="valid")
        mid_steps = mid_steps[:len(growth_rates)]

    return mid_steps, growth_rates

In [None]:

def plot_all_growth_rates(stats_list, labels, smooth_window=3, max_step=None):
    """
    Plot smoothed growth rates for multiple simulations.

    Args:
        stats_list (list): List of stats dictionaries.
        labels (list): Labels for each simulation.
        smooth_window (int): Smoothing window size.
        max_step (int): Optional maximum step to include.
    """
    plt.figure(figsize=(8, 6))

    for stats, label in zip(stats_list, labels):
        steps, growth = compute_growth_rate_over_time(stats, smooth_window=smooth_window)

        # Filter steps and growth if max_step is given
        if max_step is not None:
            mask = steps <= max_step
            steps = steps[mask]
            growth = growth[mask]

        plt.plot(steps, growth, label=label)

    plt.xlabel("Step")
    plt.ylabel("Growth Rate (Δcells / step)")
    plt.title("Instantaneous Cell Growth Rate Over Time")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

### Extract and Plot Area Metrics

In [None]:
    

def extract_area_metrics(
    input_folder,
    output_file,
    voxel_size_um=2.0,
    projection_axis='z'
):
    """
    Extract spheroid area metrics from image files and save as CSV.

    Args:
        input_folder (str): Folder containing image data.
        output_file (str): Path to output CSV.
        voxel_size (float): Conversion factor for voxel to microns.
        step_interval (int): Time between steps in hours.
    """
    def load_cell_voxels_from_jld2(file_path):
        with h5py.File(file_path, "r") as f:
            cell_voxels = {}
            max_coord = np.array([0, 0, 0])
            for cell_id in f["cell_voxels"]:
                raw = f["cell_voxels"][cell_id][()]
                if raw.dtype.names:
                    voxels = np.stack([raw[name] for name in raw.dtype.names], axis=-1)
                else:
                    voxels = raw
                if voxels.size > 0:
                    max_coord = np.maximum(max_coord, voxels.max(axis=0))
                cell_voxels[int(cell_id)] = voxels
            grid_shape = tuple(max_coord + 1)
        return cell_voxels, grid_shape

    def project_voxels_to_2D(cell_voxels, grid_shape):
        axis_map = {'x': (1, 2), 'y': (0, 2), 'z': (0, 1)}
        ax1, ax2 = axis_map[projection_axis]
        mask = np.zeros((grid_shape[ax1], grid_shape[ax2]), dtype=np.uint8)
        all_voxels = np.concatenate(list(cell_voxels.values()))
        for voxel in all_voxels:
            i, j = voxel[ax1], voxel[ax2]
            if 0 <= i < mask.shape[0] and 0 <= j < mask.shape[1]:
                mask[i, j] = 1
        return mask

    def compute_area_metrics(mask):
        labeled = morphology.label(mask)
        regions = measure.regionprops(labeled)
        if not regions:
            return 0, 0, 0
        region = max(regions, key=lambda r: r.area)
        area_px = region.area
        perimeter_px = region.perimeter
        area_um2 = area_px * voxel_size_um**2
        perimeter_um = perimeter_px * voxel_size_um
        circularity = (4 * np.pi * area_um2) / (perimeter_um**2) if perimeter_um > 0 else 0
        return area_um2, perimeter_um, circularity

    os.makedirs(os.path.dirname(output_file), exist_ok=True)

    records = []
    for fname in sorted(os.listdir(input_folder)):
        if not (fname.endswith(".jld2") and "step_" in fname):
            continue
        step_num = int(fname.split("_")[-1].replace(".jld2", ""))
        file_path = os.path.join(input_folder, fname)

        try:
            cell_voxels, grid_shape = load_cell_voxels_from_jld2(file_path)
            mask = project_voxels_to_2D(cell_voxels, grid_shape)
            area, perimeter, circ = compute_area_metrics(mask)
            records.append({
                "step": step_num,
                "area_um2": area,
                "perimeter_um": perimeter,
                "circularity": circ
            })
        except Exception as e:
            print(f"Error processing {file_path}: {e}")

    if records:
        df = pd.DataFrame(records)
        df.sort_values("step", inplace=True)
        df.to_csv(output_file, index=False)
        print(f"Saved area metrics to: {output_file}")
    else:
        print("No valid JLD2 files found or all failed.")

In [None]:

def plot_all_metrics_from_folder(folder="area_results", max_step=None):
    """
    Plot area over time for all CSV files in a folder.

    Args:
        folder (str): Folder containing area CSV files.
        max_step (int): Optional step cutoff for plotting.
    """

    label_map = {
        "10_1": "10:1",
        "3_1": "3:1",
        "1_1": "1:1",
        "1_3": "1:3",
        "2_1": "2:1",
        "5_1": "5:1",
        "20_1": "20:1",
        "GBM": "GBM",
        "MSC": "MSC"
    }

    files = sorted([f for f in os.listdir(folder) if f.endswith("_metrics.csv")])
    if not files:
        print("No CSV files found.")
        return

    plt.figure(figsize=(15, 5))

    # --- Subplot 1: Area ---
    plt.subplot(1, 3, 1)
    for fname in files:
        df = pd.read_csv(os.path.join(folder, fname))
        if max_step is not None:
            df = df[df["step"] <= max_step]
        prefix = fname.split("_")[0]
        label = label_map.get(prefix, prefix)
        plt.plot(df["step"], df["area_um2"], label=label)
    plt.title("Projected Area (µm²)")
    plt.xlabel("Step")
    plt.ylabel("Area")
    plt.grid(True)

    # --- Subplot 2: Perimeter ---
    plt.subplot(1, 3, 2)
    for fname in files:
        df = pd.read_csv(os.path.join(folder, fname))
        if max_step is not None:
            df = df[df["step"] <= max_step]
        prefix = fname.split("_")[0]
        label = label_map.get(prefix, prefix)
        plt.plot(df["step"], df["perimeter_um"], label=label)
    plt.title("Perimeter (µm)")
    plt.xlabel("Step")
    plt.ylabel("Perimeter")
    plt.grid(True)

    # --- Subplot 3: Circularity ---
    plt.subplot(1, 3, 3)
    for fname in files:
        df = pd.read_csv(os.path.join(folder, fname))
        if max_step is not None:
            df = df[df["step"] <= max_step]
        prefix = fname.split("_")[0]
        label = label_map.get(prefix, prefix)
        plt.plot(df["step"], df["circularity"], label=label)
    plt.title("Circularity")
    plt.xlabel("Step")
    plt.ylabel("Circularity")
    plt.grid(True)

    plt.tight_layout()
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.suptitle("Spheroid Shape Metrics Over Time", fontsize=14, y=1.05)
    plt.show()

In [None]:

def plot_normalized_area_from_folder(folder="area_results", target_start_area=1000.0, max_step=None):
    """
    Plot normalized spheroid area over time from CSV files in a folder.

    Args:
        folder (str): Folder with area CSVs.
        target_start_area (float): Area to normalize the start to (e.g., 1.0).
    """

    label_map = {
        "10_1": "10:1",
        "3_1": "3:1",
        "1_1": "1:1",
        "1_3": "1:3",
        "2_1": "2:1",
        "5_1": "5:1",
        "20_1": "20:1",
        "GBM": "GBM",
        "MSC": "MSC"
    }

    files = sorted([f for f in os.listdir(folder) if f.endswith("_metrics.csv")])
    if not files:
        print("No CSV files found.")
        return

    plt.figure(figsize=(8, 6))

    for fname in files:
        df = pd.read_csv(os.path.join(folder, fname))

        if df.empty or df["area_um2"].iloc[0] == 0:
            print(f"Warning: {fname} skipped due to missing or invalid start area.")
            continue

        if max_step is not None:
            df = df[df["step"] <= max_step]
            if df.empty:
                print(f"Warning: no data within max_step for {fname}")
                continue

        prefix = fname.split("_")[0]
        label = label_map.get(prefix, prefix)

        steps = df["step"].values
        areas = df["area_um2"].values
        normalized_areas = (areas / areas[0]) * target_start_area

        plt.plot(steps, normalized_areas, label=label)

    plt.xlabel("Step")
    plt.ylabel("Normalized Area (µm²)")
    plt.title(f"Normalized Spheroid Area Growth (Start = {target_start_area:.0f} µm²)")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()


In [None]:

def make_cut_view_gif(slice_x, batch, condition):
    """
    Creates an animated GIF showing the evolution of a 2D slice (in the x-plane) 
    through a 3D spheroid simulation across time.

    Parameters:
    - slice_x (int): The x-coordinate at which to slice the 3D grid.
    - batch (str): Name of the batch (used to locate the simulation output).
    - condition (str): Simulation condition (e.g., "gbm_msc_25") used as folder name.

    Workflow:
    - Loads `.jld2` step files from the corresponding simulation output directory.
    - Extracts voxel positions and cell states at the specified slice.
    - Assigns each cell state a unique color and renders the slice at each timepoint.
    - Assembles the resulting PNGs into a GIF stored in `cut_views/`.

    Output:
    - Saves the frame images in `slice_frames/`.
    - Saves the animated GIF in `cut_views/cut_view_{batch}_{condition}.gif`.

    Requirements:
    - Assumes the input files are structured in `sim_output/{batch}/{condition}/sim_output/`
      and follow the `step_XXX.jld2` naming convention.
    """
    # --- Settings ---
    print(f"Creating cut view GIF for batch {batch}, condition {condition} at slice x={slice_x}")
    GRID_SIZE = (70, 70, 70)
    input_dir = f"sim_output/{batch}/{condition}/sim_output"
    output_dir = "slice_frames"
    gif_output_path = f"cut_views/cut_view_{batch}_{condition}.gif"

    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.dirname(gif_output_path), exist_ok=True)

    # --- Automatically get all available steps ---
    step_files = sorted([
        f for f in os.listdir(input_dir)
        if re.match(r"step_\d+\.jld2", f)
    ])
    step_numbers = [int(re.findall(r"\d+", f)[0]) for f in step_files]

    # --- Build color map for cell states ---
    unique_states = set()
    for fname in step_files:
        with h5py.File(os.path.join(input_dir, fname), "r") as f:
            for k in f["cell_states"]:
                unique_states.add(f["cell_states"][k][()].decode("utf-8"))

    unique_states = sorted(unique_states)
    state_to_idx = {s: i + 1 for i, s in enumerate(unique_states)}  # background = 0

    # Create a custom colormap with black background
    base_cmap = plt.get_cmap("tab10")
    colors = [(0, 0, 0)] + [base_cmap(i / len(unique_states))[:3] for i in range(len(unique_states))]
    custom_cmap = ListedColormap(colors)

    # --- Generate slice images ---
    frame_paths = []

    for step, fname in zip(step_numbers, step_files):
        with h5py.File(os.path.join(input_dir, fname), "r") as f:
            cell_voxels = {
                int(k): np.array([list(t) for t in f["cell_voxels"][k][()]]).astype(int)
                for k in f["cell_voxels"]
            }
            cell_states = {
                int(k): f["cell_states"][k][()].decode("utf-8")
                for k in f["cell_states"]
            }

        slice_image = np.zeros((GRID_SIZE[1], GRID_SIZE[2]), dtype=int)  # Background = 0

        for cid, voxels in cell_voxels.items():
            state = cell_states[cid]
            color_idx = state_to_idx[state]
            for x, y, z in voxels:
                if x == slice_x:
                    slice_image[y, z] = color_idx

        # Plot the slice
        plt.figure(figsize=(5, 5))
        im = plt.imshow(slice_image.T, origin='lower', cmap=custom_cmap, vmin=0, vmax=len(unique_states))
        plt.title(f"Step {step} at x={slice_x}")
        plt.axis("off")

        # Create colorbar skipping background
        cbar = plt.colorbar(im, ticks=range(1, len(unique_states)+1))
        cbar.ax.set_yticklabels(unique_states)

        frame_path = os.path.join(output_dir, f"slice_{step:05d}.png")
        plt.savefig(frame_path, bbox_inches='tight')
        plt.close()
        frame_paths.append(frame_path)

    # --- Create GIF ---
    images = [imageio.imread(f) for f in frame_paths]
    imageio.mimsave(gif_output_path, images, fps=2)

    print(f"✅ Done! GIF saved at: {gif_output_path}")


In [None]:
def extract_volume_metrics(input_folder, output_csv, voxel_volume_um3=8.0):
    """
    Extracts 3D volume (μm³) from .jld2 simulation files in the given folder.

    Parameters:
    - input_folder (str): path to folder with .jld2 files
    - output_csv (str): path to output .csv file
    - voxel_volume_um3 (float): volume of one voxel (default 2x2x2 μm = 8.0)
    """

    def load_cell_voxels_from_jld2(file_path):
        with h5py.File(file_path, "r") as f:
            cell_voxels = {
                int(k): f["cell_voxels"][k][()]
                for k in f["cell_voxels"]
            }
        return cell_voxels

    def compute_volume(cell_voxels):
        return sum(len(voxels) for voxels in cell_voxels.values()) * voxel_volume_um3

    os.makedirs(os.path.dirname(output_csv), exist_ok=True)
    records = []

    for fname in sorted(os.listdir(input_folder)):
        if not (fname.endswith(".jld2") and "step_" in fname):
            continue
        step_num = int(fname.split("_")[-1].replace(".jld2", ""))
        file_path = os.path.join(input_folder, fname)

        try:
            cell_voxels = load_cell_voxels_from_jld2(file_path)
            volume_um3 = compute_volume(cell_voxels)
            records.append({"step": step_num, "volume_um3": volume_um3})
        except Exception as e:
            print(f"Error processing {file_path}: {e}")

    if records:
        df = pd.DataFrame(records)
        df.sort_values("step", inplace=True)
        df.to_csv(output_csv, index=False)
        print(f"Saved volume metrics to: {output_csv}")
    else:
        print("No valid volume data extracted.")


In [None]:
def process_condition(condition, folder):
    save_stats(analyze_sim_folder(f"sim_output/{folder}/{condition}/sim_output", max_step=130),
               f"saved_stats/{folder}/stats_{condition}.pkl")
    print(f"Saved stats for {condition} for batch {folder}")
    extract_volume_metrics(
        input_folder=f"sim_output/{folder}/{condition}/sim_output",
        output_csv=f"volume_results/{folder}/{condition}_volume.csv"
    )
    make_cut_view_gif(slice_x=35, batch=folder, condition=condition)

In [4]:
def process_folder(folder, ratios = ["10_1", "3_1", "1_1", "1_3", "2_1", "5_1", "20_1", "GBM", "MSC"]):
    for ratio in ratios:
        process_condition(ratio, folder)