# SDXL LoRA Trainer for Likeness

This notebook trains a LoRA (Low-Rank Adaptation) on SDXL to capture your likeness for use in image generation.

**Requirements:**
- Google Colab with GPU runtime (T4 compatible)
- Training images stored in Google Drive with matching .txt caption files
- ~15-30 high-quality photos of the subject

**Training Data Format:**
```
your_folder/
  image001.jpg
  image001.txt  (caption describing the image)
  image002.png
  image002.txt
  ...
```

## 1. Setup & Installation

Install kohya_ss sd-scripts and dependencies. This takes a few minutes.

In [None]:
#@title 1.1 Install Dependencies
#@markdown Run this cell first. Takes ~3-5 minutes.

import subprocess
import sys
import os

# Check GPU
gpu_info = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader'], 
                          capture_output=True, text=True)
print(f"GPU: {gpu_info.stdout.strip()}")

# Install base requirements
!pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu121
!pip install -q xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121

# Clone kohya_ss sd-scripts
KOHYA_DIR = "/content/sd-scripts"
if not os.path.exists(KOHYA_DIR):
    !git clone https://github.com/kohya-ss/sd-scripts.git {KOHYA_DIR}
    %cd {KOHYA_DIR}
    !git checkout sd3  # Use latest stable branch with SDXL support
else:
    %cd {KOHYA_DIR}
    !git pull

# Install kohya requirements
!pip install -q -r requirements.txt
!pip install -q accelerate==0.25.0 transformers==4.36.2 diffusers==0.25.1
!pip install -q safetensors bitsandbytes==0.41.3 prodigyopt lion-pytorch
!pip install -q lycoris-lora

print("\n" + "="*50)
print("Installation complete!")
print("="*50)

In [None]:
#@title 1.2 Mount Google Drive
#@markdown Connect to your Google Drive to access training images.

from google.colab import drive
drive.mount('/content/drive')

# Verify mount
!ls /content/drive/MyDrive/ | head -10
print("\nGoogle Drive mounted successfully!")

## 2. Configuration

Configure your training parameters. The defaults are optimized for T4 GPU (15GB VRAM).

In [None]:
#@title 2.1 Training Configuration
#@markdown ### Paths
TRAINING_IMAGES_FOLDER = "/content/drive/MyDrive/lora_training/images" #@param {type:"string"}
OUTPUT_FOLDER = "/content/drive/MyDrive/lora_training/output" #@param {type:"string"}
LORA_NAME = "my_likeness" #@param {type:"string"}

#@markdown ### Base Model
#@markdown Choose an SDXL base model for training
BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" #@param ["stabilityai/stable-diffusion-xl-base-1.0", "John6666/wai-ani-nsfw-ponyxl-v11-sdxl", "cagliostrolab/animagine-xl-3.1"]

#@markdown ### Training Parameters (T4 Optimized)
NETWORK_DIM = 32 #@param {type:"slider", min:4, max:128, step:4}
NETWORK_ALPHA = 16 #@param {type:"slider", min:1, max:128, step:1}
LEARNING_RATE = 1e-4 #@param {type:"number"}
UNET_LR = 1e-4 #@param {type:"number"}
TEXT_ENCODER_LR = 5e-5 #@param {type:"number"}
BATCH_SIZE = 1 #@param {type:"slider", min:1, max:4, step:1}
MAX_TRAIN_EPOCHS = 10 #@param {type:"slider", min:1, max:50, step:1}
SAVE_EVERY_N_EPOCHS = 2 #@param {type:"slider", min:1, max:10, step:1}

#@markdown ### Image Settings
RESOLUTION = 1024 #@param [512, 768, 1024] {type:"raw"}
ENABLE_BUCKET = True #@param {type:"boolean"}
MIN_BUCKET_RESO = 512 #@param {type:"integer"}
MAX_BUCKET_RESO = 1536 #@param {type:"integer"}

