# Initialisation

In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
!pip3 uninstall keras-nightly
!pip3 uninstall -y tensorflow
!pip3 install keras==2.1.6
!pip3 install tensorflow==1.15.0
!pip3 install h5py==2.10.0
!pip3 install opencv-contrib-python

In [None]:
!git clone https://github.com/matterport/Mask_RCNN.git

In [None]:
%cd Mask_RCNN

In [None]:
!python setup.py install
!pip install -r requirements.txt

In [None]:
!pip install elementpath
!pip install manga109api

In [None]:
import elementpath
from xml.etree import ElementTree
import manga109api
from google.colab import files
from os import listdir
from numpy import zeros, asarray, expand_dims, mean
from numpy import asarray
from mrcnn.utils import Dataset, extract_bboxes, compute_ap
from mrcnn.config import Config
from mrcnn.visualize import display_instances
from mrcnn.model import MaskRCNN, load_image_gt, mold_image
import matplotlib.pyplot as pyplot
from matplotlib.patches import Rectangle, Arrow
import math
import cv2

In [8]:
root_dir = "/content/drive/MyDrive/NRP/Project/Manga109/"
p = manga109api.Parser(root_dir=root_dir)

# Reformat Manga109 annotations

In [None]:
%cd /content

In [None]:
for book in p.books:
  tree = ElementTree.parse(root_dir + "annotations/" + book + ".xml")
  root = tree.getroot()

  %mkdir $book
  %cd /content/$book

  for page in root.findall(".//page"):
    new_xml = page
    b_xml = ElementTree.tostring(new_xml)
    with open("new_" + book + str(page.attrib["index"]) + ".xml", "wb") as f:
      f.write(b_xml)
  
  %cd /content

In [None]:
for book in p.books:
  !zip -r /content/$book /content/$book

In [None]:
print("\n".join(p.books))

# Prepare Dataset

In [9]:
class MangaDataset(Dataset):
  def load_dataset(self, is_train=True):
    self.add_class("dataset", 1, "face")
    self.add_class("dataset", 2, "text")
    self.add_class("dataset", 3, "frame")

    last_image_id = 0

    for book in sorted(p.books):
      images_dir = root_dir + "images/" + book + "/"
      annotations_dir = root_dir + "annotations/" + book + "/"

      for img in sorted(listdir(images_dir)):
        og_image_id = int(img[:-4])
        image_id = int(img[:-4]) + last_image_id
      
        tree = ElementTree.parse(annotations_dir + "new_" + book + str(og_image_id) + ".xml")
        root = tree.getroot()
        faces = []
        texts = []
        frames = []

        for face in root.findall(".//face"):
          faces.append(face)
        
        for text in root.findall(".//text"):
          texts.append(text)
        
        for frame in root.findall(".//frame"):
          frames.append(frame)
        
        if (not faces) or (not texts) or (not frames):
          continue

        if is_train and og_image_id >= 50:
          continue

        if not is_train and og_image_id < 50:
          continue
        
        img_path = images_dir + img
        ann_path = annotations_dir + "new_" + book + str(og_image_id) + ".xml"

        self.add_image("dataset", image_id=image_id, path=img_path, annotation=ann_path, class_ids=[0, 1, 2, 3])

      last_image_id = image_id + 1


  def extract_boxes(self, filename):
    tree = ElementTree.parse(filename)
    root = tree.getroot()
    boxes = []

    for box in root.findall(".//face"):
      att = box.attrib
      xmin = att["xmin"]
      ymin = att["ymin"]
      xmax = att["xmax"]
      ymax = att["ymax"]
      coors = [xmin, ymin, xmax, ymax, "face"]
      boxes.append(coors)
    
    for box in root.findall(".//text"):
      att = box.attrib
      xmin = att["xmin"]
      ymin = att["ymin"]
      xmax = att["xmax"]
      ymax = att["ymax"]
      coors = [xmin, ymin, xmax, ymax, "text"]
      boxes.append(coors)
    
    for box in root.findall(".//frame"):
      att = box.attrib
      xmin = att["xmin"]
      ymin = att["ymin"]
      xmax = att["xmax"]
      ymax = att["ymax"]
      coors = [xmin, ymin, xmax, ymax, "frame"]
      boxes.append(coors)

    page_att = root.attrib
    width = int(page_att["width"])
    height = int(page_att["height"])

    return boxes, width, height


  def load_mask(self, image_id):
    info = self.image_info[image_id]
    path = info["annotation"]
    boxes, w, h = self.extract_boxes(path)
    
    masks = zeros([h, w, len(boxes)], dtype="uint8")

    class_ids = []

    for i in range(len(boxes)):
      box = boxes[i]
      row_s, row_e = box[1], box[3]
      col_s, col_e = box[0], box[2]

      if box[4] == "face":
        masks[int(row_s):int(row_e), int(col_s):int(col_e), i] = 1
        class_ids.append(self.class_names.index("face"))

      elif box[4] == "text":
        masks[int(row_s):int(row_e), int(col_s):int(col_e), i] = 2
        class_ids.append(self.class_names.index("text"))
      
      elif box[4] == "frame":
        masks[int(row_s):int(row_e), int(col_s):int(col_e), i] = 3
        class_ids.append(self.class_names.index("frame"))

    return masks, asarray(class_ids, dtype="int32")


  def image_reference(self, image_id):
    info = self.image_info[image_id]
    return info["path"]

