In [3]:
import sys
import os
sys.path.append('..')


import torch

from manga_ocr.utils import get_path_project_dir
from manga_ocr.utils.pytorch_model import get_total_parameters_count

# Load/Prepare model

In [4]:
from manga_ocr.models.localization.conv_unet.conv_unet import ConvUnet

path_output_model = get_path_project_dir('data/output/models/localization.bin')
print('path_output_model', os.path.abspath(path_output_model))

if os.path.exists(path_output_model):
    print('Loading an existing model...')
    model = torch.load(path_output_model)
else:
    print('Creating a new model...')
    model = ConvUnet()
    
print(model)

path_output_model /Users/wanasit/Dropbox/Workspace_Personal/manga-ocr/data/output/models/localization.bin
Loading an existing model...
ConvUnet(
  (down_conv_3): ConvWithPoolingToHalfSize(
    (conv): Sequential(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
  )
  (down_conv_2): ConvWithPoolingToHalfSize(
    (conv): Sequential(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
  )
  (down_conv_1): ConvWithPoolingToHalfSize(
    (conv): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): 

In [5]:
print(model.image_size)
print(get_total_parameters_count(model))

(500, 500)
399490


## Load/Prepare Dataset

In [6]:
from manga_ocr.models.localization.localization_dataset import LocalizationDataset

path_dataset = get_path_project_dir('data/output/generate_manga_dataset')
print('path_dataset', os.path.abspath(path_dataset))

dataset = LocalizationDataset.load_generated_manga_dataset(path_dataset, image_size=model.image_size)
print('dataset', len(dataset))

path_dataset /Users/wanasit/Dropbox/Workspace_Personal/manga-ocr/data/output/generate_manga_dataset
dataset 4000


In [7]:
validation_dataset = dataset.subset(to_idx=100)
training_dataset = dataset.subset(from_idx=100)
print('validation_dataset', len(validation_dataset))
print('training_dataset', len(training_dataset))

validation_dataset 100
training_dataset 3900


# Training

In [8]:
from manga_ocr.models.localization.train import train
from manga_ocr.utils.ploting import plot_losses, show_images
from IPython.display import clear_output

def show_example(model, img):
    img_mask = model.create_image_mark_lines(img)
    show_images(images=[img, img_mask], texts=['input', 'output'], figsize=(10, 10))


def save_and_report(i_epoch, train_losses, val_losses):
    clear_output()
    plot_losses(train_losses, val_losses)
    show_example(model, training_dataset.images[0])
    show_example(model, validation_dataset.images[0])
    
    torch.save(model, path_output_model)

In [9]:
train(model,
      training_dataset=training_dataset,
      validation_dataset=validation_dataset,
      epoch_callback=save_and_report,
      epoch_count=10)

Epoch 0:   0%| | 10/3900 [00:20<2:15:54,  2.10s/it, training_batch_loss=0.00283]


KeyboardInterrupt: 