# Cell Shape Classification from Time-lapse Microscopy Images

This notebook analyzes binarized time-lapse microscopy images (8-bit TIF stacks with pixel values 255 or 0) and classifies cells into five categories:

1. **Spiky**: Cells with very low circularity.
2. **Round**: Cells with high circularity and low aspect ratio.
3. **Oscillator**: Cells with rapid area changes over frames.
4. **Fans**: Keratocyte/fan-shaped cells that migrate rapidly perpendicular to their long axis.
5. **Amoeboid**: Cells that do not fit into the above special categories and display typical amoeboid migratory mode.

For more detailed criteria, please see the code below and **Figure S1** and **Figure S2**. 

## Import Required Libraries

We'll use several libraries for image processing, feature extraction, and analysis:

In [1]:
%reset -f

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import os
from skimage import io, measure, morphology, feature, filters, segmentation
from scipy import ndimage
from pathlib import Path
import glob
import tifffile
from tqdm.notebook import tqdm
import matplotlib

# For visualization
plt.style.use('ggplot')
sns.set_context('notebook')


## Load and Prepare Data


In [4]:
matplotlib.use('TKAgg')
plt.ion()
plt.rc("axes",grid=False)

Current_folder = os.getcwd()
Parent_folder = os.path.dirname(Current_folder)
folder_name = "."
file_name = "combined_mask_2rows.tif"
full_path_to_load = os.path.join(folder_name, file_name)

Path_to_Save = os.path.join(folder_name, 'save_folder')
os.makedirs(Path_to_Save, exist_ok=True)


## Interactive Visualization of Image Stack

In [5]:
from matplotlib.widgets import Slider
import matplotlib
matplotlib.use('TKAgg')
plt.ion()


tiff_file = full_path_to_load
info = tifffile.TiffFile(tiff_file)

n_images = len(info.pages)
cols = info.pages[0].shape[1]
rows = info.pages[0].shape[0]




In [None]:


# Load all frames into a numpy array
I_tmp = np.zeros((n_images, rows, cols), dtype=np.uint16)  # Note the dimensions order change for easier slicing

print(f"Loading {n_images} frames from {tiff_file}")
for i in range(n_images):
    I_tmp[i] = tifffile.imread(tiff_file, key=i)

# Set up the figure for interactive visualization
plt.figure(figsize=(10, 8))
ax = plt.subplot(111)
plt.subplots_adjust(bottom=0.2)  # Make room for the slider

# Display the first frame
img_display = plt.imshow(I_tmp[0], cmap='viridis')
plt.colorbar()
# plt.title(f'Frame 0/{n_images-1}')
ax.grid(False)

# Add the slider
ax_slider = plt.axes([0.2, 0.05, 0.65, 0.03])  # [left, bottom, width, height]
slider = Slider(ax_slider, 'Frame', 0, n_images-1, valinit=0, valstep=1)

# Update function for the slider
def update(val):
    frame_idx = int(slider.val)
    img_display.set_data(I_tmp[frame_idx])
    plt.title(f'Frame {frame_idx}/{n_images-1}')
    plt.draw()

slider.on_changed(update)
plt.show()


print(f"Loaded TIFF stack with shape: {I_tmp.shape}")



In [None]:
file_name_label = "LblImg_combined_mask_2rows.tif"
full_path_to_load_label = os.path.join(folder_name, file_name_label)
tiff_file_label = full_path_to_load_label

I_lbl = np.zeros((n_images, rows, cols), dtype=np.uint32)


# Load the labeled image stack
for i in range(n_images):
    I_lbl[i] = tifffile.imread(tiff_file_label, key=i)


print(I_lbl.shape)
print(I_lbl.dtype)
print(np.unique(I_lbl))

# Set up the figure for interactive visualization
plt.figure(figsize=(10, 8))
ax = plt.subplot(111)
plt.subplots_adjust(bottom=0.2)  # Make room for the slider

# Display the first frame
tab20_cmap = plt.cm.get_cmap('tab20').copy()
tab20_cmap.set_under('black')  # Set color for values below vmin
img_display = plt.imshow(I_lbl[0], cmap=tab20_cmap, vmin=0.5)
plt.colorbar()
# plt.title(f'Frame 0/{n_images-1}')
ax.grid(False)

# Add the slider
ax_slider = plt.axes([0.2, 0.05, 0.65, 0.03])  # [left, bottom, width, height]
slider = Slider(ax_slider, 'Frame', 0, n_images-1, valinit=0, valstep=1)

# Update function for the slider
def update(val):
    frame_idx = int(slider.val)
    img_display.set_data(I_lbl[frame_idx])
    plt.title(f'Frame {frame_idx}/{n_images-1}')
    plt.draw()

