### **Outpainting inference notebook**

- **Source** : https://www.kaggle.com/code/pojesh/sdxl-output

In [None]:
%pwd

In [None]:
import os

In [None]:
'''import shutil

src_path = r'/kaggle/input/mysdxlcomps/pipeline_fill_sd_xl.py'
dst_path = r'/kaggle/working/''''

'''shutil.copy(src_path, dst_path)
print('Copied')'''

In [None]:
import os
import torch
from PIL import Image
import numpy as np
from diffusers import AutoencoderKL, TCDScheduler
from diffusers.models.model_loading_utils import load_state_dict
from huggingface_hub import hf_hub_download
import logging

# modules (keep them in the same folder)
from controlnet_union import ControlNetModel_Union
from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline

In [None]:
from PIL import ImageDraw

In [None]:
torch.cuda.empty_cache()

In [None]:
del pipe

In [None]:
del vae
del model

In [None]:
# ------------------------------------------------------------------
# CONFIGURATION – edit only these two lines if you want other sizes
# ------------------------------------------------------------------
TARGET_WIDTH  = 1800         # << hard-coded width
TARGET_HEIGHT = 1200         # << hard-coded height
INPUT_IMAGE   = r"/kaggle/input/flux-images-set1/robot_1200x1200.webp"  # source file 
OUTPUT_DIR    = r"/kaggle/working/outputs"
OUTPUT_NAME   = "robot_op_1800x1200.webp"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:


# ------------------------------------------------------------------
# Model loading 
# ------------------------------------------------------------------
logger.info("Loading models…")

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype  = torch.float16 if device == "cuda" else torch.float32

# ControlNet
config_file = hf_hub_download(
    "xinsir/controlnet-union-sdxl-1.0",
    filename="config_promax.json",
)
config = ControlNetModel_Union.load_config(config_file)
controlnet_model = ControlNetModel_Union.from_config(config)

model_file = hf_hub_download(
    "xinsir/controlnet-union-sdxl-1.0",
    filename="diffusion_pytorch_model_promax.safetensors",
)
state_dict = load_state_dict(model_file)
loaded_keys = list(state_dict.keys())

model = ControlNetModel_Union._load_pretrained_model(
    controlnet_model, state_dict, model_file,
    "xinsir/controlnet-union-sdxl-1.0", loaded_keys
)[0]
model = model.to(device=device, dtype=dtype)

# VAE & pipeline
vae = AutoencoderKL.from_pretrained(
    "madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype
).to(device)

pipe = StableDiffusionXLFillPipeline.from_pretrained(
    "SG161222/RealVisXL_V5.0_Lightning",
    torch_dtype=dtype,
    vae=vae,
    controlnet=model,
    variant="fp16" if dtype == torch.float16 else None,
).to(device)

pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
logger.info("Models loaded.")




