# **Dataset Trainer**

In [None]:
import os
import re
import toml
import shutil
import zipfile
from time import time
from IPython.display import Markdown, display

TRAINER_COMMIT_HASH = "89c30334016a27f1d890466151a61a46d8d1545e"


# These carry information from past executions
if "model_url" in globals():
  old_model_url = model_url
else:
  old_model_url = None
if "dependencies_installed" not in globals():
  dependencies_installed = False
if "model_file" not in globals():
  model_file = None


#@markdown ## Setup
dataset_name  = "" #@param {type:"string"}
output_name = "" #@param {type:"string"}
continue_from = "" #@param {type:"string"}
model_url = "https://civitai.com/api/download/models/127207?type=Model&format=SafeTensor&size=full&fp=fp32" #@param {type:"string"}
resolution = 768 #@param {type:"slider", min:512, max:1024, step:64}

#@markdown ## Steps
num_repeats = 5 #@param {type:"number"}
num_epochs = 30 #@param {type:"number"}
save_every_n_epochs = 2 #@param {type:"number"}
keep_only_last_n_epochs = 20 #@param {type:"number"}
train_batch_size = 3 #@param {type:"slider", min:1, max:8, step:1}

#@markdown ## Learning
unet_lr = 2e-4 #@param {type:"number"}
text_encoder_lr = 1e-4 #@param {type:"number"}

optimizer = "AdamW8bit" #@param ["AdamW8bit", "DAdaptation"]
optimizer_args = None
lr_scheduler = "cosine_with_restarts" #@param ["constant", "cosine", "cosine_with_restarts", "constant_with_warmup", "linear", "polynomial"]
lr_scheduler_number = 3 #@param {type:"number"}
lr_scheduler_num_cycles = lr_scheduler_number if lr_scheduler == "cosine_with_restarts" else 0
lr_scheduler_power = lr_scheduler_number if lr_scheduler == "polynomial" else 0
lr_warmup_ratio = 0.05 #@param {type:"slider", min:0.0, max:0.5, step:0.01}
lr_warmup_steps = 0
min_snr_gamma = True #@param {type:"boolean"}
min_snr_gamma_value = 5.0 if min_snr_gamma else None

#@markdown ## Network Structure
lora_type = "LoRA" #@param ["LoRA", "LoCon Lycoris", "LoHa Lycoris"]
network_dim = 128 #@param {type:"slider", min:1, max:128, step:1}
network_alpha = 64 #@param {type:"slider", min:1, max:128, step:1}
conv_dim = 8 #@param {type:"slider", min:1, max:64, step:1}
conv_alpha = 1 #@param {type:"slider", min:1, max:64, step:1}
conv_compression = False #@param {type:"boolean"}

if optimizer == "DAdaptation":
  optimizer_args = ["decouple=True","weight_decay=0.2"]
  unet_lr = 0.5
  text_encoder_lr = 0.5
  lr_scheduler = "constant_with_warmup"

network_module = "lycoris.kohya" if "Lycoris" in lora_type else "networks.lora"
network_args = None if lora_type == "LoRA" else [
  f"conv_dim={conv_dim}",
  f"conv_alpha={conv_alpha}",
]

if "Lycoris" in lora_type:
  network_args.append(f"algo={'loha' if 'LoHa' in lora_type else 'lora'}")
  network_args.append(f"disable_conv_cp={str(not conv_compression)}")

root_dir = "/content"
deps_dir = os.path.join(root_dir, "deps")
repo_dir = os.path.join(root_dir, "kohya-trainer")
main_dir      = os.path.join(root_dir, "drive/MyDrive/lora_training")
images_folder = os.path.join(main_dir, "datasets", dataset_name)
output_folder = os.path.join(main_dir, "output", output_name)
config_folder = os.path.join(main_dir, "config", output_name)
log_folder    = os.path.join(main_dir, "log")
metadata_file = os.path.join(images_folder, "metadata.json")

config_file = os.path.join(config_folder, "training_config.toml")
dataset_config_file = os.path.join(config_folder, "dataset_config.toml")
accelerate_config_file = os.path.join(repo_dir, "accelerate_config/config.yaml")


