

# **Kohya Dreambooth Training**



# I. Install

In [None]:
# @title ## 1.1. Install Dependencies
# @markdown Clone Kohya Trainer from GitHub and check for updates. Use textbox below if you want to checkout other branch or old commit. Leave it empty to stay the HEAD on main.  This will also install the required libraries.
import os
import zipfile
import shutil
import time
from subprocess import getoutput
from IPython.utils import capture
from google.colab import drive

%store -r

# root_dir
root_dir = "/content"
deps_dir = os.path.join(root_dir, "deps")
repo_dir = os.path.join(root_dir, "kohya-trainer")
training_dir = os.path.join(root_dir, "dreambooth")
pretrained_model = os.path.join(root_dir, "pretrained_model")
vae_dir = os.path.join(root_dir, "vae")
config_dir = os.path.join(training_dir, "config")

# repo_dir
accelerate_config = os.path.join(repo_dir, "accelerate_config/config.yaml")
tools_dir = os.path.join(repo_dir, "tools")
finetune_dir = os.path.join(repo_dir, "finetune")

for store in [
    "root_dir",
    "deps_dir",
    "repo_dir",
    "training_dir",
    "pretrained_model",
    "vae_dir",
    "accelerate_config",
    "tools_dir",
    "finetune_dir",
    "config_dir",
]:
    with capture.capture_output() as cap:
        %store {store}
        del cap

repo_url = "https://github.com/Linaqruf/kohya-trainer"
bitsandytes_main_py = "/usr/local/lib/python3.10/dist-packages/bitsandbytes/cuda_setup/main.py"
branch = ""  # @param {type: "string"}
install_xformers = True  # @param {'type':'boolean'}
mount_drive = True  # @param {type: "boolean"}
verbose = False # @param {type: "boolean"}

def read_file(filename):
    with open(filename, "r") as f:
        contents = f.read()
    return contents


def write_file(filename, contents):
    with open(filename, "w") as f:
        f.write(contents)


def clone_repo(url):
    if not os.path.exists(repo_dir):
        os.chdir(root_dir)
        !git clone {url} {repo_dir}
    else:
        os.chdir(repo_dir)
        !git pull origin {branch} if branch else !git pull


def install_dependencies():
    s = getoutput('nvidia-smi')

    if 'T4' in s:
        !sed -i "s@cpu@cuda@" library/model_util.py

    !pip install {'-q' if not verbose else ''} --upgrade -r requirements.txt
    !pip install {'-q' if not verbose else ''} torch==2.0.0+cu118 torchvision==0.15.1+cu118 torchaudio==2.0.1+cu118 torchtext==0.15.1 torchdata==0.6.0 --extra-index-url https://download.pytorch.org/whl/cu118 -U

    if install_xformers:
        !pip install {'-q' if not verbose else ''} xformers==0.0.19 triton==2.0.0 -U

    from accelerate.utils import write_basic_config

    if not os.path.exists(accelerate_config):
        write_basic_config(save_location=accelerate_config)


def remove_bitsandbytes_message(filename):
    welcome_message = """
def evaluate_cuda_setup():
    print('')
    print('='*35 + 'BUG REPORT' + '='*35)
    print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
    print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
    print('='*80)"""

    new_welcome_message = """
def evaluate_cuda_setup():
    import os
    if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
        print('')
        print('=' * 35 + 'BUG REPORT' + '=' * 35)
        print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
        print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
        print('To hide this message, set the BITSANDBYTES_NOWELCOME variable like so: export BITSANDBYTES_NOWELCOME=1')
        print('=' * 80)"""

    contents = read_file(filename)
    new_contents = contents.replace(welcome_message, new_welcome_message)
    write_file(filename, new_contents)


