# Preprocessing

This notebook is for preprocessing the segmented DAPI images. 

In [58]:
import nd2
from skimage.measure import regionprops
import tifffile
import pandas as pd
import numpy as np
import glob
from skimage.util import view_as_windows
from multiprocessing import Pool
from tqdm.notebook import tqdm
import seaborn as sns
import torchvision.transforms as T
import torch


raw_dir = f'/ewsc/hschluet/pbmc5/data/pbmc5/raw/'
seg_dir = f'/ewsc/hschluet/pbmc5/seg_voronoi_otsu/'


def get_normalized_3d_frame(plate, fname, i):
  # frame, z, c, x, y
  stacks = nd2.imread(raw_dir + f'Plate{plate}/{fname}.nd2')
  img = stacks[i]
  # min-max scaled to [0,1]
  img -= img.min()
  return img / img.max()

## Exclusion criteria

In [53]:
# input is 3D
# some processing is done in 2D
# background is 0
def check_seg_cells(seg, img, min_size=300, dim_thresh=0.1):
  seg = seg.copy()
  all_cells = np.unique(seg)[1:]  # drop background
  checks = pd.DataFrame(columns=['id', 'tiny', 'dim', 'xy touching', 'z touching', 
                                 'xy border', 'upper z border', 'lower z border'])
  if len(all_cells) == 0:
    return checks

  checks['id'] = all_cells
  checks = checks.set_index('id')
  checks[['tiny', 'dim', 'xy touching', 'z touching', 'xy border',
           'upper z border', 'lower z border']] = False
  
  # check for tiny masks (not cells)
  sizes = np.bincount(seg.flatten())[1:]  # skip background
  tiny = 1 + np.argwhere(sizes < min_size).flatten()
  checks.loc[tiny, 'tiny'] = True

  # dim masks (probably background effects rather than cells)
  seg_bright = seg.copy()
  seg_bright[img < dim_thresh] = 0
  # bright_sizes_short can be shorter than sizes if the last cells are not bright at all
  bright_sizes_short = np.bincount(seg_bright.flatten())[1:]  # skip background
  size_ratios = np.zeros_like(sizes)
  size_ratios[:len(bright_sizes_short)] = bright_sizes_short
  sizes[sizes == 0] = 1  # to avoid div by 0, bright_sizes_short is 0 here anyway
  size_ratios = size_ratios / sizes
  low_intensity = 1 + np.argwhere(size_ratios < 0.1).flatten()
  checks.loc[low_intensity, 'dim'] = True

  # small and dim masks are discarded for the next steps as they are not cells
  seg[np.isin(seg, tiny)] = 0
  seg[np.isin(seg, low_intensity)] = 0
  
  # check for touching
  views = view_as_windows(seg, (2, 1, 1)).reshape(-1, 2)
  views = views[views[:,0] != 0]  # drop background
  views = views[views[:,1] != 0]  # drop background
  views = views[views[:,0] != views[:,1]]  # drop matching
  z_touching = np.unique(views)
  checks.loc[z_touching, 'z touching'] = True

  views = view_as_windows(seg, (1, 2, 1)).reshape(-1, 2)
  views = views[views[:,0] != 0]  # drop background
  views = views[views[:,1] != 0]  # drop background
  views = views[views[:,0] != views[:,1]]  # drop matching
  x_touching = np.unique(views)
  checks.loc[x_touching, 'xy touching'] = True

  views = view_as_windows(seg, (1, 1, 2)).reshape(-1, 2)
  views = views[views[:,0] != 0]  # drop background
  views = views[views[:,1] != 0]  # drop background
  views = views[views[:,0] != views[:,1]]  # drop matching
  y_touching = np.unique(views)
  checks.loc[y_touching, 'xy touching'] = True
  
  # cells touching xy border
  xy_border = np.unique(np.concatenate([
      np.unique(seg[:, 0, :]), np.unique(seg[:, -1, :]),
      np.unique(seg[:, :, 0]), np.unique(seg[:, :, -1])]))[1:]

  checks.loc[xy_border, 'xy border'] = True

  # cells touching lower z border
  z_border_low = np.unique(np.unique(seg[0, :, :]))[1:]
  checks.loc[z_border_low, 'lower z border'] = True
  
  # cells touching z border
  z_border_high = np.unique(np.unique(seg[-1, :, :]))[1:]
  checks.loc[z_border_high, 'upper z border'] = True

  return checks


