In [None]:
import copy
import os
import shutil

import ipywidgets as ipyw
import pandas as pd
import qgrid
from matplotlib import pyplot as plt

import paulssonlab.deaton.trenchripper.trenchripper as tr
from paulssonlab.deaton.trenchripper.trenchripper.utils import pandas_hdf5_handler

In [None]:
tr.transferjob("", "/n/scratch3/users/d/de64")

In [None]:
sourcedir = "/n/files/SysBio/PAULSSON\ LAB/Personal\ Folders/Noah/190922/20x_segmentation_data/190925_20x_phase_yfp_segmentation"
targetdir = "/n/scratch3/users/d/de64/20X_seg_yfp"
tr.trcluster.transferjob(sourcedir, targetdir, single_file=False)

In [None]:
!squeue --user=de64

In [None]:
grid = tr.GridSearch("/n/scratch2/de64/nntest7", numepochs=100)

In [None]:
grid.display_grid()

In [None]:
grid.get_grid_params()

In [None]:
grid.run_grid_search(mem="16G")

In [None]:
!squeue --user=de64

In [None]:
class TrainingVisualizer:
    def __init__(self, trainpath, modeldbpath):
        self.trainpath = trainpath
        self.modelpath = trainpath + "/models"
        self.modeldfpath = trainpath + "/model_metadata.hdf5"
        self.modeldbpath = modeldbpath
        self.paramdbpath = modeldbpath + "/Parameters"
        self.update_dfs()
        if os.path.exists(self.modeldfpath):
            self.models_widget = qgrid.show_grid(self.model_df.sort_index())

    def update_dfs(self):
        df_idx_list = []
        for path in os.listdir(self.modelpath):
            if "training_metadata" in path:
                df_idx = int(path.split("_")[-1][:-5])
                df_idx_list.append(df_idx)
        df_list = []
        for df_idx in df_idx_list:
            dfpath = self.modelpath + "/training_metadata_" + str(df_idx) + ".hdf5"
            df_handle = pandas_hdf5_handler(dfpath)
            df = df_handle.read_df("data")
            df_list.append(copy.deepcopy(df))
            del df
        self.train_df = pd.concat(df_list)
        if os.path.exists(self.modeldfpath):
            modeldfhandle = pandas_hdf5_handler(self.modeldfpath)
            self.model_df = modeldfhandle.read_df("data").sort_index()

    def select_df_columns(self, selected_columns):
        df = copy.deepcopy(self.model_df)
        for column in df.columns.tolist():
            if column not in selected_columns:
                df = df.drop(column, 1)
        self.model_widget = qgrid.show_grid(df)

    def inter_df_columns(self):
        column_list = self.model_df.columns.tolist()
        inter = ipyw.interactive(
            self.select_df_columns,
            {"manual": True},
            selected_columns=ipyw.SelectMultiple(
                options=column_list, description="Columns to Display:", disabled=False
            ),
        )
        display(inter)

    def handle_filter_changed(self, event, widget):
        df = widget.get_changed_df().sort_index()

        all_model_indices = (
            self.train_df.index.get_level_values("Model #").unique().tolist()
        )
        current_model_indices = df.index.get_level_values("Model #").unique().tolist()

        all_epochs = []
        all_loss = []
        for model_idx in all_model_indices:
            if model_idx in current_model_indices:
                filter_df = df.loc[model_idx]
                epochs, loss = (
                    filter_df.index.get_level_values("Epoch").tolist(),
                    filter_df[self.losskey].tolist(),
                )
                all_epochs += epochs
                all_loss += loss
                self.line_dict[model_idx].set_data(epochs, loss)
                self.line_dict[model_idx].set_label(str(model_idx))
            else:
                epochs_empty, loss_empty = ([], [])
                self.line_dict[model_idx].set_data(epochs_empty, loss_empty)
                self.line_dict[model_idx].set_label("_nolegend_")

        self.ax.set_xlim(min(all_epochs), max(all_epochs) + 1)
        self.ax.set_ylim(0, max(all_loss) * 1.1)
        self.ax.legend()
        self.fig.canvas.draw()

    def inter_plot_loss(self, losskey):
        self.losskey = losskey
        self.fig, self.ax = plt.subplots()
        self.grid_widget = qgrid.show_grid(self.train_df.sort_index())
        current_df = self.grid_widget.get_changed_df()

        self.line_dict = {}
        for model_idx in current_df.index.get_level_values("Model #").unique().tolist():
            filter_df = current_df.loc[model_idx]
            epochs, loss = (
                filter_df.index.get_level_values("Epoch").tolist(),
                filter_df[losskey].tolist(),
            )
            (line,) = self.ax.plot(epochs, loss, label=str(model_idx))
            self.line_dict[model_idx] = line

        self.ax.set_xlabel("Epoch")
        self.ax.set_ylabel(losskey)
        self.ax.legend()

    def export_models(self):
        writedir(self.modeldbpath, overwrite=False)
        writedir(self.modeldbpath + "/Parameters", overwrite=False)
        modeldbhandle = pandas_hdf5_handler(self.modeldbpath + "/Models.hdf5")
        if "Models.hdf5" in os.listdir(self.modeldbpath):
            old_df = modeldbhandle.read_df("data")
            current_df = self.models_widget.get_changed_df()
            current_df = pd.concat([old_df, current_df])
        else:
            current_df = self.models_widget.get_changed_df()
        modeldbhandle.write_df("data", current_df)

        indices = current_df.index.tolist()
        exp_names = [str(item[0]) for item in indices]
        model_numbers = [str(item[1]) for item in indices]
        dates = [item.replace(" ", "_") for item in current_df["Date/Time"].tolist()]

        for i in range(len(model_numbers)):
            exp_name, model_number, date = (exp_names[i], model_numbers[i], dates[i])
            shutil.copyfile(
                self.modelpath + "/" + str(model_number) + ".pt",
                self.paramdbpath
                + "/"
                + exp_name
                + "_"
                + model_number
                + "_"
                + date
                + ".pt",
            )

