Skip to content

Commit

Permalink
Migrate StableDiffusion to Keras Core (keras-team#1982)
Browse files Browse the repository at this point in the history
* Migrate StableDiffusion to Keras Core

* Runs but wrong answer

* format

* Add tests for text encoder

* Fix softmax op

* Use arange for pos ids

* Simplify seed handing

* Fix predict call for text encoder

* Avoid use of `rank` attribute

* Fix dtype of pos ids

* Update internal layers to Keras Core

* Flatten nested folders to match other models

* Move `predict_on_batch` to dict inputs

* Fixed import for image_encoder

* Fix `None` tensor issue

* Working in JAX+Torch

* Attempt to fix GPU issue for Torch

* Another Torch GPU fix attempt

* Increase tolerance for image encoder golden value test

* Add e2e test to GCB (tf only)

* Add stable diffusion to GCB unit tests

* Fix pytest flags and create multiframework test

* format

* Fix pytest flags

* lint

---------

Co-authored-by: ianjjohnson <3072903+ianstenbit@users.noreply.github.com>
  • Loading branch information
jbischof and ianstenbit committed Sep 1, 2023
1 parent e83f229 commit 9c18f56
Show file tree
Hide file tree
Showing 14 changed files with 214 additions and 184 deletions.
1 change: 1 addition & 0 deletions cloudbuild/unit_test_jobs.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ local unittest = base.BaseTest {
'keras_cv/models/object_detection/yolo_v8',
'keras_cv/models/object_detection_3d',
'keras_cv/models/segmentation',
'keras_cv/models/stable_diffusion',
],
};

Expand Down
13 changes: 12 additions & 1 deletion keras_cv/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import keras_core
import pytest
import tensorflow as tf
from packaging import version
Expand Down Expand Up @@ -44,7 +45,11 @@ def pytest_configure(config):
)
config.addinivalue_line(
"markers",
"tf_keras_only: mark test as a tf only test",
"tf_keras_only: mark test as a tf.keras-only test",
)
config.addinivalue_line(
"markers",
"tf_only: mark test as a Tensorflow-only test",
)


Expand All @@ -68,6 +73,10 @@ def pytest_collection_modifyitems(config, items):
multi_backend(),
reason="This test is only supported on tf.keras",
)
skip_tf_only = pytest.mark.skipif(
multi_backend() and keras_core.backend.backend() != "tensorflow",
reason="This test is only supported on TensorFlow",
)
for item in items:
if "keras_format" in item.name:
item.add_marker(skip_keras_saving_test)
Expand All @@ -79,3 +88,5 @@ def pytest_collection_modifyitems(config, items):
item.add_marker(skip_extra_large)
if "tf_keras_only" in item.keywords:
item.add_marker(skip_tf_keras_only)
if "tf_only" in item.keywords:
item.add_marker(skip_tf_only)
13 changes: 0 additions & 13 deletions keras_cv/models/stable_diffusion/__internal__/layers/__init__.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from tensorflow import keras

from keras_cv.models.stable_diffusion.__internal__.layers.padded_conv2d import (
PaddedConv2D,
)
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.models.stable_diffusion.padded_conv2d import PaddedConv2D


class AttentionBlock(keras.layers.Layer):
Expand All @@ -35,20 +32,20 @@ def call(self, inputs):
q, k, v = self.q(x), self.k(x), self.v(x)

# Compute attention
shape = tf.shape(q)
shape = ops.shape(q)
h, w, c = shape[1], shape[2], shape[3]
q = tf.reshape(q, (-1, h * w, c)) # b, hw, c
k = tf.transpose(k, (0, 3, 1, 2))
k = tf.reshape(k, (-1, c, h * w)) # b, c, hw
q = ops.reshape(q, (-1, h * w, c)) # b, hw, c
k = ops.transpose(k, (0, 3, 1, 2))
k = ops.reshape(k, (-1, c, h * w)) # b, c, hw
y = q @ k
y = y * 1 / tf.sqrt(tf.cast(c, self.compute_dtype))
y = y * 1 / ops.sqrt(ops.cast(c, self.compute_dtype))
y = keras.activations.softmax(y)

