## Cell 1: Install Required Libraries
Note: If you've already installed the required libraries using requirements.txt, you can skip running this cell.

In [None]:
# Uncomment and run the following lines if you haven't installed the libraries yet.

!pip install -r requirements.txt


## Cell 2: Import Libraries and Load Configurations

In [None]:
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from PIL import Image
import matplotlib.pyplot as plt
import os
import datetime
from realesrgan import RealESRGAN
from tqdm.notebook import tqdm
import json
import sys

# Enable inline plotting
%matplotlib inline

# For progress bars in Jupyter
from tqdm.notebook import tqdm

# Ensure that the config.json is in the same directory as the notebook
CONFIG_PATH = "config.json"

# Load configuration
try:
    with open(CONFIG_PATH, 'r') as f:
        config = json.load(f)
    print("Configuration loaded successfully.")
except FileNotFoundError:
    print(f"Configuration file '{CONFIG_PATH}' not found. Please ensure it exists in the notebook directory.")
    sys.exit(1)
except json.JSONDecodeError as e:
    print(f"Error decoding JSON from '{CONFIG_PATH}': {e}")
    sys.exit(1)

# Extract configurations
models_config = config.get("models", {})
stages_config = config.get("stages", {})
controlnet_config = config.get("controlnet", {})
image_savers_config = config.get("image_savers", {})


## Cell 3: Define Utility Functions for Image Preview and Saving

In [None]:
def preview_image(image, title="Image"):
    plt.figure(figsize=(8, 8))
    plt.imshow(image)
    plt.title(title)
    plt.axis("off")
    plt.show()

def save_image(image, stage):
    try:
        # Create directories based on date if they don't exist
        date_str = datetime.datetime.now().strftime("%Y-%m-%d")
        save_path_template = image_savers_config[stage]["path"]
        save_path = save_path_template.format(date=date_str)
        os.makedirs(save_path, exist_ok=True)
        
        # Generate filename with timestamp
        timestamp = datetime.datetime.now().strftime("%H-%M-%S")
        extension = image_savers_config[stage]["extension"]
        filename = f"{stage}_{timestamp}.{extension}"
        full_path = os.path.join(save_path, filename)
        
        # Save image
        image.save(full_path)
        print(f"Image saved at {full_path}")
    except KeyError as e:
        print(f"Missing configuration for image saver '{stage}': {e}")
    except Exception as e:
        print(f"Error saving image for stage '{stage}': {e}")


## Cell 4: Load ControlNet Models

In [None]:
def load_controlnets(controlnet_paths):
    controlnets = []
    for path in controlnet_paths:
        try:
            cn = ControlNetModel.from_pretrained(path, torch_dtype=torch.float16)
            controlnets.append(cn)
            print(f"Loaded ControlNet model from {path}")
        except Exception as e:
            print(f"Error loading ControlNet model from {path}: {e}")
    return controlnets

controlnet_model_paths = models_config.get("controlnet_model_paths", [])
if not controlnet_model_paths:
    print("No ControlNet model paths found in configuration.")
    sys.exit(1)

controlnets = load_controlnets(controlnet_model_paths)
if not controlnets:
    print("Failed to load any ControlNet models. Exiting.")
    sys.exit(1)


## Cell 5: Load Stable Diffusion Pipeline

In [None]:
def load_pipeline(base_model_path, controlnets):
    try:
        # Initialize the pipeline with ControlNet
        pipe = StableDiffusionControlNetPipeline.from_pretrained(
            base_model_path,
            controlnet=controlnets,
            torch_dtype=torch.float16,
            safety_checker=None  # Disable safety checker if not needed
        )
        pipe = pipe.to("cuda")
        print("Loaded Stable Diffusion ControlNet Pipeline.")
    except Exception as e:
        print(f"Error loading Stable Diffusion Pipeline: {e}")
        sys.exit(1)
    
    # Set scheduler
    try:
        pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
        print("Set scheduler to UniPCMultistepScheduler.")
    except Exception as e:
        print(f"Error setting scheduler: {e}")
    
    return pipe

