# Let's perform inference!
<!--- @wandbcode{ethz-hackathon} -->


In [3]:
from pathlib import Path
from types import SimpleNamespace
import torch, wandb
from miniminiai import show_images
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler

You need to load the original model weights first:

In [None]:

config = SimpleNamespace(
    model_base="runwayml/stable-diffusion-v1-5",
    device="cuda",
)

WANDB_PROJECT_NAME="ethz-hackathon"
LORA_WEIGHTS_AT="your_lora_weights_artifact"

Create the diffusion Pipeline

In [None]:

pipe = DiffusionPipeline.from_pretrained(config.model_base, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

Let's create a W&B run

In [None]:


run = wandb.init(project=WANDB_PROJECT_NAME, job_type="inference", group="lora", config=config)

Let's create a Table to keep track of all our generations

In [None]:
table = wandb.Table(columns=["image", "num_inference_steps", "guidance_scale", "cross_attn_scale"])

Load the LoRA attention weights on top of the pretrained Stable Diffusion model

In [None]:
at = run.use_artifact(LORA_WEIGHTS_AT)
lora_model_path = Path(at.download()).glob("*.bin")[0]

# older diffusers style loading
# pipe.unet.load_attn_procs(lora_model_path)
pipe.load_attn_procs(lora_model_path)
pipe.to(config.device)

Define some prompts

In [None]:
prompts = ["A pokemon with blue eyes.", 
           "A pokemon that looks like a dog", 
           "a pokemon that smiles a lot", 
           "a sea pokemon in the form of a star"]

Generating with 50% conditioning (interpolating base weights and LoRA weigths)

In [None]:
def generate_images(prompts, num_inference_steps=25, guidance_scale=7.5, cross_attn_scale=0.5):
    images = pipe(
        prompts,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        cross_attention_kwargs={"scale": cross_attn_scale},
    ).images
    for img in images:
        table.add_data(wandb.Image(img),  num_inference_steps, guidance_scale, cross_attn_scale)
    return images

In [None]:
images = generate_images(prompts)
show_images(images, titles=prompts)

Full LoRA

In [None]:

images = generate_images(prompts, cross_attn_scale=1.0)
show_images(images, titles=prompts)