slider.on_changed(update)
plt.show()


print(f"Loaded TIFF stack with shape: {I_lbl.shape}")



In [None]:
import time

# Process the label image stack to extract region properties for each cell
print("Extracting region properties from labeled images...")
min_cell_area = 120  # Minimum cell area to filter out noise

# Initialize lists to store results
all_region_props = []

for frame_idx in tqdm(range(n_images), desc="Processing frames"):
    # Get current frame
    label_img = I_lbl[frame_idx]


    # Get region properties
    props = measure.regionprops(label_img)

    # Extract and store features for each cell
    for prop in props:
        # Skip very small regions
        if prop.area < min_cell_area:
            continue

        # Calculate circularity
        circularity = (4 * np.pi * prop.area) / (prop.perimeter**2) if prop.perimeter > 0 else 0

        # Calculate aspect ratio
        aspect_ratio = prop.major_axis_length / prop.minor_axis_length if prop.minor_axis_length > 0 else 1

        # Store properties
        cell_data = {
            'frame': frame_idx,
            'label': prop.label,
            'area': prop.area,
            'perimeter': prop.perimeter,
            'centroid_y': prop.centroid[0],
            'centroid_x': prop.centroid[1],
            'bbox': prop.bbox,
            'eccentricity': prop.eccentricity,
            'orientation': prop.orientation,
            'major_axis_length': prop.major_axis_length,
            'minor_axis_length': prop.minor_axis_length,
            'aspect_ratio': aspect_ratio,
            'circularity': circularity,
            'solidity': prop.solidity
        }
        all_region_props.append(cell_data)

    # Show each frame with a brief delay
    plt.figure(figsize=(8, 6))
    plt.imshow(label_img, cmap='viridis')
    # Mark centroids on the image
    for prop in props:
        if prop.area >= min_cell_area:
            y, x = prop.centroid
            plt.plot(x, y, 'ro', markersize=5)
    plt.colorbar()
    plt.title(f"Label Image - Frame {frame_idx}")
    plt.show(block=False)
    plt.pause(1.5)  # Display for half a second
    plt.close()

# Convert to DataFrame
region_props_df = pd.DataFrame(all_region_props)

# Display summary
print(f"Extracted properties for {len(region_props_df)} cell instances across {n_images} frames")
print(f"Number of unique cell labels: {region_props_df['label'].nunique()}")

excel_path = os.path.join(Path_to_Save, 'region_properties.xlsx')

region_props_df.to_excel(excel_path, index=False)


## Visualize Fitted Ellipses for Cells in Each Frame
def visualize_cells_with_ellipses():
    """
    Visualize the labeled cells with fitted ellipses based on region properties
    """
    # Create a figure for visualization
    plt.figure(figsize=(10, 8))
    ax = plt.subplot(111)
    plt.subplots_adjust(bottom=0.2)  # Make room for the slider

    # Set up initial display (frame 0)
    frame_idx = 0
    label_img = I_lbl[frame_idx]
    img_display = plt.imshow(label_img, cmap='viridis')
    plt.colorbar()
    plt.title(f'Cell Ellipses - Frame {frame_idx}/{n_images - 1}')
    ax.grid(False)

    # Lists to track plot elements for cleanup
    cell_markers = []
    cell_labels = []
    ellipses = []

    # Function to update the display
    def update(val):
        frame_idx = int(slider.val)

        # Update image
        label_img = I_lbl[frame_idx]
        img_display.set_data(label_img)
        # plt.title(f'Cell Ellipses - Frame {frame_idx}/{n_images - 1}')

        # Clear previous elements
        for marker in cell_markers:
            marker.remove()
        for label in cell_labels:
            label.remove()
        for ellipse in ellipses:
            ellipse.remove()

        cell_markers.clear()
        cell_labels.clear()
        ellipses.clear()

        # Get properties for cells in this frame
        frame_props = region_props_df[region_props_df['frame'] == frame_idx]

        for _, prop in frame_props.iterrows():
            # Add centroid marker
            x, y = prop['centroid_x'], prop['centroid_y']
            marker = ax.plot(x, y, 'ro', markersize=5)[0]
            cell_markers.append(marker)

            # Add cell ID label
            label = ax.text(x - 15, y - 15, str(prop['label']), color='white', fontsize=10,
                            bbox=dict(facecolor='black', alpha=0.7))
            cell_labels.append(label)

            # Create and add ellipse
            major_axis_length = prop['major_axis_length']
            minor_axis_length = prop['minor_axis_length']
            orientation = prop['orientation']

            # Convert orientation to degrees for matplotlib
            angle_deg = np.degrees(orientation)

            ellipse = matplotlib.patches.Ellipse(
                (x, y),
                major_axis_length,
                minor_axis_length,
                angle=angle_deg,
                fill=False,
                edgecolor='yellow',
                linewidth=2
            )
            ax.add_patch(ellipse)
            ellipses.append(ellipse)

        plt.draw()

    # Add the slider
    ax_slider = plt.axes([0.2, 0.05, 0.65, 0.03])  # [left, bottom, width, height]
    slider = Slider(ax_slider, 'Frame', 1, n_images, valinit=0, valstep=1)

    # Initial update
    update(0)

    # Register the update function with the slider
    slider.on_changed(update)

    plt.show()



