In [None]:
import torch
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, StableDiffusionXLPipeline
from PIL import Image

from ip_adapter import IPAdapterXL
from numpy.random import randint

base_model_path = "/ML-A100/team/mm/wangtao/share/models/stable-diffusion-xl-base-1.0"
image_encoder_path = "models/image_encoder"
ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin"
device = "cuda"

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

# load SDXL pipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    add_watermarker=False,
)

# load ip-adapter
ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)

In [None]:
# read image prompt
image = Image.open("/ML-A100/team/mm/shuyu/examples/woman.jpg")
if image.mode=='RGBA':
    image=image.convert('RGB')
    
image.resize((512, 512))

In [None]:
# generate image variations
num_samples = 3
images = ip_model.generate(pil_image=image, num_samples=num_samples, num_inference_steps=30, seed=4)
grid = image_grid(images, 1, num_samples)
grid

In [None]:
# multimodal prompts
# from numpy.random import randint
num_samples = 3
random_seed = randint(0, 2**32 - 1)
images = ip_model.generate(pil_image=image, num_samples=num_samples, num_inference_steps=30, seed=random_seed,
        prompt="A girl stands on the top of a mountain with her back to the camera. Surrounded by rolling mountains and misty clouds. The figure's back looks firm and lonely", scale=0.6)
grid = image_grid(images, 1, num_samples)
# display(grid.resize((900,510)))
grid

# ControlNet Depth

del pipe, ip_model
torch.cuda.empty_cache()

controlnet_path = "/ML-A100/team/mm/wangtao/share/models/controlnet/controlnet-depth-sdxl-1.0"
# load SDXL pipeline
controlnet = ControlNetModel.from_pretrained(controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=torch.float16).to(device)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    base_model_path,
    controlnet=controlnet,
    use_safetensors=True,
    torch_dtype=torch.float16,
    add_watermarker=False,
).to(device)
# load ip-adapter
ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)

In [None]:
# read image prompt
image = Image.open("assets/images/statue.png")
depth_map = Image.open("assets/structure_controls/depth.png").resize((1024, 1024))
image_grid([image.resize((256, 256)), depth_map.resize((256, 256))], 1, 2)

In [None]:
# generate image with structural control
num_samples = 3
images = ip_model.generate(pil_image=image, image=depth_map, controlnet_conditioning_scale=0.7, num_samples=num_samples, num_inference_steps=30, seed=42)
grid = image_grid(images, 1, num_samples)
grid