# SDXL Image Generator with LoRA & Refiner Support
A memory-efficient Stable Diffusion XL image generator for Google Colab/Jupyter notebooks
with 8-bit quantization, LoRA model support, and optional refiner enhancement.

## Features
- SDXL Base 1.0 with 8-bit quantization for low VRAM usage
- Optional SDXL Refiner for enhanced details
- LoRA model support for style customization
- Gradio web interface with share link
- Optimized for Google Colab free tier

## Requirements
- GPU with at least 12GB VRAM (T4 or better)
- Python 3.8+

## 1. Environment Setup
Run this cell first to install all required dependencies. This will take 2-3 minutes on first run.

In [None]:
# Check GPU availability
import subprocess
import sys

def check_gpu():
    try:
        result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
        if result.returncode == 0:
            print("✅ GPU detected!")
            print(result.stdout.split('\n')[8:11])
        else:
            print("❌ No GPU detected. This notebook requires a GPU.")
            sys.exit(1)
    except:
        print("❌ nvidia-smi not found. Please enable GPU in Runtime > Change runtime type")
        sys.exit(1)

check_gpu()

# Install dependencies
print("\n📦 Installing dependencies...")
!pip install -q --upgrade pip
!pip cache purge -q

# Install packages - let pip resolve torch version to avoid conflicts
!pip install -q \
    gradio==4.44.1 \
    diffusers==0.29.1 \
    transformers==4.41.0 \
    accelerate==0.29.0 \
    huggingface_hub==0.25.0 \
    safetensors==0.4.2 \
    xformers \
    bitsandbytes==0.42.0 \
    websockets>=13.0 \
    pillow \
    tqdm

# Verify torch installation
import torch
print(f"\n✅ Torch version: {torch.__version__}")

print("✅ Dependencies installed successfully!")

## 2. Imports and Configuration
Import required libraries and set up configuration.

In [None]:
import os
import gc
import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
from transformers import BitsAndBytesConfig
import gradio as gr
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

# Memory cleanup function
def cleanup_memory():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print(f"🔧 PyTorch version: {torch.__version__}")
print(f"🔧 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🔧 GPU: {torch.cuda.get_device_name(0)}")
    print(f"🔧 Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 3. Model Loading
Load SDXL base model with 8-bit quantization. This will download ~6.5GB on first run and take 3-5 minutes.

In [None]:
print("\n🚀 Loading SDXL base model...")

# Configure 8-bit quantization for memory efficiency
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_compute_dtype=torch.float16,
    bnb_8bit_use_double_quant=True,
)

# Load SDXL base pipeline
try:
    pipe = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16,
        device_map="balanced",
        quantization_config=bnb_config,
        use_safetensors=True,
        variant="fp16"
    )
    
    # Enable memory efficient attention
    pipe.enable_xformers_memory_efficient_attention()
    pipe.enable_vae_slicing()
    pipe.enable_vae_tiling()
    
    print("✅ SDXL base model loaded successfully!")
    cleanup_memory()
    
except Exception as e:
    print(f"❌ Error loading model: {e}")
    raise

# Refiner placeholder (loaded on demand)
refiner = None

def load_refiner():
    """Load SDXL refiner model on demand to save memory."""
    global refiner
    if refiner is None:
        print("🚀 Loading SDXL refiner...")
        try:
            refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
                "stabilityai/stable-diffusion-xl-refiner-1.0",
                torch_dtype=torch.float16,
                device_map="balanced",
                quantization_config=bnb_config,
                use_safetensors=True,
                variant="fp16"
            )
            refiner.enable_xformers_memory_efficient_attention()
            refiner.enable_vae_slicing()
            refiner.enable_vae_tiling()
            print("✅ SDXL refiner loaded successfully!")
        except Exception as e:
            print(f"❌ Error loading refiner: {e}")
            raise

## 4. Image Generation Functions
Core image generation logic with LoRA and refiner support.

In [None]:
# Track loaded LoRA to avoid reloading
current_lora_id = ""

