## 準備

In [1]:
import json
import math
import os
from collections import namedtuple
from dataclasses import dataclass
from glob import glob
from typing import Optional

import IPython

from matplotlib import pyplot as plt, patches
from tqdm import tqdm

import numpy as np

from utility import pathstr, char2code

In [2]:
import io

import IPython
import ipywidgets

import numpy as np

import PIL


def render_images(images, columns=None, scroll=False):
    if not isinstance(images, list):
        raise Exception()
    
    columns = columns or len(images)
    
    children = []
    for image in images:
        if isinstance(image, tuple):
            image, title = image
        else:
            title = None
        
        # path
        if isinstance(image, str):
            image = IPython.display.Image(image)
            if isinstance(image.data, str):
                raise Exception(f"image not found: {image.data}")
            image = ipywidgets.Image(value=image.data, layout=ipywidgets.Layout(margin="0", width="100%", object_fit="contain"))
        
        # ndarray
        elif isinstance(image, np.ndarray):
            image = PIL.Image.fromarray(image).convert("RGB")
            bytesio = io.BytesIO()
            image.save(bytesio, format="png")
            image = ipywidgets.Image(value=bytesio.getvalue(), layout=ipywidgets.Layout(margin="0", width="100%", object_fit="contain"))
        
        # ndarray
        elif isinstance(image, PIL.Image.Image):
            image = image.convert("RGB")
            bytesio = io.BytesIO()
            image.save(bytesio, format="png")
            image = ipywidgets.Image(value=bytesio.getvalue(), layout=ipywidgets.Layout(margin="0", width="100%", object_fit="contain"))
        
        else:
            raise Exception(f"unsupported image: {image}")
        
        if title is None:
            children.append(image)
        else:
            children.append(ipywidgets.VBox(
                [ipywidgets.Label(title), image],
                layout=ipywidgets.Layout(align_items="center"),
            ))
        
    grid = ipywidgets.GridBox(
        children=children,
        layout=ipywidgets.Layout(
            width="100%",
            height="fit-content",
            grid_template_columns=f"repeat({columns}, 1fr)",
            align_items="flex-end",
            grid_gap="8px",
        )
    )
    
    return grid

In [3]:
from matplotlib import pyplot as plt

from PIL import Image as PilImage, ImageDraw as PilImageDraw, ImageFont as PilImageFont

from radical import Radical, BoundingBox, ClusteringLabel


