In [27]:
import os
import re
import pandas as pd
import numpy as np
import shutil
from joblib import delayed, Parallel
from towbintools.foundation.image_handling import read_tiff_file
from tifffile import imwrite

# set random seed for reproducibility
# np.random.seed(42)
np.random.seed(387799)

In [28]:
storage_cluster_path = "/mnt/towbin.data/shared"
valid_experimentalists = ["plenart"]

valid_scopes = ["squid"]
body_strains = []
pharynx_strains = []
germline_strains = ["530"]

keywords_to_exclude = ["exclude", "fail", "failure", "crash"]
# experient_to_always_include = ["20250314_squid_10x_yap_aid_160_438_492_493"]
experient_to_always_include = []
experiments_to_exclude = [
    "20241111_squid_10x_wBT318_NaCl",
    "20241707_souvik_w318_ev_mex3_squid",
    "20240515_squid1_10x_wbt160_25C_2024-05-15_15-31-45.506118",
    "20240212_squid_wbt318_Nacl",
    "20241912_squid_10x_wBT_344_415",
    "20252401_squid_10x_wBT318_reproduction",
    "20252001_squid_10x_wBT318_NaCl",
]

In [29]:
database_path = "/mnt/towbin.data/shared/spsalmon/towbinlab_segmentation_database/crosstalk_germline"

db_organs = ['germline']

extra_adulthood_time = 40

In [30]:
os.makedirs(database_path, exist_ok=True)

if "body" in db_organs:
    body_database_path = os.path.join(database_path, "body")
    os.makedirs(body_database_path, exist_ok=True)
    body_images = os.path.join(body_database_path, "images")
    os.makedirs(body_images, exist_ok=True)
    body_masks = os.path.join(body_database_path, "masks")
    os.makedirs(body_masks, exist_ok=True)
    body_images_bf = os.path.join(body_database_path, "images_bf")
    os.makedirs(body_images_bf, exist_ok=True)

if "pharynx" in db_organs:
    pharynx_database_path = os.path.join(database_path, "pharynx")
    os.makedirs(pharynx_database_path, exist_ok=True)
    pharynx_images = os.path.join(pharynx_database_path, "images")
    os.makedirs(pharynx_images, exist_ok=True)
    pharynx_masks = os.path.join(pharynx_database_path, "masks")
    os.makedirs(pharynx_masks, exist_ok=True)

if "germline" in db_organs:
    germline_database_path = os.path.join(database_path, "germline")
    os.makedirs(germline_database_path, exist_ok=True)
    germline_images = os.path.join(germline_database_path, "images")
    os.makedirs(germline_images, exist_ok=True)
    germline_masks = os.path.join(germline_database_path, "masks")
    os.makedirs(germline_masks, exist_ok=True)

In [31]:
# Database configurations
# database_configs = {
#     'body': {
#         'size': 3000,
#         'stage_proportions': {'egg': 0.1, 'adult': 0.1, 'L1': 0.2, 'L2': 0.2, 'L3': 0.2, 'L4': 0.2},
#         'scope_proportions': {'crest': 0.2, 'squid': 0.2, 'ti2': 0.6},
#         'strains': body_strains,
#         'output_path': body_database_path
#     },
#     'pharynx': {
#         'size': 3000,
#         'stage_proportions': {'adult': 0.1, 'L1': 0.225, 'L2': 0.225, 'L3': 0.225, 'L4': 0.225},
#         'scope_proportions': {'ti2': 1.0},
#         'strains': pharynx_strains,
#         'output_path': pharynx_database_path
#     },
#     'germline': {
#         'size': 3000,
#         'stage_proportions': {'adult': 0.2, 'L1': 0.2, 'L2': 0.2, 'L3': 0.2, 'L4': 0.2},
#         'scope_proportions': {'crest': 0.6, 'squid': 0.4},
#         'strains': germline_strains,
#         'output_path': germline_database_path
#     }
# }

database_configs = {
    'germline': {
        'size': 1000,
        'stage_proportions': {'adult': 0.2, 'L1': 0.2, 'L2': 0.2, 'L3': 0.2, 'L4': 0.2},
        'scope_proportions': {'squid': 1.0},
        'strains': germline_strains,
        'output_path': germline_database_path
    }
}