base_model_path = models_config.get("base_model_path", "")
if not base_model_path:
    print("Base model path not specified in configuration.")
    sys.exit(1)

pipe = load_pipeline(
    base_model_path=base_model_path,
    controlnets=controlnets
)


## Cell 6: Define Image Generation Functions

In [None]:
def generate_image(pipe, stage_config, control_images, seed):
    prompt = stage_config.get("prompt", "")
    negative_prompt = stage_config.get("negative_prompt", "")
    steps = stage_config.get("steps", 30)
    sampler = stage_config.get("sampler", "dpmpp_2s_ancestral")  # Placeholder: Implement sampler selection if needed
    scheduler = stage_config.get("scheduler", "karras")  # Placeholder: Implement scheduler selection if needed
    width = stage_config.get("width", 512)
    height = stage_config.get("height", 512)
    cfg_scale = stage_config.get("cfg_scale", 7.0)
    
    generator = torch.Generator("cuda").manual_seed(seed)
    
    # Prepare ControlNet inputs
    # Assuming all ControlNet inputs use the same strength and no preprocessing
    controlnet_inputs = []
    for cn_model, control_image in zip(pipe.controlnet, control_images):
        if control_image is not None:
            controlnet_inputs.append({
                "controlnet": cn_model,
                "control_image": control_image,
                "strength": controlnet_config.get("strength", 1.0)
            })
        else:
            print("Warning: One of the ControlNet control images is None. Skipping this ControlNet input.")
    
    try:
        with torch.autocast("cuda"):
            output = pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                num_inference_steps=steps,
                guidance_scale=cfg_scale,
                width=width,
                height=height,
                generator=generator,
                controlnet=controlnet_inputs
            )
        image = output.images[0]
        print("Image generation completed.")
        return image
    except Exception as e:
        print(f"Error during image generation: {e}")
        return None


## Cell 7: Define Upscaling Function

In [None]:
def upscale_image(image, stage_config):
    scale = stage_config.get("scale", 2.0)
    denoise = stage_config.get("denoise", 0.4)
    tiles = stage_config.get("tiles", 2)
    model_path = models_config.get("realesrgan_model_path", "")
    
    if not model_path:
        print("Real-ESRGAN model path not specified in configuration. Skipping upscaling.")
        return image
    
    try:
        # Initialize Real-ESRGAN
        upsampler = RealESRGAN(device='cuda', scale=scale)
        upsampler.load_weights(model_path)
        print("Loaded Real-ESRGAN for upscaling.")
    except Exception as e:
        print(f"Error loading Real-ESRGAN model from {model_path}: {e}")
        return image  # Return original image if upsampler fails
    
    try:
        # Apply upscaling
        upscaled_image = upsampler.predict(image, denoise=denoise, tile=tiles)
        print("Upscaling completed.")
        return upscaled_image
    except Exception as e:
        print(f"Error during upscaling: {e}")
        return image  # Return original image if upscaling fails


## Cell 8: Define AfterDetailer Function

Note: The AfterDetailer step involves masking and refining specific areas. This implementation assumes that you have mask images or that you're using built-in masking capabilities.