#@markdown ### Trigger Word
#@markdown A unique token to activate your likeness (e.g., "ohwx", "sks", your initials)
TRIGGER_WORD = "ohwx" #@param {type:"string"}

#@markdown ### Optimizer (Prodigy recommended for automatic LR)
OPTIMIZER = "Prodigy" #@param ["AdamW8bit", "Prodigy", "Lion", "AdaFactor"]

#@markdown ### Advanced
GRADIENT_CHECKPOINTING = True #@param {type:"boolean"}
GRADIENT_ACCUMULATION = 4 #@param {type:"slider", min:1, max:16, step:1}
MIXED_PRECISION = "fp16" #@param ["fp16", "bf16", "no"]
CACHE_LATENTS = True #@param {type:"boolean"}
CACHE_TEXT_ENCODER = True #@param {type:"boolean"}

# Create output directory
import os
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

print(f"Training Configuration:")
print(f"  - Images: {TRAINING_IMAGES_FOLDER}")
print(f"  - Output: {OUTPUT_FOLDER}")
print(f"  - LoRA Name: {LORA_NAME}")
print(f"  - Base Model: {BASE_MODEL}")
print(f"  - Network Dim: {NETWORK_DIM}, Alpha: {NETWORK_ALPHA}")
print(f"  - Trigger Word: {TRIGGER_WORD}")
print(f"  - Epochs: {MAX_TRAIN_EPOCHS}, Save every: {SAVE_EVERY_N_EPOCHS}")

## 3. Data Preparation

Validate and prepare your training dataset.

In [None]:
#@title 3.1 Validate Training Data
#@markdown Check that your images and captions are properly formatted.

import os
from pathlib import Path
from PIL import Image

def validate_dataset(folder_path):
    """Validate training dataset structure and content."""
    folder = Path(folder_path)
    
    if not folder.exists():
        print(f"ERROR: Folder not found: {folder_path}")
        print("Please check your TRAINING_IMAGES_FOLDER path.")
        return False
    
    # Find all images
    image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp'}
    images = [f for f in folder.iterdir() if f.suffix.lower() in image_extensions]
    
    if not images:
        print(f"ERROR: No images found in {folder_path}")
        return False
    
    print(f"Found {len(images)} images\n")
    
    valid_pairs = []
    missing_captions = []
    invalid_images = []
    
    for img_path in sorted(images):
        # Check for matching caption file
        caption_path = img_path.with_suffix('.txt')
        
        # Validate image can be opened
        try:
            with Image.open(img_path) as img:
                width, height = img.size
        except Exception as e:
            invalid_images.append((img_path.name, str(e)))
            continue
        
        if caption_path.exists():
            caption = caption_path.read_text().strip()
            valid_pairs.append({
                'image': img_path.name,
                'caption': caption[:80] + '...' if len(caption) > 80 else caption,
                'size': f"{width}x{height}"
            })
        else:
            missing_captions.append(img_path.name)
    
    # Report results
    print("=" * 60)
    print(f"VALID IMAGE-CAPTION PAIRS: {len(valid_pairs)}")
    print("=" * 60)
    
    for pair in valid_pairs[:5]:  # Show first 5
        print(f"  {pair['image']} ({pair['size']})")
        print(f"    Caption: {pair['caption']}")
    
    if len(valid_pairs) > 5:
        print(f"  ... and {len(valid_pairs) - 5} more")
    
    if missing_captions:
        print(f"\nWARNING: {len(missing_captions)} images missing captions:")
        for name in missing_captions[:5]:
            print(f"  - {name}")
        if len(missing_captions) > 5:
            print(f"  ... and {len(missing_captions) - 5} more")
    
    if invalid_images:
        print(f"\nERROR: {len(invalid_images)} invalid/corrupted images:")
        for name, err in invalid_images:
            print(f"  - {name}: {err}")
    
    # Recommendations
    print("\n" + "=" * 60)
    print("RECOMMENDATIONS")
    print("=" * 60)
    
    if len(valid_pairs) < 10:
        print("  - Consider adding more images (15-30 recommended for likeness)")
    elif len(valid_pairs) > 50:
        print("  - Large dataset detected. Consider reducing to 30-50 best images.")
    else:
        print(f"  - Dataset size ({len(valid_pairs)}) is good for likeness training")
    
    return len(valid_pairs) > 0

