# Setup

In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
!git clone https://github.com/rois-codh/kmnist.git

In [None]:
%cd kmnist

In [None]:
# Download Kuzushiji-49
!python download_data.py

In [5]:
import numpy as np
import tensorflow as tf
from google.colab.patches import cv2_imshow
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import SGD
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from imutils import build_montages
from PIL import Image
import matplotlib.pyplot as plt
import argparse
import cv2
import matplotlib
import os
import tarfile
matplotlib.use("Agg")

In [None]:
%cd /content/drive/MyDrive/NRP/Project/OCRData

# Prepare Dataset

In [7]:
def load_hiragana_dataset():
    TRAIN_DATA = "/content/kmnist/k49-train-imgs.npz"
    TRAIN_LABELS = "/content/kmnist/k49-train-labels.npz"
    TEST_DATA = "/content/kmnist/k49-test-imgs.npz"
    TEST_LABELS = "/content/kmnist/k49-test-labels.npz"

    train_data = np.load(TRAIN_DATA)["arr_0"]
    train_labels = np.load(TRAIN_LABELS)["arr_0"]
    test_data = np.load(TEST_DATA)["arr_0"]
    test_labels = np.load(TEST_LABELS)["arr_0"]

    data = np.vstack([train_data, test_data])
    data = [cv2.resize(image, (32, 32)) for image in data]
    data = np.array(data, dtype="float32")
    data = np.expand_dims(data, axis=-1)
    data /= 255.0

    labels = np.hstack([train_labels, test_labels])

    return data, labels

In [8]:
def load_kanji_dataset():
    TRAIN_DATA = "/content/drive/MyDrive/NRP/Project/OCRData/kuzushiji50_train_imgs.npy"
    TRAIN_LABELS = "/content/drive/MyDrive/NRP/Project/OCRData/kuzushiji50_train_labels.npy"
    TEST_DATA = "/content/drive/MyDrive/NRP/Project/OCRData/kuzushiji50_test_imgs.npy"
    TEST_LABELS = "/content/drive/MyDrive/NRP/Project/OCRData/kuzushiji50_test_labels.npy"

    train_data = np.load(TRAIN_DATA)
    train_labels = np.load(TRAIN_LABELS)
    test_data = np.load(TEST_DATA)
    test_labels = np.load(TEST_LABELS)

    data = np.vstack([train_data, test_data])
    data = [cv2.resize(image, (32, 32)) for image in data]
    data = np.array(data, dtype="float32")
    data = np.expand_dims(data, axis=-1)
    data /= 255.0

    labels = np.hstack([train_labels, test_labels])
    labels = [i+49 for i in labels]
    labels = np.array(labels, dtype="int")

    return data, labels

In [9]:
data_hiragana, labels_hiragana = load_hiragana_dataset()
data_kanji, labels_kanji = load_kanji_dataset()

data = np.vstack([data_hiragana, data_kanji])
labels = np.hstack([labels_hiragana, labels_kanji])

In [10]:
le = LabelBinarizer()
labels = le.fit_transform(labels)
counts = labels.sum(axis=0)

class_totals = labels.sum(axis=0)
class_weight = {}

for i in range(0, len(class_totals)):
    class_weight[i] = class_totals.max() / class_totals[i]

train_x, test_x, train_y, test_y = train_test_split(data, labels, test_size=0.20, stratify=labels, random_state=42)

In [11]:
aug = ImageDataGenerator(
    rotation_range=10,
    zoom_range=0.05,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.15,
    horizontal_flip=False,
    fill_mode="nearest"
)

# Train Model

In [12]:
# Settings
EPOCHS = 35
INIT_LR = 1e-1
BS = 128

In [13]:
model = tf.keras.applications.resnet50.ResNet50(input_shape=(32, 32, 1), weights=None, classes=99)

In [14]:
opt = SGD(learning_rate=INIT_LR, decay=INIT_LR/EPOCHS)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])

In [None]:
H = model.fit(
    aug.flow(train_x, train_y, batch_size=BS),
    validation_data=(test_x, test_y),
    steps_per_epoch=len(train_x)//BS,
    epochs=EPOCHS,
    class_weight=class_weight,
    verbose=1
)

In [None]:
%cd /content/

In [None]:
model.save("manga_ocr.h5")

# Evaluate Model

In [None]:
label_names = [str(index) for index in range(99)]
predictions = model.predict(test_x, batch_size=BS)

print(classification_report(test_y.argmax(axis=1), predictions.argmax(axis=1), target_names=label_names))

# Analyse Model

In [None]:
images = []

for i in np.random.choice(np.arange(0, len(test_y)), size=(49,)):
    probs = model.predict(test_x[np.newaxis, i])
    prediction = probs.argmax(axis=1)
    label = label_names[prediction[0]]
    image = (test_x[i]*255).astype("uint8")
    color = (0, 255, 0)

    if prediction[0] != np.argmax(test_y[i]):
        color = (0, 0, 255)

    image = cv2.merge([image] * 3)
    image = cv2.resize(image, (96, 96), interpolation=cv2.INTER_LINEAR)
    cv2.putText(image, label, (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.75, color, 2)
    images.append(image)

montage = build_montages(images, (96, 96), (7, 7))[0]

cv2_imshow(montage)
cv2.waitKey(0)