## Oscillator Cells


In [None]:
region_props_df['cell_type'] = "Unknown"


# Sort dataframe by label and frame to ensure chronological order for each cell
region_props_df_sorted = region_props_df.sort_values(['label', 'frame'])

# Define window size for rolling calculation
window_size = 5  # Number of frames to include in each window

# Create a dictionary to store the results
cov_results = {}

# For each unique cell
for label in region_props_df_sorted['label'].unique():
    # Get data for this specific cell
    cell_data = region_props_df_sorted[region_props_df_sorted['label'] == label]

    if len(cell_data) >= window_size:
        # Calculate rolling mean and std
        rolling_mean = cell_data['area'].rolling(window=window_size, min_periods=1).mean()
        rolling_std = cell_data['area'].rolling(window=window_size, min_periods=1).std()

        # Calculate rolling COV = (std / mean) * 100
        rolling_cov = (rolling_std / rolling_mean) * 100

        # Store results with corresponding frame numbers
        for i, (idx, frame) in enumerate(zip(cell_data.index, cell_data['frame'])):
            cov_results[idx] = rolling_cov.iloc[i]

# Add the rolling COV to the main dataframe
region_props_df_sorted['rolling_area_cov'] = pd.Series(cov_results)

# Calculate the overall COV for each cell
area_stats = region_props_df_sorted.groupby('label')['area'].agg(['mean', 'std']).reset_index()
area_stats['area_cov'] = (area_stats['std'] / area_stats['mean']) * 100

# Merge the overall COV values back into the main dataframe
region_props_df_sorted = region_props_df_sorted.merge(area_stats[['label', 'area_cov']], on='label', how='left')

# Fill NaN values in rolling_area_cov with the overall area_cov
region_props_df_sorted['rolling_area_cov'] = region_props_df_sorted['rolling_area_cov'].fillna(region_props_df_sorted['area_cov'])


oscillator_mask = region_props_df_sorted['rolling_area_cov'] > 18

region_props_df_sorted.loc[oscillator_mask, 'cell_type'] = "Oscillator"

# Display summary of classified cells
print(f"Total cells: {len(region_props_df_sorted)}")
print(f"Oscillator cells: {(region_props_df_sorted['cell_type'] == 'Oscillator').sum()} ({(region_props_df_sorted['cell_type'] == 'Oscillator').sum()/len(region_props_df_sorted)*100:.2f}%)")
print(f"Unclassified cells: {(region_props_df_sorted['cell_type'] == 'Unknown').sum()} ({(region_props_df_sorted['cell_type'] == 'Unknown').sum()/len(region_props_df_sorted)*100:.2f}%)")




Total cells: 586
Oscillator cells: 107 (18.26%)
Unclassified cells: 479 (81.74%)


## Spiky Cells

In [None]:


# Classify spiky cells based on circularity threshold
spiky_mask = (region_props_df_sorted['cell_type'] == "Unknown") & (region_props_df_sorted['circularity']  < 0.4)
region_props_df_sorted.loc[spiky_mask, 'cell_type'] = "Spiky"

# Display summary of classified cells
print(f"Total cells: {len(region_props_df_sorted)}")
print(f"Spiky cells: {(region_props_df_sorted['cell_type'] == 'Spiky').sum()} ({(region_props_df_sorted['cell_type'] == 'Spiky').sum()/len(region_props_df_sorted)*100:.2f}%)")

print(f"Unclassified cells: {(region_props_df_sorted['cell_type'] == 'Unknown').sum()} ({(region_props_df_sorted['cell_type'] == 'Unknown').sum()/len(region_props_df_sorted)*100:.2f}%)")

## Round cells

In [None]:
# Add a 'cell_type' column based on Aspect Ratio
# Cells with AR < 1.2 and circularity > 0.4 are classified as "Spiky"


# Classify spiky cells based on circularity threshold
rounded_mask = (region_props_df_sorted['cell_type'] == "Unknown") & (region_props_df_sorted['aspect_ratio'] < 1.2)
region_props_df_sorted.loc[rounded_mask, 'cell_type'] = "Round"

