# NAPARI visualization of Segmentation and Tracking output

You can use this notebook to view, modified and save out training data for UNet models

Labels:
+ 0 - background 
+ 1 - GFP/Phase 
+ 2 - RFP


Extra key bindings:
+ None yet

---

```
TODO:

- (arl): output masks for making training data more easily
- (arl): put locations and classification labels onto identified cells
- (arl): visualize tracks
- (arl): deal with datasets that do not have a segmentation

```

---

```
Authors:
- Alan R. Lowe (a.lowe@ucl.ac.uk)
```

---

## Set up the data path and options

In [1]:
DATA_PATH = '/Users/arl/Dropbox/Data/Giulia/Pos13'
PAD_SEGMENTATION = True
SHOW_OUTLINES = True

---

In [2]:
import os
import re
import enum
import json
import napari

import numpy as np

from skimage import io
from zipfile import ZipFile

import matplotlib.pyplot as plt

from scipy.ndimage import binary_erosion

In [3]:
@enum.unique
class Channels(enum.Enum):
    BRIGHTFIELD = 0 
    GFP = 1
    RFP = 2
    IRFP = 3
    PHASE = 4
    WEIGHTS = 98
    MASK = 99

In [4]:
class SimpleOctopusLiteLoader(object):
    """ SimpleOctopusLiteLoader 
    
    A simple class to load OctopusLite data from a directory. 
    Caches data once it is loaded to prevent excesive io to 
    the data server.
    
    Can directly address fluorescence channels using the
    `Channels` enumerator:
    
        Channels.BRIGHTFIELD 
        Channels.GFP
        Channels.RFP 
        Channels.IRFP
        
    Usage:
        octopus = SimpleOctopusLiteLoader('/path/to/your/data')
        gfp = octopus[Channels.GFP]
    
    """
    def __init__(self, path):
        self.path = path
        self._files = {}
        self._data = {}
        
        # parse the files
        self._parse_files()
        
        self._shape = (0,1352,1688)
        
    def __contains__(self, channel):
        return channel in self.channels
        
    @property 
    def channels(self):
        return list(self._files.keys())
    
    @property 
    def shape(self):
        return self._shape
    
    def channel_name_from_index(self, channel_index):
        return Channels(int(channel_index))
    
    
    def __getitem__(self, channel_name):
        assert(channel_name in self.channels)
        
        if channel_name not in self._data:
            self._load_channel(channel_name)
            
        return self._data[channel_name]
    
    
    def _parse_files(self):
        """ parse out the files from the folder """
        files = [f for f in os.listdir(self.path) if f.endswith('.tif')]
    
        def parse_filename(fn):
            pattern = "img_channel([0-9]+)_position([0-9]+)_time([0-9]+)_z([0-9]+)"
            params = re.match(pattern, fn)
            return self.channel_name_from_index(params.group(1)), params.group(3)
        
        channels = {k:[] for k in Channels}
    
        # parse the files and sort them 
        for f in files:
            channel, time = parse_filename(f)
            channels[channel].append(f)
            
        for channel in channels.keys():
            channels[channel].sort(key=lambda f: parse_filename(f)[1])
            
        # remove any channels that are empty
        self._files = {k:v for k, v in channels.items() if v}
    
    def _load_channel(self, channel_name):
        assert(channel_name in self.channels)

        
        def load_image(fn):
            return io.imread(os.path.join(self.path, fn))
        
         # load the first image
        im = load_image(self._files[channel_name][0])
        
        # preload the stack
        stack = np.zeros((len(self._files[channel_name]),)+im.shape, dtype=im.dtype)
        self._shape = stack.shape
        
        print('Loading: {} --> {} ({})...'.format(channel_name, stack.shape, stack.dtype))
        
        stack[0,...] = im
        for i in range(1, stack.shape[0]):
            stack[i,...] = load_image(self._files[channel_name][i])
            
        self._data[channel_name] = stack

        
    def clear_cache(self, channel_name):
        print('Warning! You are clearing the cache for: {}'.format(channel_name))
        self._data[channel_name] = None

In [5]:
# def make_folder(foldername):
#     if os.path.exists(foldername):
#         return
#     os.mkdir(foldername)        

In [6]:
def load_segmentation(path):
    """ load the segmentation """
    segmentations = [z for z in os.listdir(path) if z.startswith('segmented') and z.endswith('.zip')]
    segmentations.sort(key=lambda z: os.path.getctime(os.path.join(path,z)), reverse=True)
    
    segmentation = []
    
    if SHOW_OUTLINES:
        outline = lambda s: s * np.logical_xor(binary_erosion(s.astype(np.bool), iterations=2), s.astype(np.bool))
    else:
        outline = lambda s: s
    
    with ZipFile(os.path.join(path, segmentations[0])) as segzip:
        seg_files = sorted(segzip.namelist(), key=lambda f:int(f[2:-4]))
        segmentation = [outline(io.imread(segzip.open(s))) for s in seg_files]
        
    return np.stack(segmentation, axis=0)
        

In [7]:
data = SimpleOctopusLiteLoader(DATA_PATH)
segmentation = load_segmentation(DATA_PATH)

In [8]:
if PAD_SEGMENTATION:
    # pad the segmentation
    px, py = int((data.shape[1]-segmentation.shape[1])/2), int((data.shape[2]-segmentation.shape[2])/2)
    seg = np.pad(segmentation, ((0,0), (px,px), (py,py)), constant_values=0)

In [9]:
def normalize_images(stack):
    normed = stack.astype(np.float32)
    for i in range(stack.shape[0]):
        # normed[i,...] = (normed[i,...]-np.mean(normed[i,...])) / np.std(normed[i,...])
        c = normed[i,...]
        p_lo = np.percentile(c,5)
        p_hi = np.percentile(c,99.5)
        normed[i,...] = np.clip((c - p_lo) / p_hi, 0., 1.)
    return normed

In [10]:
def bounding_boxes(seg):
    lbl, nlbl = ndimage.label(seg)
    class_label, _, minxy, maxxy = ndimage.extrema(seg, lbl, index=np.arange(1, nlbl+1))
    return class_label, minxy, maxxy

In [11]:
# start napari
with napari.gui_qt():
    viewer = napari.Viewer()
    
    if Channels.BRIGHTFIELD in data:
        phase = normalize_images(data[Channels.BRIGHTFIELD])
        viewer.add_image(phase, name='Brightfield', colormap='gray')
        
    if Channels.PHASE in data:
        phase = normalize_images(data[Channels.PHASE])
        viewer.add_image(phase, name='Phase', colormap='gray')
    
    if Channels.GFP in data:
        gfp = normalize_images(data[Channels.GFP])
        viewer.add_image(gfp, name='GFP', colormap='green', contrast_limits=(0.,1.))
        viewer.layers['GFP'].blending = 'additive'
        
    if Channels.RFP in data:
        rfp = normalize_images(data[Channels.RFP])
        viewer.add_image(rfp, name='RFP', colormap='magenta', contrast_limits=(0.,1.))
        viewer.layers['RFP'].blending = 'additive'    
    
    viewer.add_labels(seg, name='labels', opacity=1.0)        

Loading: Channels.BRIGHTFIELD --> (31, 1352, 1688) (uint8)...
Loading: Channels.GFP --> (31, 1352, 1688) (uint8)...
Loading: Channels.RFP --> (31, 1352, 1688) (uint8)...
