# Kohya's Script + Lora Train

**Another Choice**: https://github.com/ddPn08/kohya-sd-scripts-webui

More info: https://github.com/wibus-wee/stable_diffusion_chilloutmix_ipynb

> Created by [@wibus-wee](https://github.com/wibus-wee)
>
> Reference: [camenduru/stable-diffusion-webui-colab](https://github.com/camenduru/stable-diffusion-webui-colab)

In [None]:
from IPython.display import display
import ipywidgets as widgets
import requests

#@title Choose checkpoint { display-mode: "form" }
endpoint = 'https://civitai.com/api/v1/models'
#@markdown Checkpoint（ If you want to use your own model, please select "others" and fill in the "Other checkpoint" below.）
checkpointName = 'Chilloutmix' #@param ["Chilloutmix", "Sunshinemix", "grapefruit_hentai", "others"]
#@markdown Other checkpoint （Checkpoint ID in civitai , only support one model)
checkpointID = '' #@param {type:"string"}
#@markdown Other checkpoint （Checkpoint download URL, only support one model)
checkpointURL = '' #@param {type:"string"}


defaultCheckpoint = {
    'Chilloutmix': '6424',
    'Sunshinemix': '9291',
    'grapefruit_hentai': '2583'
}

downloadIds = []

if checkpointID != '':
    downloadIds = checkpointID.split(',')
if checkpointName != 'others':
        downloadIds.append(defaultCheckpoint[checkpointName])

globalDropdowns = []
globalVerions = []
globalNames = []
globalTexts = []
checkpoints = []
downloadLinks = []

def text_on_submit(change):
    checkpoints[checkpoints.index(change['old'])] = change['new']

if checkpointURL != '':
    _downloadLinks = checkpointURL.split(',')
    for _downloadLink in _downloadLinks:
        checkpoints.append(_downloadLink.split('/')[-1])
        downloadLinks.append(_downloadLink)
        text = widgets.Text(value=_downloadLink.split('/')[-1], description=_downloadLink.split('/')[-1], disabled=False)
        text.observe(text_on_submit, names='value')
        form = widgets.VBox([text])
        display(form)

def showVerionOptions(downloadId):
    res = requests.get(endpoint + '/' + downloadId).json()
    globalNames.append(res['name'])
    versions = res['modelVersions']
    globalVerions.append(versions)
    options = []
    for version in versions:
        options.append(version['files'][0]['name'])
    dropdown = widgets.Dropdown(options=options, description=res['name'])
    globalDropdowns.append(dropdown)
    form = widgets.VBox([dropdown])
    display(form)

for downloadId in downloadIds:
    showVerionOptions(downloadId)

def on_button_clicked(b):
    downloadLink = None
    for dropdown in globalDropdowns:
        checkpoint = dropdown.value
        versions = globalVerions[globalDropdowns.index(dropdown)]
        for version in versions:
            if version['files'][0]['name'] == checkpoint:
                downloadLink = version['files'][0]['downloadUrl']
                break
        if downloadLink is None:
            print('Error: downloadLink not assigned')
            return
        checkpoints.append(checkpoint)
        downloadLinks.append(downloadLink)

    print("Checkpoint " + str(checkpoints) + " <===> " + str(downloadLinks))
    %store checkpoint
    %store downloadLink


button = widgets.Button(description='Use it!')
button.on_click(on_button_clicked)
display(button)

In [None]:
#@title 0. Check GPU & Check Environment { display-mode: "form" }

#@markdown This step will check if your GPU is supported by xformers, and check if you are using Paperspace (only M4000 GPU is checked, so paid GPUs may have logical errors here, you may need to check the "isPaperspace" checkbox manually).

import os, subprocess
paperspace_m4000 = False
#@markdown ⬇️ True: Paperspace | False: Colab
isPaperspace = False #@param {type:"boolean"}
try:
    subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], stdout=subprocess.PIPE)
    if 'M4000' in subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8'):
        print("WARNING: You are using Quadro M4000 GPU, which will not be able to use xformers.")
        paperspace_m4000 = True
        isPaperspace = True
    else:
        print("You are using a suitable GPU - " + subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8') + "."
except:
    print("The GPU is not available. Please check your runtime type.")
    exit()

rootDir = isPaperspace and '/tmp' or '/content'

%store rootDir
%store paperspace_m4000 
%store isPaperspace

In [None]:
#@title 1. Install Dependencies

#@markdown ⬇️ True: Install xformers | False: Skip (Only for Paperspace M4000 GPU, if you choose this option, xformers will be installed from source code)
xformersInstall = True #@param {type:"boolean"}

%store -r rootDir 
%store -r checkpoint
%store -r downloadLink
%store -r paperspace_m4000

!apt-get -y install -qq aria2
ariaInstalled = False

try:
    subprocess.run(['aria2c', '--version'], stdout=subprocess.PIPE)
    ariaInstalled = True
except:
    pass

!git clone --recurse-submodules https://github.com/Akegarasu/lora-scripts {rootDir}/lora-scripts

%cd {rootDir}/lora-scripts
if ariaInstalled:
    !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {downloadLink} -d {rootDir}/lora-scripts/sd-models -o {checkpoint}
else:
    !wget -c {downloadLink} -P {rootDir}/stable-diffusion-webui/models/Stable-diffusion -O {rootDir}/lora-scripts/sd-models/{checkpoint}


!echo "Installing deps..."
%cd ./sd-scripts
!pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
!pip install --upgrade -r requirements.txt
if paperspace_m4000:
  if xformersInstall:
    !pip install ninja
    !pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
    !pip install --pre triton
else:
  !pip install -q https://github.com/camenduru/stable-diffusion-webui-colab/releases/download/0.0.16/xformers-0.0.16+814314d.d20230118-cp38-cp38-linux_x86_64.whl
  !pip install --pre triton
!pip install --upgrade lion-pytorch

echo "Install completed"

In [None]:
#@title 2. Setup Training Options
%store -r rootDir 
%store -r checkpoint 

!source venv/bin/activate

#@markdown 底膜路径 - 已跟随 Checkpoint 设置 | Pretrained model path - Followed by Checkpoint
pretrained_model = rootDir + "/lora-scripts/sd-models/" + checkpoint
#@markdown 训练数据集路径 | Training dataset path
#@markdown PaperSpace 如果是在 Files 上传文件请加入 /notebooks/ 前缀，在 Google Colab 请加入 /content/ 前缀
#@markdown If you upload the file to Files in PaperSpace, please add the /notebooks/ prefix. If you upload the file to Google Colab, please add the /content/ prefix.
train_data_dir = "<your_prefix>/<img_parent_dir>" #@param {type:"string"}
#@markdown 导出模型路径 | Export model path
#@markdown PaperSpace 如果是在 Files 上传文件请加入 /notebooks/ 前缀，在 Google Colab 请加入 /content/ 前缀
#@markdown If you upload the file to Files in PaperSpace, please add the /notebooks/ prefix. If you upload the file to Google Colab, please add the /content/ prefix.
export_model_dir = "<your_prefix>/<model_export_dir>" #@param {type:"string"}
#@markdown 模型保存名称 | output model name
output_name = "lora_v1.0" #@param {type:"string"}
#@markdown  模型保存格式 | model save ext (ckpt, pt, safetensors)
save_model_as = "safetensors" #@param ["ckpt", "pt", "safetensors"] {type:"string"}
#@markdown 图片分辨率 | image resolution
#@markdown (宽,高). 支持非正方形，但必须是 64 倍数。 | (width, height). Non-square images are supported, but must be a multiple of 64.
resolution = "512,512" #@param ["256,256", "512,512", "1024,1024", "2048,2048"]
#@markdown batch size | 批大小
batch_size = 1 #@param {type:"number"}
#@markdown max train epoches | 最大训练 epoch
max_train_epoches = 10 #@param {type:"number"}
#@markdown save every n epochs | 每 N 个 epoch 保存一次
save_every_n_epochs = 2 #@param {type:"number"}
#@markdown network dim | 常用 4~128，不是越大越好 | A value between 4 and 128. Not necessarily the larger the better.
network_dim = 32 #@param {type:"number"}
#@markdown network alpha
#@markdown 常用与 network_dim 相同的值或者采用较小的值，如 network_dim的一半 防止下溢。默认值为 1，使用较小的 alpha 需要提升学习率。
#@markdown A value between 1 and network_dim. Not necessarily the larger the better. Default value is 1. Use a smaller alpha value to prevent underflow. You may need to increase the learning rate.
network_alpha = 32 #@param {type:"number"}
#@markdown clip skip | 玄学 一般用 2 | Usually 2.
clip_skip = 2 #@param {type:"number"}
#@markdown train U-Net only | 仅训练 U-Net，开启这个会牺牲效果大幅减少显存使用。6G显存可以开启 | This will sacrifice the quality of the model, but it will reduce the memory usage. 6G GPU memory can be enabled.
train_unet_only = False #@param {type:"boolean"}
#@markdown train Text Encoder only | 仅训练 文本编码器
train_text_encoder_only = False #@param {type:"boolean"}
#@markdown 学习率 | Learning rate
lr = "1e-4" #@param {type:"string"}
#@markdown unet_lr | U-Net 学习率 | Learning rate of the U-Net
unet_lr = "1e-4" #@param {type:"string"}
#@markdown text_encoder_lr | 文本编码器学习率 | Learning rate of the text encoder
text_encoder_lr = "1e-5" #@param {type:"string"}
#@markdown lr_scheduler | "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"
lr_scheduler = "cosine_with_restarts" #@param ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] {type:"string"}
#@markdown lr_warmup_steps | 仅在 lr_scheduler 为 constant_with_warmup 时需要填写这个值 | Only needed when lr_scheduler is constant_with_warmup
lr_warmup_steps = 0 #@param {type:"number"}
#@markdown 若需要从已有的 LoRA 模型上继续训练，请填写 LoRA 模型路径。 | If you want to continue training from an existing LoRA model, please fill in the LoRA model path.
network_weights = "" #@param {type:"string"}
#@markdown arb min resolution | arb 最小分辨率
min_bucket_reso = 256 #@param {type:"number"}
#@markdown arb max resolution | arb 最大分辨率
max_bucket_reso = 1024 #@param {type:"number"}
#@markdown persistent dataloader workers | 容易爆内存，保留加载训练集的worker，减少每个 epoch 之间的停顿 | Persistent dataloader workers, reduce the pause between each epoch
persistent_data_loader_workers = 0 #@param {type:"number"}
#@markdown use 8bit adam optimizer | 使用 8bit adam 优化器节省显存，默认启用。部分 10 系老显卡无法使用，修改为 0 禁用。 | Use 8bit adam optimizer to save memory. Default is enabled. Some 10 series old GPUs cannot use it. Set to 0 to disable.
use_8bit_adam = True #@param {type:"boolean"}
#@markdown use lion optimizer | 使用 Lion 优化器 | Use Lion optimizer
use_lion = False #@param {type:"boolean"}

In [None]:
#@title 4. Start training
!accelerate launch --num_cpu_threads_per_process=8 "./sd-scripts/train_network.py" \
  --enable_bucket \
  --pretrained_model_name_or_path=$pretrained_model \
  --train_data_dir=$train_data_dir \
  --output_dir=$export_model_dir \
  --logging_dir="/tmp/logs" \
  --resolution=$resolution \
  --network_module=networks.lora \
  --max_train_epochs=$max_train_epoches \
  --learning_rate=$lr \
  --unet_lr=$unet_lr \
  --text_encoder_lr=$text_encoder_lr \
  --lr_scheduler=$lr_scheduler \
  --lr_warmup_steps=$lr_warmup_steps \
  --network_dim=$network_dim \
  --network_alpha=$network_alpha \
  --output_name=$output_name \
  --train_batch_size=$batch_size \
  --save_every_n_epochs=$save_every_n_epochs \
  --mixed_precision="fp16" \
  --save_precision="fp16" \
  --seed="1337" \
  --cache_latents \
  --clip_skip=$clip_skip \
  --prior_loss_weight=1 \
  --max_token_length=225 \
  --caption_extension=".txt" \
  --save_model_as=$save_model_as \
  --min_bucket_reso=$min_bucket_reso \
  --max_bucket_reso=$max_bucket_reso \
  --use_8bit_adam=$use_8bit_adam \
  --shuffle_caption --xformer