# Display summary of classified cells
print(f"Total cells: {len(region_props_df_sorted)}")


print(f"Round cells: {(region_props_df_sorted['cell_type'] == 'Round').sum()} ({(region_props_df_sorted['cell_type'] == 'Round').sum()/len(region_props_df_sorted)*100:.2f}%)")
print(f"Spiky cells: {(region_props_df_sorted['cell_type'] == 'Spiky').sum()} ({(region_props_df_sorted['cell_type'] == 'Spiky').sum()/len(region_props_df_sorted)*100:.2f}%)")
print(f"Oscillator cells: {(region_props_df_sorted['cell_type'] == 'Oscillator').sum()} ({(region_props_df_sorted['cell_type'] == 'Oscillator').sum()/len(region_props_df_sorted)*100:.2f}%)")
print(f"Unclassified cells: {(region_props_df_sorted['cell_type'] == 'Unknown').sum()} ({(region_props_df_sorted['cell_type'] == 'Unknown').sum()/len(region_props_df_sorted)*100:.2f}%)")

## Fan (Keratocyte-like) Cells -- First compute the direction (angle) of migration

In [None]:
# Calculate migration vectors between consecutive frames for each cell
print("Calculating migration vectors and orientation angles...")

# Create dictionary to track displacement for each cell
migration_data = {}

# For each cell, calculate displacement between frames
for label in tqdm(region_props_df_sorted['label'].unique(), desc="Analyzing cell migration"):
    # Get data for this specific cell
    cell_data = region_props_df_sorted[region_props_df_sorted['label'] == label]

    # Skip cells with fewer than 2 frames
    if len(cell_data) < 2:
        continue


    # Iterate through frames (except first)
    for i in range(1, len(cell_data)):


        # Calculate displacement vector
        dx = cell_data.iloc[i]['centroid_x'] - cell_data.iloc[i-1]['centroid_x']
        dy = cell_data.iloc[i]['centroid_y'] - cell_data.iloc[i-1]['centroid_y']

        # Calculate displacement angle (in radians)
        displacement_angle = np.arctan2(dy, dx)

        displacement_angle_deg = np.degrees(displacement_angle)

        # Get cell orientation (angle of major axis with respect to x-axis)
        orientation = cell_data.iloc[i]['orientation']

        orientation_deg = np.degrees(orientation)




        # Calculate angle difference based on sign
        if (orientation_deg >= 0 and displacement_angle_deg >= 0) or (orientation_deg < 0 and displacement_angle_deg < 0):
            # Same sign - subtract absolute values
            angle_diff_deg = abs(abs(orientation_deg) - abs(displacement_angle_deg))
        else:
            # Different sign - add absolute values
            angle_diff_deg = abs(orientation_deg) + abs(displacement_angle_deg)

        if angle_diff_deg > 90 or angle_diff_deg < -90:
            angle_diff_deg = abs(180 - angle_diff_deg)

        


        idx = cell_data.iloc[i].name
        migration_data[idx] = {
            'displacement_angle': displacement_angle_deg,  # convert to degrees
            'major_axis_angle': orientation_deg,  # convert to degrees
            'angle_difference': angle_diff_deg  # convert to degrees
        }



# Add migration data to the main dataframe
for col in ['displacement_angle', 'major_axis_angle', 'angle_difference']:
    region_props_df_sorted[col] = pd.Series({k: v[col] for k, v in migration_data.items() if col in v})



print("COMPLETE")






## Fan (Keratocyte-like) Cells -- Next compute the (rolling) angle differences between orientation and displacement


In [None]:
# Create a copy of the original dataframe
region_props_df_sorted_updated = region_props_df_sorted.copy()

# First fill NaN values in angle_difference column with backward fill within each cell group
region_props_df_sorted_updated['angle_difference'] = region_props_df_sorted.groupby('label')['angle_difference'].transform(
    lambda x: x.fillna(method='bfill')
)

# Now apply rolling average with window size 4 to each cell group separately
region_props_df_sorted_updated['rolling_angle_difference'] = region_props_df_sorted_updated.groupby('label')['angle_difference'].transform(
    lambda x: x.rolling(window=7, min_periods=1).mean()
)

# Fill any remaining NaN values in the rolling average column
region_props_df_sorted_updated['rolling_angle_difference'] = region_props_df_sorted_updated.groupby('label')['rolling_angle_difference'].transform(
    lambda x: x.fillna(method='bfill')
)





## Fan (Keratocyte-like) Cells -- Finally compute the rolling average of displacement angle and using combined criteria label the fan cells

In [18]:
cov_results_angle = {}
window_size = 7
# For each cell, calculate displacement between frames