def generate_image(
    prompt, 
    negative_prompt="",
    num_steps=30, 
    guidance_scale=7.5,
    width=1024,
    height=1024,
    seed=-1,
    use_refiner=False, 
    refiner_steps=10,
    refiner_strength=0.3,
    lora_model="", 
    lora_scale=0.5,
    advanced_mode=False
):
    """
    Generate images using SDXL with optional LoRA and refiner.
    """
    global current_lora_id
    
    try:
        # Set random seed
        generator = None
        if seed != -1:
            generator = torch.Generator(device="cuda").manual_seed(seed)
        
        # Handle LoRA loading/unloading
        if advanced_mode and lora_model:
            if lora_model != current_lora_id:
                # Unload previous LoRA
                if current_lora_id:
                    print(f"🔄 Unloading LoRA: {current_lora_id}")
                    pipe.unload_lora_weights()
                    cleanup_memory()
                
                # Load new LoRA
                try:
                    print(f"📥 Loading LoRA: {lora_model}")
                    pipe.load_lora_weights(lora_model)
                    current_lora_id = lora_model
                except Exception as e:
                    print(f"⚠️ LoRA loading failed: {e}")
                    current_lora_id = ""
            
            # Set LoRA scale
            if current_lora_id:
                pipe.set_adapters(["default"], adapter_weights=[lora_scale])
        
        # Generate base image
        print(f"🎨 Generating image...")
        result = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=int(num_steps),
            guidance_scale=float(guidance_scale),
            width=width,
            height=height,
            generator=generator
        )
        
        image = result.images[0]
        
        # Apply refiner if requested
        if advanced_mode and use_refiner:
            load_refiner()
            print(f"✨ Applying refiner...")
            image = refiner(
                prompt=prompt,
                negative_prompt=negative_prompt,
                image=image,
                num_inference_steps=int(refiner_steps),
                strength=float(refiner_strength),
                generator=generator
            ).images[0]
        
        cleanup_memory()
        return image
        
    except Exception as e:
        print(f"❌ Generation error: {str(e)}")
        cleanup_memory()
        # Return error placeholder
        error_img = Image.new("RGB", (512, 512), color=(255, 0, 0))
        return error_img

## 5. Gradio Interface
Create the web UI using Gradio.

In [None]:
# CSS for better styling
custom_css = """
    .container {max-width: 1200px; margin: auto; padding: 20px;}
    .generate-btn {background: linear-gradient(90deg, #4CAF50 0%, #45a049 100%); color: white;}
    .generate-btn:hover {background: linear-gradient(90deg, #45a049 0%, #4CAF50 100%);}
"""