In [None]:
%matplotlib widget

In [None]:
vis = TrainingVisualizer("/n/scratch2/de64/nntest7", "/n/scratch2/de64/nndb")

In [None]:
%matplotlib widget
vis.inter_plot_loss("Val Loss")
vis.grid_widget.on("filter_changed", vis.handle_filter_changed)

In [None]:
vis.grid_widget

In [None]:
vis.inter_df_columns()

In [None]:
vis.model_widget

In [None]:
vis.model_df

In [None]:
import matplotlib
import numpy as np
from matplotlib import pyplot as plt

%matplotlib inline

plt.hist(vis.model_df["Val F1 Cell Scores"][2], bins=50)
plt.xlabel("F-Score")
plt.ylabel("Occurances")
plt.xticks(np.arange(0, 1.01, step=0.5))
plt.draw()

In [None]:
import matplotlib
from matplotlib import pyplot as plt

%matplotlib inline

plt.hist(vis.model_df["Val F1 Cell Scores"][1], bins=50)
plt.xlabel("F-Score")
plt.ylabel("Occurances")
plt.xticks(np.arange(0, 1.01, step=0.5))
plt.draw()

In [None]:
import matplotlib
from matplotlib import pyplot as plt

%matplotlib inline

plt.hist(vis.model_df["Val F1 Cell Scores"][2], bins=50)
plt.xlabel("F-Score")
plt.ylabel("Occurances")
plt.xticks(np.arange(0, 1.01, step=0.5))
plt.draw()

In [None]:
vis = tr.TrainingVisualizer("/n/scratch2/de64/nntest7", "/n/scratch2/de64/nndb")

# :
#     def __init__(self,trainpath,modeldbpath):
#         self.trainpath = trainpath
#         self.modelpath = trainpath + "/models"
#         self.modeldfpath = trainpath + "/model_metadata.hdf5"
#         self.modeldbpath = modeldbpath
#         self.paramdbpath = modeldbpath+"/Parameters"
#         self.update_dfs()
#         if os.path.exists(self.modeldfpath):
#             self.models_widget = qgrid.show_grid(self.model_df.sort_index())

In [None]:
vis.model_df

In [None]:
import copy
import datetime
import itertools
import os
import pickle as pkl
import shutil
import subprocess
import time