for label in region_props_df_sorted_updated['label'].unique():
    # Get data for this specific cell
    cell_data_angle = region_props_df_sorted_updated[region_props_df_sorted_updated['label'] == label]


    if len(cell_data_angle) >= window_size:
        # Calculate rolling mean and std
        rolling_mean_angle = cell_data_angle['displacement_angle'].rolling(window=window_size, min_periods=1).mean()
        rolling_std_angle = cell_data_angle['displacement_angle'].rolling(window=window_size, min_periods=1).std()

        # Calculate rolling COV = (std / mean) * 100
        rolling_cov_angle = (rolling_std_angle / rolling_mean_angle) * 100

        # Store results with corresponding frame numbers
        for i, (idx, frame) in enumerate(zip(cell_data_angle.index, cell_data_angle['frame'])):
            cov_results_angle[idx] = rolling_cov_angle.iloc[i]

# Add the rolling COV to the main dataframe
region_props_df_sorted_updated['rolling_angle_cov'] = pd.Series(cov_results_angle)

# Calculate the overall COV for each cell
angle_stats = region_props_df_sorted_updated.groupby('label')['displacement_angle'].agg(['mean', 'std']).reset_index()
angle_stats['angle_cov'] = (angle_stats['std'] / angle_stats['mean']) * 100

# Merge the overall COV values back into the main dataframe
# Check if angle_cov or angle_cov_x already exists and drop them
if 'angle_cov' in region_props_df_sorted_updated.columns:
    region_props_df_sorted_updated = region_props_df_sorted_updated.drop('angle_cov', axis=1)
if 'angle_cov_x' in region_props_df_sorted_updated.columns:
    region_props_df_sorted_updated = region_props_df_sorted_updated.drop('angle_cov_x', axis=1)
if 'angle_cov_y' in region_props_df_sorted_updated.columns:
    region_props_df_sorted_updated = region_props_df_sorted_updated.drop('angle_cov_y', axis=1)

# Now do the merge
region_props_df_sorted_updated = region_props_df_sorted_updated.merge(
    angle_stats[['label', 'angle_cov']],
    on='label',
    how='left'
)

# Fill NaN values in rolling_area_cov with the overall area_cov
region_props_df_sorted_updated['rolling_angle_cov'] = region_props_df_sorted_updated['rolling_angle_cov'].fillna(region_props_df_sorted_updated['angle_cov'])


fan_mask = (region_props_df_sorted_updated['cell_type'] == "Unknown") & (region_props_df_sorted_updated['rolling_angle_difference'] > 40) & (abs(region_props_df_sorted_updated['rolling_angle_cov']) < 100)

region_props_df_sorted_updated.loc[fan_mask, 'cell_type'] = "Fan"


## Replace "Unknown" with "Amoeboid" in the cell_type column

In [19]:
region_props_df_sorted_updated.loc[region_props_df_sorted_updated['cell_type'] == "Unknown", 'cell_type'] = "Amoeboid"

## Plot labeled cells with slider

In [None]:

# Create a figure for the visualization with slider
fig, ax = plt.figure(figsize=(10, 8)), plt.subplot(111)
plt.subplots_adjust(bottom=0.2)  # Make room for the slider

# Set up initial display (frame 0)
frame_idx = 0
label_img = I_lbl[frame_idx]

img_display = ax.imshow(label_img, cmap='viridis')
ax.set_title(f"Labeled Cells - Frame {frame_idx}")

# Get cells in the initial frame
frame_cells = region_props_df_sorted_updated[region_props_df_sorted_updated['frame'] == frame_idx]
cell_markers = []
cell_type_labels = []
cell_id_labels = []

# Plot initial cell markers and labels
for _, cell in frame_cells.iterrows():
    x, y = cell['centroid_x'], cell['centroid_y']
    cell_type = cell['cell_type']
    cell_id = cell['label']

    # Choose color based on cell type
    marker_color = 'red' if cell_type == 'Spiky' else 'yellow' if cell_type == 'Round' else 'blue' if cell_type == 'Oscillator' else 'green' if cell_type == 'Fan' else 'purple'

    # Add marker
    cell_marker = ax.plot(x, y, 'o', color=marker_color, markersize=5)[0]
    cell_markers.append(cell_marker)

    # Add cell type label
    type_label = ax.text(x + 40, y + 40, cell_type, color='white', fontsize=8,
                         bbox=dict(facecolor='black', alpha=0.5))
    cell_type_labels.append(type_label)

    # Add cell ID label
    id_label = ax.text(x - 40, y, str(cell_id), color='red', fontsize=12)
    cell_id_labels.append(id_label)


# Register the update function with the slider
slider.on_changed(update)

# Display the figure
plt.show()

