In [None]:
import copy
import os
import shutil

import ipywidgets as ipyw
import pandas as pd
import qgrid
from matplotlib import pyplot as plt

import paulssonlab.deaton.trenchripper.trenchripper as tr
from paulssonlab.deaton.trenchripper.trenchripper.utils import pandas_hdf5_handler

In [None]:
import h5py
import matplotlib
import numpy as np
import skimage as sk
from matplotlib import pyplot as plt

matplotlib.rcParams["figure.figsize"] = [20, 10]
from scipy.ndimage import convolve1d

In [None]:
def get_flows(labeled, eps=0.00001):
    rps = sk.measure.regionprops(labeled)
    centers = np.array([np.round(rp.centroid).astype("uint16") for rp in rps])
    y_lens = np.array([rp.bbox[2] - rp.bbox[0] for rp in rps])
    x_lens = np.array([rp.bbox[3] - rp.bbox[1] for rp in rps])
    N_arr = 2 * (y_lens + x_lens)
    kernel = np.ones(3, float) / 3.0

    x_grad_arr = np.zeros(labeled.shape, dtype=np.float32)
    y_grad_arr = np.zeros(labeled.shape, dtype=np.float32)

    for cell_idx in range(1, len(rps) + 1):
        cell_mask = labeled == cell_idx
        cell_center = centers[cell_idx - 1]
        diffusion_arr = np.zeros(cell_mask.shape, dtype=np.float32)
        for i in range(N_arr[cell_idx - 1]):
            diffusion_arr[cell_center] += 1.0
            diffusion_arr = convolve1d(
                convolve1d(diffusion_arr, kernel, axis=0), kernel, axis=1
            )
            diffusion_arr[~cell_mask] = 0.0

        y_grad, x_grad = np.gradient(diffusion_arr)

        norm = np.sqrt(y_grad**2 + x_grad**2)

        y_grad, x_grad = (y_grad / (norm + eps)), (x_grad / (norm + eps))
        y_grad[~cell_mask] = 0.0
        x_grad[~cell_mask] = 0.0

        y_grad_arr += y_grad
        x_grad_arr += x_grad

    return y_grad_arr, x_grad_arr


def get_two_class(labeled):
    segmentation = tr.get_segmentation(
        labeled, mode_list=["background", "mask", "border"]
    )
    weightmap = tr.get_standard_weightmap(segmentation)
    if np.any(np.isnan(segmentation)) or np.any(np.isnan(weightmap)):
        print("two_class")
        segmentation = np.zeros(labeled.shape, dtype="uint8")
        weightmap = np.ones(segmentation.shape, dtype=np.float32)
    return segmentation, weightmap


def get_one_class(labeled, W0=5.0, Wsigma=2.0):
    segmentation = tr.get_segmentation(
        labeled, mode_list=["background", "masknoborder"]
    ).astype(bool)
    weightmap = tr.get_unet_weightmap(labeled, W0=W0, Wsigma=Wsigma)
    if np.any(np.isnan(segmentation)) or np.any(np.isnan(weightmap)):
        print("one_class")
        segmentation = np.zeros(labeled.shape, dtype="uint8")
        weightmap = np.ones(segmentation.shape, dtype=np.float32)
    return segmentation, weightmap


def get_cellpose(labeled):
    segmentation = tr.get_segmentation(
        labeled, mode_list=["background", "mask"]
    ).astype(bool)
    y_grad_arr, x_grad_arr = get_flows(labeled)
    if (
        np.any(np.isnan(segmentation))
        or np.any(np.isnan(y_grad_arr))
        or np.any(np.isnan(x_grad_arr))
    ):
        print("cellpose")
        segmentation = np.zeros(labeled.shape, dtype="uint8")
        x_grad_arr = np.zeros(labeled.shape, dtype=np.float32)
        y_grad_arr = np.zeros(labeled.shape, dtype=np.float32)
    return segmentation, y_grad_arr, x_grad_arr