# Attend to values
v = tf.transpose(v, (0, 3, 1, 2))
v = tf.reshape(v, (-1, c, h * w))
y = tf.transpose(y, (0, 2, 1))
v = ops.transpose(v, (0, 3, 1, 2))
v = ops.reshape(v, (-1, c, h * w))
y = ops.transpose(y, (0, 2, 1))
x = v @ y
x = tf.transpose(x, (0, 2, 1))
x = tf.reshape(x, (-1, h, w, c))
x = ops.transpose(x, (0, 2, 1))
x = ops.reshape(x, (-1, h, w, c))
return self.proj_out(x) + inputs
3 changes: 2 additions & 1 deletion keras_cv/models/stable_diffusion/clip_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from functools import lru_cache

import regex as re
from tensorflow import keras

from keras_cv.backend import keras


@lru_cache()
Expand Down
13 changes: 4 additions & 9 deletions keras_cv/models/stable_diffusion/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from tensorflow import keras

from keras_cv.models.stable_diffusion.__internal__.layers.attention_block import ( # noqa: E501
from keras_cv.backend import keras
from keras_cv.models.stable_diffusion.attention_block import ( # noqa: E501
AttentionBlock,
)
from keras_cv.models.stable_diffusion.__internal__.layers.padded_conv2d import (
PaddedConv2D,
)
from keras_cv.models.stable_diffusion.__internal__.layers.resnet_block import (
ResnetBlock,
)
from keras_cv.models.stable_diffusion.padded_conv2d import PaddedConv2D
from keras_cv.models.stable_diffusion.resnet_block import ResnetBlock


class Decoder(keras.Sequential):
Expand Down
65 changes: 34 additions & 31 deletions keras_cv/models/stable_diffusion/diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from tensorflow import keras

from keras_cv.models.stable_diffusion.__internal__.layers.padded_conv2d import (
PaddedConv2D,
)
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.models.stable_diffusion.padded_conv2d import PaddedConv2D


class DiffusionModel(keras.Model):
Expand All @@ -29,9 +26,11 @@ def __init__(
name=None,
download_weights=True,
):
context = keras.layers.Input((max_text_length, 768))
t_embed_input = keras.layers.Input((320,))
latent = keras.layers.Input((img_height // 8, img_width // 8, 4))
context = keras.layers.Input((max_text_length, 768), name="context")
t_embed_input = keras.layers.Input((320,), name="timestep_embedding")
latent = keras.layers.Input(
(img_height // 8, img_width // 8, 4), name="latent"
)

t_emb = keras.layers.Dense(1280)(t_embed_input)
t_emb = keras.layers.Activation("swish")(t_emb)
Expand Down Expand Up @@ -123,9 +122,11 @@ def __init__(
name=None,
download_weights=True,
):
context = keras.layers.Input((max_text_length, 1024))
t_embed_input = keras.layers.Input((320,))
latent = keras.layers.Input((img_height // 8, img_width // 8, 4))
context = keras.layers.Input((max_text_length, 1024), name="context")
t_embed_input = keras.layers.Input((320,), name="timestep_embedding")
latent = keras.layers.Input(
(img_height // 8, img_width // 8, 4), name="latent"
)

t_emb = keras.layers.Dense(1280)(t_embed_input)
t_emb = keras.layers.Activation("swish")(t_emb)
Expand Down Expand Up @@ -268,9 +269,9 @@ def call(self, inputs):
_, h, w, c = inputs.shape
x = self.norm(inputs)
x = self.proj1(x)
x = tf.reshape(x, (-1, h * w, c))
x = ops.reshape(x, (-1, h * w, c))
x = self.transformer_block([x, context])
x = tf.reshape(x, (-1, h, w, c))
x = ops.reshape(x, (-1, h, w, c))
return self.proj2(x) + inputs


Expand All @@ -287,8 +288,8 @@ def __init__(self, dim, num_heads, head_size, **kwargs):

def call(self, inputs):
inputs, context = inputs
x = self.attn1([self.norm1(inputs), None]) + inputs
x = self.attn2([self.norm2(x), context]) + x
x = self.attn1(self.norm1(inputs), context=None) + inputs
x = self.attn2(self.norm2(x), context=context) + x
return self.dense(self.geglu(self.norm3(x))) + x


Expand All @@ -303,31 +304,33 @@ def __init__(self, num_heads, head_size, **kwargs):
self.head_size = head_size
self.out_proj = keras.layers.Dense(num_heads * head_size)

def call(self, inputs):
inputs, context = inputs
context = inputs if context is None else context
def call(self, inputs, context=None):
if context is None:
context = inputs
q, k, v = self.to_q(inputs), self.to_k(context), self.to_v(context)
q = tf.reshape(q, (-1, inputs.shape[1], self.num_heads, self.head_size))
k = tf.reshape(
q = ops.reshape(
q, (-1, inputs.shape[1], self.num_heads, self.head_size)
)
k = ops.reshape(
k, (-1, context.shape[1], self.num_heads, self.head_size)
)
v = tf.reshape(
v = ops.reshape(
v, (-1, context.shape[1], self.num_heads, self.head_size)
)

q = tf.transpose(q, (0, 2, 1, 3)) # (bs, num_heads, time, head_size)
k = tf.transpose(k, (0, 2, 3, 1)) # (bs, num_heads, head_size, time)
v = tf.transpose(v, (0, 2, 1, 3)) # (bs, num_heads, time, head_size)
q = ops.transpose(q, (0, 2, 1, 3)) # (bs, num_heads, time, head_size)
k = ops.transpose(k, (0, 2, 3, 1)) # (bs, num_heads, head_size, time)
v = ops.transpose(v, (0, 2, 1, 3)) # (bs, num_heads, time, head_size)

score = td_dot(q, k) * self.scale
weights = keras.activations.softmax(
score
) # (bs, num_heads, time, time)
attn = td_dot(weights, v)
attn = tf.transpose(
attn = ops.transpose(
attn, (0, 2, 1, 3)
) # (bs, time, num_heads, head_size)
out = tf.reshape(
out = ops.reshape(
attn, (-1, inputs.shape[1], self.num_heads * self.head_size)
)
return self.out_proj(out)
Expand Down Expand Up @@ -359,7 +362,7 @@ def call(self, inputs):


def td_dot(a, b):
aa = tf.reshape(a, (-1, a.shape[2], a.shape[3]))
bb = tf.reshape(b, (-1, b.shape[2], b.shape[3]))
cc = keras.backend.batch_dot(aa, bb)
return tf.reshape(cc, (-1, a.shape[1], cc.shape[1], cc.shape[2]))
aa = ops.reshape(a, (-1, a.shape[2], a.shape[3]))
bb = ops.reshape(b, (-1, b.shape[2], b.shape[3]))
cc = keras.layers.Dot(axes=(2, 1))([aa, bb])
return ops.reshape(cc, (-1, a.shape[1], cc.shape[1], cc.shape[2]))
13 changes: 4 additions & 9 deletions keras_cv/models/stable_diffusion/image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from tensorflow import keras

from keras_cv.models.stable_diffusion.__internal__.layers.attention_block import ( # noqa: E501
from keras_cv.backend import keras
from keras_cv.models.stable_diffusion.attention_block import ( # noqa: E501
AttentionBlock,
)
from keras_cv.models.stable_diffusion.__internal__.layers.padded_conv2d import (
PaddedConv2D,
)
from keras_cv.models.stable_diffusion.__internal__.layers.resnet_block import (
ResnetBlock,
)
from keras_cv.models.stable_diffusion.padded_conv2d import PaddedConv2D
from keras_cv.models.stable_diffusion.resnet_block import ResnetBlock


class ImageEncoder(keras.Sequential):
Expand Down

0 comments on commit 9c18f56

Please sign in to comment.