#### Importing from multiple sources

Intermediate goal: multiple inputs - > single train, test, val files
(no augmentation yet)

In [None]:
import os
import h5py
import torch
import copy
import ipywidgets as ipyw
import scipy
import pandas as pd
import datetime
import time
import itertools
import qgrid
import shutil
import subprocess

from random import shuffle
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import precision_recall_curve
from scipy.ndimage.interpolation import map_coordinates
from scipy.interpolate import RectBivariateSpline
from scipy import interpolate, ndimage
import paulssonlab.deaton.trenchripper.trenchripper as tr

import skimage as sk
import pickle as pkl
import skimage.morphology
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from paulssonlab.deaton.trenchripper.trenchripper import (
    pandas_hdf5_handler,
    kymo_handle,
    writedir,
)
from paulssonlab.deaton.trenchripper.trenchripper import hdf5lock
from paulssonlab.deaton.trenchripper.trenchripper import object_f_scores

from matplotlib import pyplot as plt

In [None]:
dask_controller = tr.dask_controller(
    walltime="04:00:00",
    local=False,
    n_workers=40,
    memory="2GB",
    working_directory="/n/scratch2/de64/nntest7/dask",
)
dask_controller.startdask()

In [None]:
dask_controller.daskclient

In [None]:
dask_controller.shutdown()

