In [ ]:
import cv2
import os
from tensorflow.python.keras.callbacks import ModelCheckpoint
from tensorflow.python.keras.callbacks import CSVLogger
from tensorflow.python.keras.callbacks import EarlyStopping
from tensorflow.python.keras.optimizers import Adam
from models import *
from data_generator import *

train_frames_dir = './data/train_frames/'
train_masks_dir = './data/train_masks/'

val_frames_dir = './data/val_frames/'
val_masks_dir = './data/val_masks/'

INPUT_SHAPE = (256, 256, 3)
NO_CLASSES = 20

NO_OF_TRAINING_FRAMES = len(os.listdir(train_frames_dir))
NO_OF_VAL_FRAMES = len(os.listdir(val_frames_dir))

NO_OF_EPOCHS = 10

BATCH_SIZE = 4

color_mapping = {
    0 : (128, 0,128),
    1 : (244, 35,232),
    2 : ( 70, 70, 70),
    3 : (102,102,156),
    4 : (190,153,153),
    5 : (153,153,153),
    6 : (250,170, 30),
    7 : (220,220,  0),
    8 : (107,142, 35),
    9 : (152,251,152),
    10 : ( 70,130,180),
    11 : (220, 20, 60),
    12 : (255,  0,  0),
    13 : (  0,  0,142),
    14 : (  0,  0, 70),
    15 : (  0, 60,100),
    16 : (  0, 80,100),
    17 : (  0,  0,230),
    18 : (119, 11, 32),
    19 : (180,165,180)
}

In [ ]:
model = resnet50_encoder_unet_decoder(num_classes=NO_CLASSES, input_size=INPUT_SHAPE)

In [ ]:
train_gen = data_gen(img_folder = train_frames_dir,
                    mask_folder = train_masks_dir,
                    batch_size = BATCH_SIZE,
                    num_classes = NO_CLASSES,
                    input_shape = INPUT_SHAPE)

val_gen = data_gen(img_folder = val_frames_dir,
                    mask_folder = val_masks_dir,
                    batch_size = BATCH_SIZE,
                    num_classes = NO_CLASSES,
                    input_shape = INPUT_SHAPE)

In [ ]:
results = model.fit_generator(train_gen, epochs=NO_OF_EPOCHS, 
                          steps_per_epoch = (NO_OF_TRAINING_FRAMES//BATCH_SIZE),
                          validation_data=val_gen, 
                          validation_steps=(NO_OF_VAL_FRAMES//BATCH_SIZE))

In [ ]:
model.save('trainedModels/main_model.h5')

In [ ]:
val_frames = os.listdir(val_frames_dir)
train_img = cv2.imread(val_frames_dir+'/'+val_frames[1])/255

In [ ]:
prediction = model.predict(np.expand_dims(train_img, axis=0))

In [ ]:
labels = np.argmax(prediction[0], axis=-1)
labels_colored = np.empty(shape=(labels.shape[0], labels.shape[1],3))
for lab in color_mapping:
    labels_colored[labels == lab] = np.array(color_mapping[lab])

cv2.imwrite('segmentation_example.png', labels_colored)