In [32]:
all_strains = body_strains + pharynx_strains + germline_strains
# remove repeated strains
all_strains = list(set(all_strains))
strains_to_exclude = []

# to be valid, the strain name needs to be followed by either a dash or an underscore
correct_body_strains = []
correct_pharynx_strains = []
correct_germline_strains = []
for strain in body_strains:
    correct_body_strains.append(strain + "-")
    correct_body_strains.append(strain + "_")
    correct_body_strains.append(strain + " ")

for strain in pharynx_strains:
    correct_pharynx_strains.append(strain + "-")
    correct_pharynx_strains.append(strain + "_")
    correct_pharynx_strains.append(strain + " ")

for strain in germline_strains:
    correct_germline_strains.append(strain + "-")
    correct_germline_strains.append(strain + "_")
    correct_germline_strains.append(strain + " ")

In [33]:
# 1. the experiment must have a filemap in the new python format somewhere in its directory
# 2. the experiment must be from 2022-TODAY
# 3. the experiment must have the name of the scope in its name
# 4. the experiment must contain "10x" in its name
# 5. the experiment must be from Klement, Ioana or Peter
# 6. the experiment must include in its name at least one of the relevant strain names from a list
# 7. the experiment must not have some keyword in its name (e.g. "exclude", "fail")

# when taking data for different organs, different images must be used (as to not bias the network)
# when taking in data for the body, also take the brightfield channel


def get_analysis_filemap(experiment_path):
    directories = [
        os.path.join(experiment_path, d)
        for d in os.listdir(experiment_path)
        if os.path.isdir(os.path.join(experiment_path, d))
    ]

    analysis_directories = [d for d in directories if "analysis" in d]
    report_directories = [os.path.join(d, "report") for d in analysis_directories]

    report_directories = [d for d in report_directories if os.path.isdir(d)]

    for report_dir in report_directories:
        files = [os.path.join(report_dir, f) for f in os.listdir(report_dir)]
        filemap_files = [f for f in files if "analysis_filemap_annotated" in f]
        mat_files = [f for f in files if ".mat" in f]

        # return the filemap that was created last
        if filemap_files:
            filemap_files.sort(key=lambda x: os.path.getctime(x))
            return os.path.join(report_dir, filemap_files[-1])
        # check for converted experiments
        elif mat_files:
            filemap_files = [f for f in files if "analysis_filemap" in f]
            if filemap_files:
                filemap_files.sort(key=lambda x: os.path.getctime(x))
                filemap = pd.read_csv(filemap_files[-1], low_memory=False)
                if "HatchTime" in filemap.columns and "raw" in filemap.columns:
                    return os.path.join(report_dir, filemap_files[-1])
    return None

# create a lot of possible variation of the scope names
valid_scopes_variations = []
for scope in valid_scopes:
    valid_scopes_variations.append(scope)
    valid_scopes_variations.append(scope.upper())
    valid_scopes_variations.append(scope.capitalize())

filemaps = []

valid_experimentalists_dir = [
    os.path.join(storage_cluster_path, exp) for exp in valid_experimentalists
]

for exp_dir in valid_experimentalists_dir:
    experiment_directories = [
        os.path.join(exp_dir, d)
        for d in os.listdir(exp_dir)
        if os.path.isdir(os.path.join(exp_dir, d))
    ]

    for exp in experiment_directories:
        experiment_name = os.path.basename(os.path.normpath(exp))

        # check if the experiment is in the list of experiments to always include
        if experiment_name in experient_to_always_include:
            filemap = get_analysis_filemap(exp)
            if filemap:
                filemaps.append(filemap)
                continue
        # check if the experiment is in the list of experiments to exclude
        if experiment_name in experiments_to_exclude:
            continue
        try:
            year = int(experiment_name[:4])
            if year < 2023:
                continue
        except ValueError:
            continue

        if "10x" not in experiment_name and "10X" not in experiment_name:
            continue

        if not any(scope in experiment_name for scope in valid_scopes_variations):
            continue

        if any(keyword in experiment_name for keyword in keywords_to_exclude):
            continue

        if not any(strain in experiment_name for strain in all_strains):
            continue

        if any(strain in experiment_name for strain in strains_to_exclude):
            continue

        filemap = get_analysis_filemap(exp)

        if filemap:
            filemaps.append(filemap)

