# 🧨 Dreambooth-Keras + WandB 🪄🐝

[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/soumik12345/dreambooth-keras/blob/main/notebooks/generate_class_priors.ipynb)

<!--- @wandbcode{dreambooth-keras-inference} -->

This notebook shows how to generate class priors using pre-trained Stable-Diffusion for using Dreambooth.

We would use [soumik12345/dreambooth-keras](https://github.com/soumik12345/dreambooth-keras) which is a fork of [sayakpaul/dreambooth-keras](https://github.com/sayakpaul/dreambooth-keras) developed by [**Sayak Paul**](https://github.com/sayakpaul) and [**Chansung Park**](https://github.com/deep-diver).

In [None]:
!pip install -q git+https://github.com/soumik12345/dreambooth-keras.git

In [None]:
from tqdm import tqdm
import numpy as np 
import hashlib
import shutil
import PIL 
import os

import wandb
import keras_cv
import tensorflow as tf

In [None]:
wandb.init(project="dreambooth-keras", job_type="inference")

config = wandb.config
config.image_resolution = 512
config.class_prompt = "a photo of monkey"
config.num_imgs_to_generate = 300
config.batch_size = 3
config.wandb_artifact_name = "monkey-instance"


tf.keras.mixed_precision.set_global_policy("mixed_float16")
model = keras_cv.models.StableDiffusion(
    img_width=config.image_resolution,
    img_height=config.image_resolution,
    jit_compile=True
)

In [None]:
os.makedirs("class-images", exist_ok=True)

for i in tqdm(range(config.num_imgs_to_generate)):
    images = model.text_to_image(
        config.class_prompt,
        batch_size=config.batch_size,
    )
    idx = np.random.choice(len(images))
    selected_image = PIL.Image.fromarray(images[idx])
    
    hash_image = hashlib.sha1(selected_image.tobytes()).hexdigest()
    image_filename = os.path.join("class-images", f"{hash_image}.jpg")
    selected_image.save(image_filename)

artifact = wandb.Artifact(config.wandb_artifact_name, type='dataset')
artifact.add_dir("class-images")
wandb.log_artifact(artifact)

wandb.finish()

shutil.rmtree("class-images")