# Add colorbar
plt.colorbar(img_display)

# Create the slider
ax_slider = plt.axes([0.2, 0.05, 0.65, 0.03])  # [left, bottom, width, height]
slider = Slider(ax_slider, 'Frame', 0, n_images-1, valinit=0, valstep=1)

# Update function for the slider
def update(val):
    # Get the current frame index
    frame_idx = int(slider.val)

    # Update image
    label_img = I_lbl[frame_idx]
    img_display.set_data(label_img)
    ax.set_title(f"Labeled Cells - Frame {frame_idx}")

    # Clear previous cell markers and labels
    for marker in cell_markers:
        marker.remove()
    for label in cell_type_labels:
        label.remove()
    for label in cell_id_labels:
        label.remove()

    cell_markers.clear()
    cell_type_labels.clear()
    cell_id_labels.clear()

    # Get cells in the current frame
    frame_cells = region_props_df_sorted_updated[region_props_df_sorted_updated['frame'] == frame_idx]

    # Add new markers and labels
    for _, cell in frame_cells.iterrows():
        x, y = cell['centroid_x'], cell['centroid_y']
        cell_type = cell['cell_type']
        cell_id = cell['label']

        # Choose color based on cell type
        marker_color = 'red' if cell_type == 'Spiky' else 'yellow' if cell_type == 'Round' else 'blue' if cell_type == 'Oscillator' else 'green' if cell_type == 'Fan' else 'purple'

        # Add marker
        cell_marker = ax.plot(x, y, 'o', color=marker_color, markersize=5)[0]
        cell_markers.append(cell_marker)

        # Add cell type label
        type_label = ax.text(x + 40, y + 40, cell_type, color='white', fontsize=8,
                             bbox=dict(facecolor='black', alpha=0.5))
        cell_type_labels.append(type_label)

        # Add cell ID label
        id_label = ax.text(x - 40, y, str(cell_id), color='red', fontsize=12)
        cell_id_labels.append(id_label)

    fig.canvas.draw_idle()

# Register the update function with the slider
slider.on_changed(update)

# Display the figure
plt.show()

## Save the updated DataFrame to Excel

In [21]:
region_props_df_sorted_updated.to_excel(os.path.join(Path_to_Save, 'region_properties_updated.xlsx'), index=False)

## Plotting the distributions

In [None]:
sns.set_style("ticks", {'axes.grid' : False})


# Create KDE plot for circularity distribution
plt.figure(figsize=(10, 6))

# Plot KDE for all cells
sns.kdeplot(
    data=region_props_df_sorted_updated, 
    x='circularity',
    label='All Cells', 
    fill=True,
    bw_adjust=0.75,
    alpha=0.3
)

# Plot KDE for spiky cells
sns.kdeplot(
    data=region_props_df_sorted_updated[region_props_df_sorted_updated['cell_type'] == 'Spiky'],
    x='circularity',
    label='Spiky Cells',
    fill=True,
    bw_adjust=0.75,
    alpha=0.3
)

# Add vertical line at the circularity threshold value (0.4)
plt.axvline(x=0.4, color='black', linestyle='--', alpha=0.7, label='Spiky Threshold (0.4)')

# Add labels and legend
plt.xlabel('Circularity', fontsize=14)
plt.ylabel('Density', fontsize=14)
plt.title('Circularity Distribution - All Cells vs. Spiky Cells', fontsize=16)
plt.legend()
plt.tight_layout()
sns.despine()
plt.show()


# Create KDE plot for circularity distribution
plt.figure(figsize=(10, 6))

# Plot KDE for all cells
sns.kdeplot(
    data=region_props_df_sorted_updated,
    x='aspect_ratio',
    label='All Cells',
    fill=True,
    bw_adjust=0.75,
    alpha=0.3
)

# Plot KDE for spiky cells
sns.kdeplot(
    data=region_props_df_sorted_updated[region_props_df_sorted_updated['cell_type'] == 'Round'],
    x='aspect_ratio',
    label='Round Cells',
    fill=True,
    bw_adjust=0.75,
    alpha=0.3
)


# Add labels and legend
plt.xlabel('Aspect Raio', fontsize=14)
plt.ylabel('Density', fontsize=14)
plt.title('Aspect Raio Distribution - All Cells vs. Round Cells', fontsize=16)
plt.legend()
plt.tight_layout()
sns.despine()
plt.show()



# Create KDE plot for circularity distribution
plt.figure(figsize=(10, 6))

# Plot KDE for all cells
sns.kdeplot(
    data=region_props_df_sorted_updated,
    x='rolling_area_cov',
    label='All Cells',
    fill=True,
    bw_adjust=0.75,
    alpha=0.3
)