# Run validation
dataset_valid = validate_dataset(TRAINING_IMAGES_FOLDER)

if dataset_valid:
    print("\nDataset validation PASSED. Ready to proceed.")
else:
    print("\nDataset validation FAILED. Please fix the issues above.")

In [None]:
#@title 3.2 Prepare Dataset (Add Trigger Word to Captions)
#@markdown This cell prepares the dataset by creating a proper folder structure
#@markdown and optionally prepending the trigger word to all captions.

import shutil
from pathlib import Path

PREPEND_TRIGGER = True #@param {type:"boolean"}
REPEATS = 10 #@param {type:"slider", min:1, max:50, step:1}

# Create kohya-compatible folder structure
# Format: <repeats>_<trigger_word>
PREPARED_FOLDER = f"/content/training_data/{REPEATS}_{TRIGGER_WORD}"
os.makedirs(PREPARED_FOLDER, exist_ok=True)

source_folder = Path(TRAINING_IMAGES_FOLDER)
dest_folder = Path(PREPARED_FOLDER)

image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp'}
images = [f for f in source_folder.iterdir() if f.suffix.lower() in image_extensions]

processed = 0
for img_path in images:
    caption_path = img_path.with_suffix('.txt')
    if not caption_path.exists():
        continue
    
    # Copy image
    dest_img = dest_folder / img_path.name
    shutil.copy2(img_path, dest_img)
    
    # Process caption
    caption = caption_path.read_text().strip()
    
    if PREPEND_TRIGGER and not caption.lower().startswith(TRIGGER_WORD.lower()):
        # Prepend trigger word
        caption = f"{TRIGGER_WORD} person, {caption}"
    
    # Write processed caption
    dest_caption = dest_folder / img_path.with_suffix('.txt').name
    dest_caption.write_text(caption)
    
    processed += 1

print(f"Prepared {processed} image-caption pairs")
print(f"Dataset folder: {PREPARED_FOLDER}")
print(f"Repeats per image: {REPEATS}")
print(f"\nEffective training steps per epoch: {processed * REPEATS}")

# Show sample caption
sample_caption = list(dest_folder.glob('*.txt'))[0]
print(f"\nSample caption ({sample_caption.name}):")
print(f"  {sample_caption.read_text()[:200]}")

## 4. Training

Run the LoRA training. This will take a while depending on your dataset size and epochs.

In [None]:
#@title 4.1 Generate Training Config
#@markdown Creates the configuration file for kohya training.

import json
from pathlib import Path

