In [None]:
import os
import gc
import cv2
import sys
import tqdm
import random
import numpy as np
from itertools import product
from skimage.io import imsave
from multiprocessing import Pool

os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = '2'

from keras.models import model_from_json, load_model
from segnet import SegnetBuilder

In [None]:
# \ Globals ============
PIC_H, PIC_W = 192, 192
LABELS_NUMBER = 12
PROCESSES = 30

In [None]:
_jfname = 'Segnet 073.json'
_wfname = 'Segnet 044 s[1]b[0_16] best.hdf5'

# segnet = SegnetBuilder.build('aaa', PIC_W, PIC_H, 3, LABELS_NUMBER)
# segnet.compile(loss='categorical_crossentropy', optimizer='Adam', metrics=['accuracy'])

segnet = load_model(f'models/{_wfname}')

# segnet.load_weights(f'models/{_wfname}')

In [None]:
def get_filenames(datapath='lane_marking_examples'):
    filenames = []
    for top, dirs, files in os.walk(datapath):
        filenames.extend([os.path.join(top, _file) for _file in files])
    filenames.sort()

    x_paths = [x for x in filenames if not x.endswith('bin.png')]
    y_paths = [x for x in filenames if x.endswith('bin.png')]

    return x_paths, y_paths

def predict_to_label(predictions):
    labels = {
        0:  (0, 0, 0),
        1:  (8, 35, 142),
        2:  (43, 173, 180),
        3:  (153, 102, 153),
        4:  (234, 168, 160),
        5:  (192, 0, 0),
        6:  (8, 32, 128),
        7:  (12, 51, 204),
        8:  (70, 25, 100),
        9:  (14, 57, 230),
        10: (75, 47, 190),
        11: (255, 255, 255)}

    # masks = np.zeros((len(predictions), PIC_H, PIC_W, 3))
    masks = []
    for ind, pred in enumerate(predictions):
        pred = pred.reshape(PIC_H, PIC_W, 12)
        pred = np.apply_along_axis(lambda x: np.argmax(x), axis=2, arr=pred)
        h_pred = np.zeros((PIC_H, PIC_W, 3), dtype=np.uint8)
        for argmax, rgb in labels.items():
            h_pred[pred == argmax] = rgb

        masks.append(h_pred)

    return masks

In [None]:
def load_one(pathes):
    try:
        xfilepath, yfilepath = pathes

        pic_h = 192
        pic_w = 192

        # Start: Labels conversion dict ===========================
        labels = {
            (0, 0, 0):       [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            (8, 35, 142):    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            (43, 173, 180):  [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            (153, 102, 153): [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
            (234, 168, 160): [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
            (192, 0, 0):     [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
            (8, 32, 128):    [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
            (12, 51, 204):   [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
            (70, 25, 100):   [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
            (14, 57, 230):   [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
            (75, 47, 190):   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
            (255, 255, 255): [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]}
        labels = dict([(k, np.array(v)) for k, v in labels.items()])
        # End:   Labels conversion dict ===========================

        img = cv2.resize(cv2.imread(xfilepath), (pic_w, pic_h), interpolation=cv2.INTER_NEAREST)
        lbl = cv2.resize(cv2.imread(yfilepath), (pic_w, pic_h), interpolation=cv2.INTER_NEAREST)

        mask = np.zeros((pic_h, pic_w, len(labels)))
        for rgb, categorical_lbl in labels.items():
            mask[(lbl == rgb).all(2)] = categorical_lbl

        # print('Loaded', xfilepath, end='\r')
        return img, mask
    except Exception as e:
        print('!!! Exception', e)
        return None, None

In [None]:
x_paths, y_paths = get_filenames()

In [None]:
pred_index = 210

img, mask = load_one([x_paths[pred_index], y_paths[pred_index]])

print('Making predictions...')
predictions = segnet.predict(np.expand_dims(img, axis=0))

print('Converting pred to human mask...')
human_pred = predict_to_label(predictions)[0]
print('Converting mask to human mask...')
human_mask = predict_to_label([mask])[0]

print('Saving result...')
imsave('img.png', np.concatenate((human_mask, human_pred), axis=0))
print('Goto-vo')

In [None]:
pred = np.apply_along_axis(lambda x: np.argmax(x), axis=2, arr=predictions[0]) # .reshape(PIC_H, PIC_W)
pred_true = np.apply_along_axis(lambda x: np.argmax(x), axis=2, arr=mask)   # .reshape(PIC_H, PIC_W)
(pred == pred_true).sum() / (pred.shape[0] * pred.shape[1])