In [None]:
%config IPCompleter.greedy=True

In [None]:
import cv2
import numpy as np 

import pandas as pd
import tensorflow as tf

# from efficientnet.keras import *
import keras
from keras.preprocessing.image import ImageDataGenerator
from keras import backend as K
from keras import Input
from keras.models import Model
from keras.utils import *
from keras.layers import *

from tensorflow import set_random_seed
from tqdm import tnrange, tqdm_notebook
import matplotlib.pyplot as plt

set_random_seed(2020)
np.random.seed(2020)

import os
import gc

In [None]:
IMG_SIZE = 64
BATCH_SIZE = 64
DATA_PATH = '../data'
CHECKPOINT = '../data/model_weights/mobilenetv2.h5'

In [None]:
def crop_image_from_gray(img,tol=10):
    mask = img>tol
    return img[np.ix_(mask.any(1),mask.any(0))]
def resize(img):
    return cv2.resize(img.astype(np.uint8), (IMG_SIZE, IMG_SIZE))
def preprocessing(train_image):
    train_image = 255 - train_image
    train_image.resize((137, 236))
    train_image = crop_image_from_gray(train_image)
    train_image = resize(train_image)
    return train_image

In [None]:
class Dataset:
    def __init__(self):
        self.x = []
        self.y1 = []
        self.y2 = []
        self.y3 = []
        
    def get_data(self):
        files = os.listdir(DATA_PATH)
        train_csv = pd.read_csv(os.path.join(DATA_PATH, 'train.csv'))
        train_parquet = []
        for file in files:
            if 'train_image' in file:
                train_parquet.append(file)
        
        for file in tqdm_notebook(train_parquet, desc='train_parquet'):
            parquet = pd.read_parquet(os.path.join(DATA_PATH, file))
            for i in tnrange(len(parquet) - 49000, desc='parquet'):
                train_image = list(parquet.iloc[3])
                image_id = train_image.pop(0)
                train_image = np.asarray(train_image)
                train_image = preprocessing(train_image)
                self.x.append(train_image)
                
                grapheme_root = train_csv['grapheme_root'][train_csv['image_id'] == image_id]
                vowel_diacritic = train_csv['vowel_diacritic'][train_csv['image_id'] == image_id]
                consonant_diacritic = train_csv['consonant_diacritic'][train_csv['image_id'] == image_id]
                
                self.y1.append(int(grapheme_root))
                self.y2.append(int(vowel_diacritic))
                self.y3.append(int(consonant_diacritic))     
            break

In [None]:
class ClassificationModel:
    def __init__(self):
        self.model = self.build_model()
        
    def build_model(self):
        base_model = keras.applications.MobileNetV2(weights=None, include_top=False, input_shape=(IMG_SIZE,IMG_SIZE,1))
        x = base_model.output
        x = GlobalAveragePooling2D()(x)
        o1 = Dense(168, activation='softmax', kernel_initializer='he_normal', name='grapheme')(x)
        o2 = Dense(11, activation='softmax', kernel_initializer='he_normal', name='vowel')(x)
        o3 = Dense(7, activation='softmax', kernel_initializer='he_normal', name='consonant')(x)
        return Model(inputs=[base_model.input], outputs=[o1,o2,o3])
        
    def load_weight(self, model_weight_path):
        self.model.load_weights(model_weight_path)
        
    def fit_dataset(self, dataset, epochs=20, callbacks=None):
        x = np.asarray(dataset.x, dtype='float32')
        x = x.reshape((1210, 64, 64, 1))
        [y1, y2, y3] = [np.asarray(dataset.y1, dtype='float32'), np.asarray(dataset.y2, dtype='float32'), np.asarray(dataset.y3, dtype='float32')]
        
        self.model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['acc'])
        self.model.fit(x, [y1, y2, y3],
                  batch_size=BATCH_SIZE,
                  epochs=epochs,
                  callbacks=callbacks)

In [None]:
callbacks = [keras.callbacks.ModelCheckpoint(CHECKPOINT, 
                                             monitor='loss', 
                                             verbose=0, 
                                             save_best_only=False, 
                                             save_weights_only=False, 
                                             mode='auto', 
                                             period=1)]

In [None]:
dataset = Dataset()
dataset.get_data()

In [None]:
print(len(dataset.x))

In [None]:
model = ClassificationModel()
model.fit_dataset(dataset, callbacks=callbacks)