# Build training arguments
train_args = {
    # Model
    "pretrained_model_name_or_path": BASE_MODEL,
    "v2": False,
    "v_parameterization": False,
    
    # Dataset
    "train_data_dir": "/content/training_data",
    "resolution": f"{RESOLUTION},{RESOLUTION}",
    "enable_bucket": ENABLE_BUCKET,
    "min_bucket_reso": MIN_BUCKET_RESO,
    "max_bucket_reso": MAX_BUCKET_RESO,
    "bucket_reso_steps": 64,
    
    # Output
    "output_dir": OUTPUT_FOLDER,
    "output_name": LORA_NAME,
    "save_model_as": "safetensors",
    "save_every_n_epochs": SAVE_EVERY_N_EPOCHS,
    "save_precision": "fp16",
    
    # Network (LoRA)
    "network_module": "networks.lora",
    "network_dim": NETWORK_DIM,
    "network_alpha": NETWORK_ALPHA,
    "network_train_unet_only": False,
    "network_train_text_encoder_only": False,
    
    # Training
    "max_train_epochs": MAX_TRAIN_EPOCHS,
    "train_batch_size": BATCH_SIZE,
    "gradient_checkpointing": GRADIENT_CHECKPOINTING,
    "gradient_accumulation_steps": GRADIENT_ACCUMULATION,
    "mixed_precision": MIXED_PRECISION,
    
    # Learning rates
    "learning_rate": LEARNING_RATE,
    "unet_lr": UNET_LR,
    "text_encoder_lr": TEXT_ENCODER_LR,
    "lr_scheduler": "cosine_with_restarts" if OPTIMIZER != "Prodigy" else "constant",
    "lr_warmup_steps": 0 if OPTIMIZER == "Prodigy" else 100,
    "lr_scheduler_num_cycles": 3,
    
    # Optimizer
    "optimizer_type": OPTIMIZER,
    
    # Memory optimization
    "cache_latents": CACHE_LATENTS,
    "cache_latents_to_disk": False,
    "cache_text_encoder_outputs": CACHE_TEXT_ENCODER,
    "cache_text_encoder_outputs_to_disk": False,
    
    # xformers for memory efficiency
    "xformers": True,
    
    # Shuffling and augmentation
    "shuffle_caption": True,
    "keep_tokens": 1,  # Keep trigger word at start
    "caption_extension": ".txt",
    
    # SDXL specific
    "no_half_vae": True,  # Prevent VAE issues with SDXL
    
    # Logging
    "logging_dir": "/content/logs",
    "log_with": "tensorboard",
    
    # Other
    "seed": 42,
    "clip_skip": 2,
    "max_token_length": 225,
}

# Add Prodigy-specific settings
if OPTIMIZER == "Prodigy":
    train_args["optimizer_args"] = [
        "decouple=True",
        "weight_decay=0.01",
        "d_coef=2",
        "use_bias_correction=True",
        "safeguard_warmup=True",
    ]
    # Prodigy works best with LR=1
    train_args["learning_rate"] = 1.0
    train_args["unet_lr"] = 1.0
    train_args["text_encoder_lr"] = 1.0

# Save config
config_path = "/content/training_config.json"
with open(config_path, 'w') as f:
    json.dump(train_args, f, indent=2)

print("Training configuration saved!")
print(f"\nKey settings:")
print(f"  - Base model: {BASE_MODEL}")
print(f"  - Network dim: {NETWORK_DIM}, alpha: {NETWORK_ALPHA}")
print(f"  - Optimizer: {OPTIMIZER}")
print(f"  - Batch size: {BATCH_SIZE}, Gradient accumulation: {GRADIENT_ACCUMULATION}")
print(f"  - Effective batch: {BATCH_SIZE * GRADIENT_ACCUMULATION}")
print(f"  - Epochs: {MAX_TRAIN_EPOCHS}")
print(f"  - Resolution: {RESOLUTION}")

In [None]:
#@title 4.2 Start Training
#@markdown This will train your LoRA. Monitor the loss values - they should decrease over time.
#@markdown 
#@markdown **Expected training time on T4:**
#@markdown - ~20 images, 10 epochs: 30-60 minutes
#@markdown - ~30 images, 10 epochs: 45-90 minutes

import subprocess
import json

# Load config
with open('/content/training_config.json', 'r') as f:
    config = json.load(f)

# Build command line arguments
cmd = ["accelerate", "launch", "--num_cpu_threads_per_process=2", "sdxl_train_network.py"]

for key, value in config.items():
    if isinstance(value, bool):
        if value:
            cmd.append(f"--{key}")
    elif isinstance(value, list):
        for item in value:
            cmd.extend([f"--{key}", str(item)])
    else:
        cmd.extend([f"--{key}", str(value)])

# Change to kohya directory
%cd /content/sd-scripts

print("Starting training...")
print("="*60)
print(f"Output will be saved to: {OUTPUT_FOLDER}")
print("="*60 + "\n")

