In [None]:
from pathlib import Path
import numpy as np
import random
import pandas as pd

from aicsimageio import AICSImage
from aicsimageio.readers.ome_tiff_reader import OmeTiffReader
from aicsimageio.writers import OmeTiffWriter

import sys
src_path = str(Path.cwd().parent.parent)
if src_path not in sys.path:
    sys.path.append(src_path)

import src.d00_utils.utilities as utils

In [None]:
overall_dirpath = Path(input())

In [None]:
dirpaths = list(overall_dirpath.rglob('*cellch'))
print(len(dirpaths))
print(dirpaths)


In [None]:
well_conditions_csvpath = Path(input())

In [None]:
conditions_df = pd.read_csv(well_conditions_csvpath)
conditions_df.head()
conditions_df = conditions_df[conditions_df['Drug tx'] != 'none']
conditions_df

In [None]:
grouping_vars = ['Drug tx', 'cellch', 'Replicate']
sample_df = conditions_df.groupby(grouping_vars).sample(n=1)
sample_df
print(len(sample_df))
sample_df['wellID'] = sample_df['Experiment'] + '-' + sample_df['Well']
sample_df

In [None]:
# dirpaths = []

# dirpath = None
# while dirpath != 'DONE':
#     dirpath = input('Dirpath (or type DONE if done):')
#     if dirpath == 'DONE':
#         break
#     else:
#         dirpaths.append(Path(dirpath))

In [None]:
n = 1 # Number of randomly chosen images to use from each well
img_purpose = ['test', 'train']

In [None]:
# Initiate dictionary to hold imgpaths, grouping variables
img_list_d = {}
for p in img_purpose:
    img_list_d[p] = []

# Select imagepaths
for i, row in sample_df.iterrows():
    wellID = row['wellID']
    
    imgpaths = []
    for dirpath in dirpaths:
        w_search = '*' + wellID.replace('-', '*') + '*.ome.tif'
        imgpaths.extend([path for path in dirpath.glob(w_search)])
    
    
    grouping_vars_str = ', '.join([f'{var}: {row[var]}' for var in grouping_vars])
    print(f'{wellID} [{grouping_vars_str}]: {len(imgpaths)} images found')
    
    num_to_select = n * len(img_purpose)
    if len(imgpaths) > num_to_select:
        selected_bywell = random.sample(imgpaths, num_to_select)
    
        for i, p in enumerate(img_purpose):
            img_list_d[p].extend(selected_bywell[i:(i + n)])
            
for i, p in enumerate(img_purpose):
    print(f'{p}: {len(img_list_d[p])} images selected')

stack_df = pd.DataFrame()
stack_df['imgpath'] = np.concatenate([img_list for p, img_list in img_list_d.items()])
stack_df['purpose'] = np.concatenate([[p]*len(img_list) for p, img_list in img_list_d.items()])

stack_df

In [None]:
prev_train_imgpath = input('Input path for previous training image (if using). Otherwise, type NONE.')

if prev_train_imgpath=='NONE':
    print('No previous training image.')

In [None]:
savedir = Path('/Users/kwu2/Library/CloudStorage/GoogleDrive-kwu2@stanford.edu/My Drive/Lab/ImageJ/training_imgs')

stack_basename = input("Enter basename for stacked image (excluding suffix):")

i = 0
savepaths = np.array([savedir / (f'{stack_basename}_{p}_{i}.ome.tif') for p in img_purpose])

while savepaths[0].is_file() or savepaths[1].is_file():
    i = i+1
    savepaths = np.array([savedir / (f'{stack_basename}_{p}_{i}.ome.tif') for p in img_purpose])
    
for path in savepaths:
    print(path.name)
    
csv_savename = f'{stack_basename}_{i}.csv'

In [None]:
def add_imgs_to_list(imgpaths, img_list, df, num_tps=3):

    for i, imgpath in enumerate(imgpaths):
        #img_file = AICSImage(imgpath, reader=OmeTiffReader)
        img_file = AICSImage(imgpath)
        print(img_file.shape)
        img = img_file.data
        
        if i==0:
            physical_pixel_sizes = img_file.physical_pixel_sizes
        
        if img.shape[0] > 1:
            tps = [0, random.randint(1, img.shape[0]-2), img.shape[0]-1]
            tps = random.sample(tps, num_tps)
            print(f'timepoints: {tps}')
            df.loc[df['imgpath']==imgpath, 'timepoints'] = ', '.join([str(tp) for tp in tps])
            img_fewtps = np.concatenate([img[tp, np.newaxis, :, :, :, :] for tp in tps], axis=0)
            img_list.append(img_fewtps)
        else:
            img_list.append(img)
            df.loc[df['imgpath']==imgpath, 'timepoints'] = str(0)
            
    return img_list, df, physical_pixel_sizes

# Get the smallest y and x dimensions to crop all images to the same size
def crop_imgs_to_match_size(img_list):
    y_min = None
    x_min = None

    for img in img_list:

        if y_min is None:
            y_min = img.shape[3]
        else:
            y_min = np.minimum(img.shape[3], y_min)

        if x_min is None:
            x_min = img.shape[4]
        else:
            x_min = np.minimum(img.shape[4], x_min)

    imgs_crop = [img[:, :, :, :y_min, :x_min] for img in img_list]
    return imgs_crop

In [None]:
for i, p in enumerate(img_list_d.keys()):
    print(p)
    img_list = []
    imgpaths_subset = img_list_d[p]
    
    # Only append previous training images to training image stack
    if p == 'test':
        if prev_train_imgpath!='NONE':
            prev_train_imgpath = Path(prev_train_imgpath)
            prev_img_file = AICSImage(prev_train_imgpath, reader=OmeTiffReader)
            print(prev_img_file.shape)
            prev_img = prev_img_file.data
            img_list.append(prev_img)
            
    img_list, stack_df, physical_pixel_sizes = add_imgs_to_list(imgpaths_subset, img_list, stack_df, num_tps=2)
    img_list = crop_imgs_to_match_size(img_list)
    img_stacked = np.concatenate(img_list, axis=0)
    ome_metadata = utils.construct_ome_metadata(img_stacked, physical_pixel_sizes)
    
    OmeTiffWriter.save(img_stacked, savepaths[i], ome_xml=ome_metadata)
    stack_df.to_csv(savedir/csv_savename, index=False)