In [None]:
# train set
train_set = MangaDataset()
train_set.load_dataset(is_train=True)
train_set.prepare()
print("Train: %d" % len(train_set.image_ids))

In [None]:
# test/val set
test_set = MangaDataset()
test_set.load_dataset(is_train=False)
test_set.prepare()
print("Test: %d" % len(test_set.image_ids))

In [None]:
# load an image and mask
image_id = 1
image = test_set.load_image(image_id)
print(image.shape)

mask, class_ids = test_set.load_mask(image_id)
print(mask.shape)

In [None]:
# display image with masks and bounding boxes
bbox = extract_bboxes(mask)
display_instances(image, bbox, mask, class_ids, test_set.class_names)

# Train Model

In [None]:
class MangaConfig(Config):
  NAME = "manga_cfg"
  NUM_CLASSES = 1 + 3
  STEPS_PER_EPOCH = 131

In [None]:
config = MangaConfig()
model = MaskRCNN(mode="training", model_dir="/content", config=config)

model.load_weights("/content/drive/MyDrive/NRP/Project/Working/model_6.h5",
                   by_name=True,
                   exclude=["mrcnn_class_logits", "mrcnn_bbox_fc",  "mrcnn_bbox", "mrcnn_mask"])

model.train(train_set, test_set, learning_rate=0.000005, epochs=40, layers="all")

# config.LEARNING_RATE = 0.001

# Evaluate Model

In [None]:
class PredictionConfig(Config):
  NAME = "manga_cfg"
  NUM_CLASSES = 1 + 3
  GPU_COUNT = 1
  IMAGES_PER_GPU = 1

In [None]:
cfg = PredictionConfig()
model = MaskRCNN(mode="inference", model_dir="/content", config=cfg)

In [None]:
model.load_weights("/content/drive/MyDrive/NRP/Project/Working/model_5.h5", by_name=True)

In [None]:
# evaluate model using Manga109 dataset
def evaluate_model(dataset, model, cfg):
  APs = []
  for image_id in dataset.image_ids:
    image, image_meta, gt_class_id, gt_bbox, gt_mask = load_image_gt(dataset, cfg, image_id, use_mini_mask=False)
    scaled_image = mold_image(image, cfg)
    sample = expand_dims(scaled_image, 0)
    yhat = model.detect(sample, verbose=0)
    r = yhat[0]

    # change IoU threshold
    AP, _, _, _ = compute_ap(gt_bbox, gt_class_id, gt_mask, r["rois"], r["class_ids"], r["scores"], r["masks"], iou_threshold=0.5)
    APs.append(AP)

  mAP = mean(APs)
  return mAP

In [None]:
# evaluate model on training dataset
train_mAP = evaluate_model(train_set, model, cfg)
print("Train mAP: %.3f" % train_mAP)

In [None]:
# evaluate model on test dataset
test_mAP = evaluate_model(test_set, model, cfg)
print("Test mAP: %.3f" % test_mAP)

# Face-Text Association

