In [None]:
import torch
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, StableDiffusionXLPipeline
from PIL import Image

from ip_adapter import IPAdapterXL
from numpy.random import randint

base_model_path = "/ML-A100/team/mm/wangtao/share/models/stable-diffusion-xl-base-1.0"
image_encoder_path = "models/image_encoder"
ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin"
device = "cuda"

# load SDXL pipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    add_watermarker=False,
)

# load ip-adapter
ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)



In [None]:
def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

# read image prompt
image_paths=["/ML-A100/team/mm/shuyu/IP-Adapter-fox/assets/images/woman.png","/ML-A100/team/mm/shuyu/IP-Adapter-fox/assets/images/woman1.jpg"]
input_images=[Image.open(path) for path in image_paths]
for i,image in enumerate(input_images):
    if image.mode=='RGBA':
        input_images[i]=image.convert('RGB')
    input_images[i]=image.resize((512,512))
grid=image_grid(input_images,1,len(input_images))
grid

In [None]:

# multimodal prompts
num_samples = 3
random_seed = randint(0, 2**32 - 1)
images = ip_model.generate(pil_image=input_images, num_samples=num_samples, num_inference_steps=30, seed=random_seed,
        prompt="A girl looks back", scale=0.6)
print(len(images))
grid = image_grid(images, 1, len(images))
grid