In [1]:
import gradio as gr
# import spaces
import torch
from diffusers import AutoencoderKL, TCDScheduler
from diffusers.models.model_loading_utils import load_state_dict
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download

from src.controlnet_union import ControlNetModel_Union
from src.pipeline_fill_sd_xl import StableDiffusionXLFillPipeline

from PIL import Image, ImageDraw
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config_file = hf_hub_download(
    "xinsir/controlnet-union-sdxl-1.0",
    filename="config_promax.json",
)

config = ControlNetModel_Union.load_config(config_file)
controlnet_model = ControlNetModel_Union.from_config(config)
model_file = hf_hub_download(
    "xinsir/controlnet-union-sdxl-1.0",
    filename="diffusion_pytorch_model_promax.safetensors",
)
state_dict = load_state_dict(model_file)
model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
    controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
)
model.to(device="cuda", dtype=torch.float16)

vae = AutoencoderKL.from_pretrained(
    "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
).to("cuda")

pipe = StableDiffusionXLFillPipeline.from_pretrained(
    "SG161222/RealVisXL_V5.0_Lightning",
    torch_dtype=torch.float16,
    vae=vae,
    controlnet=model,
    variant="fp16",
).to("cuda")

pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)


A mixture of fp16 and non-fp16 filenames will be loaded.
Loaded fp16 filenames:
[unet/diffusion_pytorch_model.fp16.safetensors, text_encoder_2/model.fp16.safetensors, vae/diffusion_pytorch_model.fp16.safetensors, text_encoder/model.fp16.safetensors]
Loaded non-fp16 filenames:
[unet/diffusion_pytorch_model-00002-of-00002.safetensors, unet/diffusion_pytorch_model-00001-of-00002.safetensors
If this behavior is not expected, please check your folder structure.
Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00,  7.31it/s]


In [4]:
def can_expand(source_width, source_height, target_width, target_height, alignment):
    """Checks if the image can be expanded based on the alignment."""
    if alignment in ("Left", "Right") and source_width >= target_width:
        return False
    if alignment in ("Top", "Bottom") and source_height >= target_height:
        return False
    return True

In [3]:
from io import BytesIO
import base64
def base64_to_img(img_str: str) -> Image.Image:
    return Image.open(BytesIO(base64.b64decode(img_str)))

def img_to_base64(img: Image.Image) -> str:
    buffered = BytesIO()
    img.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode()