import h5py
import ipywidgets as ipyw
import numpy as np
import pandas as pd
import qgrid
import skimage as sk
import skimage.morphology
import sklearn as skl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from imgaug import augmenters as iaa
from imgaug.augmentables.heatmaps import HeatmapsOnImage
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from matplotlib import pyplot as plt
from scipy import ndimage
from scipy.ndimage import convolve1d
from torch._six import container_abcs, int_classes, string_classes
from torch.utils.data import DataLoader, Dataset

In [None]:
test = tr.UNet_Trainer("/n/scratch2/de64/nntest7", 100, "class", lr=1.0)

In [None]:
print(test.scheduler.get_last_lr())

In [None]:
optimizer = optim.SGD(
    test.model.parameters(),
    lr=test.lr,
    momentum=test.momentum,
    weight_decay=test.weight_decay,
)

In [None]:
warm_epochs = 10
cool_epochs = 100
warm_lambda = 1.0 / warm_epochs
cool_lambda = 1.0 / cool_epochs

In [None]:
def annealfn(epoch):
    numepochs = 500
    warm_epochs = 10
    cool_epochs = 100
    warm_lambda = 1.0 / warm_epochs
    cool_lambda = 1.0 / cool_epochs

    if epoch < warm_epochs:
        return warm_lambda * epoch
    elif epoch > (numepochs - cool_epochs):
        return max(0.0, cool_lambda * (numepochs - epoch))
    else:
        return 1.0


scheduler = torch.optim.lr_scheduler.LambdaLR(test.optimizer, lr_lambda=annealfn)

In [None]:
annealfn(460)

In [None]:
scheduler = torch.optim.lr_scheduler.LambdaLR(test.optimizer, lr_lambda=annealfn)

In [None]:
scheduler.get_last_lr()

In [None]:
for i in range(100):
    test.scheduler.step()
    print(test.scheduler.get_last_lr())

In [None]:
for e in range(0, self.numepochs):
    train_iter = DataLoader(
        train_data, batch_size=self.batch_size, shuffle=False, collate_fn=numpy_collate
    )
    test_iter = DataLoader(
        test_data, batch_size=self.batch_size, shuffle=False, collate_fn=numpy_collate
    )
    val_iter = DataLoader(
        val_data, batch_size=self.batch_size, shuffle=False, collate_fn=numpy_collate
    )
    df_out = self.perepoch(
        e,
        train_iter,
        test_iter,
        val_iter,
        train_data_size,
        test_data_size,
        val_data_size,
    )

    self.write_metadata(
        self.nndatapath
        + "/models/training_metadata_"
        + str(self.model_number)
        + ".hdf5",
        "w",
        df_out,
    )
end = time.time()
time_elapsed = (end - start) / 60.0
torch.save(
    self.model.state_dict(),
    self.nndatapath + "/models/" + str(self.model_number) + ".pt",
)

try:
    if self.mode == "class" or self.mode == "multiclass":
        val_f = self.get_class_fscore(val_iter)
        test_f = self.get_class_fscore(test_iter)
    elif self.mode == "cellpose":
        val_f = self.get_cellpose_fscore(val_iter)
        test_f = self.get_cellpose_fscore(test_iter)
except:
    print("Failed to compute F-scores")
    val_f = [np.NaN]
    test_f = [np.NaN]

In [None]:
train = tr.UNet_Trainer(
    "/n/scratch2/de64/nntest7/",
    100,
    "cellpose",
    numepochs=1,
    batch_size=50,
    layers=4,
    hidden_size=64,
    lr=0.2,
    gpuon=True,
)

In [None]:
import torch

device = torch.device("cuda")
train.model.load_state_dict(torch.load("/n/scratch2/de64/nntest7/models/13.pt"))
train.model.to(device)

In [None]:
from torch.utils.data import DataLoader, Dataset

In [None]:
test_data = tr.SegmentationDataset(
    train.nndatapath + "test.hdf5", mode=train.mode, W0=train.W0, Wsigma=train.Wsigma
)
test_iter = DataLoader(
    test_data, batch_size=train.batch_size, shuffle=False, collate_fn=tr.numpy_collate
)

In [None]:
test_f = train.get_cellpose_fscore(test_iter)

In [None]:
%matplotlib inline
plt.hist(test_f, range=(0, 1), bins=50)
plt.show()