def main():
    os.chdir(root_dir)

    if mount_drive:
        if not os.path.exists("/content/drive"):
            drive.mount("/content/drive")

    for dir in [
        deps_dir,
        training_dir,
        config_dir,
        pretrained_model,
        vae_dir
    ]:
        os.makedirs(dir, exist_ok=True)

    clone_repo(repo_url)

    if branch:
        os.chdir(repo_dir)
        status = os.system(f"git checkout {branch}")
        if status != 0:
            raise Exception("Failed to checkout branch or commit")

    os.chdir(repo_dir)

    !apt install aria2 {'-qq' if not verbose else ''}

    install_dependencies()
    time.sleep(3)

    remove_bitsandbytes_message(bitsandytes_main_py)

    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["BITSANDBYTES_NOWELCOME"] = "1"
    os.environ["SAFETENSORS_FAST_GPU"] = "1"

    cuda_path = "/usr/local/cuda-11.8/targets/x86_64-linux/lib/"
    ld_library_path = os.environ.get("LD_LIBRARY_PATH", "")
    os.environ["LD_LIBRARY_PATH"] = f"{ld_library_path}:{cuda_path}"

main()


Mounted at /content/drive
Cloning into '/content/kohya-trainer'...
remote: Enumerating objects: 2500, done.[K
remote: Counting objects: 100% (953/953), done.[K
remote: Compressing objects: 100% (248/248), done.[K
remote: Total 2500 (delta 806), reused 707 (delta 705), pack-reused 1547[K
Receiving objects: 100% (2500/2500), 4.93 MiB | 12.92 MiB/s, done.
Resolving deltas: 100% (1670/1670), done.
The following additional packages will be installed:
  libaria2-0 libc-ares2
The following NEW packages will be installed:
  aria2 libaria2-0 libc-ares2
0 upgraded, 3 newly installed, 0 to remove and 6 not upgraded.
Need to get 1,513 kB of archives.
After this operation, 5,441 kB of additional disk space will be used.
Selecting previously unselected package libc-ares2:amd64.
(Reading database ... 120872 files and directories currently installed.)
Preparing to unpack .../libc-ares2_1.18.1-1ubuntu0.22.04.2_amd64.deb ...
Unpacking libc-ares2:amd64 (1.18.1-1ubuntu0.22.04.2) ...
Selecting previous

In [None]:
# @title ## 1.2. Start `File Explorer`
# @markdown This will work in real-time even when you run other cells
import threading
from google.colab import output
from imjoy_elfinder.app import main

open_in_new_tab = False  # @param {type:"boolean"}

def start_file_explorer(root_dir=root_dir, port=8765):
    try:
        main(["--root-dir=" + root_dir, "--port=" + str(port)])
    except Exception as e:
        print("Error starting file explorer:", str(e))


def open_file_explorer(open_in_new_tab=False, root_dir=root_dir, port=8765):
    thread = threading.Thread(target=start_file_explorer, args=[root_dir, port])
    thread.start()

    if open_in_new_tab:
        output.serve_kernel_port_as_window(port)
    else:
        output.serve_kernel_port_as_iframe(port, height="500")


# Example usage
open_file_explorer(open_in_new_tab=open_in_new_tab, root_dir=root_dir, port=8765)


# II.Select Pretrained Model

In [None]:
# @title ## 2.2. Download Custom Model
import os

%store -r

os.chdir(root_dir)

# @markdown ### Custom model
modelUrls = "https://huggingface.co/cyberdelia/CyberRealistic/blob/main/CyberRealistic_V3.2-FP16.safetensors"  # @param {'type': 'string'}

def install(url):
    base_name = os.path.basename(url)

    if "drive.google.com" in url:
        os.chdir(pretrained_model)
        !gdown --fuzzy {url}
    elif "huggingface.co" in url:
        if "/blob/" in url:
            url = url.replace("/blob/", "/resolve/")
        # @markdown Change this part with your own huggingface token if you need to download your private model
        hf_token = "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"  # @param {type:"string"}
        user_header = f'"Authorization: Bearer {hf_token}"'
        !aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 -d {pretrained_model} -o {base_name} {url}
    else:
        !aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {pretrained_model} {url}

if modelUrls:
    urls = modelUrls.split(",")
    for url in urls:
        install(url.strip())


In [None]:
# @title ## 2.3. Download Available VAE (Optional)
import os

%store -r

os.chdir(root_dir)

vaes = {
    "none": "",
    "anime.vae.pt": "https://huggingface.co/Linaqruf/personal-backup/resolve/main/vae/animevae.pt",
    "waifudiffusion.vae.pt": "https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime.ckpt",
    "stablediffusion.vae.pt": "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt",
}
install_vaes = []

# @markdown Select one of the VAEs to download, select `none` for not download VAE:
vae_name = "stablediffusion.vae.pt"  # @param ["none", "anime.vae.pt", "waifudiffusion.vae.pt", "stablediffusion.vae.pt"]

if vae_name in vaes:
    vae_url = vaes[vae_name]
    if vae_url:
        install_vaes.append((vae_name, vae_url))


def install(vae_name, url):
    hf_token = "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"
    user_header = f'"Authorization: Bearer {hf_token}"'
    !aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 -d {vae_dir} -o {vae_name} "{url}"


def install_vae():
    for vae in install_vaes:
        install(vae[0], vae[1])


install_vae()


# III. Data Acquisition

You have three options for acquiring your dataset:

1. Uploading it to Colab's local files.
2. Bulk downloading images from Danbooru using the `Simple Booru Scraper`.
3. Locating your dataset from Google Drive.


In [None]:
# @title ## 3.1. Locating Train Data Directory
# @markdown Define the location of your training data. This cell will also create a folder based on your input. Regularization Images is optional and can be skipped.
import os
from IPython.utils import capture

%store -r

train_data_dir = "/content/drive/MyDrive/train_data"  # @param {type:'string'}
reg_data_dir = "/content/drive/MyDrive/reg_data"  # @param {type:'string'}

for dir in [train_data_dir, reg_data_dir]:
    if dir:
        with capture.capture_output() as cap:
            os.makedirs(dir, exist_ok=True)
            %store dir
            del cap

print(f"Your train data directory : {train_data_dir}")
if reg_data_dir:
    print(f"Your reg data directory : {reg_data_dir}")

Your train data directory : /content/dreambooth/train_data
Your reg data directory : /content/dreambooth/reg_data


# IV. Data Preprocessing

## 4.2. Data Annotation
You can choose to train a model using captions. We're using [BLIP](https://huggingface.co/spaces/Salesforce/BLIP) for image captioning and [Waifu Diffusion 1.4 Tagger](https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags) for image tagging similar to Danbooru.
- Use BLIP Captioning for: `General Images`
- Use Waifu Diffusion 1.4 Tagger V2 for: `Anime and Manga-style Images`

In [None]:
#@title ### 4.2.1. BLIP Captioning
#@markdown BLIP is a pre-training framework for unified vision-language understanding and generation, which achieves state-of-the-art results on a wide range of vision-language tasks. It can be used as a tool for image captioning, for example, `astronaut riding a horse in space`.
import os

os.chdir(finetune_dir)

batch_size = 8 #@param {type:'number'}
max_data_loader_n_workers = 2 #@param {type:'number'}
beam_search = True #@param {type:'boolean'}
min_length = 10 #@param {type:"slider", min:0, max:100, step:5.0}
max_length = 75 #@param {type:"slider", min:0, max:100, step:5.0}
#@markdown Use the `recursive` option to process subfolders as well, useful for multi-concept training.
recursive = False #@param {type:"boolean"}
#@markdown Debug while captioning, it will print your image file with generated captions.
verbose_logging = True #@param {type:"boolean"}

config = {
    "_train_data_dir" : train_data_dir,
    "batch_size" : batch_size,
    "beam_search" : beam_search,
    "min_length" : min_length,
    "max_length" : max_length,
    "debug" : verbose_logging,
    "caption_extension" : ".caption",
    "max_data_loader_n_workers" : max_data_loader_n_workers,
    "recursive" : recursive
}

args = ""
for k, v in config.items():
    if k.startswith("_"):
        args += f'"{v}" '
    elif isinstance(v, str):
        args += f'--{k}="{v}" '
    elif isinstance(v, bool) and v:
        args += f"--{k} "
    elif isinstance(v, float) and not isinstance(v, bool):
        args += f"--{k}={v} "
    elif isinstance(v, int) and not isinstance(v, bool):
        args += f"--{k}={v} "

final_args = f"python make_captions.py {args}"

os.chdir(finetune_dir)
!{final_args}

# VII. Training Model



In [None]:
# @title ## 5.1. Model Config
from google.colab import drive

v2 = False  # @param {type:"boolean"}
v_parameterization = False  # @param {type:"boolean"}
project_name = "-Digital-Realistic-Trainer"  # @param {type:"string"}
if not project_name:
    project_name = "last"
%store project_name
pretrained_model_name_or_path = "/content/pretrained_model/CyberRealistic_V3.2-FP16.safetensors"  # @param {type:"string"}
vae = "/content/vae/stablediffusion.vae.pt"  # @param {type:"string"}
output_dir = "/content/dreambooth/output"  # @param {'type':'string'}
resume_path = ""  # @param {'type':'string'}

# @markdown `output_to_drive` sets default `output_dir` to `/content/drive/MyDrive/dreambooth/output`. This will override the `output_dir` variable defined above.
output_to_drive = True  # @param {'type':'boolean'}

if output_to_drive:
    output_dir = "/content/drive/MyDrive/dreambooth/output"

    if not os.path.exists("/content/drive"):
        drive.mount("/content/drive")

sample_dir = os.path.join(output_dir, "sample")
for dir in [output_dir, sample_dir]:
    os.makedirs(dir, exist_ok=True)

print("Project Name: ", project_name)
print("Model Version: Stable Diffusion V1.x") if not v2 else ""
print("Model Version: Stable Diffusion V2.x") if v2 and not v_parameterization else ""
print("Model Version: Stable Diffusion V2.x 768v") if v2 and v_parameterization else ""
print(
    "Pretrained Model Path: ", pretrained_model_name_or_path
) if pretrained_model_name_or_path else print("No Pretrained Model path specified.")
print("VAE Path: ", vae) if vae else print("No VAE path specified.")
print("Output Path: ", output_dir)
print("Resume Path: ", resume_path) if resume_path else print(
    "No resume path specified."
)

In [None]:
# @title ## 5.2. Dataset Config
import toml
import glob

# @markdown This configuration is designed for `one concept` training. Refer to this [guide](https://rentry.org/kohyaminiguide#b-multi-concept-training) for multi-concept training.
dataset_repeats = 15  # @param {type:"number"}
# @markdown `activation_word` is not used in training if you train with captions/tags, but it is still printed to metadata.
activation_word = "raw photo"  # @param {type:"string"}
caption_extension = ".txt"  # @param ["none", ".txt", ".caption"]
# @markdown Please refer to `4.2.3. Custom Caption/Tag (Optional)` if you want to append `activation_word` to captions/tags  **OR** specify `token_to_captions` below.
token_to_captions = False  # @param {'type':'boolean'}
resolution = 512  # @param {type:"slider", min:512, max:1024, step:128}
flip_aug = False  # @param {type:"boolean"}
keep_tokens = 0  # @param {type:"number"}

if ',' in activation_word or ' ' in activation_word:
    words = activation_word.replace(',', ' ').split()
    class_token = words[-1]


def read_file(filename):
    with open(filename, "r") as f:
        contents = f.read()
    return contents


def write_file(filename, contents):
    with open(filename, "w") as f:
        f.write(contents)


def get_supported_images(folder):
    supported_extensions = (".png", ".jpg", ".jpeg", ".webp", ".bmp")
    return [file for ext in supported_extensions for file in glob.glob(f"{folder}/*{ext}")]


def get_subfolders_with_supported_images(folder):
    subfolders = [os.path.join(folder, subfolder) for subfolder in os.listdir(folder) if os.path.isdir(os.path.join(folder, subfolder))]
    return [subfolder for subfolder in subfolders if len(get_supported_images(subfolder)) > 0]


def process_tags(filename, custom_tag, remove_tag):
    contents = read_file(filename)
    tags = [tag.strip() for tag in contents.split(',')]
    custom_tags = [tag.strip() for tag in custom_tag.split(',')]

    for custom_tag in custom_tags:
        custom_tag = custom_tag.replace("_", " ")
        if remove_tag:
            while custom_tag in tags:
                tags.remove(custom_tag)
        else:
            if custom_tag not in tags:
                tags.insert(0, custom_tag)

    contents = ', '.join(tags)
    write_file(filename, contents)


def process_folder_recursively(folder):
    for root, _, files in os.walk(folder):
        for file in files:
            if file.endswith(caption_extension):
                file_path = os.path.join(root, file)
                extracted_class_token = get_class_token_from_folder_name(root, folder)
                train_supported_images = get_supported_images(train_data_dir)
                tag = extracted_class_token if extracted_class_token else activation_word if train_supported_images else ""
                if not tag == "":
                    process_tags(file_path, tag, remove_tag=(not token_to_captions))


def get_num_repeats(folder):
    folder_name = os.path.basename(folder)
    try:
        repeats, _ = folder_name.split('_', 1)
        num_repeats = int(repeats)
    except ValueError:
        num_repeats = 1

    return num_repeats


def get_class_token_from_folder_name(folder, parent_folder):
    if folder == parent_folder:
        return class_token

    folder_name = os.path.basename(folder)
    try:
        _, concept = folder_name.split('_', 1)
        return concept
    except ValueError:
        return ""

train_supported_images = get_supported_images(train_data_dir)
train_subfolders = get_subfolders_with_supported_images(train_data_dir)
reg_supported_images = get_supported_images(reg_data_dir)
reg_subfolders = get_subfolders_with_supported_images(reg_data_dir)

subsets = []

config = {
    "general": {
        "enable_bucket": True,
        "caption_extension": caption_extension,
        "shuffle_caption": True,
        "keep_tokens": keep_tokens,
        "bucket_reso_steps": 64,
        "bucket_no_upscale": False,
    },
    "datasets": [
        {
            "resolution": resolution,
            "min_bucket_reso": 320 if resolution > 640 else 256,
            "max_bucket_reso": 1280 if resolution > 640 else 1024,
            "caption_dropout_rate": 0,
            "caption_tag_dropout_rate": 0,
            "caption_dropout_every_n_epochs": 0,
            "flip_aug": flip_aug,
            "color_aug": False,
            "face_crop_aug_range": None,
            "subsets": subsets,
        }
    ],
}

if token_to_captions and keep_tokens < 2:
    keep_tokens = 1

if caption_extension != "none":
    process_folder_recursively(train_data_dir)

if train_supported_images:
    subsets.append({
        "image_dir": train_data_dir,
        "class_tokens": activation_word,
        "num_repeats": dataset_repeats,
    })

for subfolder in train_subfolders:
    num_repeats = get_num_repeats(subfolder)
    extracted_class_token = get_class_token_from_folder_name(subfolder, train_data_dir)
    subsets.append({
        "image_dir": subfolder,
        "class_tokens": extracted_class_token if extracted_class_token else None,
        "num_repeats": num_repeats,
    })

if reg_supported_images:
    subsets.append({
        "is_reg": True,
        "image_dir": reg_data_dir,
        "class_tokens": class_token if 'class_token' in globals() else None,
        "num_repeats": 1,
    })

for subfolder in reg_subfolders:
    extracted_class_token = get_class_token_from_folder_name(subfolder, reg_data_dir)
    subsets.append({
        "is_reg": True,
        "image_dir": subfolder,
        "class_tokens": extracted_class_token if extracted_class_token else None,
        "num_repeats": num_repeats,
    })

for subset in subsets:
    if not glob.glob(f"{subset['image_dir']}/*.txt"):
        subset["class_tokens"] = activation_word

dataset_config = os.path.join(config_dir, "dataset_config.toml")

for key in config:
    if isinstance(config[key], dict):
        for sub_key in config[key]:
            if config[key][sub_key] == "":
                config[key][sub_key] = None
    elif config[key] == "":
        config[key] = None

config_str = toml.dumps(config)

with open(dataset_config, "w") as f:
    f.write(config_str)

print(config_str)


In [None]:
# @title ## 5.3. Optimizer Config
from IPython.utils import capture

# @markdown `NEW` Gamma for reducing the weight of high-loss timesteps. Lower numbers have a stronger effect. The paper recommends 5. Read the paper [here](https://arxiv.org/abs/2303.09556).
min_snr_gamma = -1 #@param {type:"number"}
# @markdown `AdamW8bit` was the old `--use_8bit_adam`.
optimizer_type = "AdamW8bit"  # @param ["AdamW", "AdamW8bit", "Lion", "SGDNesterov", "SGDNesterov8bit", "DAdaptation", "AdaFactor"]
# @markdown Additional arguments for optimizer, e.g: `["decouple=true","weight_decay=0.6"]`
optimizer_args = ""  # @param {'type':'string'}
# @markdown Set `learning_rate` to `1.0` if you use `DAdaptation` optimizer, as it's a [free learning rate](https://github.com/facebookresearch/dadaptation) algorithm.
# @markdown You probably need to specify `optimizer_args` for custom optimizer, like using `["decouple=true","weight_decay=0.6"]` for `DAdaptation`.
learning_rate = 2e-6  # @param {'type':'number'}
stop_train_text_encoder = -1 #@param {'type':'number'}
lr_scheduler = "constant"  # @param ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "adafactor"] {allow-input: false}
lr_warmup_steps = 0  # @param {'type':'number'}
# @markdown You can define `num_cycles` value for `cosine_with_restarts` or `power` value for `polynomial` in the field below.
lr_scheduler_num_cycles = 0  # @param {'type':'number'}
lr_scheduler_power = 0  # @param {'type':'number'}

print(f"  - Min-SNR Weighting: {min_snr_gamma}") if not min_snr_gamma == -1 else ""
print(f"Using {optimizer_type} as Optimizer")
if optimizer_args:
  print(f"Optimizer Args :", optimizer_args)
print("Learning rate: ", learning_rate)
if stop_train_text_encoder > 0:
  print(f"Text Encoder training stopped at {stop_train_text_encoder} steps")
print("Learning rate warmup steps: ", lr_warmup_steps)
print("Learning rate Scheduler:", lr_scheduler)
if lr_scheduler == "cosine_with_restarts":
  print("- lr_scheduler_num_cycles: ", lr_scheduler_num_cycles)
elif lr_scheduler == "polynomial":
  print("- lr_scheduler_power: ", lr_scheduler_power)

In [None]:
Config
from IPython.utils import capture

# @markdown `NEW` Gamma for reducing the weight of high-loss timesteps. Lower numbers have a stronger effect. The paper recommends 5. Read the paper [here](https://arxiv.org/abs/2303.09556).
min_snr_gamma = -1 #@param {type:"number"}
# @markdown `AdamW8bit` was the old `--use_8bit_adam`.
optimizer_type = "AdamW8bit"  # @param ["AdamW", "AdamW8bit", "Lion", "SGDNesterov", "SGDNesterov8bit", "DAdaptation", "AdaFactor"]
# @markdown Additional arguments for optimizer, e.g: `["decouple=true","weight_decay=0.6"]`
optimizer_args = ""  # @param {'type':'string'}
# @markdown Set `learning_rate` to `1.0` if you use `DAdaptation` optimizer, as it's a [free learning rate](https://github.com/facebookresearch/dadaptation) algorithm.
# @markdown You probably need to specify `optimizer_args` for custom optimizer, like using `["decouple=true","weight_decay=0.6"]` for `DAdaptation`.
learning_rate = 2e-6  # @param {'type':'number'}
stop_train_text_encoder = -1 #@param {'type':'number'}
lr_scheduler = "constant"  # @param ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "adafactor"] {allow-input: false}
lr_warmup_steps = 0  # @param {'type':'number'}
# @markdown You can define `num_cycles` value for `cosine_with_restarts` or `power` value for `polynomial` in the field below.
lr_scheduler_num_cycles = 0  # @param {'type':'number'}
lr_scheduler_power = 0  # @param {'type':'number'}

print(f"  - Min-SNR Weighting: {min_snr_gamma}") if not min_snr_gamma == -1 else ""
print(f"Using {optimizer_type} as Optimizer")
if optimizer_args:
  print(f"Optimizer Args :", optimizer_args)
print("Learning rate: ", learning_rate)
if stop_train_text_encoder > 0:
  print(f"Text Encoder training stopped at {stop_train_text_encoder} steps")
print("Learning rate warmup steps: ", lr_warmup_steps)
print("Learning rate Scheduler:", lr_scheduler)
if lr_scheduler == "cosine_with_restarts":
  print("- lr_scheduler_num_cycles: ", lr_scheduler_num_cycles)
elif lr_scheduler == "polynomial":
  print("- lr_scheduler_power: ", lr_scheduler_power)

In [None]:
# @title ## 5.4. Training Config
import toml
import os

%store -r
enable_sample_prompt = True  # @param {type:"boolean"}
sampler = "ddim"  # @param ["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver","dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"]
noise_offset = 0.0  # @param {type:"number"}
max_train_steps = 3000  # @param {type:"number"}
vae_batch_size = 1  # @param {type:"number"}
train_batch_size = 4  # @param {type:"number"}
mixed_precision = "fp16"  # @param ["no","fp16","bf16"] {allow-input: false}
save_state = True  # @param {type:"boolean"}
save_precision = "fp16"  # @param ["float", "fp16", "bf16"] {allow-input: false}
save_n_epoch_ratio = 1  # @param {type:"number"}
save_model_as = "safetensors"  # @param ["ckpt", "safetensors", "diffusers", "diffusers_safetensors"] {allow-input: false}
max_token_length = 225  # @param {type:"number"}
clip_skip = 2  # @param {type:"number"}
gradient_checkpointing = False  # @param {type:"boolean"}
gradient_accumulation_steps = 1  # @param {type:"number"}
seed = -1  # @param {type:"number"}
logging_dir = "/content/dreambooth/logs"
prior_loss_weight = 1.0

os.chdir(repo_dir)

sample_str = f"""
  woman,reality,highres,realistic,photo_(medium),mid_shot,front view,cinematic_angle,skinny,adult,kind_smile,raw photo,8k uhd,dslr,soft lighting,high quality,film grain,hyperrealistic,looking_at_viewer,underwear,standing,black_background  \
  --n lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry \
  --w 512 \
  --h 768 \
  --l 7 \
  --s 28
"""

config = {
    "model_arguments": {
        "v2": v2,
        "v_parameterization": v_parameterization if v2 and v_parameterization else False,
        "pretrained_model_name_or_path": pretrained_model_name_or_path,
        "vae": vae,
    },
    "optimizer_arguments": {
        "min_snr_gamma": min_snr_gamma if not min_snr_gamma == -1 else None,
        "optimizer_type": optimizer_type,
        "learning_rate": learning_rate,
        "max_grad_norm": 1.0,
        "stop_train_text_encoder": stop_train_text_encoder if stop_train_text_encoder > 0 else None,
        "optimizer_args": eval(optimizer_args) if optimizer_args else None,
        "lr_scheduler": lr_scheduler,
        "lr_warmup_steps": lr_warmup_steps,
        "lr_scheduler_num_cycles": lr_scheduler_num_cycles if lr_scheduler == "cosine_with_restarts" else None,
        "lr_scheduler_power": lr_scheduler_power if lr_scheduler == "polynomial" else None,
    },
    "dataset_arguments": {
        "cache_latents": True,
        "debug_dataset": False,
        "vae_batch_size": vae_batch_size,
    },
    "training_arguments": {
        "output_dir": output_dir,
        "output_name": project_name,
        "save_precision": save_precision,
        "save_every_n_epochs": None,
        "save_n_epoch_ratio": save_n_epoch_ratio,
        "save_last_n_epochs": None,
        "save_state": save_state,
        "save_last_n_epochs_state": None,
        "resume": resume_path,
        "train_batch_size": train_batch_size,
        "max_token_length": 225,
        "mem_eff_attn": False,
        "xformers": True,
        "max_train_steps": max_train_steps,
        "max_data_loader_n_workers": 8,
        "persistent_data_loader_workers": True,
        "seed": seed if seed > 0 else None,
        "gradient_checkpointing": gradient_checkpointing,
        "gradient_accumulation_steps": gradient_accumulation_steps,
        "mixed_precision": mixed_precision,
        "clip_skip": clip_skip if not v2 else None,
        "logging_dir": logging_dir,
        "log_prefix": project_name,
        "noise_offset": noise_offset if noise_offset > 0 else None,
    },
    "sample_prompt_arguments": {
        "sample_every_n_steps": 100 if enable_sample_prompt else 999999,
        "sample_every_n_epochs": None,
        "sample_sampler": sampler,
    },
    "dreambooth_arguments": {
        "prior_loss_weight": 1.0,
    },
    "saving_arguments": {
        "save_model_as": save_model_as
    },
}

config_path = os.path.join(config_dir, "config_file.toml")
prompt_path = os.path.join(config_dir, "sample_prompt.txt")

for key in config:
    if isinstance(config[key], dict):
        for sub_key in config[key]:
            if config[key][sub_key] == "":
                config[key][sub_key] = None
    elif config[key] == "":
        config[key] = None

config_str = toml.dumps(config)

def write_file(filename, contents):
    with open(filename, "w") as f:
        f.write(contents)

write_file(config_path, config_str)
write_file(prompt_path, sample_str)

print(config_str)

In [None]:
#@title ## 5.5. Start Training

#@markdown Check your config here if you want to edit something:
#@markdown - `sample_prompt` : /content/dreambooth/config/sample_prompt.txt
#@markdown - `config_file` : /content/dreambooth/config/config_file.toml
#@markdown - `dataset_config` : /content/dreambooth/config/dataset_config.toml

#@markdown Generated sample can be seen here: /content/dreambooth/output/sample

#@markdown You can import config from another session if you want.
sample_prompt = "/content/dreambooth/config/sample_prompt.txt" #@param {type:'string'}
config_file = "/content/dreambooth/config/config_file.toml" #@param {type:'string'}
dataset_config = "/content/dreambooth/config/dataset_config.toml" #@param {type:'string'}

accelerate_conf = {
    "config_file" : accelerate_config,
    "num_cpu_threads_per_process" : 1,
}

train_conf = {
    "sample_prompts" : sample_prompt,
    "dataset_config" : dataset_config,
    "config_file" : config_file
}

def train(config):
    args = ""
    for k, v in config.items():
        if k.startswith("_"):
            args += f'"{v}" '
        elif isinstance(v, str):
            args += f'--{k}="{v}" '
        elif isinstance(v, bool) and v:
            args += f"--{k} "
        elif isinstance(v, float) and not isinstance(v, bool):
            args += f"--{k}={v} "
        elif isinstance(v, int) and not isinstance(v, bool):
            args += f"--{k}={v} "

    return args

accelerate_args = train(accelerate_conf)
train_args = train(train_conf)
final_args = f"accelerate launch {accelerate_args} train_db.py {train_args}"

os.chdir(repo_dir)
!{final_args}

# VI. Testing

In [None]:
# @title ## 6.1. Visualize loss graph (Optional)
import os

training_logs_path = "/content/dreambooth/logs"  # @param {type : "string"}

os.chdir(repo_dir)
%load_ext tensorboard
%tensorboard --logdir {training_logs_path}

In [None]:
# @title ## 6.2. Inference
%store -r

v2 = False  # @param {type:"boolean"}
v_parameterization = False  # @param {type:"boolean"}
prompt = "masterpiece, best quality, 1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt"  # @param {type: "string"}
negative = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"  # @param {type: "string"}
model = ""  # @param {type: "string"}
vae = ""  # @param {type: "string"}
outdir = "/content/tmp"  # @param {type: "string"}
scale = 7  # @param {type: "slider", min: 1, max: 40}
sampler = "ddim"  # @param ["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver","dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"]
steps = 28  # @param {type: "slider", min: 1, max: 100}
precision = "fp16"  # @param ["fp16", "bf16"] {allow-input: false}
width = 512  # @param {type: "integer"}
height = 768  # @param {type: "integer"}
images_per_prompt = 4  # @param {type: "integer"}
batch_size = 4  # @param {type: "integer"}
clip_skip = 2  # @param {type: "slider", min: 1, max: 40}
seed = -1  # @param {type: "integer"}

final_prompt = f"{prompt} --n {negative}"

config = {
    "v2": v2,
    "v_parameterization": v_parameterization,
    "ckpt": model,
    "outdir": outdir,
    "xformers": True,
    "vae": vae if vae else None,
    "fp16": True,
    "W": width,
    "H": height,
    "seed": seed if seed > 0 else None,
    "scale": scale,
    "sampler": sampler,
    "steps": steps,
    "max_embeddings_multiples": 3,
    "batch_size": batch_size,
    "images_per_prompt": images_per_prompt,
    "clip_skip": clip_skip if not v2 else None,
    "prompt": final_prompt,
}

args = ""
for k, v in config.items():
    if isinstance(v, str):
        args += f'--{k}="{v}" '
    if isinstance(v, bool) and v:
        args += f"--{k} "
    if isinstance(v, float) and not isinstance(v, bool):
        args += f"--{k}={v} "
    if isinstance(v, int) and not isinstance(v, bool):
        args += f"--{k}={v} "

final_args = f"python gen_img_diffusers.py {args}"

os.chdir(repo_dir)
!{final_args}