def setup_trainer():
  os.chdir(root_dir)
  !git clone https://github.com/kohya-ss/sd-scripts {repo_dir}
  os.chdir(repo_dir)
  !git reset --hard {TRAINER_COMMIT_HASH}

  # apply custom patches
  !wget -O - https://raw.githubusercontent.com/recris/kohya-colab/main/masked_loss_blur.patch | git apply -v
  !wget https://raw.githubusercontent.com/recris/kohya-colab/main/requirements.txt -q -O requirements.txt
  !sed -i "s@cpu@cuda@" library/model_util.py # low ram
  !sed -i 's/from PIL import Image/from PIL import Image, ImageFile\nImageFile.LOAD_TRUNCATED_IMAGES=True/g' library/train_util.py # fix truncated jpegs error

  !apt -y update -qq
  !apt -y install aria2 unzip
  !aria2c https://huggingface.co/munsy0227/kohya-colab-file/resolve/main/xformers-0.0.21.dev564-cp310-cp310-manylinux2014_x86_64.whl
  !mv bd656faa4ec396d1bb365f7b58d2831eaac129ac82076d20084f74dcd0399b08 xformers-0.0.21.dev564-cp310-cp310-manylinux2014_x86_64.whl
  !pip -q install --upgrade -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118
  !pip -q install --upgrade xformers-0.0.21.dev564-cp310-cp310-manylinux2014_x86_64.whl

  from accelerate.utils import write_basic_config
  if not os.path.exists(accelerate_config_file):
    write_basic_config(save_location=accelerate_config_file)

  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
  os.environ["BITSANDBYTES_NOWELCOME"] = "1"
  os.environ["SAFETENSORS_FAST_GPU"] = "1"


def validate_dataset():
  global lr_warmup_steps, lr_warmup_ratio, caption_extension
  supported_types = (".png", ".jpg", ".jpeg", ".webp")

  print("\n💿 Checking dataset...")
  if not dataset_name.strip() or any(c in dataset_name for c in " .()\"'\\/"):
    print("💥 Error: Please choose a valid project name.")
    return

  folders = [images_folder]
  files = os.listdir(images_folder)
  images_repeats = {images_folder: (len([f for f in files if f.lower().endswith(supported_types)]), num_repeats)}

  for folder in folders:
    if not os.path.exists(folder):
      print(f"💥 Error: The folder {folder.replace('/content/drive/', '')} doesn't exist.")
      return
  for folder, (img, rep) in images_repeats.items():
    if not img:
      print(f"💥 Error: Your {folder.replace('/content/drive/', '')} folder is empty.")
      return

  if not os.path.exists(metadata_file):
    print(f"💥 Error: Images metadata file does not exist.")
    return

  if continue_from and not (continue_from.endswith(".safetensors") and os.path.exists(continue_from)):
    print(f"💥 Error: Invalid path to existing Lora. Example: /content/drive/MyDrive/lora_training/checkpoints/example.safetensors")
    return

  pre_steps_per_epoch = sum(img*rep for (img, rep) in images_repeats.values())
  steps_per_epoch = pre_steps_per_epoch/train_batch_size
  total_steps = int(num_epochs*steps_per_epoch)
  estimated_epochs = int(total_steps/steps_per_epoch)
  lr_warmup_steps = int(total_steps*lr_warmup_ratio)

  for folder, (img, rep) in images_repeats.items():
    print("📁"+folder.replace("/content/drive/", ""))
    print(f"📈 Found {img} images with {rep} repeats, equaling {img*rep} steps.")

  print(f"📉 Divide {pre_steps_per_epoch} steps by {train_batch_size} batch size to get {steps_per_epoch} steps per epoch.")
  print(f"🔮 There will be {num_epochs} epochs, for around {total_steps} total training steps.")

  if total_steps > 20000:
    print("💥 Error: Your total steps are too high. You probably made a mistake. Aborting...")
    return

  return True


