In [None]:
!pip install 'git+https://github.com/salesforce/LAVIS.git'

In [None]:
import torch

from PIL import Image
from lavis.models import load_model_and_preprocess

In [None]:
torch.cuda.is_available()

In [None]:
model, vis_preprocess, txt_preprocess = load_model_and_preprocess("blip_diffusion", "base", device="cuda", is_eval=True)

### Description
This demo shows how to render different renditions of the given subject in a zero-shot setup.

(1) extracting BLIP-2 embeddings on ``cond_subject`` and ``cond_image``.

(2) Generating on prompts: "A ``${BLIP-2 embedding} ${tgt_subject} ${text_prompt}``".

### Tips
``tgt_subject`` can be a different subject from the ``cond_subject``. For example, if ``cond_subject="dog"`` (and you use a dog image as condition), and ``tgt_subject=="tiger"``, you'd expect the model to generate a tiger that looks like this particular dog. 

In [None]:
cond_subject = "dog"
tgt_subject = "dog"
# prompt = "painting by van gogh"
text_prompt = "swimming underwater"

cond_subjects = [txt_preprocess["eval"](cond_subject)]
tgt_subjects = [txt_preprocess["eval"](tgt_subject)]
text_prompt = [txt_preprocess["eval"](text_prompt)]

cond_image = Image.open("../images/dog.png").convert("RGB")
display(cond_image.resize((256, 256)))

cond_images = vis_preprocess["eval"](cond_image).unsqueeze(0).cuda()


In [None]:
samples = {
    "cond_images": cond_images,
    "cond_subject": cond_subjects,
    "tgt_subject": tgt_subjects,
    "prompt": text_prompt,
}

In [None]:
num_output = 4

iter_seed = 88888
guidance_scale = 7.5
num_inference_steps = 50
negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"

for i in range(num_output):
    output = model.generate(
        samples,
        seed=iter_seed + i,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        neg_prompt=negative_prompt,
        height=512,
        width=512,
    )

    display(output[0])