In [None]:
def arrow_face_text(dataset, image_id, cfg, model):
  image = dataset.load_image(image_id)
  mask, _ = dataset.load_mask(image_id)
  scaled_image = mold_image(image, cfg)
  sample = expand_dims(scaled_image, 0)

  yhat = model.detect(sample, verbose=0)[0]
  rois = list(yhat["rois"])
  class_ids = list(yhat["class_ids"])

  for j in range(mask.shape[2]):
    pyplot.imshow(mask[:, :, j], cmap="gray", alpha=0.3)
  
  pyplot.subplot(111)
  pyplot.imshow(image)
  pyplot.title("Face to Text")

  ax = pyplot.gca()

  face_centers = []
  text_centers = []
  frame_corners = []
  count = 0

  for id in class_ids:
    box = rois[count]
    count += 1

    y1, x1, y2, x2 = box
    width, height = x2 - x1, y2 - y1

    if id == 1:
      face_x = (x1 + x2)//2
      face_y = (y1 + y2)//2
      face_center = [face_x, face_y]
      face_centers.append(face_center)

      rect = Rectangle((x1, y1), width, height, fill=False, color="red")
      ax.add_patch(rect)

    elif id == 2:
      text_x = (x1 + x2)//2
      text_y = (y1 + y2)//2
      text_center = [text_x, text_y]
      text_centers.append(text_center)

      rect = Rectangle((x1, y1), width, height, fill=False, color="yellow")
      ax.add_patch(rect)

    elif id == 3:
      frame_corners.append([x1, x2, y1, y2])
      rect = Rectangle((x1, y1), width, height, fill=False, color="violet")
      ax.add_patch(rect)

  faces_to_texts = []

  for frame in frame_corners:
    x1, x2, y1, y2 = frame
    face_centers_filtered, text_centers_filtered = [], []
    num_faces, num_text = 0, 0

    for face in face_centers:
      if face[0] < x2 and face[0] > x1 and face[1] < y2 and face[1] > y1:
        num_faces += 1
        face_centers_filtered.append(face)

    for text in text_centers:
      if text[0] < x2 and text[0] > x1 and text[1] < y2 and text[1] > y1:
        num_text += 1
        text_centers_filtered.append(text)
    
    if num_faces >= num_text:
      for face in face_centers_filtered:
        if text_centers_filtered:
          nearest_text = text_centers_filtered[0]
          shortest_x = abs(face[0] - nearest_text[0])
          shortest_y = abs(face[1] - nearest_text[1])
          shortest_distance = math.sqrt(shortest_x**2 + shortest_y**2)

          for text in text_centers_filtered:
            distance_x = abs(face[0] - text[0])
            distance_y = abs(face[1] - text[1])
            distance = math.sqrt(distance_x**2 + distance_y**2)

            if distance < shortest_distance:
              shortest_distance = distance
              nearest_text = text

          face_to_text = [face, nearest_text]
          faces_to_texts.append(face_to_text)

    elif num_faces < num_text:
      for text in text_centers_filtered:
        if face_centers_filtered:
          nearest_face = face_centers_filtered[0]
          shortest_x = abs(text[0] - nearest_face[0])
          shortest_y = abs(text[1] - nearest_face[1])
          shortest_distance = math.sqrt(shortest_x**2 + shortest_y**2)

          for face in face_centers_filtered:
            distance_x = abs(face[0] - text[0])
            distance_y = abs(face[1] - text[1])
            distance = math.sqrt(distance_x**2 + distance_y**2)

            if distance < shortest_distance:
              shortest_distance = distance
              nearest_face = face

          face_to_text = [nearest_face, text]
          faces_to_texts.append(face_to_text)

  for face_to_text in faces_to_texts:
    face, text = face_to_text

    face_x, face_y = face
    text_x, text_y = text

    length_x = abs(face_x - text_x)
    length_y = abs(face_y - text_y)

    if face_x > text_x: #face is to the right of text
      length_x *= -1
      
    if face_y > text_y: #face is below text
      length_y *= -1

    arrow = Arrow(face_x, face_y, length_x, length_y, color="cornflowerblue")
    ax.add_patch(arrow)

  pyplot.show()

# Text Order Determination

