## Setup


In [None]:
!pip install git+https: // github.com/keras-team/keras-cv -q

In [None]:
import time
import keras_cv
from tensorflow import keras
import matplotlib.pyplot as plt

print(keras_cv.__version__)


def plot_images(images, title):
  plt.figure(figsize=(20, 20))
  for i in range(len(images)):
    ax = plt.subplot(1, len(images), i + 1)
    plt.title(title)
    plt.imshow(images[i])
    plt.axis("off")

## Prepare the models

In [None]:
model = keras_cv.models.StableDiffusion(img_width=512, img_height=512)

## 이미지 생성

In [None]:
images = model.text_to_image("photograph of an astronaut riding a horse", batch_size=3)
plot_images(images, "간단한 프롬프트")

In [None]:
images = model.text_to_image(
  "cute magical flying dog, fantasy art, "
  "golden color, high quality, highly detailed, elegant, sharp focus, "
  "concept art, character concepts, digital painting, mystery, adventure",
  batch_size=3,
)
plot_images(images, "복잡한 프롬프트")

## 벤치마킹

### 기본 모델

In [None]:
benchmark_result = []
start = time.time()
images = model.text_to_image(
  "A cute otter in a rainbow whirlpool holding shells, watercolor",
  batch_size=3,
)
end = time.time()
benchmark_result.append(["Standard", end - start])
plot_images(images)

print(f"Standard model: {(end - start):.2f} seconds")
keras.backend.clear_session()  # Clear session to preserve memory.

print("Compute dtype:", model.diffusion_model.compute_dtype)
print(
  "Variable dtype:",
  model.diffusion_model.variable_dtype,
)

### 혼합정밀도(Mixed precision)

In [None]:
keras.mixed_precision.set_global_policy("mixed_float16")
model = keras_cv.models.StableDiffusion()
print("Compute dtype:", model.diffusion_model.compute_dtype)
print(
  "Variable dtype:",
  model.diffusion_model.variable_dtype,
)

In [None]:
# Warm up model to run graph tracing before benchmarking.
model.text_to_image("warming up the model", batch_size=3)

start = time.time()
images = model.text_to_image(
  "a cute magical flying dog, fantasy art, "
  "golden color, high quality, highly detailed, elegant, sharp focus, "
  "concept art, character concepts, digital painting, mystery, adventure",
  batch_size=3,
)
end = time.time()
benchmark_result.append(["Mixed Precision", end - start])
plot_images(images, "Mixed Precision")

print(f"Mixed precision model: {(end - start):.2f} seconds")
keras.backend.clear_session()


### XLA Compilation

`jit_compile=True` 일때 XLA 컴파일 활성화

In [None]:
# Set back to the default for benchmarking purposes.
keras.mixed_precision.set_global_policy("float32")

model = keras_cv.models.StableDiffusion(jit_compile=True)
# Before we benchmark the model, we run inference once to make sure the TensorFlow
# graph has already been traced.
images = model.text_to_image("An avocado armchair", batch_size=3)
plot_images(images, "XLA Compilation")

In [None]:
start = time.time()
images = model.text_to_image(
  "A cute otter in a rainbow whirlpool holding shells, watercolor",
  batch_size=3,
)
end = time.time()
benchmark_result.append(["XLA", end - start])
plot_images(images, "XLA Compilation")

print(f"With XLA: {(end - start):.2f} seconds")
keras.backend.clear_session()

### XLA + 혼합정밀도

In [None]:
keras.mixed_precision.set_global_policy("mixed_float16")
model = keras_cv.models.StableDiffusion(jit_compile=True)

# Let's make sure to warm up the model
images = model.text_to_image(
  "Teddy bears conducting machine learning research",
  batch_size=3,
)
plot_images(images,"XLA + Mixed Precision")

In [None]:
start = time.time()
images = model.text_to_image(
  "A mysterious dark stranger visits the great pyramids of egypt, "
  "high quality, highly detailed, elegant, sharp focus, "
  "concept art, character concepts, digital painting",
  batch_size=3,
)
end = time.time()
benchmark_result.append(["XLA + Mixed Precision", end - start])
plot_images(images,"XLA + Mixed Precision")

print(f"XLA + mixed precision: {(end - start):.2f} seconds")

### 결과 확인

In [None]:
print("{:<20} {:<20}".format("Model", "Runtime"))
for result in benchmark_result:
  name, runtime = result
  print("{:<20} {:<20}".format(name, runtime))