In [None]:
path_opts = [
    "/n/scratch3/users/d/de64/190917_20x_phase_gfp_segmentation002/190917_20x_phase_gfp_segmentation002",
    "/n/scratch3/users/d/de64/190922_20x_phase_gfp_segmentation/190922_20x_phase_gfp_segmentation",
    "/n/scratch3/users/d/de64/190925_20x_phase_yfp_segmentation/190925_20x_phase_yfp_segmentation",
    "/n/scratch3/users/d/de64/ezrdm training sb7/ezrdm training sb7",
    "/n/scratch3/users/d/de64/mbm training sb7/mbm training sb7",
    "/n/scratch3/users/d/de64/Sb7_L35/Sb7_L35",
    "/n/scratch3/users/d/de64/MM_DVCvecto_TOP_1_9/MM_DVCvecto_TOP_1_9",
    "/n/scratch3/users/d/de64/Vibrio_2_1_TOP/Vibrio_2_1_TOP",
    "/n/scratch3/users/d/de64/Vibrio_A_B_VZRDM--04--RUN_80ms/Vibrio_A_B_VZRDM--04--RUN_80ms",
    "/n/scratch3/users/d/de64/RpoSOutliers_WT_hipQ_100X/RpoSOutliers_WT_hipQ_100X",
    "/n/scratch3/users/d/de64/Main_Experiment/Main_Experiment",
    "/n/scratch3/users/d/de64/bde17_gotime/bde17_gotime",
]

In [None]:
path_idx = 2

for j in range(3):
    img_idx = j

    with h5py.File(
        path_opts[path_idx]
        + "/fluorsegmentation/segmentation_"
        + str(img_idx)
        + ".hdf5",
        "r",
    ) as segfile:
        seg_arr = segfile["data"][:, :3]
        seg_arr = seg_arr.reshape(-1, seg_arr.shape[2], seg_arr.shape[3])
    for i in range(seg_arr.shape[0]):
        segmentation, weightmap = get_two_class(seg_arr[i])
        segmentation, weightmap = get_one_class(seg_arr[i], W0=5.0, Wsigma=2.0)
#         segmentation,y_grad_arr,x_grad_arr = get_cellpose(seg_arr[i])

In [None]:
seg_arr.shape

In [None]:
for i in range(seg_arr.shape[0]):
    segmentation, weightmap = get_one_class(seg_arr[i], W0=5.0, Wsigma=2.0)

In [None]:
"/n/scratch3/users/d/de64/190922_20x_phase_gfp_segmentation/190922_20x_phase_gfp_segmentation/fluorsegmentation/segmentation_1.hdf5"

['/n/scratch3/users/d/de64/190917_20x_phase_gfp_segmentation002',
 '/n/scratch3/users/d/de64/190922_20x_phase_gfp_segmentation',
 '/n/scratch3/users/d/de64/190925_20x_phase_yfp_segmentation',
 '/n/scratch3/users/d/de64/ezrdm\\ training\\ sb7',
 '/n/scratch3/users/d/de64/mbm\\ training\\ sb7',
 '/n/scratch3/users/d/de64/Sb7_L35',
 '/n/scratch3/users/d/de64/MM_DVCvecto_TOP_1_9',
 '/n/scratch3/users/d/de64/Vibrio_2_1_TOP',
 '/n/scratch3/users/d/de64/Vibrio_A_B_VZRDM--04--RUN_80ms',
 '/n/scratch3/users/d/de64/RpoSOutliers_WT_hipQ_100X',
 '/n/scratch3/users/d/de64/Main_Experiment',
 '/n/scratch3/users/d/de64/bde17_gotime']