# Plot KDE for spiky cells
sns.kdeplot(
    data=region_props_df_sorted_updated[region_props_df_sorted_updated['cell_type'] == 'Oscillator'],
    x='rolling_area_cov',
    label='Oscillator Cells',
    fill=True,
    bw_adjust=0.75,
    alpha=0.3
)



# Add labels and legend
plt.xlabel('Rolling Area COV', fontsize=14)
plt.ylabel('Density', fontsize=14)
plt.title('Area COV Distribution - All Cells vs. Oscillator Cells', fontsize=16)
plt.legend()
plt.tight_layout()
sns.despine()
plt.show()


# Create KDE plot for circularity distribution
plt.figure(figsize=(10, 6))

# Plot KDE for all cells
sns.kdeplot(
    data=region_props_df_sorted_updated,
    x='rolling_angle_difference',
    label='All Cells',
    fill=True,
    bw_adjust=0.75,
    alpha=0.3
)

# Plot KDE for spiky cells
sns.kdeplot(
    data=region_props_df_sorted_updated[region_props_df_sorted_updated['cell_type'] == 'Fan'],
    x='rolling_angle_difference',
    label='Fan Cells',
    fill=True,
    bw_adjust=0.75,
    alpha=0.3
)


# Add labels and legend
plt.xlabel('Angle', fontsize=14)
plt.ylabel('Density', fontsize=14)
plt.title('Angle Distribution - All Cells vs. Fan Cells', fontsize=16)
plt.legend()
plt.tight_layout()
sns.despine()
plt.show()



# Create KDE plot for circularity distribution
plt.figure(figsize=(10, 6))

# Plot KDE for all cells
sns.kdeplot(
    data=region_props_df_sorted_updated,
    x='rolling_angle_cov',
    label='All Cells',
    fill=True,
    bw_adjust=0.75,
    alpha=0.3
)

plt.xlim(-2000, 2000)

# Plot KDE for spiky cells
sns.kdeplot(
    data=region_props_df_sorted_updated[region_props_df_sorted_updated['cell_type'] == 'Fan'],
    x='rolling_angle_cov',
    label='Fan Cells',
    fill=True,
    bw_adjust=0.75,
    alpha=0.3
)


# Add labels and legend
plt.xlabel('Rolling Angle COV', fontsize=14)
plt.ylabel('Density', fontsize=14)
plt.title('Angle COV Distribution - All Cells vs. Fan Cells', fontsize=16)
plt.legend()
plt.tight_layout()
sns.despine()
plt.show()

## Plot simple heatmap showing the types of each labeled cell over time

In [None]:
cell_type_modes = region_props_df_sorted_updated.groupby('label')['cell_type'].agg(
    lambda x: x.mode().iloc[0] if not x.mode().empty else "Unknown"
).to_dict()

# Create a pivot table for the heatmap
# First, create a numerical encoding for cell_type
cell_type_map = {
    'Spiky': 1,
    'Round': 2,
    'Oscillator': 3,
    'Fan': 4,
    'Amoeboid': 5
}

# Add a numerical representation of cell_type
region_props_df_sorted_updated['cell_type_code'] = region_props_df_sorted_updated['cell_type'].map(cell_type_map)

# Create pivot table: frames on x-axis, cell labels on y-axis, values are cell types
pivot_df = region_props_df_sorted_updated.pivot_table(
    values='cell_type_code',
    index='label',
    columns='frame',
    fill_value=0  # Use 0 for missing values (will be displayed as black)
)

# Sort the index to have cell labels in ascending order
pivot_df = pivot_df.sort_index()

# Filter to include only frames up to 30
pivot_df = pivot_df.loc[:, pivot_df.columns < 30]

# Create the heatmap
plt.figure(figsize=(10, 5))

# Custom colormap with black for missing values (0)
cmap = plt.cm.get_cmap('Paired', 5).copy()
cmap.set_under('black')  # Set color for values below vmin

# Create heatmap
sns.heatmap(pivot_df,
            cmap=cmap,
            vmin=0.5,  # Start just above 0 to have 0 values colored black
            vmax=5.5,  # Just above max value
            cbar_kws={'ticks': [1, 2, 3, 4, 5], 'location': 'right', 'pad': 0.1},
            linewidth = 0.5)

# Customize colorbar labels
cbar = plt.gca().collections[0].colorbar
cbar.set_ticklabels(['Spiky', 'Round', 'Oscillator', 'Fan', 'Amoeboid'])

# Get the actual frame numbers from the column names of pivot_df
frame_numbers = pivot_df.columns.tolist()