def create_config():
  global dataset_config_file, config_file, model_file

  config_dict = {
    "additional_network_arguments": {
      "unet_lr": unet_lr,
      "text_encoder_lr": text_encoder_lr,
      "network_dim": network_dim,
      "network_alpha": network_alpha,
      "network_module": network_module,
      "network_args": network_args,
      "network_train_unet_only": True if text_encoder_lr == 0 else None,
      "network_weights": continue_from if continue_from else None,
      "network_dropout": 0.15,
      "scale_weight_norms": 2.0
    },
    "optimizer_arguments": {
      "learning_rate": unet_lr,
      "lr_scheduler": lr_scheduler,
      "lr_scheduler_num_cycles": lr_scheduler_num_cycles if lr_scheduler == "cosine_with_restarts" else None,
      "lr_scheduler_power": lr_scheduler_power if lr_scheduler == "polynomial" else None,
      "lr_warmup_steps": lr_warmup_steps if lr_scheduler != "constant" else None,
      "optimizer_type": optimizer,
      "optimizer_args": optimizer_args if optimizer_args else None,
    },
    "training_arguments": {
      "max_train_steps": None,
      "max_train_epochs": num_epochs,
      "save_every_n_epochs": save_every_n_epochs,
      "save_last_n_epochs": keep_only_last_n_epochs,
      "train_batch_size": train_batch_size,
      "noise_offset": None,
      "clip_skip": 1,
      "min_snr_gamma": min_snr_gamma_value,
      "weighted_captions": False,
      "seed": 42,
      "max_token_length": 225,
      "xformers": True,
      "lowram": True,
      "max_data_loader_n_workers": 8,
      "persistent_data_loader_workers": True,
      "save_precision": "fp16",
      "mixed_precision": "fp16",
      "output_dir": output_folder,
      "logging_dir": log_folder,
      "output_name": output_name,
      "log_prefix": output_name,
      "save_state": False,
      "save_last_n_epochs_state": None,
      "resume": None,
      "masked_loss": True
    },
    "model_arguments": {
      "pretrained_model_name_or_path": model_file,
      "v2": False,
      "v_parameterization": None,
    },
    "saving_arguments": {
      "save_model_as": "safetensors",
    },
    "dreambooth_arguments": {
      "prior_loss_weight": 1.0,
    },
    "dataset_arguments": {
      "cache_latents": True,
    },
  }

  for key in config_dict:
    if isinstance(config_dict[key], dict):
      config_dict[key] = {k: v for k, v in config_dict[key].items() if v is not None}

  with open(config_file, "w") as f:
    f.write(toml.dumps(config_dict))

  print(f"\n📄 Config saved to {config_file}")

  dataset_config_dict = {
    "general": {
      "resolution": resolution,
      "shuffle_caption": False,
      "keep_tokens": 1,
      "flip_aug": False,
      "caption_extension": ".txt",
      "enable_bucket": True,
      "bucket_reso_steps": 64,
      "bucket_no_upscale": False,
      "min_bucket_reso": 320 if resolution > 640 else 256,
      "max_bucket_reso": 1280 if resolution > 640 else 1024,
    },
    "datasets": [
      {
        "subsets": [
          {
            "num_repeats": num_repeats,
            "image_dir": images_folder,
            "metadata_file": metadata_file,
            "class_tokens": None
          }
        ]
      }
    ]
  }

  for key in dataset_config_dict:
    if isinstance(dataset_config_dict[key], dict):
      dataset_config_dict[key] = {k: v for k, v in dataset_config_dict[key].items() if v is not None}

  with open(dataset_config_file, "w") as f:
    f.write(toml.dumps(dataset_config_dict))

  print(f"📄 Dataset config saved to {dataset_config_file}")


def download_model():
  global old_model_url, model_url, model_file
  real_model_url = model_url.strip()

  if real_model_url.lower().endswith((".ckpt", ".safetensors")):
    model_file = f"/content{real_model_url[real_model_url.rfind('/'):]}"
  else:
    model_file = "/content/downloaded_model.safetensors"
    if os.path.exists(model_file):
      !rm "{model_file}"

  if m := re.search(r"(?:https?://)?(?:www\.)?huggingface\.co/[^/]+/[^/]+/blob", model_url):
    real_model_url = real_model_url.replace("blob", "resolve")
  elif m := re.search(r"(?:https?://)?(?:www\.)?civitai\.com/models/([0-9]+)", model_url):
    real_model_url = f"https://civitai.com/api/download/models/{m.group(1)}"

  !aria2c "{real_model_url}" --console-log-level=warn -c -s 16 -x 16 -k 10M -d / -o "{model_file}"

  if model_file.lower().endswith(".safetensors"):
    from safetensors.torch import load_file as load_safetensors
    try:
      test = load_safetensors(model_file)
      del test
    except Exception as e:
      #if "HeaderTooLarge" in str(e):
      new_model_file = os.path.splitext(model_file)[0]+".ckpt"
      !mv "{model_file}" "{new_model_file}"
      model_file = new_model_file
      print(f"Renamed model to {os.path.splitext(model_file)[0]}.ckpt")

  if model_file.lower().endswith(".ckpt"):
    from torch import load as load_ckpt
    try:
      test = load_ckpt(model_file)
      del test
    except Exception as e:
      return False

  return True


def main():
  global dependencies_installed

  if not os.path.exists('/content/drive'):
    from google.colab import drive
    print("📂 Connecting to Google Drive...")
    drive.mount('/content/drive')

  for dir in (main_dir, deps_dir, repo_dir, log_folder, images_folder, output_folder, config_folder):
    os.makedirs(dir, exist_ok=True)

  if not validate_dataset():
    return

  if not dependencies_installed:
    print("\n🏭 Installing dependencies...\n")
    t0 = time()
    setup_trainer()
    t1 = time()
    dependencies_installed = True
    print(f"\n✅ Installation finished in {int(t1-t0)} seconds.")
  else:
    print("\n✅ Dependencies already installed.")

  if old_model_url != model_url or not model_file or not os.path.exists(model_file):
    print("\n🔄 Downloading model...")
    if not download_model():
      print("\n💥 Error: The model you selected is invalid or corrupted, or couldn't be downloaded. You can use a civitai or huggingface link, or any direct download link.")
      return
    print()
  else:
    print("\n🔄 Model already downloaded.\n")

  create_config()

  print("\n⭐ Starting trainer...\n")
  os.chdir(repo_dir)

  !accelerate launch --config_file={accelerate_config_file} --num_cpu_threads_per_process=1 train_network.py --dataset_config={dataset_config_file} --config_file={config_file}

  if not get_ipython().__dict__['user_ns']['_exit_code']:
    display(Markdown("### ✅ Done!"))

main()
