In [1]:
import random
from colorutils import Color
from generation_utils import get_random_color_pair, get_tiled_option_cfgs
from configs import CFG1, TextBoxCFG, ImageCFG, Point, Margin
from generators import generate_data
from PIL import ImageFont
from random import randint, choice, uniform
from utils import split_sentence


# データ準備

In [2]:
from pathlib import Path
import pandas as pd

all_messages = pd.read_csv("texts/chatgpt_ruby.csv")

val_idx = random.sample(all_messages.index.values.tolist(), 100)
train_idx = [idx for idx in all_messages.index.values if idx not in val_idx]

messages = all_messages.loc[train_idx]
val_messages = all_messages.loc[val_idx]
names = pd.read_csv("texts/names.csv")

bgimage_paths = list(Path("bgimages").rglob(pattern="*.png"))
fgimage_paths = list(Path("fgimages").rglob(pattern="*.png"))
messages.shape, val_messages.shape, names.shape, len(bgimage_paths), len(fgimage_paths)

((8984, 9), (100, 9), (6480, 5), 976, 354)

In [3]:
import glob
message_fonts = glob.glob("fonts/fonts_message/*")
ruby_fonts = glob.glob("fonts/fonts_ruby/*")
len(message_fonts), len(ruby_fonts)

(11, 10)

# Patterns

## Niji