# Set x-ticks to show frame numbers at specific intervals
step = 5  # Show every 5th frame
x_positions = np.arange(0, len(frame_numbers), step)
x_labels = [frame_numbers[i] for i in x_positions if i < len(frame_numbers)]
plt.xticks(x_positions, x_labels, rotation=45)
plt.tick_params(axis='x', length=10, width=2, color='black')


# Add cell_type_modes labels on the right side of the heatmap
ax = plt.gca()
pos = ax.get_position()
cell_labels = pivot_df.index.tolist()

# Add text for each cell label showing its most common type
for i, cell_label in enumerate(cell_labels):
    if cell_label in cell_type_modes:
        plt.text(len(frame_numbers) + 0.5, i + 0.5,
                 f"{cell_type_modes[cell_label]}",
                 ha='left', va='center', fontsize=9,
                 bbox=dict(facecolor='white', alpha=0.7, pad=2))


# Set title and labels
plt.title('Cell Type Classification Over Time', fontsize=16)
plt.xlabel('Frame Number', fontsize=14)
plt.ylabel('Cell Label', fontsize=14)

plt.tight_layout()
plt.show()

## Create a multi-frame RGB TIFF stack from the complete visualization


In [None]:
# Create a multi-frame RGB TIFF stack from the complete visualization
print("Creating RGB TIFF stack from complete visualization...")

import tempfile
import os
from PIL import Image
import io

n_frames, height, width = I_lbl.shape
rgb_stack = np.zeros((n_frames, height, width, 3), dtype=np.uint8)

# Create a temporary directory to store individual frame images
temp_dir = tempfile.mkdtemp()

# Process each frame
for frame_idx in tqdm(range(n_frames)):
    # Create a new figure
    plt.figure(figsize=(12, 12))
    ax = plt.subplot(111)

    # Plot the labeled image with viridis colormap
    label_img = I_lbl[frame_idx]
    img_display = ax.imshow(label_img, cmap='viridis')
    ax.set_title(f"Labeled Cells - Frame {frame_idx}")

    # Get cells in this frame and add markers/labels
    frame_cells = region_props_df_sorted_updated[region_props_df_sorted_updated['frame'] == frame_idx]

    # Add cell markers and labels
    for _, cell in frame_cells.iterrows():
        x, y = cell['centroid_x'], cell['centroid_y']
        cell_type = cell['cell_type']
        cell_id = cell['label']

        # Choose color based on cell type
        marker_color = 'red' if cell_type == 'Spiky' else 'yellow' if cell_type == 'Round' else 'blue' if cell_type == 'Oscillator' else 'green' if cell_type == 'Fan' else 'purple'

        # Add marker and labels
        ax.plot(x, y, 'o', color=marker_color, markersize=4)
        ax.text(x + 40, y + 40, cell_type, color='white', fontsize=4,
                bbox=dict(facecolor='black', alpha=0.5))
        ax.text(x - 40, y, str(cell_id), color='red', fontsize=12)

    # Add colorbar and adjust layout
    # plt.colorbar(img_display)
    plt.tight_layout()

    # Save the figure to a temporary file with high DPI
    temp_file = os.path.join(temp_dir, f"frame_{frame_idx:03d}.png")
    plt.savefig(temp_file, dpi=600, bbox_inches='tight')
    plt.close()

    # Read the saved image
    img = Image.open(temp_file)

    # Resize to match original dimensions
    img_resized = img.resize((width, height), Image.LANCZOS)

    # Convert to numpy array
    img_array = np.array(img_resized)

    # Add to stack (handle RGB or RGBA)
    if img_array.shape[2] == 4:  # RGBA
        rgb_stack[frame_idx] = img_array[:, :, :3]
    else:  # RGB
        rgb_stack[frame_idx] = img_array

    # Clean up the temporary file
    os.remove(temp_file)

# Save the RGB visualization stack
visualization_path = os.path.join(Path_to_Save, 'complete_visualization_draft_updated.tif')
tifffile.imwrite(visualization_path, rgb_stack, metadata={'axes': 'TYXC'})
print(f"Saved complete visualization stack to: {visualization_path}")

# Clean up the temporary directory
os.rmdir(temp_dir)

# Create a preview of the first frame
plt.figure(figsize=(12, 12))
plt.imshow(rgb_stack[9])
plt.title('Complete Visualization (Frame 10)')
plt.axis('off')
plt.tight_layout()
plt.savefig(os.path.join(Path_to_Save, 'complete_visualization_preview_1.png'), dpi=600)
plt.close()

plt.figure(figsize=(12, 12))
plt.imshow(rgb_stack[14])
plt.title('Complete Visualization (Frame 15)')
plt.axis('off')
plt.tight_layout()
plt.savefig(os.path.join(Path_to_Save, 'complete_visualization_preview_2.png'), dpi=600)
plt.close()
