In [None]:
import cv2
from sklearn.model_selection import train_test_split
import numpy as np
import tensorflow as tf
import tensorflow.contrib.eager as tfe
import argparse
import glob
from generator import multi_plot, preprocess_labels, generator
from model import DilatedCNN

tfe.enable_eager_execution()

In [None]:
images = []
masks = []

def load_files():
    for file in glob.glob("./Train/CameraRGB/*.png"):
        img = cv2.imread(file)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.normalize(img.astype('float'), None, 0.0, 1.0, cv2.NORM_MINMAX)
        images.append(img)
    for file in glob.glob("./Train/CameraSeg/*.png"):
        img = cv2.imread(file)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = preprocess_labels(img)
        masks.append(img)

load_files()

In [None]:
x_train, x_test, y_train, y_test = train_test_split(images, masks, test_size=0.2)
x_train = np.stack(x_train)
x_test = np.stack(x_test)
y_train = np.stack(y_train)
y_test = np.stack(y_test)

In [None]:
np.save("x_train", x_train)
np.save("x_test", x_test)
np.save("y_train", y_train)
np.save("y_test", y_test)

In [None]:
x_train = np.load("x_train.npy")
x_test = np.load("x_test.npy")
y_train = np.load("y_train.npy")
y_test = np.load("y_test.npy")

In [None]:
model = DilatedCNN()
# model.load()

In [None]:
optimizer = tf.train.AdamOptimizer(learning_rate=0.01)

In [None]:
batch_size = 16
from replay_memory import PrioritisedReplayMemory
memory = PrioritisedReplayMemory(capacity=batch_size*10)

In [None]:
gen = generator(x_train, y_train, memory, batch_size=batch_size)

In [None]:
x, y, indices = next(gen)
print(x.shape, y.shape)
for index in range(0, 2):
    multi_plot([x[index], y[index][:, :, 0], y[index][:, :, 1]])

In [None]:
for i in range(200):
    images, masks, indices = next(gen)
    x = tf.constant(images, dtype=tf.float32)
    y = tf.constant(masks, dtype=tf.float32)
    loss = model.train(x, y, optimizer)
    loss = tf.squeeze(loss)
    memory.update(indices, loss.numpy())
    
    if i % 10 == 0:
        print("run {} loss: {}".format(i, loss.numpy().mean()))

In [None]:
%time y_hat = model(x)
index = 4
multi_plot([images[index], y_hat[index][:, :, 0], y_hat[index][:, :, 1]])
# multi_plot([images[index], masks[index][:, :, 0], masks[index][:, :, 1]])