In [None]:
import sys
import os
sys.path.append('..')
from datetime import datetime

now = datetime.now()
timestamp = now.strftime("%Y_%m_%d_%H%M%S")
print(timestamp)

# Load/Prepare model

In [None]:
import torch
from comic_ocr.models import localization
from comic_ocr.utils.files import get_path_project_dir, load_image
from comic_ocr.utils.pytorch_model import get_total_parameters_count

model_path = get_path_project_dir(
    f'data/output/models/localization_{timestamp}.bin'
)
if os.path.exists(model_path) and not forcing_new_model:
    print('Loading an existing model...')
    model = torch.load(model_path)
else:
    print('Creating a new model...')
    model = localization.create_new_model()

def save_model(model, model_path):
    torch.save(model, model_path)


print(f'Training [{model_path}]')

print('- preferred_image_size:', model.preferred_image_size)
print('- get_total_parameters_count', get_total_parameters_count(model))
print('- model', model)

## Load/Prepare Dataset

In [None]:
from comic_ocr.models.localization.localization_dataset import LocalizationDataset
from comic_ocr.utils.ploting import show_images
from comic_ocr.utils import image_with_annotations, concatenated_images

In [None]:
dataset_generated_path = get_path_project_dir('data/output/generate_manga_dataset')
dataset_generated = LocalizationDataset.load_generated_dataset(model, dataset_generated_path)

print('dataset_generated', len(dataset_generated))
show_images([
    concatenated_images(
        (dataset_generated.get_image(0), dataset_generated.get_mask_char(0), dataset_generated.get_mask_line(0))), 
    concatenated_images(
        (dataset_generated.get_image(1), dataset_generated.get_mask_char(1), dataset_generated.get_mask_line(1))), 
], figsize=(30, 15), num_col=1)

In [None]:
dataset_annotated_path = get_path_project_dir('data/manga_line_annotated')
dataset_annotated = LocalizationDataset.load_line_annotated_dataset(model, dataset_annotated_path)

print('dataset_annotated', len(dataset_annotated))
show_images([
    concatenated_images(
        (dataset_annotated.get_image(0), dataset_annotated.get_mask_char(0), dataset_annotated.get_mask_line(0))), 
    concatenated_images(
        (dataset_annotated.get_image(1), dataset_annotated.get_mask_char(1), dataset_annotated.get_mask_line(1))), 
], figsize=(30, 15), num_col=1)

In [None]:
import random
random.seed('abc')
dataset_annotated = dataset_annotated.shuffle()
dataset_generated = dataset_generated.shuffle()


show_images([
    dataset_generated.get_image(0),
    dataset_generated.get_image(1),
    dataset_generated.get_image(2),
    dataset_annotated.get_image(0), 
    dataset_annotated.get_image(1), 
    dataset_annotated.get_image(2),
])

validation_dataset = LocalizationDataset.merge(
    dataset_generated.subset(to_idx=90),
    dataset_annotated.subset(to_idx=10))

training_dataset = LocalizationDataset.merge(
    dataset_generated.subset(from_idx=90, to_idx=3000),
    dataset_annotated.subset(from_idx=10)
).shuffle()



print('validation_dataset', len(validation_dataset))
print('training_dataset', len(training_dataset))

# Training

In [None]:
from comic_ocr.utils.ploting import plot_losses, show_images
from IPython.display import clear_output

example = load_image(get_path_project_dir('example/manga_annotated/normal_01.jpg'))

def show_example(model, img):
    mask_char, mask_line = model._create_output_mark_images(img)
    show_images(
        images=[img, mask_char.convert('RGB'), mask_line.convert('RGB')], 
        texts=['Input', 'Characters', 'Lines'], 
        figsize=(10, 10), num_col=3
    )

show_example(model, training_dataset.get_image(0))
show_example(model, dataset_annotated.get_image(0))
show_example(model, dataset_annotated.get_image(1))
show_example(model, example)

In [None]:
def save_and_report(i_epoch, train_losses, val_losses):
    clear_output()
    plot_losses(train_losses, val_losses)
    show_example(model, training_dataset.get_image(0))
    show_example(model, dataset_annotated.get_image(0))
    show_example(model, dataset_annotated.get_image(1))
    show_example(model, example)
    
    torch.save(model, model_path)

In [None]:
from comic_ocr.models.localization.train import train
_ = train(model,
          train_dataset=training_dataset,
          validate_dataset=validation_dataset,
          validate_every_n=300,
          update_callback=save_and_report,
          update_every_n=300,
          batch_size=5,
          epoch_count=5)