# Run training
!{' '.join(cmd)}

print("\n" + "="*60)
print("Training complete!")
print("="*60)

In [None]:
#@title 4.3 View Training Logs (TensorBoard)
#@markdown Monitor training progress with TensorBoard.

%load_ext tensorboard
%tensorboard --logdir /content/logs

## 5. Testing Your LoRA

Generate test images using your newly trained LoRA.

In [None]:
#@title 5.1 Load Pipeline with LoRA
#@markdown Load the SDXL pipeline and apply your trained LoRA.

import torch
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
from safetensors.torch import load_file
import os

# Find the latest LoRA file
lora_files = sorted(
    [f for f in os.listdir(OUTPUT_FOLDER) if f.endswith('.safetensors')],
    key=lambda x: os.path.getmtime(os.path.join(OUTPUT_FOLDER, x)),
    reverse=True
)

if not lora_files:
    print("ERROR: No LoRA files found in output folder!")
else:
    LORA_PATH = os.path.join(OUTPUT_FOLDER, lora_files[0])
    print(f"Using LoRA: {lora_files[0]}")
    
    # Load base pipeline
    print("\nLoading SDXL pipeline...")
    pipe = StableDiffusionXLPipeline.from_pretrained(
        BASE_MODEL,
        torch_dtype=torch.float16,
        use_safetensors=True,
        variant="fp16"
    ).to("cuda")
    
    # Set scheduler
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(
        pipe.scheduler.config,
        use_karras_sigmas=True
    )
    
    # Load LoRA
    print(f"Loading LoRA from {LORA_PATH}...")
    pipe.load_lora_weights(LORA_PATH)
    
    # Memory optimization
    pipe.enable_xformers_memory_efficient_attention()
    
    print("\nPipeline ready!")

In [None]:
#@title 5.2 Generate Test Images
#@markdown Generate images using your trained LoRA. Remember to include your trigger word!

import matplotlib.pyplot as plt
from datetime import datetime

#@markdown ### Prompt
PROMPT = "ohwx person, professional portrait photo, studio lighting, neutral background" #@param {type:"string"}
NEGATIVE_PROMPT = "deformed, ugly, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, bad hands, poorly drawn hands, fused fingers, too many fingers" #@param {type:"string"}

#@markdown ### Generation Settings
NUM_IMAGES = 4 #@param {type:"slider", min:1, max:8, step:1}
WIDTH = 1024 #@param {type:"slider", min:512, max:1536, step:64}
HEIGHT = 1024 #@param {type:"slider", min:512, max:1536, step:64}
GUIDANCE_SCALE = 7.5 #@param {type:"slider", min:1, max:20, step:0.5}
NUM_STEPS = 30 #@param {type:"slider", min:10, max:50, step:5}
LORA_SCALE = 0.8 #@param {type:"slider", min:0.1, max:1.5, step:0.1}

# Set LoRA scale
pipe.fuse_lora(lora_scale=LORA_SCALE)

print(f"Generating {NUM_IMAGES} images...")
print(f"Prompt: {PROMPT}")
print(f"LoRA Scale: {LORA_SCALE}\n")

images = []
for i in range(NUM_IMAGES):
    seed = torch.randint(0, 2**32, (1,)).item()
    generator = torch.Generator(device="cuda").manual_seed(seed)
    
    image = pipe(
        prompt=PROMPT,
        negative_prompt=NEGATIVE_PROMPT,
        width=WIDTH,
        height=HEIGHT,
        guidance_scale=GUIDANCE_SCALE,
        num_inference_steps=NUM_STEPS,
        generator=generator,
    ).images[0]
    
    images.append((image, seed))
    print(f"  Generated image {i+1}/{NUM_IMAGES} (seed: {seed})")

# Unfuse for next generation with different scale
pipe.unfuse_lora()