In [None]:
def order_frame_halves(unordered_frames):
  ordered_frames = []
  while unordered_frames:
    for frame in unordered_frames:
      for other_frame in unordered_frames:
        if other_frame[3] < frame[1]: # if there is a frame above the current frame
          break # the current frame is not the next frame to read
      else:
        ordered_frames.append(frame[:-1]) # the current frame is the next frame
        unordered_frames.remove(frame)

  return ordered_frames

In [None]:
def order_text(dataset, image_id, img_width, cfg, model):
  image = dataset.load_image(image_id)
  mask, _ = dataset.load_mask(image_id)
  scaled_image = mold_image(image, cfg)
  sample = expand_dims(scaled_image, 0)

  yhat = model.detect(sample, verbose=0)[0]
  rois = list(yhat["rois"])
  class_ids = list(yhat["class_ids"])

  for j in range(mask.shape[2]):
    pyplot.imshow(mask[:, :, j], cmap="gray", alpha=0.3)
  
  pyplot.subplot(111)
  pyplot.imshow(image)
  pyplot.title("Order Text")

  ax = pyplot.gca()

  text_centers = []
  total_unordered_frames = []
  count = 0

  for id in class_ids:
    box = rois[count]
    count += 1

    y1, x1, y2, x2 = box
    width, height = x2 - x1, y2 - y1

    if id == 2:
      text_x = (x1 + x2)//2
      text_y = (y1 + y2)//2
      text_center = [text_x, text_y]
      text_centers.append(text_center)

      rect = Rectangle((x1, y1), width, height, fill=False, color="yellow")
      ax.add_patch(rect)
    elif id == 3:
      frame_x_center = (x1 + x2)//2
      total_unordered_frames.append([x1, y1, x2, y2, frame_x_center]) # the corners are ordered this way to improve sorting later on

  total_unordered_frames.sort(reverse=True)
  ordered_frames = [] # 1st, 2nd, 3rd, ...

  half_line = img_width//2

  unordered_frames_first = [frame for frame in total_unordered_frames if frame[-1] > half_line] # frames to the right of half line
  unordered_frames_second = [frame for frame in total_unordered_frames if frame[-1] <= half_line] # frames to the left of half line

  ordered_frames_first = order_frame_halves(unordered_frames_first)
  ordered_frames_second = order_frame_halves(unordered_frames_second)

  index = 1
  
  for frame in ordered_frames_first:
    x1, y1, x2, y2 = frame
    text_centers_filtered = []

    for text in text_centers:
      if text[0] < x2 and text[0] > x1 and text[1] < y2 and text[1] > y1:
        text_centers_filtered.append([text[0], -text[1]])
    
    text_centers_filtered.sort(reverse=True)
    for text in text_centers_filtered:
      pyplot.text(text[0], -text[1], index)
      index += 1

  for frame in ordered_frames_second:
    x1, y1, x2, y2 = frame
    text_centers_filtered = []

    for text in text_centers:
      if text[0] < x2 and text[0] > x1 and text[1] < y2 and text[1] > y1:
        text_centers_filtered.append([text[0], -text[1]])
    
    text_centers_filtered.sort(reverse=True)
    for text in text_centers_filtered:
      pyplot.text(text[0], -text[1], index)
      index += 1

  pyplot.show()

# Use Model

In [None]:
dataset = test_set
last_image_id = 0

for book in sorted(p.books):
  images_dir = root_dir + "images/" + book + "/"
  annotations_dir = root_dir + "annotations/" + book + "/"

  for img in sorted(listdir(images_dir)):
    image_id = int(img[:-4])
    current_image_id = image_id + last_image_id

    xml_file = annotations_dir + "new_" + book + str(image_id) + ".xml"

    tree = ElementTree.parse(xml_file)
    root = tree.getroot()
    faces = []
    texts = []

    for face in root.findall(".//face"):
      faces.append(face)
    
    for text in root.findall(".//text"):
      texts.append(text)
    
    if len(faces) < 1: #if there are no faces
      continue
    
    if len(texts) < 1: #if there are no texts
      continue

    img_width = int(root.attrib["width"]) # get width of page

    arrow_face_text(dataset, current_image_id, cfg, model)
    order_text(dataset, image_id, img_width, cfg, model)

    print(book, current_image_id)

  last_image_id = current_image_id + 1