print(f"Found {len(filemaps)} valid experiments")

print(filemaps)

Found 1 valid experiments
['/mnt/towbin.data/shared/plenart/20252305_squid_10x_wBT530_IAA_20_degrees/analysis_Peter/report/analysis_filemap_annotated.csv']


In [34]:
# get the images from the filemaps
def pick_within_larval_stage(images, times, ls_beg, ls_end, n_picks=1):
    try:
        if np.isnan(ls_beg) or np.isnan(ls_end):
            return None, None
        
        valid_stacks = [(s, t) for s, t in zip(images, times) if ls_beg <= t <= ls_end]
        
        if valid_stacks:
            # Handle case where npicks is larger than available stacks
            npicks = min(n_picks, len(valid_stacks))
            
            # Randomly select npicks number of stacks without replacement
            selected_indices = np.random.choice(len(valid_stacks), size=n_picks, replace=False)
            selected_stacks = [valid_stacks[i] for i in selected_indices]
            
            # Unzip the selected stacks into separate lists
            selected_images, selected_times = zip(*selected_stacks)
            
            # If npicks is 1, return single values to maintain backward compatibility
            if npicks == 1:
                return selected_images[0], selected_times[0]
            
            return list(selected_images), list(selected_times)
        else:
            if n_picks == 1:
                return None, None
            return [], []
    except:
        print(f'Error in picking image within larval stage: {ls_beg}, {ls_end}')
        if npicks == 1:
            return None, None
        return [], []
        
database_filemap = pd.DataFrame()

for filemap in filemaps:
    print(filemap)
    experiment_name = filemap.split("/")[-4]
    filemap_df = pd.read_csv(filemap)

    strain = [strain for strain in all_strains if strain in experiment_name][0]

    # volume_column = [col for col in filemap_df.columns if "volume" in col][0]
    # print(volume_column)

    worm_type_column = [col for col in filemap_df.columns if "worm_type" in col][0]
    # print(worm_type_column)
    
    # get the name of the microscope by matching the filemap path with the valid scopes
    microscope = [scope for scope in valid_scopes_variations if scope in filemap][0]

    strain

    if microscope in ['crest', 'Crest', 'CREST', 'squid']:
        n_picks = 10
    else:
        n_picks = 2

    for point in filemap_df['Point'].unique():
        rows = []
        point_df = filemap_df[filemap_df['Point'] == point]
        hatch_time, m1, m2, m3, m4 = point_df['HatchTime'].values[0], point_df['M1'].values[0], point_df['M2'].values[0], point_df['M3'].values[0], point_df['M4'].values[0]
        raw_images = point_df['raw'].values
        time = point_df['Time'].values
        worm_type = point_df[worm_type_column].values
        
        # ignore this point if there are more than 50% non worm/egg images
        if len([wt for wt in worm_type if wt in ['worm', 'egg']]) < 0.5 * len(worm_type):
            continue
            
        try:
            stacks = [os.path.join(os.path.dirname(filemap), img) for img in raw_images]
            stacks_time = time
        except:
            continue

        # Handle egg stage
        egg_images, egg_image_times = pick_within_larval_stage(stacks, stacks_time, 0, hatch_time, n_picks=n_picks)
        if egg_images:  # Will be a list when n_picks > 1
            if not isinstance(egg_images, list):  # Handle single pick case
                egg_images = [egg_images]
            for egg_image in egg_images:
                egg_image = egg_image.replace("external.data/TowbinLab", "towbin.data/shared")
                row = {'Point': point, 'Image': egg_image, 'Stage': 0, 'Microscope': microscope, 'Experiment': experiment_name, 'Strain': strain}
                rows.append(row)

        # Handle larval stages
        for i, (ls_beg, ls_end) in enumerate([(hatch_time, m1), (m1, m2), (m2, m3), (m3, m4)]):
            stage_images, stage_times = pick_within_larval_stage(stacks, stacks_time, ls_beg, ls_end, n_picks=n_picks)
            if stage_images:  # Will be a list when n_picks > 1
                if not isinstance(stage_images, list):  # Handle single pick case
                    stage_images = [stage_images]
                for stage_image in stage_images:
                    stage_image = stage_image.replace("external.data/TowbinLab", "towbin.data/shared")
                    row = {'Point': point, 'Image': stage_image, 'Stage': i+1, 'Microscope': microscope, 'Experiment': experiment_name, 'Strain': strain}
                    rows.append(row)

        # Handle adult stage
        try:
            adult_images, adult_image_times = pick_within_larval_stage(
                stacks, 
                stacks_time, 
                m4, 
                np.min([m4 + extra_adulthood_time, np.max(stacks_time)]),
                n_picks=n_picks
            )
        except:
            adult_images = None

        if adult_images:  # Will be a list when n_picks > 1
            if not isinstance(adult_images, list):  # Handle single pick case
                adult_images = [adult_images]
            for adult_image in adult_images:
                adult_image = adult_image.replace("external.data/TowbinLab", "towbin.data/shared")
                row = {'Point': point, 'Image': adult_image, 'Stage': 5, 'Microscope': microscope, 'Experiment': experiment_name, 'Strain': strain}
                rows.append(row)

        database_filemap = pd.concat([database_filemap, pd.DataFrame(rows)])

