In [None]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from skimage.filters import unsharp_mask
from skimage.io import imread
from skimage.morphology import remove_small_objects, remove_small_holes, dilation
from skimage.segmentation import slic, mark_boundaries, find_boundaries
from skimage.transform import rescale

from src.train_predict import train_sp_wnet, predict_sp_wnet
from src.utils import im_crop_multiple, sample_patch_datasets
from src.wnet import WNet

In [None]:
rc = {"axes.spines.left": False,
      "axes.spines.right": False,
      "axes.spines.bottom": False,
      "axes.spines.top": False,
      "xtick.bottom": False,
      "xtick.labelbottom": False,
      "ytick.labelleft": False,
      "ytick.left": False,
      "figure.dpi": 200}
plt.rcParams.update(rc)

# Data

In [None]:
def read_image(fname, cut=0, scale=1, unit=1, unsharp_par=None):
    img = imread(fname)
    if cut > 0:
        img = img[cut:-cut, cut:-cut]
    img = rescale(img, scale, multichannel=True)
    img = im_crop_multiple(img, unit)
    if unsharp_par is not None:
        img = unsharp_mask(img, radius=unsharp_par[0], amount=unsharp_par[1])
    return img.astype('float32')


def read_image_nn(fname):
    return read_image(fname, cut=10, scale=1/3, unit=16, 
                      unsharp_par=(5, 1))


# select image for training
folder = f'data/videos/cheetah'
iids = [10, 123, 270, 370]
image_in_list = []
for iid in iids:
    image_in = read_image_nn(f'{folder}/clips/{iid:04d}.jpg')
    image_in_list.append(image_in)
    print('Shape of input image:', image_in.shape)
    plt.imshow(image_in)
    plt.show()

