### 準備

In [1]:
import json
import math
import os
from collections import namedtuple
from dataclasses import dataclass
from glob import glob
from typing import Optional

import IPython

from matplotlib import pyplot as plt, patches
from tqdm import tqdm

import numpy as np


pathstr = lambda *s: os.path.abspath(os.path.expanduser(os.path.join(*s)))

char2code = lambda c: format(ord(c), '#06x')
code2char = lambda c: chr(int(c, base=16))

In [5]:
from dataclasses import dataclass, asdict as dc_asdict
from typing import Optional


@dataclass
class Radical:
    name: str
    idx_: Optional[int]
    center_x: float
    center_y: float
    width: float
    height: float

    def __init__(self, name, **kwargs):
        self.name = name
        self.idx = kwargs.get("idx", None)

        if "center_x" in kwargs:
            self.center_x = kwargs["center_x"]
            self.center_y = kwargs["center_y"]
            self.width = kwargs["width"]
            self.height = kwargs["height"]

        elif "left" in kwargs:
            left = kwargs["left"]
            right = kwargs["right"]
            top = kwargs["top"]
            bottom = kwargs["bottom"]

            self.center_x = (left + right) / 2
            self.center_y = (top + bottom) / 2
            self.width = right - left
            self.height = bottom - top

        else:
            raise Exception(f"invalid arguments: {kwargs}")

    @staticmethod
    def from_radicaljson(dct):
        name = dct["name"]
        if dct["part"] is not None:
            name += f'_{dct["part"]}'

        bounding = dct["bounding"]

        left = bounding["left"]
        right = bounding["right"]
        top = bounding["top"]
        bottom = bounding["bottom"]

        center_x = (left + right) / 2
        center_y = (top + bottom) / 2
        width = right - left
        height = bottom - top

        return Radical(name, center_x=center_x, center_y=center_y, width=width, height=height)

    @staticmethod
    def from_dict(dct):
        return Radical(**dct)

    def to_dict(self):
        return dc_asdict(self)

    @property
    def idx(self):
        if self.idx_ is None:
            raise Exception("idx is not registered")
        return self.idx_

    @idx.setter
    def idx(self, idx):
        self.idx_ = idx

    @property
    def left(self):
        return self.center_x - self.width / 2

    @property
    def right(self):
        return self.center_x + self.width / 2

    @property
    def top(self):
        return self.center_y - self.height / 2

    @property
    def bottom(self):
        return self.center_y + self.height / 2


@dataclass
class Char:
    name: Optional[str]
    radicals: list[Radical]

    @staticmethod
    def from_radicaljson(dct):
        def get_radicals(dct):
            from_children = []
            for d in dct["children"]:
                c = get_radicals(d)
                if c is None:
                    from_children = None
                    break
                from_children += c

            from_name = dct["name"] and [Radical.from_radicaljson(dct)]

            if (from_children is not None) and len(from_children):
                # 例) 三 = 三_1 + 三_2 = 一 + 三_2 のとき 三_1 + 三_2 を採用する
                if (len(from_children) == 1) and (from_name is not None):
                    return from_name

                return from_children

            return from_name

        name = dct["name"]
        radicals = get_radicals(dct)
        return Char(name, radicals)

    @staticmethod
    def from_dict(dct):
        name = dct["name"]
        radicals = [Radical.from_dict(r) for r in dct["radicals"]]
        return Char(name, radicals)

    def to_dict(self):
        name = self.name
        radicals = [r.to_dict() for r in self.radicals]
        return {"name": name, "radicals": radicals}

    def to_formula_string(self):
        return f"{self.name} = {' + '.join(map(lambda r: r.name, self.radicals))}"

    def register_radicalidx(self, radicalname2idx):
        for r in self.radicals:
            if r.name not in radicalname2idx:
                raise Exception(f"radicalname2idx does not support radical '{r.name}' of '{self.name}'")
            r.idx = radicalname2idx[r.name]