In [None]:
class data_augmentation:
    def __init__(self, mode="a", p_flip=0.5, max_rot=10, min_padding=20):
        if mode not in ["a", "m", "w"]:
            raise ValueError("Not a valid augmentation mode")
        self.mode = mode
        self.p_flip = p_flip
        self.max_rot = max_rot
        self.min_padding = min_padding

    def random_crop(self, img_arr, seg_arr):
        false_arr = np.zeros(img_arr.shape[2:4], dtype=bool)
        random_crop_len_y = np.random.uniform(
            low=0.3, high=1.0, size=(1, img_arr.shape[0])
        )
        random_crop_len_x = np.random.uniform(
            low=0.5, high=1.0, size=(1, img_arr.shape[0])
        )

        random_crop_len = np.concatenate([random_crop_len_y, random_crop_len_x], axis=0)

        random_crop_remainder = 1.0 - random_crop_len
        random_crop_start = (
            np.random.uniform(low=0.0, high=1.0, size=(2, img_arr.shape[0]))
        ) * random_crop_remainder
        low_crop = np.floor(
            random_crop_start * np.array(img_arr.shape[2:4])[:, np.newaxis]
        ).astype("int32")
        high_crop = np.floor(
            low_crop + (random_crop_len * np.array(img_arr.shape[2:4])[:, np.newaxis])
        ).astype("int32")

        out_arr = []
        out_seg_arr = []
        center = (img_arr.shape[2] // 2, img_arr.shape[3] // 2)
        for t in range(img_arr.shape[0]):
            mask = copy.copy(false_arr)
            working_arr = copy.copy(img_arr[t, 0, :, :])
            working_seg_arr = copy.copy(seg_arr[t, 0, :, :])

            dim_0_range = high_crop[0, t] - low_crop[0, t]
            dim_1_range = high_crop[1, t] - low_crop[1, t]
            top_left = (center[0] - dim_0_range // 2, center[1] - dim_1_range // 2)

            dim_0_maxscale = img_arr.shape[2] / dim_0_range
            dim_1_maxscale = img_arr.shape[3] / dim_1_range

            dim_0_scale = np.clip(
                np.random.normal(loc=1.0, scale=0.1), 0.8, dim_0_maxscale
            )
            dim_1_scale = np.clip(
                np.random.normal(loc=1.0, scale=0.1), 0.8, dim_1_maxscale
            )

            #             dim_0_scale = 1.
            #             dim_1_scale = 1.

            rescaled_img = sk.transform.rescale(
                working_arr[
                    low_crop[0, t] : high_crop[0, t], low_crop[1, t] : high_crop[1, t]
                ],
                (dim_0_scale, dim_1_scale),
                preserve_range=True,
            ).astype(int)
            rescaled_seg = (
                sk.transform.rescale(
                    working_seg_arr[
                        low_crop[0, t] : high_crop[0, t],
                        low_crop[1, t] : high_crop[1, t],
                    ]
                    == 1,
                    (dim_0_scale, dim_1_scale),
                )
                > 0.5
            ).astype("int8")

            if self.mode is "m":
                rescaled_border = (
                    sk.transform.rescale(
                        working_seg_arr[
                            low_crop[0, t] : high_crop[0, t],
                            low_crop[1, t] : high_crop[1, t],
                        ]
                        == 2,
                        (dim_0_scale, dim_1_scale),
                    )
                    > 0.5
                )
                rescaled_seg[rescaled_border] = 2

            top_left = (
                center[0] - rescaled_img.shape[0] // 2,
                center[1] - rescaled_img.shape[1] // 2,
            )
            working_arr[
                top_left[0] : top_left[0] + rescaled_img.shape[0],
                top_left[1] : top_left[1] + rescaled_img.shape[1],
            ] = rescaled_img
            working_seg_arr[
                top_left[0] : top_left[0] + rescaled_img.shape[0],
                top_left[1] : top_left[1] + rescaled_img.shape[1],
            ] = rescaled_seg

            mask[
                top_left[0] : top_left[0] + rescaled_img.shape[0],
                top_left[1] : top_left[1] + rescaled_img.shape[1],
            ] = True
            working_arr[~mask] = 0
            working_seg_arr[~mask] = False

            out_arr.append(working_arr)
            out_seg_arr.append(working_seg_arr)
        out_arr = np.expand_dims(np.array(out_arr), 1)
        out_seg_arr = np.expand_dims(np.array(out_seg_arr), 1)
        return out_arr, out_seg_arr

    def random_x_flip(self, img_arr, seg_arr, p=0.5):
        choices = np.random.choice(
            np.array([True, False]), size=img_arr.shape[0], p=np.array([p, 1.0 - p])
        )
        out_img_arr = copy.copy(img_arr)
        out_seg_arr = copy.copy(seg_arr)
        out_img_arr[choices, 0, :, :] = np.flip(img_arr[choices, 0, :, :], axis=1)
        out_seg_arr[choices, 0, :, :] = np.flip(seg_arr[choices, 0, :, :], axis=1)
        return out_img_arr, out_seg_arr

    def random_y_flip(self, img_arr, seg_arr, p=0.5):
        choices = np.random.choice(
            np.array([True, False]), size=img_arr.shape[0], p=np.array([p, 1.0 - p])
        )
        out_img_arr = copy.copy(img_arr)
        out_seg_arr = copy.copy(seg_arr)
        out_img_arr[choices, 0, :, :] = np.flip(img_arr[choices, 0, :, :], axis=2)
        out_seg_arr[choices, 0, :, :] = np.flip(seg_arr[choices, 0, :, :], axis=2)
        return out_img_arr, out_seg_arr

    def change_brightness(self, img_arr, num_control_points=3):
        out_img_arr = copy.copy(img_arr)
        for t in range(img_arr.shape[0]):
            control_points = (
                np.add.accumulate(np.ones(num_control_points + 2)) - 1.0
            ) / (num_control_points + 1)
            control_point_locations = (control_points * 65535).astype(int)
            orig_locations = copy.copy(control_point_locations)
            random_points = np.random.uniform(
                low=0, high=65535, size=num_control_points
            ).astype(int)
            sorted_points = np.sort(random_points)
            control_point_locations[1:-1] = sorted_points
            mapping = interpolate.PchipInterpolator(
                orig_locations, control_point_locations
            )
            out_img_arr[t, 0, :, :] = mapping(img_arr[t, 0, :, :])
        return out_img_arr

    def add_padding(self, img_arr, seg_arr, max_rot=20, min_padding=20):
        hyp_length = np.ceil(
            (img_arr.shape[2] ** 2 + img_arr.shape[3] ** 2) ** (1 / 2)
        ).astype(int)
        max_rads = ((90 - max_rot) / 360) * (2 * np.pi)
        min_rads = (90 / 360) * (2 * np.pi)
        max_y = np.maximum(
            np.ceil(hyp_length * np.sin(max_rads)),
            np.ceil(hyp_length * np.sin(min_rads)),
        ).astype(int)
        max_x = np.maximum(
            np.ceil(hyp_length * np.cos(max_rads)),
            np.ceil(hyp_length * np.cos(min_rads)),
        ).astype(int)
        delta_y = max_y - img_arr.shape[2]
        delta_x = max_x - img_arr.shape[3]
        if delta_x % 2 == 1:
            delta_x += 1
        if delta_y % 2 == 1:
            delta_y += 1
        delta_y = np.maximum(delta_y, 2 * min_padding)
        delta_x = np.maximum(delta_x, 2 * min_padding)
        padded_img_arr = np.pad(
            img_arr,
            (
                (0, 0),
                (0, 0),
                (delta_y // 2, delta_y // 2),
                (delta_x // 2, delta_x // 2),
            ),
            "constant",
            constant_values=0,
        )
        padded_seg_arr = np.pad(
            seg_arr,
            (
                (0, 0),
                (0, 0),
                (delta_y // 2, delta_y // 2),
                (delta_x // 2, delta_x // 2),
            ),
            "constant",
            constant_values=0,
        )
        return padded_img_arr, padded_seg_arr

    def translate(self, pad_img_arr, pad_seg_arr, img_arr, seg_arr):
        trans_img_arr = copy.copy(pad_img_arr)
        trans_seg_arr = copy.copy(pad_seg_arr)
        delta_y = pad_img_arr.shape[2] - img_arr.shape[2]
        delta_x = pad_img_arr.shape[3] - img_arr.shape[3]
        for t in range(pad_img_arr.shape[0]):
            trans_y = np.random.randint(-(delta_y // 2), high=delta_y // 2)
            trans_x = np.random.randint(-(delta_x // 2), high=delta_x // 2)
            trans_img_arr[
                t,
                0,
                delta_y // 2 : delta_y // 2 + img_arr.shape[2],
                delta_x // 2 : delta_x // 2 + img_arr.shape[3],
            ] = 0
            trans_seg_arr[
                t,
                0,
                delta_y // 2 : delta_y // 2 + img_arr.shape[2],
                delta_x // 2 : delta_x // 2 + img_arr.shape[3],
            ] = 0
            trans_img_arr[
                t,
                0,
                delta_y // 2 + trans_y : delta_y // 2 + img_arr.shape[2] + trans_y,
                delta_x // 2 + trans_x : delta_x // 2 + img_arr.shape[3] + trans_x,
            ] = pad_img_arr[
                t,
                0,
                delta_y // 2 : delta_y // 2 + img_arr.shape[2],
                delta_x // 2 : delta_x // 2 + img_arr.shape[3],
            ]
            trans_seg_arr[
                t,
                0,
                delta_y // 2 + trans_y : delta_y // 2 + img_arr.shape[2] + trans_y,
                delta_x // 2 + trans_x : delta_x // 2 + img_arr.shape[3] + trans_x,
            ] = pad_seg_arr[
                t,
                0,
                delta_y // 2 : delta_y // 2 + img_arr.shape[2],
                delta_x // 2 : delta_x // 2 + img_arr.shape[3],
            ]
        return trans_img_arr, trans_seg_arr

    def rotate(self, img_arr, seg_arr, max_rot=20):
        rot_img_arr = copy.copy(img_arr)
        rot_seg_arr = copy.copy(seg_arr)
        for t in range(img_arr.shape[0]):
            r = np.random.uniform(low=-max_rot, high=max_rot)
            rot_img_arr[t, 0, :, :] = sk.transform.rotate(
                img_arr[t, 0, :, :], r, preserve_range=True
            ).astype("int32")
            rot_seg = (sk.transform.rotate(seg_arr[t, 0, :, :] == 1, r) > 0.5).astype(
                "int8"
            )
            if self.mode is "m":
                rot_border = sk.transform.rotate(seg_arr[t, 0, :, :] == 2, r) > 0.5
                rot_seg[rot_border] = 2
            rot_seg_arr[t, 0, :, :] = rot_seg
        return rot_img_arr, rot_seg_arr

    def deform_img_arr(self, img_arr, seg_arr):
        def_img_arr = copy.copy(img_arr)
        def_seg_arr = copy.copy(seg_arr)
        for t in range(img_arr.shape[0]):
            y_steps = np.linspace(0.0, 4.0, num=img_arr.shape[2])
            x_steps = np.linspace(0.0, 4.0, num=img_arr.shape[3])
            grid = np.random.normal(scale=1.0, size=(2, 4, 4))
            dx = RectBivariateSpline(np.arange(4), np.arange(4), grid[0]).ev(
                y_steps[:, np.newaxis], x_steps[np.newaxis, :]
            )
            dy = RectBivariateSpline(np.arange(4), np.arange(4), grid[1]).ev(
                y_steps[:, np.newaxis], x_steps[np.newaxis, :]
            )
            y, x = np.meshgrid(
                np.arange(img_arr.shape[2]), np.arange(img_arr.shape[3]), indexing="ij"
            )
            indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1))
            elastic_img = map_coordinates(
                img_arr[t, 0, :, :], indices, order=1
            ).reshape(img_arr.shape[2:4])

            def_img_arr[t, 0, :, :] = elastic_img

            elastic_cell = map_coordinates(
                seg_arr[t, 0, :, :] == 1, indices, order=1
            ).reshape(seg_arr.shape[2:4])
            elastic_cell = sk.morphology.binary_closing(elastic_cell)
            def_seg_arr[t, 0, elastic_cell] = 1
            if self.mode is "m":
                elastic_border = map_coordinates(
                    seg_arr[t, 0, :, :] == 2, indices, order=1
                ).reshape(seg_arr.shape[2:4])
                def_seg_arr[t, 0, elastic_border] = 2
        return def_img_arr, def_seg_arr

    def add_borders(self, seg_arr):
        output_arr = copy.copy(seg_arr)
        for k in range(seg_arr.shape[0]):
            working_mask = seg_arr[k, 0] == 1
            expanded = sk.morphology.binary_dilation(working_mask)
            border = expanded ^ working_mask
            output_arr[k, 0, border] = 2
        return output_arr

    def repair_borders(self, seg_arr):
        output_arr = copy.copy(seg_arr)
        output_arr[output_arr == 2] = 0
        output_arr = self.add_borders(output_arr)
        return output_arr

    def get_augmented_data(self, img_arr, seg_arr):
        img_arr, seg_arr = (img_arr.astype("int32"), seg_arr.astype("int8"))
        if self.mode is "m":
            seg_arr = self.add_borders(seg_arr)
        img_arr, seg_arr = self.random_crop(img_arr, seg_arr)
        img_arr, seg_arr = self.random_x_flip(img_arr, seg_arr, p=self.p_flip)
        img_arr, seg_arr = self.random_y_flip(img_arr, seg_arr, p=self.p_flip)
        img_arr = self.change_brightness(img_arr)
        pad_img_arr, pad_seg_arr = self.add_padding(
            img_arr, seg_arr, max_rot=self.max_rot + 5
        )
        img_arr, seg_arr = self.translate(pad_img_arr, pad_seg_arr, img_arr, seg_arr)
        del pad_img_arr
        del pad_seg_arr
        img_arr, seg_arr = self.rotate(img_arr, seg_arr, max_rot=self.max_rot)
        img_arr, seg_arr = self.deform_img_arr(img_arr, seg_arr)
        seg_arr = self.repair_borders(seg_arr)
        #         img_arr,seg_arr = (img_arr.astype("int32"),seg_arr.astype("int8"))
        return img_arr, seg_arr

In [None]:
test = data_augmentation(mode="m")

In [None]:
k = 22
t = 10

with h5py.File(
    "/n/scratch2/de64/2019-05-31_validation_data/kymograph/kymograph_0.hdf5", "r"
) as infile:
    img = infile["Phase"][k, t]

with h5py.File(
    "/n/scratch2/de64/2019-05-31_validation_data/fluorsegmentation/segmentation_0.hdf5",
    "r",
) as infile:
    seg = infile["data"][k, t]

In [None]:
ax = plt.subplot(1, 4, 1)
ax.imshow(img)
ax = plt.subplot(1, 4, 2)
ax.imshow(seg)
ax = plt.subplot(1, 4, 3)
img_arr, seg_arr = test.get_augmented_data(
    img[np.newaxis, np.newaxis], seg[np.newaxis, np.newaxis]
)
ax.imshow(img_arr[0, 0])
ax = plt.subplot(1, 4, 4)
ax.imshow(seg_arr[0, 0])
plt.show()

In [None]:
class weightmap_generator:
    def __init__(self, mode, w0=0.0, wm_sigma=0.0):
        if mode not in ["a", "m", "w"]:
            raise ValueError("Not a valid augmentation mode")
        self.mode = mode
        self.w0 = w0
        self.wm_sigma = wm_sigma

    def make_one_class_weightmap(self, single_mask):
        ttl_count = single_mask.size

        backround_mask = single_mask == 0
        cell_mask = single_mask == 1

        background_count = np.sum(backround_mask)
        cell_count = np.sum(cell_mask)

        class_weight = np.array(
            [ttl_count / (background_count + 1), ttl_count / (cell_count + 1)]
        )
        class_weight = class_weight / np.sum(class_weight)

        weight_map = np.zeros(single_mask.shape, dtype=float)
        weight_map[backround_mask] = class_weight[0]
        weight_map[cell_mask] = class_weight[1]

        return weight_map

    def make_two_class_weightmap(self, single_mask):
        ttl_count = single_mask.size

        backround_mask = single_mask == 0
        cell_mask = single_mask == 1
        border_mask = single_mask == 2

        background_count = np.sum(backround_mask)
        cell_count = np.sum(cell_mask)
        border_count = np.sum(border_mask)

        class_weight = np.array(
            [
                ttl_count / (background_count + 1),
                ttl_count / (cell_count + 1),
                ttl_count / (border_count + 1),
            ]
        )
        class_weight = class_weight / np.sum(class_weight)

        weight_map = np.zeros(single_mask.shape, dtype=float)
        weight_map[backround_mask] = class_weight[0]
        weight_map[cell_mask] = class_weight[1]
        weight_map[border_mask] = class_weight[2]

        return weight_map

    def make_unet_weight_map(self, single_mask):
        binary_mask = single_mask == 1

        ttl_count = binary_mask.size
        cell_count = np.sum(binary_mask == 1)
        background_count = ttl_count - cell_count
        class_weight = np.array(
            [ttl_count / (background_count + 1), ttl_count / (cell_count + 1)]
        )
        class_weight = class_weight / np.sum(class_weight)

        labeled = sk.measure.label(binary_mask)
        labels = np.unique(labeled)[1:]

        dist_maps = []
        borders = []

        num_labels = len(labels)

        if num_labels == 0:
            weight_map = np.ones(binary_mask.shape) * class_weight[0]
        elif num_labels == 1:
            cell = labeled == 1
            #             dilated = sk.morphology.binary_dilation(cell)
            eroded = sk.morphology.binary_dilation(cell)
            border = eroded ^ cell
            weight_map = np.ones(binary_mask.shape) * class_weight[0]
            weight_map[binary_mask] += class_weight[1]
        #             weight[border] = 0.
        else:
            for i in labels:
                cell = labeled == i
                #                 dilated = sk.morphology.binary_dilation(cell)
                eroded = sk.morphology.binary_dilation(cell)
                border = eroded ^ cell
                borders.append(border)
                dist_map = scipy.ndimage.morphology.distance_transform_edt(~border)
                dist_maps.append(dist_map)
            dist_maps = np.array(dist_maps)
            borders = np.array(borders)
            borders = np.max(borders, axis=0)
            dist_maps = np.sort(dist_maps, axis=0)
            weight_map = self.w0 * np.exp(
                -((dist_maps[0] + dist_maps[1]) ** 2) / (2 * (self.wm_sigma ** 2))
            )
            weight_map[binary_mask] += class_weight[1]
            weight_map[~binary_mask] += class_weight[0]
        #             weight[borders] = 0.
        return weight

In [None]:
class UNet_Training_DataLoader:
    def __init__(
        self,
        nndatapath="",
        experimentname="",
        output_names=["train", "test", "val"],
        output_modes=["a", "m"],
        num_epochs=10,
        input_paths=[],
    ):
        self.nndatapath = nndatapath
        self.experimentname = experimentname
        self.output_names = output_names
        self.output_modes = output_modes
        self.num_epochs = num_epochs
        self.input_paths = input_paths

        self.metapath = self.nndatapath + "/metadata.hdf5"

    def get_metadata(self, headpath):
        meta_handle = pandas_hdf5_handler(headpath + "/metadata.hdf5")
        global_handle = meta_handle.read_df("global", read_metadata=True)
        kymo_handle = meta_handle.read_df("kymograph", read_metadata=True)
        fovdf = kymo_handle.reset_index(inplace=False)
        fovdf = fovdf.set_index(
            ["fov", "row", "trench"], drop=True, append=False, inplace=False
        )
        fovdf = fovdf.sort_index()

        channel_list = global_handle.metadata["channels"]
        fov_list = kymo_handle["fov"].unique().tolist()
        t_len = len(kymo_handle.index.get_level_values("timepoints").unique())
        ttl_trenches = len(fovdf["trenchid"].unique())
        trench_dict = {
            fov: len(fovdf.loc[fov]["trenchid"].unique()) for fov in fov_list
        }
        shape_y = kymo_handle.metadata["kymograph_params"]["ttl_len_y"]
        shape_x = kymo_handle.metadata["kymograph_params"]["trench_width_x"]
        kymograph_img_shape = tuple((shape_y, shape_x))
        return (
            channel_list,
            fov_list,
            t_len,
            trench_dict,
            ttl_trenches,
            kymograph_img_shape,
        )

    def inter_get_selection(self):
        output_tabs = []
        for i in range(len(self.output_names)):
            dset_tabs = []
            for j in range(len(self.input_paths)):
                (
                    channel_list,
                    fov_list,
                    t_len,
                    trench_dict,
                    ttl_trenches,
                    kymograph_img_shape,
                ) = self.get_metadata(self.input_paths[j])

                feature_dropdown = ipyw.Dropdown(
                    options=channel_list,
                    value=channel_list[0],
                    description="Feature Channel:",
                    disabled=False,
                )
                max_samples = ipyw.IntText(
                    value=0, description="Maximum Samples per Dataset:", disabled=False
                )
                t_range = ipyw.IntRangeSlider(
                    value=[0, t_len - 1],
                    description="Timepoint Range:",
                    min=0,
                    max=t_len - 1,
                    step=1,
                    disabled=False,
                    continuous_update=False,
                )

                working_tab = ipyw.VBox(
                    children=[feature_dropdown, max_samples, t_range]
                )
                dset_tabs.append(working_tab)

            dset_ipy_tabs = ipyw.Tab(children=dset_tabs)
            for j in range(len(self.input_paths)):
                dset_ipy_tabs.set_title(j, self.input_paths[j].split("/")[-1])
            output_tabs.append(dset_ipy_tabs)
        output_ipy_tabs = ipyw.Tab(children=output_tabs)
        for i, output_name in enumerate(self.output_names):
            output_ipy_tabs.set_title(i, output_name)
        self.tab = output_ipy_tabs

        return self.tab

    def get_import_params(self):
        self.import_param_dict = {}
        for i, output_name in enumerate(self.output_names):
            self.import_param_dict[output_name] = {}
            for j, input_path in enumerate(self.input_paths):
                working_vbox = self.tab.children[i].children[j]
                self.import_param_dict[output_name][input_path] = {
                    child.description: child.value for child in working_vbox.children
                }

        print("======== Import Params ========")
        for i, output_name in enumerate(self.output_names):
            print(str(output_name))
            for j, input_path in enumerate(self.input_paths):
                (
                    channel_list,
                    fov_list,
                    t_len,
                    trench_dict,
                    ttl_trenches,
                    kymograph_img_shape,
                ) = self.get_metadata(input_path)
                ttl_possible_samples = t_len * ttl_trenches
                param_dict = self.import_param_dict[output_name][input_path]
                requested_samples = param_dict["Maximum Samples per Dataset:"]
                if requested_samples > 0:
                    print(str(input_path))
                    for key, val in param_dict.items():
                        print(key + " " + str(val))
                    print(
                        "Requested Samples / Total Samples: "
                        + str(requested_samples)
                        + "/"
                        + str(ttl_possible_samples)
                    )

        del self.tab

    def export_chunk(self, output_name, init_idx, chunk_size, chunk_idx):
        output_meta_handle = pandas_hdf5_handler(self.metapath)
        output_df = output_meta_handle.read_df(output_name)
        working_df = output_df[init_idx : init_idx + chunk_size]
        nndatapath = (
            self.nndatapath + "/" + output_name + "_" + str(chunk_idx) + ".hdf5"
        )

        dset_paths = working_df.index.get_level_values(0).unique().tolist()
        for dset_path in dset_paths:
            dset_path_key = dset_path.split("/")[-1]
            dset_df = working_df.loc[dset_path]
            param_dict = self.import_param_dict[output_name][dset_path]
            feature_channel = param_dict["Feature Channel:"]

            img_arr_list = []
            seg_arr_list = []

            file_indices = dset_df.index.get_level_values(0).unique().tolist()

            for file_idx in file_indices:
                file_df = dset_df.loc[file_idx]

                img_path = dset_path + "/kymograph/kymograph_" + str(file_idx) + ".hdf5"
                seg_path = (
                    dset_path
                    + "/fluorsegmentation/segmentation_"
                    + str(file_idx)
                    + ".hdf5"
                )

                with h5py.File(img_path, "r") as imgfile:
                    working_arr = imgfile[feature_channel][:]

                for trench_idx, row in file_df.iterrows():
                    img_arr = working_arr[trench_idx, row["timepoints"]][
                        np.newaxis, np.newaxis, :, :
                    ]  # 1,1,y,x img
                    img_arr = img_arr.astype("int32")
                    img_arr_list.append(img_arr)

                with h5py.File(seg_path, "r") as segfile:
                    working_arr = segfile["data"][:]

                for trench_idx, row in file_df.iterrows():
                    seg_arr = working_arr[trench_idx, row["timepoints"]][
                        np.newaxis, np.newaxis, :, :
                    ]
                    seg_arr = seg_arr.astype("int8")
                    seg_arr_list.append(seg_arr)

            output_img_arr = np.concatenate(img_arr_list, axis=0)
            output_seg_arr = np.concatenate(seg_arr_list, axis=0)
            chunk_shape = (1, 1, output_img_arr.shape[2], output_img_arr.shape[3])

            with h5py.File(nndatapath, "w") as outfile:
                for output_mode in self.output_modes:
                    augmenter = data_augmentation(mode=output_mode)
                    for epoch in range(self.num_epochs):
                        img_arr, seg_arr = augmenter.get_augmented_data(
                            output_img_arr, output_seg_arr
                        )
                        img_handle = outfile.create_dataset(
                            dset_path_key
                            + "/"
                            + output_mode
                            + "/epoch_"
                            + str(epoch)
                            + "/img",
                            data=img_arr,
                            chunks=chunk_shape,
                            dtype="int32",
                        )
                        seg_handle = outfile.create_dataset(
                            dset_path_key
                            + "/"
                            + output_mode
                            + "/epoch_"
                            + str(epoch)
                            + "/seg",
                            data=seg_arr,
                            chunks=chunk_shape,
                            dtype="int8",
                        )

        return init_idx

    #         if augment:
    #             img_arr,seg_arr = self.data_augmentation.get_augmented_data(img_arr,seg_arr)

    #         chunk_shape = (1,1,img_arr.shape[2],img_arr.shape[3])

    #         with h5py.File(nndatapath,"w") as outfile:
    #             img_handle = outfile.create_dataset("img",data=img_arr,chunks=chunk_shape,dtype='int32')
    #             seg_handle = outfile.create_dataset("seg",data=seg_arr,chunks=chunk_shape,dtype='int8')

    #         for item in weight_grid_list:
    #             w0,wm_sigma = item
    #             weightmap_gen = weightmap_generator(self.nndatapath,w0,wm_sigma)
    #             weightmap_arr = weightmap_gen.make_weightmaps(seg_arr)
    #             with h5py.File(nndatapath,"a") as outfile:
    #                 weightmap_handle = outfile.create_dataset("weight_" + str(item),data=weightmap_arr,chunks=chunk_shape,dtype='int32')

    #         return file_idx

    def gather_chunks(self, output_name, init_idx_list, chunk_idx_list, chunk_size):

        #                       outputdf,output_metadata,selectionname,file_idx_list,weight_grid_list):
        nnoutputpath = self.nndatapath + "/" + output_name + ".hdf5"
        output_meta_handle = pandas_hdf5_handler(self.metapath)
        output_df = output_meta_handle.read_df(output_name)

        dset_paths = output_df.index.get_level_values(0).unique().tolist()
        for dset_path in dset_paths:
            dset_path_key = dset_path.split("/")[-1]
            dset_df = output_df.loc[dset_path]

            tempdatapath = self.nndatapath + "/" + output_name + "_0.hdf5"
            with h5py.File(tempdatapath, "r") as infile:
                img_shape = infile[
                    dset_path_key + "/" + self.output_modes[0] + "/epoch_0/img"
                ].shape

            output_shape = (len(dset_df.index), 1, img_shape[2], img_shape[3])
            chunk_shape = (1, 1, img_shape[2], img_shape[3])

            with h5py.File(nnoutputpath, "w") as outfile:
                for output_mode in self.output_modes:
                    for epoch in range(self.num_epochs):
                        img_handle = outfile.create_dataset(
                            dset_path_key
                            + "/"
                            + output_mode
                            + "/epoch_"
                            + str(epoch)
                            + "/img",
                            output_shape,
                            chunks=chunk_shape,
                            dtype="int32",
                        )
                        seg_handle = outfile.create_dataset(
                            dset_path_key
                            + "/"
                            + output_mode
                            + "/epoch_"
                            + str(epoch)
                            + "/seg",
                            output_shape,
                            chunks=chunk_shape,
                            dtype="int8",
                        )

        #             for item in weight_grid_list:
        #                 weightmap_handle = outfile.create_dataset("weight_" + str(item),output_shape,chunks=chunk_shape,dtype='int32')
        current_dset_path = ""
        for i, init_idx in enumerate(init_idx_list):
            chunk_idx = chunk_idx_list[i]
            nndatapath = (
                self.nndatapath + "/" + output_name + "_" + str(chunk_idx) + ".hdf5"
            )
            working_df = output_df[init_idx : init_idx + chunk_size]
            dset_paths = working_df.index.get_level_values(0).unique().tolist()

            with h5py.File(nndatapath, "r") as infile:

                for dset_path in dset_paths:
                    if dset_path != current_dset_path:
                        current_idx = 0
                        current_dset_path = dset_path

                    dset_path_key = dset_path.split("/")[-1]
                    dset_df = output_df.loc[dset_path]

                    with h5py.File(nnoutputpath, "a") as outfile:

                        for output_mode in self.output_modes:
                            for epoch in range(self.num_epochs):
                                img_arr = infile[
                                    dset_path_key
                                    + "/"
                                    + output_mode
                                    + "/epoch_"
                                    + str(epoch)
                                    + "/img"
                                ][
                                    :
                                ]  # k,1,y,x
                                seg_arr = infile[
                                    dset_path_key
                                    + "/"
                                    + output_mode
                                    + "/epoch_"
                                    + str(epoch)
                                    + "/seg"
                                ][
                                    :
                                ]  # k,1,y,x
                                num_indices = img_arr.shape[0]

                                outfile[
                                    dset_path_key
                                    + "/"
                                    + output_mode
                                    + "/epoch_"
                                    + str(epoch)
                                    + "/img"
                                ][current_idx : current_idx + num_indices] = img_arr
                                outfile[
                                    dset_path_key
                                    + "/"
                                    + output_mode
                                    + "/epoch_"
                                    + str(epoch)
                                    + "/seg"
                                ][current_idx : current_idx + num_indices] = seg_arr
                    current_idx += num_indices

            os.remove(nndatapath)

    #                 img_arr = infile["img"][:]
    #                 seg_arr = infile["seg"][:]
    #                 weight_arr_list = []
    #                 for item in weight_grid_list:
    #                     weight_arr_list.append(infile["weight_" + str(item)][:])
    #             num_indices = img_arr.shape[0]
    #             with h5py.File(nnoutputpath,"a") as outfile:
    #                 outfile["img"][current_idx:current_idx+num_indices] = img_arr
    #                 outfile["seg"][current_idx:current_idx+num_indices] = seg_arr
    #                 for i,item in enumerate(weight_grid_list):
    #                     outfile["weight_" + str(item)][current_idx:current_idx+num_indices] = weight_arr_list[i]
    #             current_idx += num_indices
    #             os.remove(nndatapath)

    #     def export_data(self,dask_controller):
    def export_data(self, dask_controller, chunk_size=250):

        dask_controller.futures = {}
        output_meta_handle = pandas_hdf5_handler(self.metapath)
        all_output_dfs = {}

        for output_name, _ in self.import_param_dict.items():
            output_df = []

            for input_path, param_dict in self.import_param_dict[output_name].items():
                input_meta_handle = pandas_hdf5_handler(input_path + "/metadata.hdf5")

                num_samples = param_dict["Maximum Samples per Dataset:"]
                feature_channel = param_dict["Feature Channel:"]
                t_range = param_dict["Timepoint Range:"]

                kymodf = input_meta_handle.read_df("kymograph", read_metadata=True)
                kymodf["filepath"] = input_path
                trenchdf = kymodf.reset_index(inplace=False)
                trenchdf = trenchdf.set_index(
                    ["filepath", "trenchid", "timepoints"],
                    drop=True,
                    append=False,
                    inplace=False,
                )
                trenchdf = trenchdf.sort_index()
                trenchdf = trenchdf.loc[
                    pd.IndexSlice[:, :, t_range[0] : t_range[1] + 1], :
                ]

                trenchdf_subset = trenchdf.sample(n=num_samples)
                filedf_subset = trenchdf_subset.reset_index(inplace=False)
                filedf_subset = filedf_subset.set_index(
                    ["filepath", "File Index", "File Trench Index"],
                    drop=True,
                    append=False,
                    inplace=False,
                )
                filedf_subset = filedf_subset.sort_index()
                output_df.append(filedf_subset)
            output_df = pd.concat(output_df)
            output_meta_handle.write_df(output_name, output_df)
            all_output_dfs[output_name] = output_df

        for output_name in all_output_dfs.keys():

            output_df = all_output_dfs[output_name]

            ## split into equal computation chunks here

            chunk_idx_list = []
            for chunk_idx, init_idx in enumerate(range(0, len(output_df), chunk_size)):
                future = dask_controller.daskclient.submit(
                    self.export_chunk,
                    output_name,
                    init_idx,
                    chunk_size,
                    chunk_idx,
                    retries=1,
                )
                dask_controller.futures["Chunk Number: " + str(chunk_idx)] = future
                chunk_idx_list.append(chunk_idx)

            init_idx_list = dask_controller.daskclient.gather(
                [
                    dask_controller.futures["Chunk Number: " + str(chunk_idx)]
                    for chunk_idx in chunk_idx_list
                ]
            )
            self.gather_chunks(output_name, init_idx_list, chunk_idx_list, chunk_size)


#         outputdf = filedf.reset_index(inplace=False)
#         outputdf = outputdf.set_index(["trenchid","timepoints"], drop=True, append=False, inplace=False)
#         outputdf = outputdf.sort_index()

#         del outputdf["File Index"]
#         del outputdf["File Trench Index"]

#         selection_keys = ["channel", "fov_list", "t_subsample_step", "t_range", "max_trenches", "ttl_imgs", "kymograph_img_shape"]
#         selection = {selection_keys[i]:item for i,item in enumerate(selection)}
#         selection["experiment_name"],selection["data_name"] = (self.experimentname, dataname)
#         selection["W0 List"], selection["Wm Sigma List"] = (self.grid_dict['W0 (Border Region Weight):'],self.grid_dict['Wm Sigma (Border Region Spread):'])

#         output_metadata = {"nndataset" : selection}

#         segparampath = datapath + "/fluorescent_segmentation.par"
#         with open(segparampath, 'rb') as infile:
#             seg_param_dict = pkl.load(infile)

#         output_metadata["segmentation"] = seg_param_dict

#         input_meta_handle = pandas_hdf5_handler(datapath + "/metadata.hdf5")
#         for item in ["global","kymograph"]:
#             indf = input_meta_handle.read_df(item,read_metadata=True)
#             output_metadata[item] = indf.metadata

#         output_meta_handle.write_df(selectionname,outputdf,metadata=output_metadata)

#         file_idx_list = dask_controller.daskclient.gather([dask_controller.futures["File Number: " + str(file_idx)] for file_idx in filelist])
#         self.gather_chunks(outputdf,output_metadata,selectionname,file_idx_list,weight_grid_list)


#     def display_grid(self):
#         tab_dict = {'W0 (Border Region Weight):':[1., 3., 5., 10.],'Wm Sigma (Border Region Spread):':[1., 2., 3., 4., 5.]}
#         children = [ipyw.SelectMultiple(options=val,value=(val[1],),description=key,disabled=False) for key,val in tab_dict.items()]
#         self.tab = ipyw.Tab()
#         self.tab.children = children
#         for i,key in enumerate(tab_dict.keys()):
#             self.tab.set_title(i, key[:-1])
#         return self.tab

#     def get_grid_params(self):
#         if hasattr(self,'tab'):
#             self.grid_dict = {child.description:child.value for child in self.tab.children}
#             delattr(self, 'tab')
#         elif hasattr(self,'grid_dict'):
#             pass
#         else:
#             raise "No selection defined."
#         print("======== Grid Params ========")
#         for key,val in self.grid_dict.items():
#             print(key + " " + str(val))

#     def export_all_data(self,n_workers=20,memory='4GB'):
#         writedir(self.nndatapath,overwrite=True)

#         grid_keys = self.grid_dict.keys()
#         grid_combinations = list(itertools.product(*list(self.grid_dict.values())))

#         self.data_augmentation = data_augmentation()

#         dask_cont = dask_controller(walltime='01:00:00',local=False,n_workers=n_workers,memory=memory)
#         dask_cont.startdask()
# #         dask_cont.daskcluster.start_workers()
#         dask_cont.displaydashboard()

#         try:
#             for selectionname in ["train","test","val"]:
#                 if selectionname == "train":
#                     self.export_data(selectionname,dask_cont,grid_combinations,augment=True)
#        dataloader         else:
#                     self.export_data(selectionname,dask_cont,grid_combinations,augment=False)
#             dask_cont.shutdown()
#         except:
#             dask_cont.shutdown()
#             raise

In [None]:
dataloader = UNet_Training_DataLoader(
    nndatapath="/n/scratch2/de64/nntest7",
    experimentname="First NN",
    input_paths=["/n/scratch2/de64/2019-05-31_validation_data"],
    output_modes=["a", "m"],
    num_epochs=10,
)

In [None]:
dataloader.inter_get_selection()

In [None]:
dataloader.get_import_params()

In [None]:
dataloader.export_data(dask_controller, chunk_size=50)

In [None]:
with h5py.File("/n/scratch2/de64/nntest7/train.hdf5", "r") as infile:
    data = infile["2019-05-31_validation_data/m/epoch_0/img"][:1000]
    segdata = infile["2019-05-31_validation_data/m/epoch_0/seg"][:1000]
    data1 = infile["2019-05-31_validation_data/m/epoch_1/img"][:1000]
    segdata1 = infile["2019-05-31_validation_data/m/epoch_1/seg"][:1000]
#     img_arr = infile["2019-05-31_validation_data/img"][:]

In [None]:
import matplotlib

matplotlib.rcParams["figure.figsize"] = [20, 10]
idx = 110
plt.imshow(data[idx, 0])
plt.show()
plt.imshow(segdata[idx, 0])
plt.show()
plt.imshow(data1[idx, 0])
plt.show()
plt.imshow(segdata1[idx, 0])
plt.show()

In [None]:
dataloader.import_param_dict