In [5]:
def infer(image_path, width, height, overlap_width, num_inference_steps, resize_option, custom_resize_size, prompt_input=None, alignment="Middle"):
    source = Image.open(image_path)
    print(f"Source Image : {source.size}")
    target_size = (width, height)
    overlap = overlap_width

    # Upscale if source is smaller than target in both dimensions
    # if source.width < target_size[0] and source.height < target_size[1]:
    #     scale_factor = min(target_size[0] / source.width, target_size[1] / source.height)
    #     new_width = int(source.width * scale_factor)
    #     new_height = int(source.height * scale_factor)
    #     source = source.resize((new_width, new_height), Image.LANCZOS)

    # if source.width > target_size[0] or source.height > target_size[1]:
    #     scale_factor = min(target_size[0] / source.width, target_size[1] / source.height)
    #     new_width = int(source.width * scale_factor)
    #     new_height = int(source.height * scale_factor)
    #     source = source.resize((new_width, new_height), Image.LANCZOS)

    if resize_option == "Full":
        resize_size = max(source.width, source.height)
    elif resize_option == "1/2":
        resize_size = max(source.width, source.height) // 2
    elif resize_option == "1/3":
        resize_size = max(source.width, source.height) // 3
    elif resize_option == "1/4":
        resize_size = max(source.width, source.height) // 4
    else:  # Custom
        resize_size = custom_resize_size
    print(f"Original Size : {source.size}")
    aspect_ratio = source.height / source.width
    print(f"Aspect Ratio : {aspect_ratio}")
    new_width = resize_size
    print(f"New Width : {new_width}")
    new_height = int(resize_size * aspect_ratio)
    print(f"New Height : {new_height}")
    source = source.resize((new_width, new_height), Image.LANCZOS)

    # if not can_expand(source.width, source.height, target_size[0], target_size[1], alignment):
    #     alignment = "Middle"

    # Calculate margins based on alignment
    # if alignment == "Middle":
    #     margin_x = (target_size[0] - source.width) // 2
    #     margin_y = (target_size[1] - source.height) // 2
    # elif alignment == "Left":
    #     margin_x = 0
    #     margin_y = (target_size[1] - source.height) // 2
    # elif alignment == "Right":
    #     margin_x = target_size[0] - source.width
    #     margin_y = (target_size[1] - source.height) // 2
    # elif alignment == "Top":
    #     margin_x = (target_size[0] - source.width) // 2
    #     margin_y = 0
    # elif alignment == "Bottom":
    #     margin_x = (target_size[0] - source.width) // 2
    #     margin_y = target_size[1] - source.height

    margin_x = (target_size[0] - source.width) // 2
    margin_y = (target_size[1] - source.height) // 2


    background = Image.new('RGB', target_size, (255, 255, 255))
    background.paste(source, (margin_x, margin_y))

    mask = Image.new('L', target_size, 255)
    mask_draw = ImageDraw.Draw(mask)

    # Adjust mask generation based on alignment
    # if alignment == "Middle":
    #     mask_draw.rectangle([
    #         (margin_x + overlap, margin_y + overlap),
    #         (margin_x + source.width - overlap, margin_y + source.height - overlap)
    #     ], fill=0)
    # elif alignment == "Left":
    #     mask_draw.rectangle([
    #         (margin_x, margin_y),
    #         (margin_x + source.width - overlap, margin_y + source.height)
    #     ], fill=0)
    # elif alignment == "Right":
    #     mask_draw.rectangle([
    #         (margin_x + overlap, margin_y),
    #         (margin_x + source.width, margin_y + source.height)
    #     ], fill=0)
    # elif alignment == "Top":
    #     mask_draw.rectangle([
    #         (margin_x, margin_y),
    #         (margin_x + source.width, margin_y + source.height - overlap)
    #     ], fill=0)
    # elif alignment == "Bottom":
    #     mask_draw.rectangle([
    #         (margin_x, margin_y + overlap),
    #         (margin_x + source.width, margin_y + source.height)
    #     ], fill=0)

    mask_draw.rectangle([
        (margin_x + overlap, margin_y + overlap),
        (margin_x + source.width - overlap, margin_y + source.height - overlap)
    ], fill=0)

    cnet_image = background.copy()
    cnet_image.paste(0, (0, 0), mask)

    final_prompt = f"high quality, 4k"

#     import requests
#     import json
#     url = "http://0.0.0.0:8000/generate-image/"

#     order_id = "94r804t8t-e830r8r48t"
#     cnet_image.save("cnet_image_original.png")
#     c_net_base64 = img_to_base64(cnet_image)
#     mask_base64 = img_to_base64(mask)
#     headers = {'Content-Type': 'application/json'}

#     payload = {
#         "order_id": order_id,
#         "c_net_base64": c_net_base64,
#         "mask_base64": mask_base64,
#         "num_inference_steps": num_inference_steps,
#         "prompt_input": prompt_input
# }

#     response = requests.post(url, json=payload, headers=headers)
#     output_image_b64 = response.json()["output_image"]
#     output_image = Image.open(BytesIO(base64.b64decode(output_image_b64)))
#     return output_image

    (
        prompt_embeds,
        negative_prompt_embeds,
        pooled_prompt_embeds,
        negative_pooled_prompt_embeds,
    ) = pipe.encode_prompt(final_prompt, "cuda", True)

    for image in pipe(
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
        image=cnet_image,
        num_inference_steps=num_inference_steps
    ):
        cnet_image, image = cnet_image, image

    image = image.convert("RGBA")
    mask = mask.resize(image.size)
    cnet_image.paste(image, (0, 0), mask)

    return background , cnet_image

