In [22]:
import json
from datasets import load_dataset, Dataset, load_from_disk
from transformers import AutoTokenizer
from tqdm import tqdm
import numpy as np

In [23]:
CACHE_DIR = "/huggingface/cache"

TOKENIZER_NAME = "./dart-tokenizer-20240219"

DATASET_NAME = "isek-ai/danbooru-tags-2023"

In [24]:
tokenizer = AutoTokenizer.from_pretrained(
    "p1atdev/dart-tokenizer-v1", trust_remote_code=True, cache_dir=CACHE_DIR
)

A new version of the following files was downloaded from https://huggingface.co/p1atdev/dart-tokenizer-v1:
- tokenization_dart.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


tokenizer.json:   0%|          | 0.00/2.25M [00:00<?, ?B/s]

## Load filtered tags


In [25]:
ds = load_from_disk("danbooru-tags-filtered-20240219")

In [26]:
ds

Dataset({
    features: ['id', 'copyright', 'character', 'artist', 'general', 'meta', 'rating', 'score', 'created_at'],
    num_rows: 5293004
})

# Filter by date


In [27]:
from dateutil.parser import parse as parse_date

In [28]:
ds["created_at"][:10]

['2005-10-16T00:19:32.000+09:00',
 '2005-10-28T12:38:48.000+09:00',
 '2005-10-03T09:55:56.000+09:00',
 '2005-10-26T07:26:19.000+09:00',
 '2005-10-04T05:05:17.000+09:00',
 '2005-10-03T12:19:24.000+09:00',
 '2005-10-16T07:18:57.000+09:00',
 '2005-10-23T18:52:41.000+09:00',
 '2005-10-18T10:17:52.000+09:00',
 '2005-10-01T05:25:38.000+09:00']

In [29]:
parse_date("2005-10-16T00:19:32.000+09:00").year

2005

In [30]:
# use after 2020
ds = ds.filter(lambda x: parse_date(x["created_at"]).year >= 2020, batched=False)
ds

Dataset({
    features: ['id', 'copyright', 'character', 'artist', 'general', 'meta', 'rating', 'score', 'created_at'],
    num_rows: 2531560
})

## Concat tags


In [31]:
rating_map = {
    "g": "rating:general",
    "s": "rating:sensitive",
    "q": "rating:questionable",
    "e": "rating:questionable",
}

rating_parent_tag_map = {
    "g": "rating:sfw",
    "s": "rating:sfw",
    "q": "rating:nsfw",
    "e": "rating:nsfw",
}

In [32]:
assert all(
    [
        tokenizer.convert_tokens_to_ids(token) != tokenizer.unk_token
        for token in list(rating_map.values()) + list(rating_parent_tag_map.values())
    ]
)

In [33]:
BOS = "<|bos|>"
EOS = "<|eos|>"

RATING_BOS = "<rating>"
RATING_EOS = "</rating>"

COPYRIGHT_BOS = "<copyright>"
COPYRIGHT_EOS = "</copyright>"

CHARACTER_BOS = "<character>"
CHARACTER_EOS = "</character>"

GENERAL_BOS = "<general>"
GENERAL_EOS = "</general>"

INPUT_END = "<|input_end|>"  # boundary of input and output

VERY_VAGUE = "<|very_vague|>"
VAGUE = "<|vague|>"
DETAILED = "<|detailed|>"
VERY_DETAILED = "<|very_detailed|>"

In [34]:
assert all(
    [
        tokenizer.convert_tokens_to_ids(token) != tokenizer.unk_token
        for token in [
            BOS,
            EOS,
            RATING_BOS,
            RATING_EOS,
            COPYRIGHT_BOS,
            COPYRIGHT_EOS,
            CHARACTER_BOS,
            CHARACTER_EOS,
            GENERAL_BOS,
            GENERAL_EOS,
            INPUT_END,
            VERY_VAGUE,
            VAGUE,
            DETAILED,
            VERY_DETAILED,
        ]
    ]
)

In [35]:
ds

Dataset({
    features: ['id', 'copyright', 'character', 'artist', 'general', 'meta', 'rating', 'score', 'created_at'],
    num_rows: 2531560
})

In [36]:
from tag_manager import TagManger

In [37]:
tag_manaer = TagManger()

In [38]:
people_tags = tag_manaer.tags["PEOPLE"]
people_tags

['1girl',
 '2girls',
 '3girls',
 '4girls',
 '5girls',
 '6+girls',
 '1boy',
 '2boys',
 '3boys',
 '4boys',
 '5boys',
 '6+boys',
 '1other',
 '2others',
 '3others',
 '4others',
 '5others',
 '6+others',
 'no humans']