def run_checks(args):
    plate, fname_i = args
    
    i = int(fname_i.split('_')[-1])
    fname = fname_i[:-(1 + len(str(i)))]
    seg = tifffile.imread(seg_dir + f'Plate{plate}/segmentation/labeled_segmentation_3d/{fname}_{i}.tiff')
    img_3d = get_normalized_3d_frame(plate, fname, i)
    if len(img_3d.shape) == 4:
      img_3d = img_3d[:, 0]  # DAPI
    
    checks = check_seg_cells(seg, img_3d)
    checks.to_csv(seg_dir + f'Plate{plate}/qc/{fname}_{i}_exclusion_checks.csv')
    return True

In [54]:
plate_fname_is = []
for p in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]:
  plate_fname_is.extend([(p, fname.split('/')[-1][:-5]) for fname in glob.glob(f'{seg_dir}Plate{p}/segmentation/labeled_segmentation_3d/*.tiff')])

with Pool(processes=min(len(plate_fname_is), 64)) as pool:
  for ret in tqdm(pool.imap_unordered(run_checks, plate_fname_is), total=len(plate_fname_is)):
    assert ret

  0%|          | 0/18571 [00:00<?, ?it/s]

In [115]:
# add z overlap check as we will max-z project
def add_z_overlap_check(seg, checks):
  seg = seg.copy()
  all_cells = np.unique(seg)[1:]  # drop background
  if len(all_cells) == 0:
    return checks
  
  # skip non-cells
  seg[np.isin(seg, checks[checks['tiny']].index)] = 0
  seg[np.isin(seg, checks[checks['dim']].index)] = 0

  checks['z overlap'] = False
  seg_max = seg.max(axis=0)  # max z project
  seg_min = seg.min(axis=0, where=seg!=0, initial=seg.max())  # min z project
  seg_min[seg_max == 0] = 0  # background
  z_overlap = np.unique(seg[:, seg_max != seg_min])[1:]
  checks.loc[z_overlap, 'z overlap'] = True

  return checks


def run_z_check(args):
    plate, fname_i = args
    
    i = int(fname_i.split('_')[-1])
    fname = fname_i[:-(1 + len(str(i)))]

    try:
      checks = pd.read_csv(seg_dir + f'Plate{plate}/qc/{fname}_{i}_exclusion_checks.csv', index_col=0)
    except FileNotFoundError as e:
       print('skipping empty / contaminated well', e)
       return True
    
    seg = tifffile.imread(seg_dir + f'Plate{plate}/segmentation/labeled_segmentation_3d/{fname}_{i}.tiff')
    checks = add_z_overlap_check(seg, checks)
    checks.to_csv(seg_dir + f'Plate{plate}/qc/{fname}_{i}_exclusion_checks.csv')
    return True

In [116]:
plate_fname_is = []
for p in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]:
  plate_fname_is.extend([(p, fname.split('/')[-1][:-5]) for fname in glob.glob(f'{seg_dir}Plate{p}/segmentation/labeled_segmentation_3d/*.tiff')])

with Pool(processes=min(len(plate_fname_is), 64)) as pool:
  for ret in tqdm(pool.imap_unordered(run_z_check, plate_fname_is), total=len(plate_fname_is)):
    assert ret

  0%|          | 0/18571 [00:00<?, ?it/s]

skipping empty / contaminated well [Errno 2] No such file or directory: '/ewsc/hschluet/pbmc5/seg_voronoi_otsu/Plate2/qc/WellE03_ChannelDAPI_Seq0048_1_exclusion_checks.csv'
skipping empty / contaminated well [Errno 2] No such file or directory: '/ewsc/hschluet/pbmc5/seg_voronoi_otsu/Plate2/qc/WellD11_ChannelDAPI_Seq0036_7_exclusion_checks.csv'
skipping empty / contaminated well [Errno 2] No such file or directory: '/ewsc/hschluet/pbmc5/seg_voronoi_otsu/Plate2/qc/WellE03_ChannelDAPI_Seq0048_2_exclusion_checks.csv'
skipping empty / contaminated well [Errno 2] No such file or directory: '/ewsc/hschluet/pbmc5/seg_voronoi_otsu/Plate2/qc/WellD11_ChannelDAPI_Seq0036_10_exclusion_checks.csv'
skipping empty / contaminated well [Errno 2] No such file or directory: '/ewsc/hschluet/pbmc5/seg_voronoi_otsu/Plate2/qc/WellE03_ChannelDAPI_Seq0048_8_exclusion_checks.csv'
skipping empty / contaminated well [Errno 2] No such file or directory: '/ewsc/hschluet/pbmc5/seg_voronoi_otsu/Plate2/qc/WellE03_Chann

## Save 3D crops