In [None]:
n_batches = 8
side_sizes = [64, 96]
count_per_batch_for_max_size = 4
total_pixels_per_batch = max(side_sizes) ** 2 * count_per_batch_for_max_size
shapes_and_counts = []
for sx in side_sizes:
    for sy in side_sizes:
        if sy * 2 >= sx >= sy // 2:
            shapes_and_counts.append(
                (sx, sy, total_pixels_per_batch // (sx * sy)))

image_tensor_list = []
image_patches_list = []
for i, image_in in enumerate(image_in_list):
    image_patches = sample_patch_datasets(
        image_in, shapes_and_counts=shapes_and_counts,
        n_batches=n_batches, seed=i, returns_image_tensor=False)
    image_tensor_list.append(torch.from_numpy(np.moveaxis(image_in, -1, -3)))
    image_patches_list.append(image_patches)

In [None]:
# plot patches
for image_tensor, image_patches in zip(image_tensor_list, image_patches_list):
    fig, ax = plt.subplots(1, len(image_patches) + 1, dpi=400)
    for i, (dx, dy, count) in enumerate(shapes_and_counts):
        idx = np.random.choice(len(image_patches[i]))
        ax[i].imshow(image_patches[i][idx].permute(1, 2, 0))
        ax[i].set_title(f'[{dx} x {dy}], {count} imgs/batch', fontsize=3)
    ax[-1].imshow(image_tensor.permute(1, 2, 0))
    ax[-1].set_title(f'original', fontsize=3)
    plt.show()

# Superpixel

In [None]:
def make_sp(img):
    return slic(img.astype(float), n_segments=12000,
                compactness=10, sigma=1, start_label=0)

image_sp_list = []
for image_in in image_in_list:
    image_sp = make_sp(image_in)
    image_sp_list.append(image_sp)
    plt.figure(dpi=200)
    plt.imshow(mark_boundaries(image_in, image_sp))
    plt.title(f'superpixels, N={len(np.unique(image_sp))}', fontsize=5)
    plt.show()

# Training


In [None]:
re_train = True
if re_train:
    # clear results
    os.system(f'rm -rf {folder}/results/*')
    Path(f'{folder}/results').mkdir(exist_ok=True)
    # create wnet
    seed = 84
    torch.manual_seed(seed)
    wnet = WNet(3, 32, ch_mul=64, n_blocks=4)
    # train
    hist = train_sp_wnet(
        wnet,
        # image data
        image_tensor_list, image_patches_list=image_patches_list, n_batches=n_batches,
        # superpixel
        image_sp_list=image_sp_list, sp_seg_mode='argmax_of_mean',
        tau_cut='unused', tau_sim_kmeans='unused', k_kmeans='unused', tau_con=1.,
        use_sparse_adj=True,
        # beta values
        beta_rc_image=.5, beta_rc_patches=1.,
        beta_cut=None, beta_sim=.1, beta_con=.1,
        # results after each epoch
        plot_epoch=True, save_epoch_results_to='screen',
        # others
        epochs=15, lr=0.001, device='cuda', progress_bar=True)
    # save weights
    torch.save(wnet.state_dict(), f'{folder}/results/wnet.pt')

# Predict

In [None]:
def predict(wnet_model, iid_pred, merge_list=None):
    # predict
    img_pred = read_image_nn(f'{folder}/clips/{iid_pred:04d}.jpg')
    img_pred_tensor = torch.from_numpy(np.moveaxis(img_pred, -1, -3))
    img_sp_pred = make_sp(img_pred)
    ft_img_list, rc_img_list, label_img_list = predict_sp_wnet(
        wnet_model, [img_pred_tensor], image_sp_list=[img_sp_pred],
        sp_seg_mode='argmax_of_mean', device='cuda',
        make_label_continuous=True, returns_in_numpy=True)

    # post-processing on labels
    img_label = label_img_list[0]
    if merge_list is not None:
        for merge in merge_list:
            for idx_merge in merge:
                img_label[img_label == idx_merge] = min(merge)
    return img_pred, img_label

In [None]:
# show some predictions
iids = [9, 90, 233, 252]
fig, ax = plt.subplots(1, len(iids), dpi=200, figsize=(10, 2))
wnet = WNet(3, 32, ch_mul=64, n_blocks=4)
wnet.load_state_dict(torch.load(f'{folder}/results/wnet.pt'))
for j, iid in enumerate(iids):
    image_pred, label_pred = predict(wnet, iid)
    ax[j].imshow(label_pred, cmap='rainbow')
    for lab in np.unique(label_pred):
        x, y = np.where(label_pred == lab)
        x_mean, y_mean = np.mean(x), np.mean(y)
        ax[j].text(y_mean, x_mean, str(lab))
plt.show()

In [None]:
Path(f'{folder}/results/label').mkdir(exist_ok=True)
Path(f'{folder}/results/mark').mkdir(exist_ok=True)

# load model
wnet = WNet(3, 32, ch_mul=64, n_blocks=4)
wnet.load_state_dict(torch.load(f'{folder}/results/wnet.pt'))

# labels to be merged for background
merges = [[0, 1, 3]]
bkg_label = merges[0][0]
obj_label = 2
sky_label = 4

for iid in range(1, 403):
    print(iid)
    # label images
    image_pred, label_pred = predict(wnet, iid, merge_list=merges)
    # remove small holes in background
    b = remove_small_holes(label_pred == bkg_label, 100)
    label_pred[b] = bkg_label
    # change small sky patches to object
    b = remove_small_objects(label_pred == sky_label, 10000)
    label_pred[label_pred == sky_label - b] = obj_label
    plt.imshow(label_pred, cmap='rainbow')
    plt.savefig(f'{folder}/results/label/{iid:04d}.jpg', 
                bbox_inches='tight', pad_inches=0)
    plt.close()
    
    # boundary images
    # upscale label image to original size
    size_before_round = read_image(f'{folder}/clips/{iid:04}.jpg', 
                                   cut=10, scale=1/3, unit=1, unsharp_par=None).shape
    label_up = np.full((size_before_round[0], size_before_round[1]), 
                       fill_value=bkg_label, dtype=int)
    start0 = (label_up.shape[0] - label_pred.shape[0]) // 2
    start1 = (label_up.shape[1] - label_pred.shape[1]) // 2
    label_up[start0:start0 + label_pred.shape[0], 
             start1:start1 + label_pred.shape[1]] = label_pred
    label_up = rescale(label_up, 3)
    # find boundaries and make them thicker
    boundaries = find_boundaries(label_up, mode='thick')
    # plot
    image_show = read_image(f'{folder}/clips/{iid:04d}.jpg', 
                            cut=10, scale=1, unit=1, 
                            unsharp_par=None)[:label_up.shape[0], :label_up.shape[1]]
    image_show[boundaries] = [0, 0, 1]  # blue
    image_show = im_crop_multiple(image_show, 64)[4:-4, 4:-4]
    plt.imshow(image_show)
    plt.show()
    plt.imsave(f'{folder}/results/mark/{iid:04d}.jpg', image_show)

In [None]:
os.system(f'ffmpeg -start_number 1 -y -i '
          f'{folder}/results/label/%04d.jpg '
          f'-vf pad="width=ceil(iw/2)*2:height=ceil(ih/2)*2" '
          f'{folder}/results/label.mp4')
os.system(f'ffmpeg  -start_number 1 -y -i '
          f'{folder}/results/mark/%04d.jpg '
          f'-vf pad="width=ceil(iw/2)*2:height=ceil(ih/2)*2" '
          f'{folder}/results/mark.mp4')