In [None]:
# data_loader = tr.UNet_Training_DataLoader(nndatapath="/n/scratch3/users/d/de64/2020-06-14_NN",experimentname="2020-06-14 Neural Net",\
#                            input_paths=["/n/scratch3/users/d/de64/190917_20x_phase_gfp_segmentation002/190917_20x_phase_gfp_segmentation002",\
#                                         "/n/scratch3/users/d/de64/190922_20x_phase_gfp_segmentation/190922_20x_phase_gfp_segmentation",\
#                                         "/n/scratch3/users/d/de64/190925_20x_phase_yfp_segmentation/190925_20x_phase_yfp_segmentation",\
#                                         "/n/scratch3/users/d/de64/ezrdm training sb7/ezrdm training sb7",\
#                                         "/n/scratch3/users/d/de64/mbm training sb7/mbm training sb7",\
#                                         "/n/scratch3/users/d/de64/Sb7_L35/Sb7_L35",\
#                                         "/n/scratch3/users/d/de64/MM_DVCvecto_TOP_1_9/MM_DVCvecto_TOP_1_9",\
#                                         "/n/scratch3/users/d/de64/Vibrio_2_1_TOP/Vibrio_2_1_TOP",\
#                                         "/n/scratch3/users/d/de64/Vibrio_A_B_VZRDM--04--RUN_80ms/Vibrio_A_B_VZRDM--04--RUN_80ms",\
#                                         "/n/scratch3/users/d/de64/RpoSOutliers_WT_hipQ_100X/RpoSOutliers_WT_hipQ_100X",\
#                                         "/n/scratch3/users/d/de64/Main_Experiment/Main_Experiment",\
#                                         "/n/scratch3/users/d/de64/bde17_gotime/bde17_gotime"],\
#                        )

In [None]:
data_loader = tr.UNet_Training_DataLoader(
    nndatapath="/n/scratch3/users/d/de64/2020-07-05_NN",
    experimentname="2020-07-05 Neural Net",
    input_paths=[
        "/n/scratch3/users/d/de64/ezrdm_training_sb7/ezrdm_training_sb7",
        "/n/scratch3/users/d/de64/mbm_training_sb7/mbm_training_sb7",
        "/n/scratch3/users/d/de64/Sb7_L35/Sb7_L35",
        "/n/scratch3/users/d/de64/MM_DVCvecto_TOP_1_9/MM_DVCvecto_TOP_1_9",
        "/n/scratch3/users/d/de64/Vibrio_2_1_TOP/Vibrio_2_1_TOP",
        "/n/scratch3/users/d/de64/Vibrio_A_B_VZRDM--04--RUN_80ms/Vibrio_A_B_VZRDM--04--RUN_80ms",
        "/n/scratch3/users/d/de64/RpoSOutliers_WT_hipQ_100X/RpoSOutliers_WT_hipQ_100X",
        "/n/scratch3/users/d/de64/Main_Experiment/Main_Experiment",
        "/n/scratch3/users/d/de64/bde17_gotime/bde17_gotime",
    ],
)

In [None]:
data_loader.inter_get_selection()

E. coli Ti6
2) /n/files/SysBio/PAULSSON\ LAB/Personal\ Folders/!!Jacob Quinn Shenker/190917/190917_20x_phase_gfp_segmentation002      11GB (seg: ok)
3) /n/files/SysBio/PAULSSON\ LAB/Personal\ Folders/Noah/190922/20x_segmentation_data/190922_20x_phase_gfp_segmentation    358GB (seg: ok)
4) /n/files/SysBio/PAULSSON\ LAB/Personal\ Folders/Noah/190922/20x_segmentation_data/190925_20x_phase_yfp_segmentation    ?? (seg: ok)
5) /n/files/SysBio/PAULSSON\ LAB/Personal\ Folders/Carlos/Data_Ti6/SB7_trainingdata_Unet/ezrdm\ training\ sb7.nd2         85GB (seg: ok)
6) /n/files/SysBio/PAULSSON\ LAB/Personal\ Folders/Carlos/Data_Ti6/SB7_trainingdata_Unet/mbm\ training\ sb7.nd2           85GB (seg: ok)

E. coli Ti5
7) /n/files/SysBio/PAULSSON\ LAB/Personal\ Folders/Carlos/Data_Ti5/before bl2+/SB7_trainingdata_NN_MM/Sb7_L35.nd2         66GB (seg: good)