In [2]:
import cv2


def get_image(path):
    image = cv2.imread(path) # BGR
    image = np.swapaxes(image, 0, 1) # (height, width, bgr) -> (width, height, bgr)
    return image

In [3]:
import ipywidgets

import cv2
import IPython


def render_images(images, columns=None, scroll=False):
    if not isinstance(images, list):
        raise Exception()
    
    columns = columns or len(images)
    
    children = []
    for image in images:
        if isinstance(image, tuple):
            image, title = image
        else:
            title = None
        
        # path
        if isinstance(image, str):
            image = IPython.display.Image(image)
            if isinstance(image.data, str):
                raise Exception(f"image not found: {image.data}")
            image = ipywidgets.Image(value=image.data, layout=ipywidgets.Layout(margin="0"))
        
        # ndarray
        elif isinstance(image, np.ndarray):
            _, image = cv2.imencode(".jpg", image)
            image = IPython.display.Image(data=image.tobytes())
            image = ipywidgets.Image(value=image.data, layout=ipywidgets.Layout(margin="0"))
            
        else:
            raise Exception(f"unsupported image: {image}")
        
        if title is None:
            children.append(image)
        else:
            children.append(ipywidgets.VBox(
                [ipywidgets.Label(title), image],
                layout=ipywidgets.Layout(align_items="center"),
            ))
        
    grid = ipywidgets.GridBox(
        children=children,
        layout=ipywidgets.Layout(
            width="100%",
            height="fit-content",
            grid_template_columns=f"repeat({columns}, 1fr)",
            align_items="flex-end",
            grid_gap="8px",
        )
    )
    
    return grid

In [6]:
from matplotlib import pyplot as plt

from PIL import Image as PilImage, ImageDraw as PilImageDraw, ImageFont as PilImageFont