In [None]:
def shuffle_tags(tag_text: str):
    tags = tag_text.split(", ")
    np.random.shuffle(tags)
    return ", ".join(tags)


generator = np.random.default_rng()


def split_general_tags(
    general_tags: list[str],
    range_min: float = 0,
    range_max: float = 0.75,
    people_dropout_rate: float = 0.05,
):
    # isolate people tags
    isloted_tags, other_tags = [], []

    # 5% の確率で人物系のタグ全部ランダムにしてしまう
    if generator.random() <= people_dropout_rate:
        other_tags = general_tags
    else:
        for tag in general_tags:
            if tag in people_tags:
                isloted_tags.append(tag)
            else:
                other_tags.append(tag)

    # 範囲からランダムな割合を選択
    ratio = np.random.uniform(range_min, range_max)

    # 配列の要素をランダムに並べ替え
    np.random.shuffle(other_tags)

    # 選ばれた割合に基づいて分割位置を決定
    split_index = int(len(other_tags) * ratio)

    # 配列を2つのグループに分ける
    input_group = isloted_tags + other_tags[:split_index]
    output_group = other_tags[split_index:]

    # shuffle only input
    np.random.shuffle(input_group)

    # sort output group
    output_group = sorted(output_group)

    return ", ".join(input_group), ", ".join(output_group)


# 詳細度のタグを取得する
def get_detail_level_tag(count: int):
    if count <= 10:
        return VERY_VAGUE
    elif count <= 20:
        return VAGUE
    elif count <= 40:
        return DETAILED
    else:
        return VERY_DETAILED


def concat_tags(examples):
    all_tags = []

    for i, _ in enumerate(examples["id"]):
        rating = examples["rating"][i]
        copyright = examples["copyright"][i]
        character = examples["character"][i]
        general = examples["general"][i]

        assert rating is not None
        assert general is not None

        rating, rating_parent = rating_map[rating], rating_parent_tag_map[rating]
        rating = ", ".join([rating, rating_parent])

        if copyright is None:
            copyright = ""
        if character is None:
            character = ""

        rating = shuffle_tags(rating)
        copyright = shuffle_tags(copyright)
        character = shuffle_tags(character)

        general_tags = general.split(", ")

        detail_level = get_detail_level_tag(len(general_tags))

        input_tags, output_tags = split_general_tags(
            general_tags, range_min=0, range_max=0.75
        )

        tag_text = "".join(
            [
                BOS,
                RATING_BOS,
                rating,
                RATING_EOS,
                COPYRIGHT_BOS,
                copyright,
                COPYRIGHT_EOS,
                CHARACTER_BOS,
                character,
                CHARACTER_EOS,
                GENERAL_BOS,
                detail_level,
                input_tags,
                INPUT_END,
                output_tags,
                GENERAL_EOS,
                EOS,
            ]
        )
        all_tags.append(tag_text)

    return {"tag_text": all_tags}


ds = ds.map(concat_tags, batched=True)
ds

In [None]:
ds["tag_text"][100:110]