# ------------------------------------------------------------------
# Helper functions
# ------------------------------------------------------------------
def prepare_image_and_mask(image, width, height, overlap_percentage=10,
                           resize_option="Full", custom_resize_percentage=50,
                           alignment="Middle", overlap_left=True,
                           overlap_right=True, overlap_top=True,
                           overlap_bottom=True):
    target_size = (width, height)
    scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
    new_width = int(image.width * scale_factor)
    new_height = int(image.height * scale_factor)
    source = image.resize((new_width, new_height), Image.LANCZOS)

    if resize_option == "Full":
        resize_percentage = 100
    elif resize_option == "50%":
        resize_percentage = 50
    elif resize_option == "33%":
        resize_percentage = 33
    elif resize_option == "25%":
        resize_percentage = 25
    else:
        resize_percentage = custom_resize_percentage

    resize_factor = resize_percentage / 100
    new_width = max(int(source.width * resize_factor), 64)
    new_height = max(int(source.height * resize_factor), 64)
    source = source.resize((new_width, new_height), Image.LANCZOS)

    overlap_x = max(int(new_width * (overlap_percentage / 100)), 1)
    overlap_y = max(int(new_height * (overlap_percentage / 100)), 1)

    if alignment == "Middle":
        margin_x = (target_size[0] - new_width) // 2
        margin_y = (target_size[1] - new_height) // 2
    elif alignment == "Left":
        margin_x, margin_y = 0, (target_size[1] - new_height) // 2
    elif alignment == "Right":
        margin_x, margin_y = target_size[0] - new_width, (target_size[1] - new_height) // 2
    elif alignment == "Top":
        margin_x, margin_y = (target_size[0] - new_width) // 2, 0
    elif alignment == "Bottom":
        margin_x, margin_y = (target_size[0] - new_width) // 2, target_size[1] - new_height

    margin_x = max(0, min(margin_x, target_size[0] - new_width))
    margin_y = max(0, min(margin_y, target_size[1] - new_height))

    background = Image.new('RGB', target_size, (255, 255, 255))
    background.paste(source, (margin_x, margin_y))

    mask = Image.new('L', target_size, 255)
    mask_draw = ImageDraw.Draw(mask)

    white_gaps_patch = 2
    left_overlap   = margin_x + overlap_x if overlap_left   else margin_x + white_gaps_patch
    right_overlap  = margin_x + new_width - overlap_x if overlap_right  else margin_x + new_width - white_gaps_patch
    top_overlap    = margin_y + overlap_y if overlap_top    else margin_y + white_gaps_patch
    bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch

    if alignment == "Left":
        left_overlap = margin_x + overlap_x if overlap_left else margin_x
    elif alignment == "Right":
        right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width
    elif alignment == "Top":
        top_overlap = margin_y + overlap_y if overlap_top else margin_y
    elif alignment == "Bottom":
        bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height

    mask_draw.rectangle([(left_overlap, top_overlap),
                         (right_overlap, bottom_overlap)], fill=0)
    return background, mask

def can_expand(source_width, source_height, target_width, target_height, alignment):
    if alignment in ("Left", "Right") and source_width >= target_width:
        return False
    if alignment in ("Top", "Bottom") and source_height >= target_height:
        return False
    return True

def process_outpaint(image, width, height, num_inference_steps=8, prompt_input=""):
    overlap_percentage = 10
    resize_option = "Full"
    custom_resize_percentage = 50
    alignment = "Middle"
    overlap_left = overlap_right = overlap_top = overlap_bottom = True

    background, mask = prepare_image_and_mask(
        image, width, height, overlap_percentage, resize_option,
        custom_resize_percentage, alignment, overlap_left, overlap_right,
        overlap_top, overlap_bottom
    )

    if not can_expand(background.width, background.height, width, height, alignment):
        alignment = "Middle"

    cnet_image = background.copy()
    cnet_image.paste(0, (0, 0), mask)

    final_prompt = f"{prompt_input} , high quality, 4k" if prompt_input else "high quality, 4k"

    with torch.no_grad(), torch.autocast(device_type=device, dtype=dtype):
        (prompt_embeds, negative_prompt_embeds,
         pooled_prompt_embeds, negative_pooled_prompt_embeds) = pipe.encode_prompt(final_prompt, device, True)

        result_image = None
        for img in pipe(prompt_embeds=prompt_embeds,
                        negative_prompt_embeds=negative_prompt_embeds,
                        pooled_prompt_embeds=pooled_prompt_embeds,
                        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
                        image=cnet_image,
                        num_inference_steps=num_inference_steps):
            result_image = img

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    result_image = result_image.convert("RGBA")
    cnet_image.paste(result_image, (0, 0), mask)
    return cnet_image

In [None]:
# ------------------------------------------------------------------
# Main execution
# ------------------------------------------------------------------
if __name__ == "__main__":
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    if not os.path.isfile(INPUT_IMAGE):
        logger.error(f"Input file '{INPUT_IMAGE}' not found.")
        exit(1)

    with Image.open(INPUT_IMAGE) as im:
        if im.mode != "RGB":
            im = im.convert("RGB")

    logger.info(f"Outpainting {INPUT_IMAGE} -> {TARGET_WIDTH}x{TARGET_HEIGHT}")
    result = process_outpaint(im, TARGET_WIDTH, TARGET_HEIGHT)
    out_path = os.path.join(OUTPUT_DIR, OUTPUT_NAME)
    result.save(out_path)
    logger.info(f"Saved: {out_path}")