In [None]:
def save_3d_variable_size_crops(args):
  plate, fname_i = args
  i = int(fname_i.split('_')[-1])
  fname = fname_i[:-(1 + len(str(i)))]
  
  img_3d = get_normalized_3d_frame(plate, fname, i)
  if len(img_3d.shape) == 4:
    img_3d = img_3d[:, 0]  # DAPI

  seg = tifffile.imread(seg_dir + f'Plate{plate}/segmentation/labeled_segmentation_3d/{fname}_{i}.tiff')
  qc = pd.read_csv(seg_dir + f'Plate{plate}/qc/{fname}_{i}_exclusion_checks.csv').set_index('id')
  
  seg_2d = seg.max(axis=0)
  regions = regionprops(seg_2d)
  for cell in regions:
    if qc.loc[cell.label, 'tiny'] or qc.loc[cell.label, 'dim'] or qc.loc[cell.label, 'z overlap']:
      continue

    minr, minc, maxr, maxc = cell.bbox
    crop = img_3d[:, np.newaxis, minr:maxr, minc:maxc]
    seg_crop = seg[:, np.newaxis, minr:maxr, minc:maxc] == cell.label
    crop = np.concatenate([crop, seg_crop], axis=1)
    tifffile.imwrite(
            seg_dir + f'Plate{plate}/3d_crops/' + fname + f'_{i}_{cell.label}.tiff',
            crop,
        )
  return True

In [57]:
plate_fname_is = []
for p in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]:
  # using exclusions checks as these are only kept for usuable images
  # there were some raw images and segmentations for empty and contaminated wells
  plate_fname_is.extend([(p, fname.split('/')[-1][:-len('_exclusion_checks.csv')]) for fname in glob.glob(f'{seg_dir}Plate{p}/qc/*exclusion_checks.csv')])

with Pool(processes=min(len(plate_fname_is), 64)) as pool:
  for ret in tqdm(pool.imap_unordered(save_3d_variable_size_crops, plate_fname_is), total=len(plate_fname_is)):
    assert ret

  0%|          | 0/18412 [00:00<?, ?it/s]

## Bundle data

In [128]:
def bundle_data(plate, res=32):
    infos = []
    imgs = []
    masks = []
    labels = []

    save_dir = f'/ewsc/hschluet/pbmc5/bundled_data/'
    info = pd.read_csv(f'/ewsc/hschluet/pbmc5/data/pbmc5/layout/plate_{plate}_layout.csv', index_col=0)
    center_crop = T.CenterCrop(res)

    fnames = [f.split('/')[-1][:-len('_exclusion_checks.csv')] for f in glob.glob(f'{seg_dir}Plate{plate}/qc/*exclusion_checks.csv')]
    for f in tqdm(np.sort(fnames)):
        qc = pd.read_csv(f'{seg_dir}Plate{plate}/qc/{f}_exclusion_checks.csv', index_col=0)
        qc = qc[~(qc[['dim', 'tiny', 'z overlap', 'xy border', 'z touching', 'xy touching']].any(axis=1))].copy()  # discard 

        well = f.split('Well')[1][:3]
        sample = info.loc[well, 'sample']
        patient = sample.split('_')[0]
        time = 0 if sample[0] == 'H' else sample.split('_')[1]
        series = int(f.split('_')[-1])

        for cell in qc.index:
            strict = not (qc.loc[cell, ['upper z border', 'lower z border']].any())
            img_3d = tifffile.imread(seg_dir + f'Plate{plate}/3d_crops/{f}_{cell}.tiff',).astype(np.float32)
            img = img_3d[:,0].max(axis=0)
            mask = img_3d[:,1].max(axis=0)

            imgs.append(center_crop(torch.tensor(img)))
            masks.append(center_crop(torch.tensor(mask)))
            labels.append(patient[0] == 'H')  # whether they're healthy
            infos.append((plate, well, series, cell, patient, time, strict))

    infos = pd.DataFrame.from_records(infos, columns=['plate', 'well', 'series', 'cell', 'patient', 'time', 'qc'])
    imgs = torch.stack(imgs, dim=0).float()
    masks = torch.stack(masks, dim=0)
    labels = torch.tensor(np.array(labels))
    
    infos.to_csv(f'{save_dir}plate_{plate}_info.csv', index=False)
    torch.save(imgs, f'{save_dir}plate_{plate}_imgs.pt')
    torch.save(masks, f'{save_dir}plate_{plate}_masks.pt')
    torch.save(labels, f'{save_dir}plate_{plate}_labels.pt')
    return True

In [None]:
plates = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]

with Pool(processes=len(plates)) as pool:
  for ret in tqdm(pool.imap_unordered(bundle_data, plates), total=len(plates)):
    assert ret

  0%|          | 0/16 [00:00<?, ?it/s]