['<|bos|><rating>rating:nsfw, rating:questionable</rating><copyright>idolmaster, idolmaster shiny colors</copyright><character>mayuzumi fuyuko, producer (idolmaster)</character><general><|detailed|>1boy, 1girl, brown eyes<|input_end|>3koma, barefoot, black hair, black ribbon, blunt bangs, blush, breasts, comic, faceless, faceless male, heart, hetero, leg lock, long hair, looking at viewer, mating press, medium breasts, missionary, nude, open mouth, ribbon, sex, short hair, spoken heart, toe scrunch, tsundere, two side up</general><|eos|>',
 '<|bos|><rating>rating:questionable, rating:nsfw</rating><copyright>neon genesis evangelion, end of evangelion</copyright><character>souryuu asuka langley, mass production eva</character><general><|detailed|>mecha on girl, breasts, long fingers, 1girl, interspecies, highleg<|input_end|>adapted costume, alternate size, blue eyes, clothes lift, dress, drooling, fingering, hetero, highleg dress, interface headset, lips, long hair, monster, muscular, no

In [22]:
ds.save_to_disk("./danbooru-sfttext-20240221")

Saving the dataset (0/2 shards):   0%|          | 0/808012 [00:00<?, ? examples/s]

In [21]:
ds.push_to_hub("p1atdev/dart-sft-20240221")

Uploading the dataset shards:   0%|          | 0/6 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/422 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/422 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/422 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/422 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/422 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/422 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/p1atdev/dart-sft-20240221/commit/69b7ff78c91a26b07653df96455d59aa13f8827c', commit_message='Upload dataset', commit_description='', oid='69b7ff78c91a26b07653df96455d59aa13f8827c', pr_url=None, pr_revision=None, pr_num=None)

## Tokenize


In [4]:
ds = load_from_disk("./danbooru-tagtext-20240219")
ds

Dataset({
    features: ['id', 'copyright', 'character', 'artist', 'general', 'meta', 'rating', 'score', 'created_at', 'tag_text'],
    num_rows: 6524302
})

In [39]:
tokenizer(
    "<|bos|><rating>rating:nsfw, rating:questionable</rating><copyright>dragon quest, dragon quest v</copyright><character>hero's daughter (dq5), king slime, slime (dragon quest)</character><general>1girl, :d, ^_^, ass, bestiality, black eyes, blob, blonde hair, blue skin, blush, boots, bow, closed eyes, clothes lift, colored skin, crown, dress, dress lift, female orgasm, full body, gloves, hair bow, hat, head back, loli, no panties, open mouth, orgasm, rape, see-through, see-through body, short dress, simple background, sketch, smile, solo, tears, top-down bottom-up, vaginal, wince</general><|eos|>",
)

{'input_ids': [0, 4, 49, 46, 5, 6, 3252, 931, 7, 8, 12129, 3, 14631, 9, 10, 60268, 61285, 38967, 39211, 47069, 50914, 58886, 55184, 42984, 32162, 55932, 45343, 25530, 28649, 59997, 24239, 66120, 51088, 46454, 18581, 19188, 65308, 65393, 27019, 67294, 53925, 31442, 21535, 67510, 29954, 32664, 49077, 40333, 45530, 43500, 40044, 32039, 32815, 67386, 36057, 11, 1], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [40]:
column_names = ds.column_names
column_names.remove("tag_text")
column_names

['id',
 'copyright',
 'character',
 'artist',
 'general',
 'meta',
 'rating',
 'score',
 'created_at']

In [42]:
def tokenize_text(examples):
    tokenized = tokenizer(examples["tag_text"])
    input_ids = tokenized.input_ids

    # check not to have unknown copyright or character tags
    for i, token_ids in enumerate(input_ids):
        unk_token_idx = [
            i for i, x in enumerate(token_ids) if x == tokenizer.unk_token_id
        ]
        general_bos_idx = token_ids.index(10)

        if any([unk < general_bos_idx for unk in unk_token_idx]):
            raise Exception("unk before general!")

    input_ids = [  # remove unk token
        [token_id for token_id in item if token_id != tokenizer.unk_token_id]
        for item in input_ids
    ]

    return {**tokenized}


ds = ds.map(tokenize_text, batched=True, remove_columns=column_names)
ds

Map:   0%|          | 0/5293004 [00:00<?, ? examples/s]

Dataset({
    features: ['tag_text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 5293004
})

In [43]:
ds[1000]

{'tag_text': '<|bos|><rating>rating:sfw, rating:sensitive</rating><copyright>to heart (series), to heart 2</copyright><character>himeyuri ruri, himeyuri sango, ilfa (to heart)</character><general>3girls, animal ears, bun cover, buruma, cat ears, cat tail, comic, double bun, greyscale, gym uniform, hair bun, monochrome, mouse ears, mouse tail, multiple girls, tail, thighhighs</general><|eos|>',
 'input_ids': [0,
  4,
  48,
  45,
  5,
  6,
  499,
  2647,
  7,
  8,
  15158,
  7845,
  4179,
  9,
  10,
  47498,
  24847,
  23653,
  63183,
  34279,
  29591,
  46121,
  36792,
  43736,
  31133,
  33982,
  56065,
  17149,
  30195,
  18572,
  45564,
  19477,
  11,
  1],
 'token_type_ids': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,


In [44]:
ds.save_to_disk("./danbooru-tokenized-20240219")

Saving the dataset (0/9 shards):   0%|          | 0/5293004 [00:00<?, ? examples/s]

### push to hub


In [None]:
ds = load_from_disk("./danbooru-tokenized-20240219")
ds

In [45]:
ds.push_to_hub("p1atdev/dart-tokenized-pretrain-20240219")

Uploading the dataset shards:   0%|          | 0/9 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/589 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/589 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/589 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/589 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/589 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/589 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/589 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/589 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/589 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/p1atdev/dart-tokenized-pretrain-20240219/commit/013e3afc021ce12f418ba510be15cc56716caa85', commit_message='Upload dataset', commit_description='', oid='013e3afc021ce12f418ba510be15cc56716caa85', pr_url=None, pr_revision=None, pr_num=None)