def plot_forecast(ax, image_size, char, radicals):
    root_image = PilImage.new("RGB", (image_size, image_size), color=(255, 255, 255))
    root_draw = PilImageDraw.Draw(root_image)
    
    root_draw.rectangle((0, 0, image_size - 1, image_size - 1), outline=(0, 0, 0), width=1)
    
    font = PilImageFont.truetype(pathstr("~/datadisk/dataset/font/NotoSansJP-Regular.ttf"), size=image_size // 4, index=0)
    
    for radical in radicals:
        if radical.position is None:
            pass
        elif isinstance(radical.position, BoundingBox):
            center_x = int(radical.position.center_x * image_size)
            center_y = int(radical.position.center_y * image_size)
            left = int(radical.position.left * image_size)
            right = int(radical.position.right * image_size)
            top = int(radical.position.top * image_size)
            bottom = int(radical.position.bottom * image_size)
            width = int(radical.position.width * image_size)
            height = int(radical.position.height * image_size)

            root_draw.rectangle((left, top, right, bottom), outline=(255, 0, 0))
            ax.annotate(radical.name, (left, top), ha="left", va="top", fontsize=16, fontweight="bold", color="red")

        elif isinstance(radical.position, ClusteringLabel):
            pass
        
        else:
            raise Exception(f"unknown position: {position}")

    ax.set_axis_off()
    ax.imshow(root_image, cmap="gray")


# plot_forecast(plt.gca(), 512, "倹", [
#     Radical("亻", BoundingBox(left=0.078125, right=0.3125, top=0.140625, bottom=0.875)),
#     Radical("人", BoundingBox(left=0.296875, right=0.859375, top=0.109375, bottom=0.453125)),
#     Radical("一", BoundingBox(left=0.4375, right=0.703125, top=0.34375, bottom=0.390625)),
#     Radical("口", BoundingBox(left=0.375, right=0.75, top=0.453125, bottom=0.671875)),
#     Radical("人", BoundingBox(left=0.328125, right=0.84375, top=0.390625, bottom=0.890625)),
# ])

In [4]:
from radical import BoundingBox


def parse_radical(dct):
    try:
        return Radical.from_dict(dct)
    except: # legacy
        return Radical(
            name=dct["name"],
            position=BoundingBox(
                left=dct["left"],
                right=dct["right"],
                top=dct["top"],
                bottom=dct["bottom"],
            ),
            idx_=dct["idx_"],
        )


def render_output_test_images(save_path_and_epoch_list):
    with open(pathstr(save_path_and_epoch_list[0][0], "train_info.json")) as f:
        train_info = json.load(f)
        epochs = train_info["epochs"]
    
    num_epochs_digit = len(str(epochs))

    test_radicallists = train_info["test"]["radicallists"]
    for el in test_radicallists:
        el["elements"] = [parse_radical(r) for r in el["elements"]]

    test_writers = train_info["test"]["writers"]
    num_test_writers = test_writers if isinstance(test_writers, int) else len(test_writers)

    output = []
    for i, el in enumerate(test_radicallists):
        radicalname = el["name"]
        radicallist = el["elements"]

        for j, (save_path, epoch) in enumerate(save_path_and_epoch_list):
            if j == 0:
                forecast_output = ipywidgets.Output(layout=ipywidgets.Layout(width="50%"))
                plot_forecast(plt.gca(), 512, radicalname, radicallist)
                plt.tight_layout()
                forecast_output.append_display_data(plt.gcf().figure)
                plt.close()
                label = ipywidgets.Label(f"{radicalname} = {' + '.join(map(lambda r: r.name, radicallist))}")
                output.append(ipywidgets.VBox([label, forecast_output]))
            else:
                output.append(ipywidgets.Label())
            
            image = IPython.display.Image(pathstr(save_path, "generated", f"test_{i:0>2}_{str(epoch).zfill(num_epochs_digit)}.png"))
            if isinstance(image.data, str):
                image = IPython.display.Image(pathstr(save_path, "generated", f"test_{i:0>2}_{str(epoch).zfill(num_epochs_digit)}.jpg"))
            if isinstance(image.data, str):
                raise Exception(f"image not found: {image.data[:-4]}.(png|jpg)")
            
            output.append(ipywidgets.VBox(
                children=[
                    ipywidgets.Label(
                        f"{save_path.split('/')[-1]} ({epoch=})",
                        overflow="hidden",
                    ),
                    ipywidgets.Image(
                        value=image.data,
                        layout=ipywidgets.Layout(
                            margin="0",
                            object_fit="cover",
                        ),
                    ),
                ],
                layout=ipywidgets.Layout(
                    width="100%",
                    height="100%",
                ),
            ))

    output = ipywidgets.GridBox(
        output,
        layout=ipywidgets.Layout(
            width="min(1024px, 100%)",
            grid_template_columns="2fr 8fr",
            grid_gap="8px",
        )
    )
    return output


# render_output_test_images([
#     ("output/rs test encode_type=3 depth=binary-random", 1000),
# ])

In [5]:
def render_train_output(save_path: str, test_epochs: list[int]):
    with open(pathstr(save_path, "train_info.json")) as f:
        train_info = json.load(f)
    
        loss = train_info["loss"]
    
    title_el = ipywidgets.Label(
        save_path,
        layout=ipywidgets.Layout(
            width="calc(100% - 2 * var(--jp-widgets-margin))",
            height="fit-content",
            white_space="normal",
        ),
    )
    
    loss_el = ipywidgets.Output(layout=ipywidgets.Layout(width="50%"))
    plt.title("loss")
    plt.plot(range(len(loss)), loss)
    loss_el.append_display_data(plt.gcf().figure)
    plt.close()
    
    test_images = render_output_test_images([(save_path, e) for e in test_epochs])
    
    container_el = ipywidgets.VBox(
        [title_el, loss_el, test_images],
        layout=ipywidgets.Layout(width="min(1024px, 100%)"),
    )
    
    return container_el

## 結果

### クラスタリング

In [6]:
# render_train_output(pathstr("./output/test writer_mode=dataset/ETL8G_400+KVG_radical(legacy) (encode_type=3, radical_depth=max)"), [1000])

In [7]:
# render_train_output(pathstr("./output/test writer_mode=dataset/ETL8G_400+KVG_radical(pad=4,sw=2) encode_type=bbox"), [1000])

In [8]:
# render_train_output(pathstr("./output/test writer_mode=dataset/ETL8G+KVG_radical(pad=4,sw=2) (encode_type=3, radical_depth=max)"), [500])

In [9]:
# 部首埋め込み 768 次元のうち 128 次元をクラスタリングのラベルの学習埋め込みに
# render_train_output(pathstr("./output/test writer_mode=dataset/ETL8G_400+KVG_radical(pad=4,sw=2) encode_type=cl_0 clustering=256(legacy)"), [1000])

In [10]:
# KVG の padding を 4 種にした
render_train_output(pathstr("./output/test writer_mode=dataset/ETL8G_400+KVG_radical(pad={4,8,12,16},sw=2) encode_type=cl_0 clustering=256(legacy)"), [1000])

VBox(children=(Label(value='/data/nosaka/project/RadicalStylist/output/test writer_mode=dataset/ETL8G_400+KVG_…

### SVG 合成

In [11]:
"""
ETL8G: 141841 件

SVG 正規: 2672 件 * 4
ETL8G に含まれる部首しか持たない字のみ (常用漢字の個数に近い)

SVG 合成: 139169 件 * 4
ピクセルの被りが padding=4px, stroke_width=2px で 1 つも起こらないもののみ

SVG の * 4 は 4px, 8px, 12px, 16px の 4 種の余白を持つ画像を用意した分
"""
# render_train_output(pathstr("./output/test writer_mode=dataset/ETL8G(no_bg)+KVG_radical(pad={4,8,12,16},sw=2)+KVG_C(pad={4,8,12,16},sw=2,n_limit=140000) encode_type=cl_0"), [100])
render_train_output(pathstr("./output/test writer_mode=dataset/ETL8G+KVG_radical(pad={4,8,12,16},sw=2)+KVG_C(pad={4,8,12,16},sw=2,n_limit=139169) encode_type=cl_0"), [120])

VBox(children=(Label(value='/data/nosaka/project/RadicalStylist/output/test writer_mode=dataset/ETL8G+KVG_radi…

In [12]:
"""
ETL8G を 4 倍
"""
# render_train_output(pathstr("./output/test writer_mode=dataset/ETL8G(no_bg)*4+KVG_radical(pad={4,8,12,16},sw=2)+KVG_C(pad={4,8,12,16},sw=2,n_limit=139169) encode_type=cl_0"), [50])

'\nETL8G を 4 倍\n'

### MSE だと小さい部首より大きい部首のほうが優先されるかもなので変えてみる

In [13]:
# ピクセル毎に重みをつけて vae に入れてかけあわせてみる
"""
for images, radicallists, writerindices in pbar:
    weights = torch.ones_like(images)
    weights = self.vae.encode(weights)

    # # https://stackoverflow.com/questions/59831211/neighbours-of-a-cell-in-matrix-pytorch
    # box = torch.ones((3, 3), dtype=images.dtype, device=images.device, requires_grad=False)  
    # box = box / box.sum()
    # box = box[None, None, ...].repeat(images.size(1), 1, 1, 1)
    # weights = F.conv2d(images + (3 / 255), box, padding=1, groups=images.size(1))
    # weights = F.batch_norm(weights, torch.zeros(3), torch.ones(3))
    # weights = self.vae.encode(weights)

    ...

    loss = torch.mean(((noise - predicted_noise) ** 2) * weights)
"""

# render_train_output(pathstr("./output/test writer_mode=dataset/ETL8G_400+KVG_radical(pad={4,8,12,16},sw=2) encode_type=cl_0 loss=w_mse"), [1])

'\nfor images, radicallists, writerindices in pbar:\n    weights = torch.ones_like(images)\n    weights = self.vae.encode(weights)\n\n    # # https://stackoverflow.com/questions/59831211/neighbours-of-a-cell-in-matrix-pytorch\n    # box = torch.ones((3, 3), dtype=images.dtype, device=images.device, requires_grad=False)  \n    # box = box / box.sum()\n    # box = box[None, None, ...].repeat(images.size(1), 1, 1, 1)\n    # weights = F.conv2d(images + (3 / 255), box, padding=1, groups=images.size(1))\n    # weights = F.batch_norm(weights, torch.zeros(3), torch.ones(3))\n    # weights = self.vae.encode(weights)\n\n    ...\n\n    loss = torch.mean(((noise - predicted_noise) ** 2) * weights)\n'

In [14]:
# バッチ毎にピクセルの総和の逆数の重みをつける
"""
for images, radicallists, writerindices in pbar:
    weights = sum(images.size()) / images.sum(dim=(1, 2, 3)).to(device=self.device)
    
    ...

    loss = torch.sum(torch.mean(((noise - predicted_noise) ** 2), dim=(1, 2, 3)) * weights) / weights.sum()
"""

# render_train_output(pathstr("./output/test writer_mode=dataset/ETL8G_400+KVG_radical(pad={4,8,12,16},sw=2) encode_type=cl_0 loss=mse_w"), [1000])

'\nfor images, radicallists, writerindices in pbar:\n    weights = sum(images.size()) / images.sum(dim=(1, 2, 3)).to(device=self.device)\n    \n    ...\n\n    loss = torch.sum(torch.mean(((noise - predicted_noise) ** 2), dim=(1, 2, 3)) * weights) / weights.sum()\n'

In [15]:
render_output_test_images((
    (pathstr("./output/test writer_mode=dataset/ETL8G_400+KVG_radical(pad={4,8,12,16},sw=2) encode_type=cl_0 clustering=256(legacy)"), 1000),
    (pathstr("./output/test writer_mode=dataset/ETL8G_400+KVG_radical(pad={4,8,12,16},sw=2) encode_type=cl_0 loss=mse_w"), 1000),
))

GridBox(children=(VBox(children=(Label(value='何 = 亻 + 可'), Output(layout=Layout(width='50%'), outputs=({'outpu…

In [16]:
"""
埋め込みの次元を増やした
"""
# render_train_output(pathstr("./output/test writer_mode=dataset/ETL8G_400+KVG_radical(pad={4,8,12,16},sw=2) encode_type=cl_0 dim=1536"), [200])

'\n埋め込みの次元を増やした\n'

### VAE

#### NotoSans, NotoSerif

In [17]:
# render_train_output(pathstr("./output/test writer_mode=dataset vae=noto*2(n_random=65536,epoch=1)/ETL8G_400+KVG_radical(pad={4,8,12,16},sw=2) encode_type=cl_0"), [300])

In [18]:
# render_train_output(pathstr("./output/test writer_mode=dataset vae=noto*2(n_random=65536,epoch=80)/ETL8G*4+KVG_radical(pad={4,8,12,16},sw=2)+KVG_C(pad={4,8,12,16},sw=2,n_limit=139169) encode_type=cl_0"), [120])

#### さなりフォント

In [19]:
# render_train_output(pathstr("./output/test writer_mode=dataset vae=SanariFont001(n_random=262140,epoch=30)/ETL8G_400+KVG_radical(pad={4,8,12,16},sw=2) encode_type=cl_0"), [100])

In [20]:
# render_output_test_images((
#     (pathstr("./output/test writer_mode=dataset/ETL8G_400+KVG_radical(pad={4,8,12,16},sw=2) encode_type=cl_0 clustering=256(legacy)"), 1000),
#     (pathstr("./output/test writer_mode=dataset vae=SanariFont001(n_random=262140,epoch=30)/ETL8G_400+KVG_radical(pad={4,8,12,16},sw=2) encode_type=cl_0"), 1000),
# ))

### こどもの字を入れる

#### 試験

In [21]:
"""
ETL8G_onlykanji_train (127657 件)
こどもの漢字の 2 値化 (3719 件)
KVG (2672 * 4 件)
"""
# render_train_output(pathstr("./output/nlp2024/nlp2024+KVG(pad={4,8,12,16},sw=2) radenc=cl_0"), [400])

'\nETL8G_onlykanji_train (127657 件)\nこどもの漢字の 2 値化 (3719 件)\nKVG (2672 * 4 件)\n'

In [22]:
"""
ETL8G_onlykanji_train (127657 件)
こどもの漢字の 2 値化 (3719 件)
KVG (2672 * 4 件)
KVG 合成 (30000 * 4 件)
"""
# render_train_output(pathstr("./output/nlp2024/nlp2024+KVG(pad={4,8,12,16},sw=2)+KVG_C(pad={4,8,12,16},sw=2,n_limit=30000) radenc=cl_0"), [150])

'\nETL8G_onlykanji_train (127657 件)\nこどもの漢字の 2 値化 (3719 件)\nKVG (2672 * 4 件)\nKVG 合成 (30000 * 4 件)\n'

#### 本番

In [23]:
# render_train_output(pathstr("./output/nlp2024/character nlp2024(kana,kanji)"), [1])

In [34]:
render_train_output(pathstr("./output/nlp2024/character(d=384)_2 nlp2024(kana,kanji)"), [1000])

VBox(children=(Label(value='/data/nosaka/project/RadicalStylist/output/nlp2024/character(d=384)_2 nlp2024(kana…

In [28]:
"""
ETL9G_onlykanji_train(sheet=1-1000): 133281 件
こどもの漢字: 3719 件
KVG: 9964 * 4 件
"""
# render_train_output(pathstr("./output/nlp2024/radical nlp2024(kana,kanji)+KVG(pad={4,8,12,16},sw=2)"), [750])

'\nETL9G_onlykanji_train(sheet=1-1000): 133281 件\nこどもの漢字: 3719 件\nKVG: 9964 * 4 件\n'

In [35]:
render_train_output(pathstr("./output/nlp2024/radical384 nlp2024(kana,kanji)+KVG(pad={4,8,12,16},sw=2)"), [1000])

VBox(children=(Label(value='/data/nosaka/project/RadicalStylist/output/nlp2024/radical384 nlp2024(kana,kanji)+…