In [None]:
def after_detailer(pipe, stage_config, control_images, seed, base_image):
    prompt = stage_config.get("prompt", "")
    negative_prompt = stage_config.get("negative_prompt", "")
    mask_prompt = stage_config.get("mask_prompt", "")  # Placeholder: Implement mask handling if needed
    steps = stage_config.get("steps", 30)
    sampler = stage_config.get("sampler", "dpmpp_2s_ancestral")  # Placeholder
    scheduler = stage_config.get("scheduler", "karras")  # Placeholder
    denoise = stage_config.get("denoise", 0.4)
    mask_detection_area = stage_config.get("mask_detection_area", True)
    
    generator = torch.Generator("cuda").manual_seed(seed)
    
    # Prepare ControlNet inputs
    controlnet_inputs = []
    for cn_model, control_image in zip(pipe.controlnet, control_images):
        if control_image is not None:
            controlnet_inputs.append({
                "controlnet": cn_model,
                "control_image": control_image,
                "strength": controlnet_config.get("strength", 1.0)
            })
        else:
            print("Warning: One of the ControlNet control images is None. Skipping this ControlNet input.")
    
    try:
        with torch.autocast("cuda"):
            output = pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                num_inference_steps=steps,
                guidance_scale=6.0,  # Adjusted CFG for refinement
                generator=generator,
                controlnet=controlnet_inputs,
                image=base_image  # Assuming 'image' parameter is used for img2img-like refinement
            )
        refined_image = output.images[0]
        print("AfterDetailer refinement completed.")
        return refined_image
    except Exception as e:
        print(f"Error during AfterDetailer refinement: {e}")
        return base_image  # Return base image if refinement fails


## Cell 9: Load ControlNet Images

In [None]:
def load_control_images(stage):
    # Define the paths to your control images
    # Adjust these paths if your control images are stored elsewhere
    control_image_paths = [
        'control_image_1.png',
        'control_image_2.png',
        'control_image_3.png'
    ]
    
    loaded_images = []
    for path in control_image_paths:
        try:
            img = Image.open(path).convert("RGB")
            loaded_images.append(img)
            print(f"Loaded control image: {path}")
        except FileNotFoundError:
            print(f"Control image '{path}' not found. Appending None.")
            loaded_images.append(None)
        except Exception as e:
            print(f"Error loading control image '{path}': {e}. Appending None.")
            loaded_images.append(None)
    return loaded_images

# Load Control Images for TXT2IMG
print("Loading ControlNet images for TXT2IMG...")
control_images_txt2img = load_control_images("txt2img")

# Load Control Images for Upscale (assuming same as txt2img)
print("Loading ControlNet images for Upscale...")
control_images_upscale = control_images_txt2img.copy()

# Load Control Images for AfterDetailer
print("Loading ControlNet images for AfterDetailer...")
control_images_after_detailer = load_control_images("after_detailer")


## Cell 10: Execute Workflow

In [None]:
# TXT2IMG Stage
print("=== TXT2IMG Generation ===")
txt2img_stage = stages_config.get("txt2img", {})
if not txt2img_stage:
    print("TXT2IMG stage configuration not found.")
    sys.exit(1)

base_image = generate_image(
    pipe=pipe,
    stage_config=txt2img_stage,
    control_images=control_images_txt2img.copy(),
    seed=txt2img_stage.get("seed", 42)
)

if base_image:
    preview_image(base_image, "Base Image")
    save_image(base_image, "txt2img")
else:
    print("TXT2IMG generation failed. Exiting workflow.")
    sys.exit(1)

# Upscaling Stage
print("\n=== Upscaling ===")
upscale_stage = stages_config.get("upscale", {})
if not upscale_stage:
    print("Upscale stage configuration not found.")
    sys.exit(1)

upscaled_image = upscale_image(
    image=base_image,
    stage_config=upscale_stage
)

if upscaled_image:
    preview_image(upscaled_image, "Upscaled Image")
    save_image(upscaled_image, "upscale")
else:
    print("Upscaling failed. Proceeding with base image.")
    upscaled_image = base_image

# AfterDetailer Stage
print("\n=== AfterDetailer Refinement ===")
after_detailer_stage = stages_config.get("after_detailer", {})
if not after_detailer_stage:
    print("AfterDetailer stage configuration not found.")
    sys.exit(1)

detailed_image = after_detailer(
    pipe=pipe,
    stage_config=after_detailer_stage,
    control_images=control_images_after_detailer.copy(),
    seed=after_detailer_stage.get("seed", 42),
    base_image=upscaled_image
)

if detailed_image:
    preview_image(detailed_image, "Detailed Image")
    save_image(detailed_image, "after_detailer")
else:
    print("AfterDetailer refinement failed.")

print("\n=== Workflow Completed Successfully! ===")