# Build the interface
with gr.Blocks(css=custom_css, title="SDXL Generator") as demo:
    gr.Markdown("""
    # 🎨 SDXL Image Generator
    ### Generate stunning images with Stable Diffusion XL
    
    **Features:** 8-bit quantization for low VRAM • LoRA support • Optional refiner • Optimized for Colab
    """)
    
    with gr.Row():
        # Left column - Controls
        with gr.Column(scale=1):
            # Basic controls
            with gr.Group():
                gr.Markdown("### Basic Settings")
                prompt_input = gr.Textbox(
                    lines=3,
                    label="Prompt",
                    placeholder="Describe what you want to generate...\nBe specific about style, lighting, and details.",
                    value=""
                )
                
                negative_prompt = gr.Textbox(
                    lines=2,
                    label="Negative Prompt",
                    placeholder="What to avoid: blurry, low quality, distorted...",
                    value="blurry, low quality, ugly, distorted"
                )
                
                with gr.Row():
                    steps_slider = gr.Slider(
                        minimum=10, maximum=50, value=25, step=1,
                        label="Steps", info="More steps = better quality but slower"
                    )
                    guidance_slider = gr.Slider(
                        minimum=1.0, maximum=20.0, value=7.5, step=0.5,
                        label="Guidance Scale", info="How closely to follow prompt"
                    )
                
                with gr.Row():
                    width_slider = gr.Slider(
                        minimum=512, maximum=2048, value=1024, step=64,
                        label="Width"
                    )
                    height_slider = gr.Slider(
                        minimum=512, maximum=2048, value=1024, step=64,
                        label="Height"
                    )
                
                seed_input = gr.Number(
                    value=-1, label="Seed", 
                    info="Use -1 for random, or specific number for reproducibility"
                )
            
            # Advanced settings
            advanced_mode = gr.Checkbox(label="🔧 Enable Advanced Mode", value=False)
            
            with gr.Group(visible=False) as advanced_group:
                gr.Markdown("### Advanced Settings")
                
                # Refiner settings
                with gr.Accordion("✨ Refiner Settings", open=True):
                    use_refiner = gr.Checkbox(label="Use Refiner", value=False)
                    refiner_steps = gr.Slider(
                        minimum=5, maximum=30, value=10, step=1,
                        label="Refiner Steps"
                    )
                    refiner_strength = gr.Slider(
                        minimum=0.1, maximum=0.5, value=0.3, step=0.05,
                        label="Refiner Strength", 
                        info="How much the refiner changes the image"
                    )
                
                # LoRA settings
                with gr.Accordion("🎭 LoRA Settings", open=True):
                    lora_model = gr.Textbox(
                        label="LoRA Model ID",
                        placeholder="e.g., TheLastBen/Papercut_SDXL",
                        info="HuggingFace model ID for style LoRAs"
                    )
                    lora_scale = gr.Slider(
                        minimum=0.0, maximum=1.5, value=0.7, step=0.1,
                        label="LoRA Weight", 
                        info="Strength of LoRA influence"
                    )
            
            generate_btn = gr.Button("🚀 Generate Image", variant="primary", elem_classes="generate-btn")
        
        # Right column - Output
        with gr.Column(scale=1):
            output_image = gr.Image(
                label="Generated Image",
                type="pil",
                interactive=False
            )
            
            with gr.Row():
                gr.Markdown("""
                **Tips:**
                - Start with 25-30 steps for good quality
                - Use negative prompts to avoid unwanted elements
                - LoRA models can dramatically change the style
                - Refiner adds details but increases generation time
                """)
    
    # Examples
    with gr.Row():
        gr.Examples(
            examples=[
                [
                    "A majestic lion in African savanna at golden hour, photorealistic, detailed fur, 8k quality",
                    "blurry, cartoon, anime, low quality",
                    25, 7.5, 1024, 1024, 42
                ],
                [
                    "Fantasy castle floating in clouds, magical atmosphere, detailed architecture, trending on artstation",
                    "modern, contemporary, cars, people",
                    30, 8.0, 1024, 1024, 123
                ],
                [
                    "Cyberpunk street scene, neon lights, rain, detailed, blade runner style, cinematic",
                    "daylight, sunny, medieval, ancient",
                    30, 7.0, 1344, 768, 456
                ]
            ],
            inputs=[prompt_input, negative_prompt, steps_slider, guidance_slider, 
                   width_slider, height_slider, seed_input],
            label="Example Prompts"
        )
    
    # Event handlers
    advanced_mode.change(
        fn=lambda x: gr.update(visible=x),
        inputs=advanced_mode,
        outputs=advanced_group
    )
    
    generate_btn.click(
        fn=generate_image,
        inputs=[
            prompt_input, negative_prompt,
            steps_slider, guidance_slider,
            width_slider, height_slider, seed_input,
            use_refiner, refiner_steps, refiner_strength,
            lora_model, lora_scale,
            advanced_mode
        ],
        outputs=output_image
    )

## 6. Launch the App
Launch the Gradio interface with a public share link. The share link expires after 72 hours.

In [None]:
print("\n🌐 Launching Gradio interface...")
print("📝 Note: Share links expire after 72 hours")
print("💡 Tip: For faster generation, reduce steps or image size\n")

# Launch with share link
demo.launch(
    share=True,
    debug=False,
    show_error=True,
    server_name="0.0.0.0",
    server_port=7860,
    quiet=False
)

## 📝 Popular LoRA Models

- `TheLastBen/Papercut_SDXL` - Paper cut art style
- `artificialguybr/LogoRedmond-LogoLoraForSDXL` - Logo design
- `KappaNeuro/studio-ghibli-style-sdxl` - Studio Ghibli style
- `alvdansen/frosting_lane_redux_sdxl` - Vintage photography

## ⚠️ Troubleshooting

- **Out of Memory**: Reduce image size or use fewer steps
- **LoRA Not Loading**: Check the model ID is correct on HuggingFace
- **Slow Generation**: Normal for SDXL; 30-60 seconds per image on T4