E. coli, Ti3
11) /n/files/SysBio/PAULSSON\ LAB/SILVIA/Ti3Data/2020_01_29/RpoSOutliers_WT_hipQ_100X.nd2                                  476GB (seg: good)

In [None]:
data_loader.get_import_params()

In [None]:
daskcont = tr.trcluster.dask_controller(
    walltime="01:00:00",
    local=False,
    cores=2,
    n_workers=20,
    memory="8GB",
    working_directory="/n/scratch3/users/d/de64/2020-06-14_NN" + "/dask",
)
daskcont.startdask()

In [None]:
daskcont.daskclient

In [None]:
data_loader.export_data(daskcont, chunk_size=250)

In [None]:
daskcont.shutdown()

In [None]:
grid = tr.GridSearch("/n/scratch3/users/d/de64/2020-07-05_NN/", numepochs=100)

In [None]:
grid.display_grid()

In [None]:
grid.get_grid_params()

In [None]:
grid.run_grid_search(mem="16G", hours=10)

In [None]:
!squeue --user=de64

In [None]:
import paulssonlab.deaton.trenchripper.trenchripper as tr

nntrainer = tr.unet.UNet_Trainer(
    "/n/scratch3/users/d/de64/2020-07-05_NN/",
    0,
    "class",
    gpuon=True,
    numepochs=100,
    batch_size=10,
    layers=2,
    hidden_size=32,
    lr=0.01,
    momentum=0.9,
    weight_decay=0.0001,
    dropout=0.0,
    W0=5.0,
    Wsigma=2.0,
    warm_epochs=10,
    cool_epochs=50,
)
nntrainer.train_model()

In [None]:
from torch.utils.data import DataLoader, Dataset

nntrainer = tr.unet.UNet_Trainer(
    "/n/scratch3/users/d/de64/2020-06-14_NN//",
    0,
    "class",
    gpuon=False,
    numepochs=100,
    batch_size=25,
    layers=4,
    hidden_size=64,
    lr=0.01,
    momentum=0.9,
    weight_decay=0.0001,
    dropout=0.0,
    W0=5.0,
    Wsigma=2.0,
    warm_epochs=10,
    cool_epochs=50,
)

train_data = tr.SegmentationDataset(
    nntrainer.nndatapath + "train.hdf5",
    mode=nntrainer.mode,
    W0=nntrainer.W0,
    Wsigma=nntrainer.Wsigma,
)
train_iter = DataLoader(
    train_data,
    batch_size=nntrainer.batch_size,
    shuffle=False,
    collate_fn=tr.numpy_collate,
)

In [None]:
for item in train_iter:
    hand = tr.kymo_handle()
    hand.import_wrap(item["img"][:, 0])
    img = hand.return_unwrap()
    plt.imshow(img)
    plt.show()
    print(img.shape)

In [None]:
# diagnose and fix loss function errors
# test all three models for 1 epoch
# run 3 layer, 64 hidden state NNs with minimal augmentation (overfit)
# recompute examples (accuracy seems low...)

import paulssonlab.deaton.trenchripper.trenchripper as tr

trainer = tr.UNet_Trainer(
    "/n/scratch3/users/d/de64/2020-06-14_NN/",
    100,
    "class",
    numepochs=100,
    batch_size=100,
    gpuon=True,
    lr=0.05,
    cool_epochs=30,
    layers=2,
    hidden_size=16,
)

In [None]:
trainer.train_model()

In [None]:
%load_ext line_profiler
%lprun -f trainer.perepoch trainer.train_model()

In [None]:
test_data.chunk_ranges

In [None]:
test_data.chunk_dsets

In [None]:
train_data = tr.SegmentationDataset(
    trainer.nndatapath + "train.hdf5",
    mode=trainer.mode,
    W0=trainer.W0,
    Wsigma=trainer.Wsigma,
)
train_iter = DataLoader(
    train_data,
    batch_size=trainer.batch_size,
    shuffle=False,
    collate_fn=tr.numpy_collate,
)
for i, item in enumerate(train_iter):
    if i == 245:
        break

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt

In [None]:
img_arr, seg_arr, y_grad_arr, x_grad_arr = (
    item["img"],
    item["seg"],
    item["y_grad"],
    item["x_grad"],
)
y_grad_arr, x_grad_arr, seg_arr = y_grad_arr[:, 0], x_grad_arr[:, 0], seg_arr[:, 0]
x = torch.Tensor(img_arr.astype(float))
y = np.stack([y_grad_arr, x_grad_arr, seg_arr], axis=1)
y = torch.Tensor(y)
if trainer.gpuon:
    x = x.cuda()
    y = y.cuda()

fx = trainer.model.forward(x)
mask_pred = F.sigmoid(fx[:, 2])

mse = F.mse_loss(fx[:, :2], y[:, :2], reduction="none")  ## N,* to N,*
cross_entropy = F.binary_cross_entropy(mask_pred, y[:, 2], reduction="none")

loss = cross_entropy + 5.0 * mse[:, 0] + 5.0 * mse[:, 1]
# loss = trainer.cellpose_train(x,y)
# loss.detach().cpu().numpy()

In [None]:
np.any(np.isnan(loss.detach().cpu().numpy()))

In [None]:
(np.array([value[0] for value in test_data.dset_shapes.values()]) // 1001) + 1

In [None]:
train_data = SegmentationDataset(
    trainer.nndatapath + "train.hdf5",
    mode=trainer.mode,
    W0=trainer.W0,
    Wsigma=trainer.Wsigma,
)
test_data = SegmentationDataset(
    trainer.nndatapath + "test.hdf5",
    mode=trainer.mode,
    W0=trainer.W0,
    Wsigma=trainer.Wsigma,
)
val_data = SegmentationDataset(
    trainer.nndatapath + "val.hdf5",
    mode=trainer.mode,
    W0=trainer.W0,
    Wsigma=trainer.Wsigma,
)

train_data_size = train_data.size
test_data_size = test_data.size
val_data_size = val_data.size

for e in range(0, trainer.numepochs):
    train_iter = DataLoader(
        train_data,
        batch_size=trainer.batch_size,
        shuffle=False,
        collate_fn=tr.numpy_collate,
    )
    test_iter = DataLoader(
        test_data,
        batch_size=trainer.batch_size,
        shuffle=False,
        collate_fn=tr.numpy_collate,
    )
    val_iter = DataLoader(
        val_data,
        batch_size=trainer.batch_size,
        shuffle=False,
        collate_fn=tr.numpy_collate,
    )
    for i, b in enumerate(train_iter):
        img_arr, seg_arr, weight_arr = (b["img"], b["seg"], b["weight"])
    for i, b in enumerate(test_iter):
        img_arr, seg_arr, weight_arr = (b["img"], b["seg"], b["weight"])
    for i, b in enumerate(val_iter):
        img_arr, seg_arr, weight_arr = (b["img"], b["seg"], b["weight"])

In [None]:
test_data.chunk_dsets

In [None]:
test_data.dset_shapes

In [None]:
test_data.chunk_ranges

In [None]:
trainer.train_model()

In [None]:
class TrainingVisualizer:
    def __init__(self, trainpath, modeldbpath):
        self.trainpath = trainpath
        self.modelpath = trainpath + "/models"
        self.modeldfpath = trainpath + "/model_metadata.hdf5"
        self.modeldbpath = modeldbpath
        self.paramdbpath = modeldbpath + "/Parameters"
        self.update_dfs()
        if os.path.exists(self.modeldfpath):
            self.models_widget = qgrid.show_grid(self.model_df.sort_index())

    def update_dfs(self):
        df_idx_list = []
        for path in os.listdir(self.modelpath):
            if "training_metadata" in path:
                df_idx = int(path.split("_")[-1][:-5])
                df_idx_list.append(df_idx)
        df_list = []
        for df_idx in df_idx_list:
            dfpath = self.modelpath + "/training_metadata_" + str(df_idx) + ".hdf5"
            df_handle = pandas_hdf5_handler(dfpath)
            df = df_handle.read_df("data")
            df_list.append(copy.deepcopy(df))
            del df
        self.train_df = pd.concat(df_list)
        if os.path.exists(self.modeldfpath):
            modeldfhandle = pandas_hdf5_handler(self.modeldfpath)
            self.model_df = modeldfhandle.read_df("data").sort_index()

    def select_df_columns(self, selected_columns):
        df = copy.deepcopy(self.model_df)
        for column in df.columns.tolist():
            if column not in selected_columns:
                df = df.drop(column, 1)
        self.model_widget = qgrid.show_grid(df)

    def inter_df_columns(self):
        column_list = self.model_df.columns.tolist()
        inter = ipyw.interactive(
            self.select_df_columns,
            {"manual": True},
            selected_columns=ipyw.SelectMultiple(
                options=column_list, description="Columns to Display:", disabled=False
            ),
        )
        display(inter)

    def handle_filter_changed(self, event, widget):
        df = widget.get_changed_df().sort_index()

        all_model_indices = (
            self.train_df.index.get_level_values("Model #").unique().tolist()
        )
        current_model_indices = df.index.get_level_values("Model #").unique().tolist()

        all_epochs = []
        all_loss = []
        for model_idx in all_model_indices:
            if model_idx in current_model_indices:
                filter_df = df.loc[model_idx]
                epochs, loss = (
                    filter_df.index.get_level_values("Epoch").tolist(),
                    filter_df[self.losskey].tolist(),
                )
                all_epochs += epochs
                all_loss += loss
                self.line_dict[model_idx].set_data(epochs, loss)
                self.line_dict[model_idx].set_label(str(model_idx))
            else:
                epochs_empty, loss_empty = ([], [])
                self.line_dict[model_idx].set_data(epochs_empty, loss_empty)
                self.line_dict[model_idx].set_label("_nolegend_")

        self.ax.set_xlim(min(all_epochs), max(all_epochs) + 1)
        self.ax.set_ylim(0, max(all_loss) * 1.1)
        self.ax.legend()
        self.fig.canvas.draw()

    def inter_plot_loss(self, losskey):
        self.losskey = losskey
        self.fig, self.ax = plt.subplots()
        self.grid_widget = qgrid.show_grid(self.train_df.sort_index())
        current_df = self.grid_widget.get_changed_df()

        self.line_dict = {}
        for model_idx in current_df.index.get_level_values("Model #").unique().tolist():
            filter_df = current_df.loc[model_idx]
            epochs, loss = (
                filter_df.index.get_level_values("Epoch").tolist(),
                filter_df[losskey].tolist(),
            )
            (line,) = self.ax.plot(epochs, loss, label=str(model_idx))
            self.line_dict[model_idx] = line

        self.ax.set_xlabel("Epoch")
        self.ax.set_ylabel(losskey)
        self.ax.legend()

    def export_models(self):
        writedir(self.modeldbpath, overwrite=False)
        writedir(self.modeldbpath + "/Parameters", overwrite=False)
        modeldbhandle = pandas_hdf5_handler(self.modeldbpath + "/Models.hdf5")
        if "Models.hdf5" in os.listdir(self.modeldbpath):
            old_df = modeldbhandle.read_df("data")
            current_df = self.models_widget.get_changed_df()
            current_df = pd.concat([old_df, current_df])
        else:
            current_df = self.models_widget.get_changed_df()
        modeldbhandle.write_df("data", current_df)

        indices = current_df.index.tolist()
        exp_names = [str(item[0]) for item in indices]
        model_numbers = [str(item[1]) for item in indices]
        dates = [item.replace(" ", "_") for item in current_df["Date/Time"].tolist()]

        for i in range(len(model_numbers)):
            exp_name, model_number, date = (exp_names[i], model_numbers[i], dates[i])
            shutil.copyfile(
                self.modelpath + "/" + str(model_number) + ".pt",
                self.paramdbpath
                + "/"
                + exp_name
                + "_"
                + model_number
                + "_"
                + date
                + ".pt",
            )

In [None]:
%matplotlib widget
import os

In [None]:
vis = TrainingVisualizer(
    "/n/scratch3/users/d/de64/2020-07-05_NN", "/n/scratch3/users/d/de64/nndb"
)

In [None]:
%matplotlib widget
vis.inter_plot_loss("Val Loss")
vis.grid_widget.on("filter_changed", vis.handle_filter_changed)

In [None]:
vis.grid_widget

In [None]:
vis.inter_df_columns()

In [None]:
vis.model_widget

In [None]:
vis.model_df

In [None]:
import matplotlib
import numpy as np
from matplotlib import pyplot as plt

%matplotlib inline

plt.hist(vis.model_df["Val F1 Cell Scores"][1], bins=50)
plt.xlabel("F-Score")
plt.ylabel("Occurances")
plt.xticks(np.arange(0, 1.01, step=0.5))
plt.draw()

In [None]:
vis = tr.TrainingVisualizer("/n/scratch2/de64/nntest7", "/n/scratch2/de64/nndb")

# :
#     def __init__(self,trainpath,modeldbpath):
#         self.trainpath = trainpath
#         self.modelpath = trainpath + "/models"
#         self.modeldfpath = trainpath + "/model_metadata.hdf5"
#         self.modeldbpath = modeldbpath
#         self.paramdbpath = modeldbpath+"/Parameters"
#         self.update_dfs()
#         if os.path.exists(self.modeldfpath):
#             self.models_widget = qgrid.show_grid(self.model_df.sort_index())

In [None]:
vis.model_df

In [None]:
import copy
import datetime
import itertools
import os
import pickle as pkl
import shutil
import subprocess
import time

import h5py
import ipywidgets as ipyw
import numpy as np
import pandas as pd
import qgrid
import skimage as sk
import skimage.morphology
import sklearn as skl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from imgaug import augmenters as iaa
from imgaug.augmentables.heatmaps import HeatmapsOnImage
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from matplotlib import pyplot as plt
from scipy import ndimage
from scipy.ndimage import convolve1d
from torch._six import container_abcs, int_classes, string_classes
from torch.utils.data import DataLoader, Dataset

In [None]:
test = tr.UNet_Trainer("/n/scratch2/de64/nntest7", 100, "class", lr=1.0)

In [None]:
print(test.scheduler.get_last_lr())

In [None]:
optimizer = optim.SGD(
    test.model.parameters(),
    lr=test.lr,
    momentum=test.momentum,
    weight_decay=test.weight_decay,
)

In [None]:
warm_epochs = 10
cool_epochs = 100
warm_lambda = 1.0 / warm_epochs
cool_lambda = 1.0 / cool_epochs

In [None]:
def annealfn(epoch):
    numepochs = 500
    warm_epochs = 10
    cool_epochs = 100
    warm_lambda = 1.0 / warm_epochs
    cool_lambda = 1.0 / cool_epochs

    if epoch < warm_epochs:
        return warm_lambda * epoch
    elif epoch > (numepochs - cool_epochs):
        return max(0.0, cool_lambda * (numepochs - epoch))
    else:
        return 1.0


scheduler = torch.optim.lr_scheduler.LambdaLR(test.optimizer, lr_lambda=annealfn)

In [None]:
annealfn(460)

In [None]:
scheduler = torch.optim.lr_scheduler.LambdaLR(test.optimizer, lr_lambda=annealfn)

In [None]:
scheduler.get_last_lr()

In [None]:
for i in range(100):
    test.scheduler.step()
    print(test.scheduler.get_last_lr())

In [None]:
for e in range(0, self.numepochs):
    train_iter = DataLoader(
        train_data, batch_size=self.batch_size, shuffle=False, collate_fn=numpy_collate
    )
    test_iter = DataLoader(
        test_data, batch_size=self.batch_size, shuffle=False, collate_fn=numpy_collate
    )
    val_iter = DataLoader(
        val_data, batch_size=self.batch_size, shuffle=False, collate_fn=numpy_collate
    )
    df_out = self.perepoch(
        e,
        train_iter,
        test_iter,
        val_iter,
        train_data_size,
        test_data_size,
        val_data_size,
    )

    self.write_metadata(
        self.nndatapath
        + "/models/training_metadata_"
        + str(self.model_number)
        + ".hdf5",
        "w",
        df_out,
    )
end = time.time()
time_elapsed = (end - start) / 60.0
torch.save(
    self.model.state_dict(),
    self.nndatapath + "/models/" + str(self.model_number) + ".pt",
)

try:
    if self.mode == "class" or self.mode == "multiclass":
        val_f = self.get_class_fscore(val_iter)
        test_f = self.get_class_fscore(test_iter)
    elif self.mode == "cellpose":
        val_f = self.get_cellpose_fscore(val_iter)
        test_f = self.get_cellpose_fscore(test_iter)
except:
    print("Failed to compute F-scores")
    val_f = [np.NaN]
    test_f = [np.NaN]

In [None]:
train = tr.UNet_Trainer(
    "/n/scratch3/users/d/de64/2020-06-14_NN",
    100,
    "multiclass",
    numepochs=1,
    batch_size=50,
    layers=3,
    hidden_size=32,
    lr=0.2,
    gpuon=False,
)

In [None]:
matplotlib.rcParams["figure.figsize"] = [20, 10]

In [None]:
import torch

device = torch.device("cpu")
train.model.load_state_dict(
    torch.load(
        "/n/scratch3/users/d/de64/2020-06-14_NN/models/4.pt", map_location=device
    )
)
train.model.to(device)

In [None]:
from torch.utils.data import DataLoader, Dataset

In [None]:
test_data = tr.SegmentationDataset(
    train.nndatapath + "/test.hdf5", mode=train.mode, W0=train.W0, Wsigma=train.Wsigma
)
test_iter = DataLoader(
    test_data, batch_size=train.batch_size, shuffle=False, collate_fn=tr.numpy_collate
)

In [None]:
def process_pred(pred, thr, border_buffer=2):
    output = []
    for i in range(pred.shape[0]):
        out_pred = pred[i, 1] > thr
        out_pred = sk.segmentation.clear_border(out_pred, buffer_size=border_buffer)
        output.append(out_pred)
    output = np.array(output)
    return output

In [None]:
for item in test_iter:
    img = item["img"]
    x = torch.Tensor(img.astype(float))
    pred = train.model.forward(x).data.numpy()
    proc_pred = process_pred(pred, 0.5, border_buffer=0)
    break

In [None]:
handle = tr.kymo_handle()
handle.import_wrap(img[:, 0])
imgs = handle.return_unwrap()
plt.imshow(imgs)
plt.show()

In [None]:
handle = tr.kymo_handle()
handle.import_wrap(pred[:, 1])
preds = handle.return_unwrap()
plt.imshow(preds)
plt.show()

In [None]:
handle = tr.kymo_handle()
handle.import_wrap(pred[:, 2])
preds = handle.return_unwrap()
plt.imshow(preds)
plt.show()

In [None]:
handle = tr.kymo_handle()
handle.import_wrap((pred[:, 2] < 0.3) * (pred[:, 1] > 0.5))
preds = handle.return_unwrap()
plt.imshow(preds)
plt.show()

In [None]:
handle = tr.kymo_handle()
handle.import_wrap(proc_pred)
proc_preds = handle.return_unwrap()
plt.imshow(proc_preds)
plt.show()

In [None]:
scaled_imgs = (imgs - np.min(imgs)) / (np.max(imgs) - np.min(imgs))
scaled_preds = (preds - np.min(preds)) / (np.max(preds) - np.min(preds))

In [None]:
plt.imshow(scaled_imgs)
plt.show()
plt.imshow(scaled_preds > 0.7)
plt.show()
plt.imshow(sk.segmentation.clear_border(scaled_preds > 0.7, buffer_size=2))
plt.show()

In [None]:
test_f = train.get_cellpose_fscore(test_iter)

In [None]:
%matplotlib inline
plt.hist(test_f, range=(0, 1), bins=50)
plt.show()