In [4]:
def generate_niji():
    cfg = CFG1()

    # 名前
    n_names = choice([0, 1])
    if n_names > 0:
        cfg.namebox.text = names.sample(1)["text"].values[0]
        cfg.namebox.font = ImageFont.truetype(font=choice(message_fonts), size=60) 
        cfg.namebox.font_hex = "#ffffff"
        cfg.namebox.bg_hex = Color(hsv=(uniform(0, 360), 0.7, 0.7)).hex
        cfg.namebox.bg_alpha = randint(160, 220)
        cfg.namebox.tl = Point(180, 650)
        cfg.namebox.br = Point(840, cfg.namebox.tl.y + cfg.namebox.minheight)
    else:
        cfg.namebox = None

    # メッセージ
    cfg.msgbox.text = messages.sample(1)["text"].values[0]
    cfg.msgbox.font = ImageFont.truetype(font=choice(message_fonts), size=randint(50, 60))
    cfg.msgbox.bg_alpha = randint(160, 220)

    font_hex, bg_hex = get_random_color_pair(s=0)
    cfg.msgbox.font_hex = font_hex
    cfg.msgbox.bg_hex = bg_hex

    cfg.msgbox.margin = Margin(top=80, right=200, left=200, bottom=10)
    cfg.msgbox.tl = Point(60, 700)
    cfg.msgbox.br = Point(1920-60, 990)


    # 背景画像
    cfg.bg_cfg = ImageCFG(path=choice(bgimage_paths))

    # キャラクター画像
    n_fg_images = choice([0, 1, 2])
    cfg.character_cfg_list = [
        ImageCFG(path=choice(fgimage_paths), tl=Point(int((cfg.W // (1+n_fg_images))*(i+0.5)), 0))
        for i in range(n_fg_images)]
    
    # 選択肢
    optionbox_list = []
    nrow = choice([0, 1, 2, 3, 4])
    ncol = choice([0, 1, 2])

    for _ in range(max(0, choice([nrow*ncol, nrow*ncol-1]))):
        option_text = split_sentence(messages.sample(1)["text"].values[0])[0]
        option_cfg = TextBoxCFG()
        option_cfg.text = option_text
        option_cfg.bg_alpha = randint(180, 255)
        option_cfg.tl = Point(270, 200)
        option_cfg.br = Point(1660, 600)
        option_cfg.font_hex = Color(hsv=(uniform(0, 360), uniform(0.6, 0.8), 1)).hex
        optionbox_list.append(option_cfg)
    
    # タイルレイアウト
    if len(optionbox_list) > 0:
        cfg.optionbox_list = get_tiled_option_cfgs(
            nrow=nrow, ncol=ncol, tl=Point(x=270, y=100), br=Point(x=1660, y=600), cfgs=optionbox_list
            )

    output = generate_data(cfg)
    gt_obj = {"gt_parse":output.to_gt_parse()}
    return output.image, gt_obj

## FGO

In [5]:
def generate_fgo():
    cfg = CFG1()

    # 名前
    n_names = choice([0, 1])
    if n_names > 0:
        cfg.namebox.text = names.sample(1)[choice(["text", "text_ruby_hiragana", "text_ruby_katakana"])].values[0]
        cfg.namebox.font = ImageFont.truetype(font=choice(message_fonts), size=60) 
        cfg.namebox.font_hex = Color(hsv=(0, 0, uniform(0.85, 1.0))).hex
        cfg.namebox.bg_hex = Color(hsv=(uniform(180, 220), uniform(0.75, 1.0), uniform(0.6, 0.8))).hex
        cfg.namebox.bg_alpha = randint(180, 220)
        cfg.namebox.tl = Point(170, 700)
        cfg.namebox.br = Point(830, cfg.namebox.tl.y + cfg.namebox.minheight)
    else:
        cfg.namebox = None

    # メッセージ
    cfg.msgbox.text = messages.sample(1)[choice(["text", "text_ruby_hiragana", "text_ruby_katakana"])].values[0]
    cfg.msgbox.font = ImageFont.truetype(font=choice(message_fonts), size=randint(50, 60))
    cfg.msgbox.bg_alpha = randint(180, 220)
    cfg.msgbox.font_hex = Color(hsv=(0, 0, uniform(0.85, 1.0))).hex
    cfg.msgbox.bg_hex = Color(hsv=(uniform(190, 230), uniform(0.75, 1.0), uniform(0.15, 0.35))).hex

    cfg.msgbox.margin = Margin(top=45, right=90, left=90, bottom=10)
    cfg.msgbox.tl = Point(190, 790)
    cfg.msgbox.br = Point(1920-190, 1060)

    # 背景画像
    cfg.bg_cfg = ImageCFG(path=choice(bgimage_paths))

    # キャラクター画像
    n_fg_images = choice([0, 1, 2])
    cfg.character_cfg_list = [
        ImageCFG(path=choice(fgimage_paths), tl=Point(int((cfg.W // (1+n_fg_images))*(i+0.5)), 0))
        for i in range(n_fg_images)]
    
    # 選択肢
    optionbox_list = []
    nrow = choice([1, 2, 3])
    ncol = choice([0, 1])

    for _ in range(nrow*ncol):
        option_text = split_sentence(messages.sample(1)[choice(["text", "text_ruby_hiragana", "text_ruby_katakana"])].values[0])[0]
        option_cfg = TextBoxCFG()
        option_cfg.text = option_text
        option_cfg.bg_alpha = randint(220, 255)
        option_cfg.tl = Point(350, 0)
        option_cfg.br = Point(1550, option_cfg.minheight*3)
        option_cfg.font_hex = Color(hsv=(0, 0, uniform(0.85, 1.0))).hex
        option_cfg.font = ImageFont.truetype(font=choice(message_fonts), size=randint(45, 55))
        option_cfg.ruby_font = ImageFont.truetype(font=choice(ruby_fonts), size=randint(12, 18))
        option_cfg.bg_hex = Color(hsv=(0, 0, uniform(0.0, 0.15))).hex
        optionbox_list.append(option_cfg)
    
    # タイルレイアウト
    if len(optionbox_list) > 0:
        fit_font = choice([True, False, False])
        cfg.optionbox_list = get_tiled_option_cfgs(
            nrow=nrow, ncol=ncol,
            tl=Point(x=randint(300, 400), y=randint(80, 160)),
            br=Point(x=randint(1520, 1620), y=randint(600, 700)),
            cfgs=optionbox_list,
            fit_font=fit_font,
            nowrap=not fit_font
            )

    output = generate_data(cfg)
    gt_obj = {"gt_parse":output.to_gt_parse()}
    return output.image, gt_obj

## Random1
名前box左上パターン

In [6]:
def generate_random1():
    cfg = CFG1()

    # 名前
    n_names = choice([0, 1])
    if n_names > 0:
        cfg.namebox.text = names.sample(1)[choice(["text", "text_ruby_hiragana", "text_ruby_katakana"])].values[0]
        cfg.namebox.font = ImageFont.truetype(font=choice(message_fonts), size=randint(45, 60))
        font_hex, bg_hex = get_random_color_pair()
        cfg.namebox.font_hex = font_hex
        cfg.namebox.bg_hex = bg_hex
        cfg.namebox.centering = choice([True, False])
        cfg.namebox.bg_alpha = randint(180, 220)
        cfg.namebox.tl = Point(choice([randint(50, 200), randint(740, 810)]), randint(500, 700))
        cfg.namebox.br = Point(
            cfg.namebox.tl.x + randint(300, 500),
            cfg.namebox.tl.y + cfg.namebox.minheight + randint(0, 20)
            )
    else:
        cfg.namebox = None

    # メッセージ
    n_messages = choice([0, 1, 1, 1, 1, 1, 1, 1])
    if n_messages > 0:

        cfg.msgbox.text = messages.sample(1)[choice(["text", "text_ruby_hiragana", "text_ruby_katakana"])].values[0]
        cfg.msgbox.font = ImageFont.truetype(font=choice(message_fonts), size=randint(45, choice([60, 60, 60, 80])))
        cfg.msgbox.bg_alpha = randint(180, 220)
        font_hex, bg_hex = get_random_color_pair()
        cfg.msgbox.font_hex = font_hex
        cfg.msgbox.bg_hex = bg_hex
        cfg.msgbox.centering = choice([True, False])
        cfg.msgbox.margin = Margin(top=randint(10, 50), right=randint(20, 100), left=randint(30, 100), bottom=10)
        cfg.msgbox.tl = Point(randint(50, 400), randint(600, 800) if cfg.namebox is None else cfg.namebox.br.y + randint(-20, 50))
        cfg.msgbox.br = Point(randint(1520, 1870), randint(cfg.msgbox.tl.y + cfg.msgbox.minheight, 1060))
    else:
        cfg.msgbox = None

    # 背景画像
    cfg.bg_cfg = ImageCFG(path=choice(bgimage_paths))

    # キャラクター画像
    n_fg_images = choice([0, 1, 2])
    cfg.character_cfg_list = [
        ImageCFG(path=choice(fgimage_paths), tl=Point(int((cfg.W // (1+n_fg_images))*(i+0.5)), 0))
        for i in range(n_fg_images)]
    
    # 選択肢
    optionbox_list = []
    nrow = choice([1, 2, 2, 3, 3, 4])
    ncol = choice([0, 1, 1, 1, 2])

    font_hex, bg_hex = get_random_color_pair()
    colorful_option_box = choice([True, False])
    centering_option_box = choice([True, False])
    for _ in range(nrow*ncol):
        option_text = split_sentence(messages.sample(1)[choice(["text", "text_ruby_hiragana", "text_ruby_katakana"])].values[0])[0]
        option_cfg = TextBoxCFG()
        option_cfg.text = option_text
        option_cfg.bg_alpha = randint(180, 220)
        if colorful_option_box:
            font_hex, bg_hex = get_random_color_pair()
        option_cfg.centering = centering_option_box
        option_cfg.font_hex = font_hex
        option_cfg.bg_hex = bg_hex
        option_cfg.tl = Point(randint(50, 300), 0)
        option_cfg.br = Point(randint(1620, 1870), option_cfg.minheight*3)
        option_cfg.font = ImageFont.truetype(font=choice(message_fonts), size=randint(45, 55))
        option_cfg.ruby_font = ImageFont.truetype(font=choice(ruby_fonts), size=randint(12, 18))
        option_cfg.margin = Margin(
            top=randint(10, 20), right=randint(10, 50), left=randint(20, 50), bottom=5)
        optionbox_list.append(option_cfg)
    
    # タイルレイアウト
    if len(optionbox_list) > 0:
        bottomy = cfg.namebox.tl.y if cfg.namebox is not None else cfg.msgbox.tl.y if cfg.msgbox is not None else 1080
        fit_font = choice([True, False, False])
        cfg.optionbox_list = get_tiled_option_cfgs(
            nrow=nrow, ncol=ncol,
            tl = Point(randint(50, 300), randint(10, 20 if nrow > 2 else 200)),
            br = Point(randint(1620, 1870), bottomy - randint(10, 20 if nrow > 2 else 200)),
            cfgs=optionbox_list,
            fit_font=fit_font,
            nowrap=not fit_font
            )

    output = generate_data(cfg)
    gt_obj = {"gt_parse":output.to_gt_parse()}
    return output.image, gt_obj

# 実行

In [7]:
import json
import datasets
from datasets import Dataset
images, gts = [], []
from tqdm import tqdm
for i in tqdm(range(1, 101)):
    try:
        # image, gt = generate_niji()
        # image, gt = generate_fgo()
        image, gt = generate_random1()
        output_image_path = f"output_images/{i:05d}.png"
        # display(gt, image)
        image.save(output_image_path)
        images.append(output_image_path)
        gts.append(gt)
    except Exception as e:
        print(i, e)
        raise

    # if len(images) == 200:
    #     df = pd.DataFrame([[img, json.dumps(gt, ensure_ascii=False)] for img, gt in zip(images, gts)], columns=["image", "ground_truth"])
    #     ds = Dataset.from_pandas(df)
    #     ds = ds.cast_column("image", datasets.Image())
    #     ds.save_to_disk(f"output_datasets/dataset_{i}")
    #     images, gts = [], []

100%|██████████| 100/100 [01:52<00:00,  1.13s/it]


In [16]:
dss = []
for i in range(200, 2001, 200):
    dss.append(datasets.load_from_disk(f"output_datasets/dataset_{i}"))
ds = datasets.concatenate_datasets(dss)
ds.save_to_disk("output_datasets/dataset_fgo")

Saving the dataset (0/9 shards):   0%|          | 0/2000 [00:00<?, ? examples/s]

In [59]:
# ds = datasets.load_from_disk("./output_datasets/dataset_niji_val/")
# ds[7]

# OLD

# Random(OLD)

In [None]:

def generate_random():
    cfg = CFG1()

    # 名前
    if random.random() < .75:
        name, name_ruby_hira, name_ruby_kata = names.sample(1)[["text", "text_ruby_hiragana", "text_ruby_katakana"]].values[0]
        cfg.name_text = choice([name, name_ruby_hira, name_ruby_kata])
        cfg.name_font_path = choice(message_fonts)
        cfg.name_font_color, cfg.namebox_hex = get_random_color_pair()
        cfg.namebox_alpha = randint(80, 255)
    minheight = cfg.namebox_minheight()
    is_namebox_above = random.random() < 0.5  # nameboxをmsgboxの上に配置するか下に配置するか
    cfg.namebox_tl = (
        choice([randint(10, 100), randint(540, 700)]),
        randint(600, 750) if is_namebox_above else randint(1080-(minheight+100), 1080-(minheight+10))
        )
    cfg.namebox_br = (cfg.namebox_tl[0] + randint(400, 800), cfg.namebox_tl[1] + randint(minheight, minheight+10))

    # メッセージ
    text, text_ruby_hira, text_ruby_kata = messages.sample(1)[["text", "text_ruby_hiragana", "text_ruby_katakana"]].values[0]
    cfg.text = choice([text, text_ruby_hira, text_ruby_kata])
    cfg.msg_font_path = choice(message_fonts)
    cfg.msg_font_color, cfg.msgbox_hex = get_random_color_pair()
    cfg.msgbox_alpha = randint(160, 255)
    cfg.msg_font_size = randint(36, 55)
    cfg.msg_ruby_font_size = randint(16, 24)
    cfg.msgbox_tl = (
        randint(20, 300),
        randint(cfg.namebox_br[1]-cfg.msg_margin.top, cfg.namebox_br[1]-cfg.msg_margin.top+50) if is_namebox_above else randint(540, 700)
        )
    cfg.msgbox_br = (
        randint(1600, 1880),
        randint(1000, 1055) if is_namebox_above else randint(cfg.namebox_tl[1]-30, cfg.namebox_tl[1]+10)
    )
    # 選択肢
    n_options = random.choice([0, 0, 0, 0, 0, 1, 1, 1, 2, 2])
    optionbox_centor = randint(800, 1100)
    optionbox_halfwidth = randint(500, 700)
    optionbox_height = randint(cfg.optionbox_minheight(), cfg.optionbox_minheight()+10)
    optionbox_top = randint(120, 250)
    cfg.optionbox_tl_list = [
        (optionbox_centor-optionbox_halfwidth, optionbox_top),
        (optionbox_centor-optionbox_halfwidth, optionbox_top + optionbox_height+ randint(10, 50))
    ]
    cfg.optionbox_br_list = [
        (optionbox_centor+optionbox_halfwidth, cfg.optionbox_tl_list[0][1]+optionbox_height),
        (optionbox_centor+optionbox_halfwidth, cfg.optionbox_tl_list[1][1]+optionbox_height)
    ]
    if n_options == 1:
        cfg.optionbox_tl_list = [(optionbox_centor-optionbox_halfwidth, (cfg.optionbox_tl_list[0][1]+cfg.optionbox_tl_list[1][1])//2)]
        cfg.optionbox_br_list = [(optionbox_centor+optionbox_halfwidth, (cfg.optionbox_br_list[0][1]+cfg.optionbox_br_list[1][1])//2)]
    elif n_options == 0:
        cfg.optionbox_tl_list = []
        cfg.optionbox_br_list = []
    option_texts = []
    for _ in range(n_options):
        op_text, op_text_ruby_hira, op_text_ruby_kata = messages.sample(1)[["text", "text_ruby_hiragana", "text_ruby_katakana"]].values[0]
        op_target_text = random.choice([op_text, op_text_ruby_hira, op_text_ruby_kata])
        op_target_text = random.choice(split_sentence(op_target_text)) # messagesのtextから選択肢には1文だけ取り出して使う
        option_texts.append(op_target_text)
    cfg.option_texts = option_texts
    cfg.option_font_path = random.choice(message_fonts)
    cfg.option_ruby_font_path = random.choice(ruby_fonts)
    cfg.option_font_color, cfg.optionbox_hex = get_random_color_pair()
    cfg.optionbox_alpha = random.randint(160, 255)
    cfg.option_font_size = random.randint(36, 60)
    cfg.option_ruby_font_size = random.randint(16, 24)

    # 背景画像
    bg_path = random.choice(bgimage_paths)
    cfg.bg_path = bg_path

    # 人物などの画像
    n_fg_images = random.choice([0, 1, 2])
    if n_fg_images == 1:
        cfg.fg_tl_list = [(500, 100)]
    elif n_fg_images == 2:
        cfg.fg_tl_list = [(200, 100), (1000, 100)]

    cfg.fg_pathlist = random.sample(fgimage_paths, n_fg_images)

    output = create_image(cfg)
    gt_obj = {
        "gt_parse":{"messages":[output.text], "names":[output.name_text], "options":output.option_texts},
        "meta":{"version":"0.1.0", "split":"train", "image_size":{"width":cfg.W, "height":cfg.H}}
        }
    return output.image, gt_obj