In [None]:
from functools import partial

import torch
torch.cuda.empty_cache()

from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
from wandb_addons.diffusers import StableDiffusionXLCallback

In [None]:
base_pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
)

base_pipeline.enable_model_cpu_offload()

In [None]:
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

num_inference_steps = 50

callback = StableDiffusionXLCallback(
    pipeline=base_pipeline,
    prompt=prompt,
    wandb_project="diffusers-sdxl",
    wandb_entity="geekyrakshit",
    weave_mode=True,
    num_inference_steps=num_inference_steps,
    initial_stage_name="base",
)

image = base_pipeline(
    prompt=prompt,
    output_type="latent",
    num_inference_steps=num_inference_steps,
    callback=partial(callback, end_experiment=False)
).images[0]

In [None]:
refiner_pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0",
    text_encoder_2=base_pipeline.text_encoder_2,
    vae=base_pipeline.vae,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
)
refiner_pipeline.enable_model_cpu_offload()

In [None]:
num_inference_steps = 50
strength = 0.3

callback.add_refiner_stage(refiner_pipeline, num_inference_steps=num_inference_steps, strength=strength)

image = refiner_pipeline(prompt=prompt, image=image[None, :], callback=callback).images[0]