In [None]:
!pip install 'git+https://github.com/salesforce/LAVIS.git'

In [None]:
import torch

from PIL import Image
from lavis.models import load_model_and_preprocess

In [None]:
torch.cuda.is_available()

In [None]:
model, vis_preprocess, txt_preprocess = load_model_and_preprocess("blip_diffusion", "base", device="cuda", is_eval=True)

### Description
This demo shows how to edit a **real** image with a finetuned checkpoint on a given subject. It works in the following steps:

(1) load the finetuned checkpoint.

(2) run DDIM inversion on the given image using prompt ``A ${src_subject} ${prompt}.``;

(3) extracting BLIP-2 embeddings on condition subject image, using ``cond_subject`` and ``cond_image``.

(4) edit the real image with the subject visuals, using the prompt ``A ${BLIP-2 embedding} ${tgt_subject} ${prompt}`` and the DDIM inverted latents.

In [None]:
cond_subject = "dog"
src_subject = "cat"
tgt_subject = "dog"

text_prompt = "sit on sofa"

cond_subject = txt_preprocess["eval"](cond_subject)
src_subject = txt_preprocess["eval"](src_subject)
tgt_subject = txt_preprocess["eval"](tgt_subject)
text_prompt = [txt_preprocess["eval"](text_prompt)]

src_image = Image.open("../images/cat-sofa.png").convert("RGB")
display(src_image.resize((256, 256)))

In [None]:
finetuned_ckpt = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP-Diffusion/db-dog/checkpoint_40.pth"
# can also use a local checkpoint
# finetuned_ckpt = "../checkpoints/db-dog/checkpoint_40.pth"
model.load_checkpoint(finetuned_ckpt)

In [None]:
samples = {
    "cond_images": None,
    "cond_subject": cond_subject,
    "src_subject": src_subject,
    "tgt_subject": tgt_subject,
    "prompt": text_prompt,
    "raw_image": src_image,
}

In [None]:
iter_seed = 8887
guidance_scale = 7.5
num_inference_steps = 50 
num_inversion_steps = 50 # increase to improve DDIM inversion quality
lb_threshold = 0.3 # increase to edit fewer pixels.
negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"

output = model.edit(
    samples,
    seed=iter_seed,
    guidance_scale=guidance_scale,
    num_inference_steps=num_inference_steps,
    num_inversion_steps=num_inversion_steps,
    neg_prompt=negative_prompt,
    lb_threshold=lb_threshold,
)

print("=" * 30)
print("Before editing:")
display(output[0])

print("After editing:")
display(output[1])

In [None]:
output[0].size