def plot_forecast(ax, image_size, draw_by_font, char):
    root_image = PilImage.new("RGB", (image_size, image_size), color=(255, 255, 255))
    root_draw = PilImageDraw.Draw(root_image)
    
    root_draw.rectangle((0, 0, image_size - 1, image_size - 1), outline=(0, 0, 0), width=1)
    
    font = PilImageFont.truetype(pathstr("~/datadisk/dataset/font/NotoSansJP-Regular.ttf"), size=image_size // 4, index=0)
    
    for radical in char.radicals:
        center_x = int(radical.center_x * image_size)
        center_y = int(radical.center_y * image_size)
        left = int(radical.left * image_size)
        right = int(radical.right * image_size)
        top = int(radical.top * image_size)
        bottom = int(radical.bottom * image_size)
        width = int(radical.width * image_size)
        height = int(radical.height * image_size)
        
        root_draw.rectangle((left, top, right, bottom), outline=(255, 0, 0))
        
        if draw_by_font:
            image = PilImage.new("RGBA", (image_size, image_size), color=(255, 255, 255, 0))
            draw = PilImageDraw.Draw(image)
            draw.text(
                xy=(center_x, center_y),
                text=radical.name.split("_")[0], font=font, fill=(0, 0, 0, 255), anchor="mm"
            )
            image = image.crop(image.getbbox())
            image = image.resize((width, height))
            
            root_image.paste(image, (left, top), image)
        
        ax.annotate(radical.name, (left, top), ha="left", va="top", fontsize=16, fontweight="bold", color="red")

    ax.set_axis_off()
    ax.imshow(root_image, cmap="gray")

# plot_forecast(plt.gca(), 512, True, Char("倹", [
#     Radical("亻", left=0.078125, right=0.3125, top=0.140625, bottom=0.875),
#     Radical("人", left=0.296875, right=0.859375, top=0.109375, bottom=0.453125),
#     Radical("一", left=0.4375, right=0.703125, top=0.34375, bottom=0.390625),
#     Radical("口", left=0.375, right=0.75, top=0.453125, bottom=0.671875),
#     Radical("人", left=0.328125, right=0.84375, top=0.390625, bottom=0.890625),
# ]))

In [17]:
def render_output_test_images(save_path_and_epoch_list):
    with open(pathstr(save_path_and_epoch_list[0][0], "train_info.json")) as f:
        train_info = json.load(f)

    test_chars = train_info["test"]["chars"]
    test_chars = [Char.from_dict(char) for char in test_chars]

    test_writers = train_info["test"]["writers"]
    num_test_writers = test_writers if isinstance(test_writers, int) else len(test_writers)

    output = []
    for i, char in enumerate(test_chars):

        for j, (save_path, epoch) in enumerate(save_path_and_epoch_list):
            if j == 0:
                forecast_output = ipywidgets.Output(layout=ipywidgets.Layout(width="50%"))
                plot_forecast(plt.gca(), 512, False, char)
                plt.tight_layout()
                forecast_output.append_display_data(plt.gcf().figure)
                plt.close()
                output.append(ipywidgets.VBox([ipywidgets.Label(char.to_formula_string()), forecast_output]))
            else:
                output.append(ipywidgets.Label())
            
            image = IPython.display.Image(pathstr(save_path, "generated", f"test_{i:0>2}_{epoch:0>4}.png"))
            if isinstance(image.data, str):
                image = IPython.display.Image(pathstr(save_path, "generated", f"test_{i:0>2}_{epoch:0>4}.jpg"))
            if isinstance(image.data, str):
                raise Exception(f"image not found: {image.data[:-5]}.(png|jpg)")
                
            image = ipywidgets.Image(
                value=image.data,
                layout=ipywidgets.Layout(
                    margin="0",
                    object_fit="cover",
                ),
            )
            
            output.append(ipywidgets.VBox(
                children=[ipywidgets.Label(f"{save_path.split('/')[-1]} ({epoch=})"), image],
                layout=ipywidgets.Layout(
                    width="100%",
                    height="100%",
                ),
            ))

    output = ipywidgets.GridBox(
        output,
        layout=ipywidgets.Layout(
            width="100%",
            grid_template_columns="2fr 8fr",
            grid_gap="8px",
        )
    )
    return output

# render_output_test_images([
#     ("output/rs normal ETL8G_400", 500),
#     ("output/rs normal ETL8G_400", 1000),
#     ("output/rs ignore_writer ETL8G_400", 1000),
#     ("output/rs ignore_writer ETL8G", 1000),
# ])

In [9]:
def render_train_output(save_path: str, test_epochs: list[int]):
    with open(pathstr(save_path, "train_info.json")) as f:
        train_info = json.load(f)
    
        loss = train_info["loss"]
    
    title_el = ipywidgets.Label(save_path)
    
    loss_el = ipywidgets.Output(layout=ipywidgets.Layout(width="50%"))
    plt.title("loss")
    plt.plot(range(len(loss)), loss)
    loss_el.append_display_data(plt.gcf().figure)
    plt.close()
    
    test_images = render_output_test_images([(save_path, e) for e in test_epochs])
    
    container_el = ipywidgets.VBox(
        [title_el, loss_el, test_images]
    )
    
    return container_el

VBox(children=(Label(value='/data/nosaka/project/RadicalStylist/output/rs test encode_type=1'), Output(layout=…

### 出力

In [None]:
# render_train_output(pathstr("output/rs test encode_type=1"), [1000])

In [24]:
# render_train_output(pathstr("output/rs test encode_type=2"), [1000])

In [25]:
# render_train_output(pathstr("output/rs test encode_type=3"), [1000])

In [31]:
render_output_test_images([
    ("./output/rs ignore_writer ETL8G_400", 1000),
    ("./output/test character_encode/encode_type=2", 1000),
    ("./output/test character_encode/encode_type=3", 1000),
])

GridBox(children=(VBox(children=(Label(value='何 = 亻 + 丁_1 + 口 + 丁_2'), Output(layout=Layout(width='50%'), outp…