In [5]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

num_filters_encoder = [128, 256, 512, 512, 512]
num_filters_decoder = [512, 512, 512, 256, 128]
latent_dim = 1024

(x_train, _), (x_test, _) = keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

print("x_train.shape =", x_train.shape)
print("x_test.shape  =", x_test.shape)

########################################
# 2. Побудова енкодера
########################################
encoder_inputs = keras.Input(shape=(32, 32, 3))

x = encoder_inputs
for filters in num_filters_encoder:
    x = layers.Conv2D(
        filters, kernel_size=3, strides=2, padding="same", use_bias=False
    )(x)
    x = layers.LayerNormalization()(x)  # Changed from InstanceNormalization
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Dropout(0.3)(x)

x = layers.Flatten()(x)
x = layers.Dense(1024, activation='relu')(x)

mu = layers.Dense(latent_dim, name='mu')(x)
logvar = layers.Dense(latent_dim, name='logvar')(x)

def sampling(args):
    mu_, logvar_ = args
    epsilon = tf.random.normal(shape=tf.shape(mu_))
    return mu_ + tf.exp(0.5 * logvar_) * epsilon

z = layers.Lambda(sampling, name='z')([mu, logvar])

encoder = keras.Model(encoder_inputs, [mu, logvar, z], name='encoder')
encoder.summary()

########################################
# 3. Побудова декодера
########################################
latent_inputs = keras.Input(shape=(latent_dim,))

x = layers.Dense(1 * 1 * num_filters_decoder[0], activation='relu')(latent_inputs)
x = layers.Reshape((1, 1, num_filters_decoder[0]))(x)

for filters in num_filters_decoder:
    x = layers.Conv2DTranspose(
        filters, kernel_size=3, strides=2, padding="same", use_bias=False
    )(x)
    x = layers.LayerNormalization()(x)  # Changed from InstanceNormalization
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Dropout(0.3)(x)

decoder_outputs = layers.Conv2DTranspose(3, kernel_size=3, padding='same', activation='sigmoid')(x)

decoder = keras.Model(latent_inputs, decoder_outputs, name='decoder')
decoder.summary()

class VAE(keras.Model):
    def __init__(self, encoder, decoder, alpha=0.5, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.alpha = alpha

        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def call(self, inputs, training=None, mask=None):
        mu_, logvar_, z_ = self.encoder(inputs)
        return self.decoder(z_)

    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]

        with tf.GradientTape() as tape:
            mu_, logvar_, z_ = self.encoder(data, training=True)
            reconstruction = self.decoder(z_, training=True)

            reconstruction_loss_mse = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.mean_squared_error(data, reconstruction),
                    axis=(1,2)
                )
            )

            reconstruction_loss_l1 = tf.reduce_mean(
                tf.reduce_sum(
                    tf.abs(data - reconstruction),
                    axis=(1,2,3)
                )
            )

            reconstruction_loss = (1.0 - self.alpha)*reconstruction_loss_mse + self.alpha*reconstruction_loss_l1

            kl_loss = 1 + logvar_ - tf.square(mu_) - tf.exp(logvar_)
            kl_loss = -0.5 * tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))

            total_loss = reconstruction_loss + kl_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

vae_model = VAE(encoder, decoder)

initial_learning_rate = 1e-4
lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=initial_learning_rate,
    decay_steps=5000
)

optimizer = keras.optimizers.AdamW(learning_rate=lr_schedule, weight_decay=1e-4)

vae_model.compile(optimizer=optimizer, loss=lambda y_true, y_pred: 0.0)

class GenerateImagesCallback(keras.callbacks.Callback):
    def __init__(self, model, latent_dim=512, interval=10, n_images=5):
        super().__init__()
        self._model = model
        self.latent_dim = latent_dim
        self.interval = interval
        self.n_images = n_images

    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.interval == 0:
            z_random = tf.random.normal(shape=(self.n_images, self.latent_dim))
            generated_images = self._model.decoder(z_random)

            print(f"\n[Callback] Згенеровані зображення після {epoch+1} епох:")
            plt.figure(figsize=(15,3))
            for i in range(self.n_images):
                ax = plt.subplot(1, self.n_images, i+1)
                plt.imshow(generated_images[i])
                plt.axis("off")
            plt.suptitle(f"Epoch {epoch+1}")
            plt.show()

history = vae_model.fit(
    x_train, x_train,
    epochs=100,            
    batch_size=1024,
    validation_data=(x_test, x_test),
    callbacks=[GenerateImagesCallback(vae_model, latent_dim=latent_dim, interval=10, n_images=5)]
)

n_to_generate = 10
z_sample = tf.random.normal(shape=(n_to_generate, latent_dim))
generated_images = vae_model.decoder(z_sample)

def plot_images(images, n=10):
    plt.figure(figsize=(20, 4))
    for i in range(n):
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(images[i])
        plt.axis("off")
    plt.show()

plot_images(generated_images.numpy(), n=n_to_generate)

x_test_subset = x_test[:10]
mu, logvar, z = encoder(x_test_subset)
x_test_reconstructed = decoder(z)

def plot_original_vs_recon(original, reconstructed):
    n = len(original)
    plt.figure(figsize=(20, 4))
    for i in range(n):
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(original[i])
        plt.title("Original")
        plt.axis("off")

        ax = plt.subplot(2, n, n + i + 1)
        plt.imshow(reconstructed[i])
        plt.title("Reconstructed")
        plt.axis("off")
    plt.show()

plot_original_vs_recon(x_test_subset, x_test_reconstructed.numpy())

x_train.shape = (50000, 32, 32, 3)
x_test.shape  = (10000, 32, 32, 3)




Epoch 1/100


AttributeError: module 'keras._tf_keras.keras.losses' has no attribute 'mean_squared_error'