In [8]:
if __name__ == "__main__":
    # Define input parameters here
    image_path = "/home/ubuntu/Varad/diffusion-extender/assets/beach.jpg"  # Change this to your input image path
    width = 2048
    height = 2048
    overlap_width = 42
    num_inference_steps = 10
    resize_option = "Full"  # Options: "Full", "1/2", "1/3", "1/4", "Custom"
    custom_resize_size = 512
    prompt_input = "High quality, 4k"  # Optional prompt
    alignment = "Middle" 

    background ,result_image = infer(image_path, width, height, overlap_width, num_inference_steps, resize_option, custom_resize_size, prompt_input, alignment)


    # background.save("background_image.jpg")
    # result_image.save("result_image.png")

Source Image : (1280, 854)
Original Size : (1280, 854)
Aspect Ratio : 0.6671875
New Width : 1280
New Height : 854


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]


In [17]:
torch.cuda.empty_cache()

In [4]:
url = "http://164.52.212.87:8000/generate-image/"

margin_x = 0 
margin_y = 0
height = 640
width = 640 
image_url = "https://ai-image-editor-wasabi-bucket.apyhi.com/assets/AI_PHOTO_UNCROP/image_extender_new_6.webp"
num_inference_steps = 10
overlap_width = 42
resize_option = "Full"

payload = {         
"order_id" : "1234567890",
"image_url" : image_url,
"width" : width,
"height" : height,
"overlap_width" : overlap_width,
"num_inference_steps" : num_inference_steps,
"resize_option" : resize_option,
# "prompt_input" : prompt_input,
"margin_x" : margin_x,
"margin_y" : margin_y,
}

import requests
response = requests.post(url, json=payload)
# print(response.json())

In [5]:
print(response.json())



{'message': 'success', 'status_code': 200, 'output_image': 'iVBORw0KGgoAAAANSUhEUgAAAoAAAAKACAIAAACDr150AAEAAElEQVR4nKz90YJkSa4jiAHkieyr1S/pi/SiP9eLpN3pjGOEHgCae/WMtBpJcftWZUVGuB83o5EgCNL4f/q//J+nSpBYkEQImhlBAwxwZgaQJAmDoorsagy6mpgCuykWiiRAsGsEQCQlACAh+X+iRAnvlFRSS8+ZH+iZ8wP9C/qD+YPzo/PgNE5RhSlOFTCSpooENKcInYEONJiBjmbOmXMw0u87Z/S+54xe6ffMmZnREfKHwQgjjqjxR9QZHeEd/P6OoN/3+GcAgAAkwQ/QZBHNIkGgyIKaFDhANeec56EESFUlaQRUn4FIkgQKkEA2SQCCiiCEwRn0TxVwfg+L7wGbZAHyjxKUhlVV1ExVjYQiIUozI0EYAX/Pe3T+vucA7xGznyry6frz5/mvn/pffp7/5b/+/B//6+d/+dfP//G//vzXn59/Pf3neZ6nf7p/frq6xWIVqkT+jkD+FV/Wv4/+Dv7b6H8d/t9f/a8H/9vg7fp9OaxzRsRofn/PKwzxan5/zxCjGUE2ERKCCALwKgyqGAMCQBAiWKT/2RLJmUGRxRmhKFKQAFWBgChBXUXKryx6PwHmRT/mKpAg6E0ViuWdKRBCFbuLAItgwWbQjQGL7SeTX7WqCKG7KUJ4fgpQQf4BiN01YjUJ1HhTKYAsCCxC4pplNWKlwAhV9Mr5dMmHw0d1IEAApBFGGkmUz/KZkXBmCPhnZyarIYCY0XgdvDDZHWgkQKMiNeOP4F+CFxaSXYCQD+OnmPHa2iF4n/dAgVUjTXZZBe52kyxILEoCKJBVkqrqPq9/RX5GMW8LFQlNzpnYRc6AKIKid7X8P6+zTU0gQKnyfsLY73lTBqCAyepqNAKOcKAzGnJGI0ycRVbQ66nJUgECdzmBL2+p

In [4]:
from PIL import Image
import base64
from io import BytesIO
output_image_b64 = response.json()["output_image"]
output_image = Image.open(BytesIO(base64.b64decode(output_image_b64)))
# output_image #.save("output_image.png")

In [10]:
# result_image

In [67]:
# result_image

In [68]:
# result_image[1]

In [69]:

# result_image