In [None]:
# imports
import numpy as np
from skimage import io
import skimage.transform
import os
from tqdm import tqdm

import matplotlib.pyplot as plt

from data_utils import *

from keras_fcn import FCN

%load_ext autoreload
%autoreload 2

from keras import optimizers

In [None]:
from tensorflow.python.client import device_lib
# print(device_lib.list_local_devices())

import os
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [None]:
DATASET_DIR = '/mnt/82db778e-0496-450c-9b25-d1e50a90e476/data/data4stas/02_data_segm/'

TRAIN_IMG = 'poligon_minsk_1_yandex_z18_train.jpg'
TRAIN_MASK = 'poligon_minsk_1_yandex_z18_train.jpg_msk.png'

TEST_IMG = 'poligon_minsk_1_yandex_z18_val.jpg'
TEST_MASK = 'poligon_minsk_1_yandex_z18_val.jpg_msk.png'

In [None]:
# train_img = io.imread(DATASET_DIR + TRAIN_IMG).astype('float')
train_img = io.imread(DATASET_DIR + TRAIN_IMG)
train_mask = io.imread(DATASET_DIR + TRAIN_MASK)

test_img = io.imread(DATASET_DIR + TEST_IMG)
test_mask = io.imread(DATASET_DIR + TEST_MASK)

In [None]:
print(train_img.shape)
print(train_mask.shape)

print(test_img.shape)
print(test_mask.shape)

In [None]:
def sliding_window(image, stride=10, window_size=(20,20)):
    """Extract patches according to a sliding window.

    Args:
        image (numpy array): The image to be processed.
        stride (int, optional): The sliding window stride (defaults to 10px).
        window_size(int, int, optional): The patch size (defaults to (20,20)).

    Returns:
        list: list of patches with window_size dimensions
    """
    patches = []
    # slide a window across the image
    for x in range(0, image.shape[0], stride):
        for y in range(0, image.shape[1], stride):
            new_patch = image[x:x + window_size[0], y:y + window_size[1]]
            if new_patch.shape[:2] == window_size:
                patches.append(new_patch)
    return patches

def transform(patch, flip=False, mirror=False, rotations=[]):
    """Perform data augmentation on a patch.

    Args:
        patch (numpy array): The patch to be processed.
        flip (bool, optional): Up/down symetry.
        mirror (bool, optional): left/right symetry.
        rotations (int list, optional) : rotations to perform (angles in deg).

    Returns:
        array list: list of augmented patches
    """
    transformed_patches = [patch]
    for angle in rotations:
        transformed_patches.append(skimage.img_as_ubyte(skimage.transform.rotate(patch, angle)))
    if flip:
        transformed_patches.append(np.flipud(patch))
    if mirror:
        transformed_patches.append(np.fliplr(patch))
    return transformed_patches


def augmented_sliding_window(patches, flip=False, mirror=False, rotations=[]):
    transformed_patches = []
    
    for patch in patches:
        transformed_patches.extend(transform(patch, flip, mirror, rotations))
    
    return transformed_patches

In [None]:
# patch_size = 224
patch_size = 128
stride = patch_size * 3 // 4

flip=True
mirror = True
# rotations = [90]
# rotations = [45, 90, 135, 180, 225, 270, 315]
rotations = [90, 180, 270]

In [None]:
train_patches_img = augmented_sliding_window(sliding_window(train_img, stride=stride, window_size=(patch_size, patch_size)), flip, mirror, rotations)
train_patches_mask = augmented_sliding_window(sliding_window(train_mask, stride=stride, window_size=(patch_size, patch_size)), flip, mirror, rotations)

test_patches_img = sliding_window(test_img, stride=stride, window_size=(patch_size, patch_size))
test_patches_mask = sliding_window(test_mask, stride=stride, window_size=(patch_size, patch_size))

In [None]:
print(len(train_patches_img))
print(len(train_patches_mask))

print(len(test_patches_img))
print(len(test_patches_mask))

In [None]:
def show(image):
    plt.imshow(image)
    plt.show()

In [None]:
[show(train_patches_img[i]) for i in range(0, 11)]

In [None]:
X_train, y_train, X_val, y_val, X_test, y_test = preprocess(train_patches_img, train_patches_mask, test_patches_img, test_patches_mask)

In [None]:
print(X_train.shape)
print(y_train.shape)

print(X_val.shape)
print(y_val.shape)

print(X_test.shape)
print(y_test.shape)

In [None]:
X_train_all = np.concatenate((X_train, X_val))
y_train_all = np.concatenate((y_train, y_val))

In [None]:
fcn_vgg16 = FCN(input_shape=(patch_size, patch_size, 3), classes=3,  
                weights='None', trainable_encoder=True)

sgd = optimizers.SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)

fcn_vgg16.compile(optimizer=sgd,
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

history = fcn_vgg16.fit(X_train, y_train, batch_size=32, epochs=5, validation_data=(X_val, y_val))

In [None]:
def plot_history(history):
    acc = history.history['acc']
    val_acc = history.history.get('val_acc')
    loss = history.history['loss']
    val_loss = history.history.get('val_loss')

    epochs = range(1, len(loss) + 1)

    plt.plot(epochs, loss)
    if val_loss is not None:
        plt.plot(epochs, val_loss)
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.show()

    plt.plot(epochs, acc)
    if val_acc is not None:
        plt.plot(epochs, val_acc)
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.show()
    
plot_history(history)

In [None]:
fcn_vgg16.summary()

In [None]:
y_pred = fcn_vgg16.predict(X_test)

In [None]:
y_pred_cls = np.argmax(y_pred, axis=3)

In [None]:
y_pred_cls.shape

In [None]:
y_test_cls = np.argmax(y_test, axis=3)

In [None]:
np.mean(y_test_cls == y_pred_cls)