# Training Segmentation Model

In [None]:
import manga109api

In [None]:
manga109_root = "../datasets/Manga109/Manga109_released_2021_12_30"
dataset = manga109api.Parser(manga109_root)

In [None]:
from mrcnn.config import Config
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
from mrcnn.model import log
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf

%matplotlib inline

# Directory to save logs and trained model
MODEL_DIR = "model"
#
# # Local path to trained weights file
# COCO_MODEL_PATH = "../model-weights/mask_rcnn_coco.h5"
# # Download COCO trained weights from Releases if needed
# if not os.path.exists(COCO_MODEL_PATH):
#     utils.download_trained_weights(COCO_MODEL_PATH)

In [None]:
def get_ax(rows=1, cols=1, size=8):
    """Return a Matplotlib Axes array to be used in
    all visualizations in the notebook. Provide a
    central point to control graph sizes.

    Change the default size attribute to control the size
    of rendered images
    """
    _, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
    return ax

In [None]:
import cv2
class MangaDataset(utils.Dataset):
    """Loads the manga109 dataset. The labels consist of "frame" and "text".
    """

    def load_manga(self, is_train = True):
        """Loads the images.
        """
        # Add classes
        self.add_class("manga", 1, "frame")
        self.add_class("manga", 2, "text")

        count = -1
        # add images
        for book in dataset.books:
            for page in dataset.get_annotation(book)["page"]:

                # train test split
                count += 1
                if count % 5 == 0 and is_train: continue
                if count % 5 != 0 and (not is_train): continue

                attr = {"book": book, "page": page["@index"]}

                self.add_image("manga", image_id=count, path=manga109_root + "/images/" + book + "/" + str(page["@index"]).zfill(3) + ".jpg", book=book, page=page["@index"])


    def image_reference(self, image_id):
        """Return the manga data of the image."""
        info = self.image_info[image_id]
        if info["source"] == "manga":
            return info["manga"]
        else:
            super(self.__class__).image_reference(self, image_id)

    def load_mask(self, image_id):
        """Generate instance masks for shapes of the given image ID.
        """
        info = self.image_info[image_id]
        book = info["book"]
        page = dataset.get_annotation(book)["page"][info["page"]]
        class_ids = []
        masks = []
        width = page["@width"]
        height = page["@height"]
        for frame in page["frame"]:
            mask = np.zeros((height, width), dtype="uint8")
            mask[frame["@ymin"] : frame["@ymax"], frame["@xmin"] : frame["@xmax"]] = 1
            masks.append(mask)
            class_ids.append(self.class_names.index("frame"))
        for text in page["text"]:
            mask = np.zeros((height, width), dtype="uint8")
            mask[text["@ymin"] : text["@ymax"], text["@xmin"] : text["@xmax"]] = 2
            masks.append(mask)
            class_ids.append(self.class_names.index("text"))
        if class_ids:
            masks_arr = np.dstack(masks)
        else:
            masks_arr = np.zeros((height, width, 1))
            class_ids.append(0)
        return masks_arr, np.array(class_ids).astype(np.int32)

In [None]:
train_set = MangaDataset()
train_set.load_manga(is_train=True)
train_set.prepare()

test_set = MangaDataset()
test_set.load_manga(is_train=False)
test_set.prepare()

In [None]:
# Load and display random samples
image_ids = np.random.choice(train_set.image_ids, 4)
for image_id in image_ids:
    image = train_set.load_image(image_id)
    mask, class_ids = train_set.load_mask(image_id)
    visualize.display_top_masks(image, mask, class_ids, train_set.class_names)

In [None]:
class MangaConfig(Config):
    """Configuration for training on the manga109 dataset.
    Derives from the base Config class and overrides values specific
    to the manga109 dataset.
    """
    # Give the configuration a recognizable name
    NAME = "manga"

    # Number of classes (including background)
    NUM_CLASSES = 1 + 2  # background + 2 classes (frame + text)

    STEPS_PER_EPOCH = len(train_set.image_ids)

    TRAIN_ROIS_PER_IMAGE = 128

    IMAGES_PER_GPU = 1

    GPU_COUNT = 1


config = MangaConfig()
config.display()

## Training

In [None]:
# Create model in training mode
tf.compat.v1.disable_eager_execution()
model = modellib.MaskRCNN(mode="training", config=config,
                          model_dir=MODEL_DIR)
# Which weights to start with?
init_with = "last"

if init_with == "random":
    pass
elif init_with == "last":
    # Load the last model you trained and continue training
    model.load_weights(model.find_last(), by_name=True)

In [None]:
# train all layers
model.train(train_set, test_set,
            learning_rate=config.LEARNING_RATE,
            epochs=2,
            layers="all")