In [None]:
%load_ext autoreload
%autoreload 2

import os

os.environ['OMP_NUM_THREADS'] = '16'
os.environ['MKL_NUM_THREADS'] = '16'
os.environ['OPENBLAS_NUM_THREADS'] = '16'
os.environ['NUMEXPR_MAX_THREADS'] = '16'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import sys
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import glob
import pickle
import copy
import imageio
import tensorflow as tf
tf.config.threading.set_intra_op_parallelism_threads(16)
tf.config.threading.set_inter_op_parallelism_threads(16)

import wellmap

sys.path.append('utils/')
import live_utils

from deepcell.applications import NuclearSegmentation, CellTracking

In [None]:
# Warm up segmentation and tracking models
app = NuclearSegmentation.from_version("1.1")
tracker = CellTracking.from_version("1.1")

In [None]:
experiment_date = 'example'

metadata = wellmap.load(f'data/wellmap/{experiment_date}.toml')

In [None]:
import shutil

for _, row in tqdm(metadata.iterrows()):

    nuc_area = row['nuc_area']
    dilation_size = row['dilation_size']

    image_directory = row['path_to_live']
    output_directory = row['output_dir']
    dirname = output_directory+'/processed'
    if not os.path.isdir(dirname):
        os.makedirs(dirname, exist_ok=True)

    nuc_channel = row['live_nuc_channel']
    channel_to_register = row['registration_channel']
    registration_type = row['registration_type']

    file = glob.glob(f"{row['path_to_live']}/Well{row['well0']}*.nd2")[0]

    if not os.path.isdir(image_directory):
        raise FileNotFoundError("That image directory doesn't exist. Try again.")
    
    if not os.path.isfile(file):
        raise FileNotFoundError("That image file doesn't exist. Try again.")
    
    if nuc_channel is None:
        raise ValueError("Please indicate the nuclear channel.")

    if nuc_area is None:
        raise ValueError("Please indicate the nuclear area to filter out.")

    if channel_to_register is None:
        raise ValueError("Please indicate the channel you would like to use for registration.")

    if not os.path.exists(f'{output_directory}/{os.path.basename(file)}'):
        shutil.copy2(file, f'{output_directory}/{os.path.basename(file)}')

    file = f'{output_directory}/{os.path.basename(file)}'
    
    # Get input and output names for writing the files
    filename = live_utils.get_outdirs(file)
    
    # Get microns per pixel and image sizes from metadata
    mpp, image_sizes = live_utils.get_mpp_from_nd2(file)        
    
    # Make numpy array of the image and move channel axis to the end
    full_image = live_utils.get_npy(file, row)

    registered_image =  live_utils.register_image_stack(full_image)

    if os.path.exists(f'{output_directory}/{filename}_nuc_mask.npy'):
        
        registered_mask = np.load(f'{output_directory}/{filename}_nuc_mask.npy')
        
    else:
        # Make and clean the nuclear mask
        registered_mask = live_utils.segment_and_clean(registered_image, mpp, nuc_area=nuc_area, nuc_channel=nuc_channel, segment_app=app)

    if row['save_mask']:
        np.save(f'{output_directory}/{filename}_nuc_mask.npy', registered_mask)

    # Register the image and mask stacks
    
    live_utils.track_and_pickle(registered_image, registered_mask, nuc_channel=nuc_channel, tracker=tracker, dirname=dirname, filename=filename, dilation_size=dilation_size)

If you want to visualize how well the tracking worked, the script below will do a good job! It'll iterate over all your images and spit out two subplots (left is the full image, right is the tracked masks)

In [None]:
def adjust_contrast_limits(image, vmin=None, vmax=None):
    """
    Adjust the contrast limits of an image using linear scaling.
    """
    if vmin is None:
        vmin = np.percentile(image, 2)  # Default to the 2nd percentile as lower limit
    if vmax is None:
        vmax = np.percentile(image, 99.99)  # Default to the 98th percentile as upper limit
    # Perform linear scaling
    return np.clip((image - vmin) / (vmax - vmin), 0.1, 1)

def plot(x, y, cmap=None, vmax=None):
    yy = copy.deepcopy(y)
    xx = copy.deepcopy(x)

    yy = np.ma.masked_equal(yy, 0)

    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    ax[0].imshow(xx, cmap='gray', vmax=5000)
    ax[0].axis('off')
    ax[0].set_title('Raw')
    ax[1].imshow(yy, cmap=cmap, vmax=vmax)
    ax[1].set_title('Tracked')
    ax[1].axis('off')

    fig.canvas.draw()  # draw the canvas, cache the renderer
    image = np.frombuffer(fig.canvas.buffer_rgba(), dtype='uint8')
    image = image.reshape(fig.canvas.get_width_height()[::-1] + (4,))
    plt.close(fig)

    return image

folder_to_open = metadata['output_dir']+'/processed'
pickles_to_open = glob.glob(f'{folder_to_open}/*.pickle')

for idx, row in tqdm(metadata.iterrows()):

    folder_to_open = row['output_dir']+'/processed'
    well = row['well0']
    pickles_to_open = glob.glob(f'{folder_to_open}/Well{well}*.pickle')[0]
    pickle_name = os.path.splitext(os.path.basename(pickles_to_open))[0]
    
    with open(pickles_to_open, 'rb') as file:
        tracks = pickle.load(file)
    
    y_tracked = tracks['y_tracked']
    x = tracks['registered_image']
    ymax = np.max(y_tracked)
    cmap = live_utils.shuffle_colors(ymax, 'tab20')
    
    imageio.mimsave(
        
        f'{pickle_name}.gif',
        
        [plot(np.sum(x[i], axis=-1), y_tracked[i,...,0], cmap=cmap, vmax=ymax)
         
         for i in range(y_tracked.shape[0])]
    
    )

An example script for plotting graphs -- most likely will not work with your data, but can be used as a jumping off point!

In [None]:
dirname = output_directory+'/processed'
treatment_frame = metadata.loc[0,'treatment_frame']
n_channels = 4

fp_ylims = (1, 1, 1)

for i, row in metadata.iterrows():

    pickle_name = f"/{row['output_dir']}/processed/{row['well0']}_tracks.pickle"

    alphanum = row['well0']

    print(row['treatment'])

    with open(pickle_name, 'rb') as file:
        track_dict = pickle.load(file)

    fig, ax = plt.subplots(1, n_channels-1, figsize=(8,2))

    for channel in range(n_channels-1):

        channel_store = []

        for track_idx in tuple(track_dict.keys()):

            if ~(track_dict[track_idx]['exclude']) & (track_dict[track_idx]['channel_exclude'][channel]):
                    
                    cn_ratio = track_dict[track_idx]['cyto_intensity'][channel,:]/track_dict[track_idx]['nuc_intensity'][channel,:]
                    channel_store.append(cn_ratio)

        channel_store = np.array(channel_store).transpose()

        ax[channel].set_ylim((-0.2, fp_ylims[channel]))
        
        n_frames = channel_store.shape[0]
        frame_vec = np.arange(n_frames)/10
        mean_store = np.mean(channel_store, axis=1)
        ax[channel].plot(frame_vec,mean_store - np.mean(mean_store[:treatment_frame]), linewidth=0.5, c='k')
        ax[channel].set_xlabel('Time (h)')
        ax[channel].set_title(row['live_channels'][channel])
        ax[0].set_ylabel('C/N')

    sensor = row['sensor']
    treatment = row['treatment']
    plt.savefig(f'{output_directory}/processed/{alphanum}_{sensor}_{treatment}')
    plt.show()
    # plt.close()