In [None]:
import shutil 
from __future__ import print_function, unicode_literals, absolute_import, division
import sys
import numpy as np
import matplotlib
matplotlib.rcParams["image.interpolation"] = 'none'
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from glob import glob
from tqdm import tqdm
from tifffile import imread
from csbdeep.utils import Path, normalize

from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available
from stardist.matching import matching, matching_dataset
from stardist.models import Config2D, StarDist2D, StarDistData2D

np.random.seed(42)
lbl_cmap = random_label_cmap()

In [None]:
# name of your model
name = '2D_versatile_fluo_Rings_V2'

# load pretrained model and make a copy to local folder
model_pretrained = StarDist2D(None, name='2D_versatile_fluo', basedir='models')
shutil.copytree(model_pretrained.logdir, name)

# load your duplicate of the pretrained model
model = StarDist2D(None, name)


In [None]:
#load in training data
X = sorted(glob('E:/Burton_Lab/Umanshi/OIC-181_Mitochondria_Morphology/OIC-181_Mitochondrial_Analysis/Training_Data/Images/Patches/*.tif'))
Y = sorted(glob('E:/Burton_Lab/Umanshi/OIC-181_Mitochondria_Morphology/OIC-181_Mitochondrial_Analysis/Training_Data/Masks/Patches/*.tif'))
assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))

In [None]:
len(X)

In [None]:
X = list(map(imread,X))
Y = list(map(imread,Y))
n_channel = 1

In [None]:
axis_norm = (0,1)   # normalize channels independently
# axis_norm = (0,1,2) # normalize channels jointly
if n_channel > 1:
    print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))
    sys.stdout.flush()

X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]
Y = [fill_label_holes(y) for y in tqdm(Y)]

In [None]:
assert len(X) > 1, "not enough training data"
rng = np.random.RandomState(42)
ind = rng.permutation(len(X))
n_val = max(1, int(round(0.15 * len(ind))))
ind_train, ind_val = ind[:-n_val], ind[-n_val:]
X_val, Y_val = [X[i] for i in ind_val]  , [Y[i] for i in ind_val]
X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] 
print('number of images: %3d' % len(X))
print('- training:       %3d' % len(X_trn))
print('- validation:     %3d' % len(X_val))

In [None]:
def random_fliprot(img, mask): 
    assert img.ndim >= mask.ndim
    axes = tuple(range(mask.ndim))
    perm = tuple(np.random.permutation(axes))
    img = img.transpose(perm + tuple(range(mask.ndim, img.ndim))) 
    mask = mask.transpose(perm) 
    for ax in axes: 
        if np.random.rand() > 0.5:
            img = np.flip(img, axis=ax)
            mask = np.flip(mask, axis=ax)
    return img, mask 

def random_intensity_change(img):
    img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)
    return img


def augmenter(x, y):
    """Augmentation of a single input/label image pair.
    x is an input image
    y is the corresponding ground-truth label image
    """
    x, y = random_fliprot(x, y)
    x = random_intensity_change(x)
    # add some gaussian noise
    sig = 0.02*np.random.uniform(0,1)
    x = x + sig*np.random.normal(0,1,x.shape)
    return x, y

In [None]:
#configs to change for training, if needed
model.config.train_patch_size = (128,128)
model.config.train_batch_size = 16 
model.config.train_learning_rate = 1e-5
model.config.train_epochs = 200

# finetune on new data
model.train(X_trn,Y_trn, validation_data=(X_val,Y_val),augmenter=augmenter)

In [None]:
model.optimize_thresholds(X_val, Y_val)

In [None]:
Y_val_pred = [my_model.predict_instances(x, n_tiles=my_model._guess_n_tiles(x), show_tile_progress=False)[0]
              for x in tqdm(X_val)]

In [None]:
import napari

In [None]:
viewer = napari.view_image(X_val[2])
viewer.add_labels(Y_val[2])

In [None]:
full_imgs = sorted(glob("E:/Burton_Lab/Umanshi/OIC-181_Mitochondria_Morphology/OIC-181_Mitochondrial_Analysis/Training_Data/Images/*.tiff"))
full_imgs = list(map(imread,full_imgs))

In [None]:
full_imgs = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(full_imgs)]

In [None]:
preds = [my_model.predict_instances(x, n_tiles=my_model._guess_n_tiles(x), show_tile_progress=False)[0]
              for x in tqdm(full_imgs)]

In [None]:
viewer = napari.view_image(full_imgs[5])
viewer.add_labels(preds[5])

In [None]:
save_path = 'E:/Burton_Lab/Umanshi/OIC-181_Mitochondria_Morphology/OIC-181_Mitochondrial_Analysis/Training_Data/Transfer_Learning_Results/'

In [None]:
import skimage as sk
import os

In [None]:
c = 0
for pred in preds:
    sk.io.imsave(os.path.join(save_path,'Transfer_learning_img_0'+str(c)+'.tif'),pred)
    sk.io.imsave(os.path.join(save_path,'img_0'+str(c)+'.tif'),full_imgs[c])
    c += 1