# 🧨 Dreambooth-Keras + WandB 🪄🐝

[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/soumik12345/dreambooth-keras/blob/main/notebooks/inference_wandb.ipynb)

<!--- @wandbcode{dreambooth-keras-inference} -->

This notebook shows how to perform inference with a DreamBooth fine-tuned Stable Diffusion model.

## 🌈 Install Dreambooth-Keras

We would use [soumik12345/dreambooth-keras](https://github.com/soumik12345/dreambooth-keras) which is a fork of [sayakpaul/dreambooth-keras](https://github.com/sayakpaul/dreambooth-keras) developed by [**Sayak Paul**](https://github.com/sayakpaul) and [**Chansung Park**](https://github.com/deep-diver).

In [None]:
!pip install -q git+https://github.com/soumik12345/dreambooth-keras.git

In [None]:
import wandb
from PIL import Image
from dreambooth_keras.utils import load_model_from_wandb_artifact

## 🐝 Initialize WandB run

We initialize a [Weights & Biases run](https://docs.wandb.ai/guides/runs) for storing generated images to a [Weights & Biases table](https://docs.wandb.ai/guides/data-vis).

In [None]:
wandb.init(project="dreambooth-keras", job_type="inference")

config = wandb.config
config.model_artifact_address = "geekyrakshit/dreambooth-keras/run_n5oakq7c_model:v0"
config.image_resolution = 512
config.num_diffusion_steps = 500
config.batch_size = 5
config.unique_id = "sks"
config.class_category = "monkey"
config.prompt = "a painting of sks monkey in the style of Michelangelo"
config.unconditional_guidance_scale = 15


wandb_table = wandb.Table(columns=[
    "prompt", "images", "unique-id", "class-category","image-resolution", "num-diffusion-steps"
])

## 🧑‍🎨 Perform Inference

First we load our model from Weights & Biases artifacts created using the [`dreambooth_keras.utils.DreamBoothCheckpointCallback`](https://github.com/soumik12345/dreambooth-keras/blob/main/dreambooth_keras/utils.py#L93) which automatically logs model checkpoints as [Weights & Biases artifacts](https://docs.wandb.ai/guides/data-and-model-versioning) at the end of each epoch during training. We load these checkpoint using the simple utility [`dreambooth_keras.utils.load_model_from_wandb_artifact`](https://github.com/soumik12345/dreambooth-keras/blob/main/dreambooth_keras/utils.py#L23).

In [None]:
dreambooth_model = load_model_from_wandb_artifact(
    artifact_address=config.model_artifact_address,
    image_resolution=config.image_resolution
)

Now, we perform inference on our *dreamboothed* stable-diffusion model.

In [None]:
images = dreambooth_model.text_to_image(
    config.prompt,
    batch_size=config.batch_size,
    num_steps=config.num_diffusion_steps,
    unconditional_guidance_scale=config.unconditional_guidance_scale,
)

Next we log our images to a [Weights & Biases table](https://docs.wandb.ai/guides/data-vis) that not only makes ut easier to visualize but also easily accessible for future reference.

In [None]:
images = [
    wandb.Image(Image.fromarray(image), caption=f"{i}: {config.prompt}")
    for i, image in enumerate(images)
]
wandb_table.add_data(
    config.prompt,
    images,
    config.unique_id,
    config.class_category,
    config.image_resolution,
    config.num_diffusion_steps
)
wandb.log({"Inference-Results": wandb_table})

In [None]:
wandb.finish()