# Display results
cols = min(NUM_IMAGES, 4)
rows = (NUM_IMAGES + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(5*cols, 5*rows))

if NUM_IMAGES == 1:
    axes = [[axes]]
elif rows == 1:
    axes = [axes]

for idx, (image, seed) in enumerate(images):
    row, col = idx // cols, idx % cols
    axes[row][col].imshow(image)
    axes[row][col].set_title(f"Seed: {seed}", fontsize=10)
    axes[row][col].axis('off')

# Hide empty subplots
for idx in range(NUM_IMAGES, rows * cols):
    row, col = idx // cols, idx % cols
    axes[row][col].axis('off')

plt.tight_layout()
plt.show()

# Save to Drive
SAVE_TEST_IMAGES = True #@param {type:"boolean"}
if SAVE_TEST_IMAGES:
    test_output_dir = os.path.join(OUTPUT_FOLDER, "test_images")
    os.makedirs(test_output_dir, exist_ok=True)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    for idx, (image, seed) in enumerate(images):
        filename = f"{timestamp}_test_{idx:02d}_seed{seed}.png"
        image.save(os.path.join(test_output_dir, filename))
    
    print(f"\nTest images saved to: {test_output_dir}")

In [None]:
#@title 5.3 Copy LoRA to Final Location
#@markdown Copy your best LoRA to a convenient location in Google Drive.

import shutil

FINAL_LORA_FOLDER = "/content/drive/MyDrive/sdxl_loras" #@param {type:"string"}

os.makedirs(FINAL_LORA_FOLDER, exist_ok=True)

# List available LoRA checkpoints
print("Available LoRA checkpoints:")
lora_files = sorted(
    [f for f in os.listdir(OUTPUT_FOLDER) if f.endswith('.safetensors')]
)
for i, f in enumerate(lora_files):
    size_mb = os.path.getsize(os.path.join(OUTPUT_FOLDER, f)) / (1024*1024)
    print(f"  [{i}] {f} ({size_mb:.1f} MB)")

#@markdown Select which checkpoint to copy (index number from list above)
CHECKPOINT_INDEX = 0 #@param {type:"integer"}

if CHECKPOINT_INDEX < len(lora_files):
    src = os.path.join(OUTPUT_FOLDER, lora_files[CHECKPOINT_INDEX])
    dst = os.path.join(FINAL_LORA_FOLDER, lora_files[CHECKPOINT_INDEX])
    
    shutil.copy2(src, dst)
    print(f"\nCopied to: {dst}")
else:
    print(f"Invalid index. Please choose 0-{len(lora_files)-1}")

## 6. Cleanup

Free up resources when done.

In [None]:
#@title 6.1 Cleanup GPU Memory

import gc
import torch

# Delete pipeline
try:
    del pipe
except:
    pass

# Clear CUDA cache
torch.cuda.empty_cache()
gc.collect()

# Show memory status
!nvidia-smi

print("\nGPU memory cleared!")

---

## Tips for Better Results

### Training Data Quality
- Use **15-30 high-quality images** of the subject
- Include variety: different angles, lighting, expressions, backgrounds
- Avoid low-resolution, blurry, or heavily filtered images
- Crop faces consistently if doing face-focused training

### Captioning Tips
- Start each caption with the trigger word (e.g., "ohwx person")
- Be descriptive: "ohwx person, professional headshot, studio lighting, wearing blue suit"
- Include relevant details: clothing, background, lighting, expression
- Be consistent with terminology across captions

### Training Parameters
- **Network dim 32** is good for likeness, increase to 64 for more detail
- **10-15 epochs** is usually sufficient for likeness
- Lower **learning rate** (1e-5) if you see artifacts
- Increase **repeats** if you have few images (<15)

### Using the LoRA
- Always include the trigger word in prompts
- Start with **LoRA scale 0.7-0.8**, adjust as needed
- Higher scale = stronger likeness but may reduce flexibility
- Lower scale = more stylistic freedom but weaker likeness