/mnt/towbin.data/shared/plenart/20252305_squid_10x_wBT530_IAA_20_degrees/analysis_Peter/report/analysis_filemap_annotated.csv
Error in picking image within larval stage: 0, 1.0
Error in picking image within larval stage: 0, 7.0
Error in picking image within larval stage: 0, 7.0
Error in picking image within larval stage: 0, 6.0
Error in picking image within larval stage: 0, 4.0
Error in picking image within larval stage: 0, 7.0
Error in picking image within larval stage: 0, 0.0
Error in picking image within larval stage: 0, 5.0
Error in picking image within larval stage: 0, 5.0
Error in picking image within larval stage: 0, 2.0
Error in picking image within larval stage: 0, 7.0
Error in picking image within larval stage: 0, 5.0
Error in picking image within larval stage: 0, 7.0
Error in picking image within larval stage: 0, 8.0
Error in picking image within larval stage: 0, 7.0
Error in picking image within larval stage: 0, 3.0
Error in picking image within larval stage: 0, 7.0
Error i

In [35]:
# reset index
database_filemap.reset_index(drop=True, inplace=True)
database_filemap.to_csv(os.path.join(database_path, "database_filemap.csv"), index=False)

In [38]:
# Define microscope name mappings
microscope_names = {'crest': ['Crest', 'crest', 'CREST'], 
                    'squid': ['Squid', 'squid', 'SQUID'], 
                    'ti2': ['Ti2', 'ti2', 'TI2', 'orca', 'Orca', 'ORCA']}

# Create mapping from variations to standard names
variation_to_unified_name = {}
for microscope, variations in microscope_names.items():
    for variation in variations:
        variation_to_unified_name[variation] = microscope

# Add standardized microscope names to database
database_filemap['CorrectMicroscopeName'] = database_filemap['Microscope'].apply(
    lambda x: variation_to_unified_name[x]
)

def calculate_image_combinations(
    database_size,
    scope_proportions,
    stage_proportions
):
    """Calculate number of images needed for each scope and stage combination."""
    stage_to_number = {'egg': 0, 'L1': 1, 'L2': 2, 'L3': 3, 'L4': 4, 'adult': 5}
    
    combinations = {}
    
    # Calculate for each scope and stage combination
    for scope, scope_prop in scope_proportions.items():
        combinations[scope] = {}
        for stage, stage_prop in stage_proportions.items():
            n_images = int(database_size * scope_prop * stage_prop)
            stage_number = stage_to_number[stage]
            combinations[scope][stage_number] = n_images
    
    return combinations

def create_database(
    database_name,
    database_size,
    stage_proportions,
    scope_proportions,
    strains,
    output_path,
    available_data
):
    """Create a database with specified proportions and parameters.
    
    Args:
        database_name: Name of the database (for logging)
        database_size: Total number of images to include
        stage_proportions: Dictionary mapping stages to their proportion
        scope_proportions: Dictionary mapping microscopes to their proportion
        strains: List of strains to include
        output_path: Path to save the resulting CSV
        available_data: DataFrame containing available images to sample from
    
    Returns:
        tuple: (created database, remaining available data)
    """
    print(f"Creating {database_name} database...")
    
    # Create empty database
    database = pd.DataFrame()
    
    # Calculate combinations
    combinations = calculate_image_combinations(
        database_size, 
        stage_proportions=stage_proportions, 
        scope_proportions=scope_proportions
    )
    
    # For each scope and stage, sample the required number of images
    for scope, stages in combinations.items():
        for stage, n_images in stages.items():
            try:
                stage_images = available_data[
                    (available_data['CorrectMicroscopeName'] == scope) & 
                    (available_data['Stage'] == stage) &
                    (available_data['Strain'].isin(strains))
                ].sample(n=n_images)
                
                database = pd.concat([database, stage_images])
            except:
                print(f"Warning: Could not sample {n_images} images for {scope}, stage {stage}")
                continue
    
    # Reset index and save
    database.reset_index(drop=True, inplace=True)
    database.to_csv(os.path.join(output_path, f"{database_name}_database_filemap.csv"), index=False)
    
    # Remove selected images from available data
    remaining_data = available_data[~available_data['Image'].isin(database['Image'])]
    
    return database, remaining_data

# Create all databases sequentially
available_data = database_filemap.copy()
databases = {}

for db_name, config in database_configs.items():
    databases[db_name], available_data = create_database(
        database_name=db_name,
        database_size=config['size'],
        stage_proportions=config['stage_proportions'],
        scope_proportions=config['scope_proportions'],
        strains=config['strains'],
        output_path=config['output_path'],
        available_data=available_data
    )
    
    print(f"Created {db_name} database with {len(databases[db_name])} images")

Creating germline database...
Created germline database with 1000 images


In [39]:
def process_row(row, output_dir, channel):
    image_path = row['Image']
    image = read_tiff_file(image_path)
    channel = channel[0] if isinstance(channel, list) else channel
    if channel >= image.shape[0]:
        channel = image.shape[0] - 1
    image = image[channel, :, :]  # select the channel
    image_name = row['OutputName']
    # save the image
    imwrite(os.path.join(output_dir, image_name), image, compression="zlib")

def get_masks(row, output_dir, channel):
    image_path = row['Image']
    mask_path = image_path.replace("raw", f"analysis/ch{channel[0]+1}_seg")
    if not os.path.exists(mask_path):
        print(f"Mask not found for {image_path}, skipping.")
        return
    mask = read_tiff_file(mask_path)
    mask_name = row['OutputName']
    imwrite(os.path.join(output_dir, mask_name), mask, compression="zlib")

# extract the images from the database
def extract_images(database, output_dir, channel):
    Parallel(n_jobs=32, prefer="threads")(delayed(process_row)(row, output_dir, channel) for _, row in database.iterrows())

def extract_masks(database, output_dir, channel):
    Parallel(n_jobs=32, prefer="threads")(delayed(get_masks)(row, output_dir, channel) for _, row in database.iterrows())

In [40]:
def get_output_name(row):
    return f'image_{row.name}_{row["CorrectMicroscopeName"]}_{row["Stage"]}.tif'

if "body" in db_organs:
    body_database = databases['body']
    body_database['OutputName'] = body_database.apply(get_output_name, axis=1)
    body_database.to_csv(os.path.join(body_database_path, "body_database_filemap.csv"), index=False)
    body_database = body_database.sample(frac=1).reset_index(drop=True)

    extract_images(body_database, body_images, [1])

if "germline" in db_organs:
    germline_database = databases['germline']
    germline_database['OutputName'] = germline_database.apply(get_output_name, axis=1)
    germline_database.to_csv(os.path.join(germline_database_path, "germline_database_filemap.csv"), index=False)
    germline_database = germline_database.sample(frac=1).reset_index(drop=True)
    
    extract_images(germline_database, germline_images, [0])

if "pharynx" in db_organs:
    pharynx_database = databases['pharynx']
    pharynx_database['OutputName'] = pharynx_database.apply(get_output_name, axis=1)
    pharynx_database.to_csv(os.path.join(pharynx_database_path, "pharynx_database_filemap.csv"), index=False)
    pharynx_database = pharynx_database.sample(frac=1).reset_index(drop=True)

    extract_images(pharynx_database, pharynx_images, [0])