In [None]:
import tensorflow as tf
import numpy as np
from tensorflow import keras
from keras import layers, models
from transformers import ViTFeatureExtractor, TFViTModel
from PIL import Image
from datetime import datetime
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import cv2
import pickle
import re, math
import shutil
import random
import tensorflow_addons as tfa

import pandas as pd

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        # Select GPU number 1
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

1 Physical GPUs, 1 Logical GPUs


In [3]:
from transformers import CLIPProcessor, TFCLIPModel, CLIPTokenizer

# Load the CLIP model and processor
model_name = "openai/clip-vit-base-patch32"
clip_model = TFCLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
tokenizer = CLIPTokenizer.from_pretrained(model_name)

All model checkpoint layers were used when initializing TFCLIPModel.

All the layers of TFCLIPModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFCLIPModel for predictions without further training.


## Set up dataset

In [5]:
# File paths
caption_file_path = "./2024-datalab-cup3-reverse-image-caption/dataset/text2ImgData.pkl"
id2word_file_path = "./2024-datalab-cup3-reverse-image-caption/dictionary/id2Word.npy"

# Load the .pkl file
with open(caption_file_path, 'rb') as file:
    captions_data = pickle.load(file)

# Convert to DataFrame
captions_df = pd.DataFrame(captions_data)

# Load the id2Word mapping
data = np.load(id2word_file_path)

# Convert array to dictionary
id2word = {int(row[0]): row[1] for row in data}


In [6]:
# Batch size for processing
BATCH_SIZE = 1024  # Adjust based on your GPU memory

# Prepare captions
all_captions = []
all_image_paths = []

for index, row in tqdm(captions_df.iterrows(), total=captions_df.shape[0], desc="Preparing captions"):
    image_path = row['ImagePath']
    captions = row['Captions']
    all_image_paths.extend([image_path] * len(captions))
    all_captions.extend(captions)

# Function to process text embeddings
def batch_text_embeddings(captions_batch):
    texts = [" ".join([id2word.get(int(word_id), "<UNK>") for word_id in caption]) for caption in captions_batch]
    text_inputs = tokenizer(texts, return_tensors="tf", padding=True, truncation=True)
    text_embeddings = clip_model.get_text_features(input_ids=text_inputs["input_ids"])
    text_embeddings = tf.nn.l2_normalize(text_embeddings, axis=-1)
    return text_embeddings

print(all_image_paths)
# Batched processing
image_text_pairs = []
for i in tqdm(range(0, len(all_captions), BATCH_SIZE), desc="Processing batches"):
    batch_captions = all_captions[i:i + BATCH_SIZE]
    batch_image_paths = all_image_paths[i:i + BATCH_SIZE]
    batch_embeddings = batch_text_embeddings(batch_captions)
    image_text_pairs.extend(zip(batch_image_paths, batch_embeddings.numpy()))

# Create TensorFlow Dataset
dataset = tf.data.Dataset.from_generator(
    lambda: iter(image_text_pairs),
    output_signature=(
        tf.TensorSpec(shape=(), dtype=tf.string),  # Image path
        tf.TensorSpec(shape=(512,), dtype=tf.float32)  # Text embeddings
    )
)

# Preview the dataset
for image_path, text_embedding in dataset.take(2):
    print("Image Path:", image_path.numpy().decode())
    print("Text Embedding Shape:", text_embedding.shape)

Preparing captions:   0%|          | 0/7370 [00:00<?, ?it/s]

['./102flowers/image_06734.jpg', './102flowers/image_06734.jpg', './102flowers/image_06734.jpg', './102flowers/image_06734.jpg', './102flowers/image_06734.jpg', './102flowers/image_06734.jpg', './102flowers/image_06734.jpg', './102flowers/image_06734.jpg', './102flowers/image_06734.jpg', './102flowers/image_06736.jpg', './102flowers/image_06736.jpg', './102flowers/image_06736.jpg', './102flowers/image_06736.jpg', './102flowers/image_06736.jpg', './102flowers/image_06736.jpg', './102flowers/image_06736.jpg', './102flowers/image_06736.jpg', './102flowers/image_06736.jpg', './102flowers/image_06736.jpg', './102flowers/image_06737.jpg', './102flowers/image_06737.jpg', './102flowers/image_06737.jpg', './102flowers/image_06737.jpg', './102flowers/image_06737.jpg', './102flowers/image_06737.jpg', './102flowers/image_06737.jpg', './102flowers/image_06737.jpg', './102flowers/image_06737.jpg', './102flowers/image_06738.jpg', './102flowers/image_06738.jpg', './102flowers/image_06738.jpg', './102f

Processing batches:   0%|          | 0/69 [00:00<?, ?it/s]

Image Path: ./102flowers/image_06734.jpg
Text Embedding Shape: (512,)
Image Path: ./102flowers/image_06734.jpg
Text Embedding Shape: (512,)


In [7]:
output_file = "./2024-datalab-cup3-reverse-image-caption/image_text_pairs.pkl"

with open(output_file, "wb") as f:
    pickle.dump(image_text_pairs, f)

print(f"Image-text pairs saved to {output_file}")

Image-text pairs saved to ./2024-datalab-cup3-reverse-image-caption/image_text_pairs.pkl


### Load from pkl

In [10]:
import pickle
import tensorflow as tf

# Path to the stored file
input_file = "./2024-datalab-cup3-reverse-image-caption/dataset/image_text_pairs.pkl"

# Load the image-text pairs
with open(input_file, "rb") as f:
    image_text_pairs = pickle.load(f)

print(f"Loaded {len(image_text_pairs)} image-text pairs.")

# Create a TensorFlow Dataset from the loaded data
dataset = tf.data.Dataset.from_generator(
    lambda: iter(image_text_pairs),
    output_signature=(
        tf.TensorSpec(shape=(), dtype=tf.string),  # Image path
        tf.TensorSpec(shape=(512,), dtype=tf.float32)  # Text embeddings
    )
)

# Preview the dataset
for image_path, text_embedding in dataset.take(3):
    print("Image Path:", image_path.numpy().decode())
    print("Text Embedding Shape:", text_embedding.shape)


Loaded 70504 image-text pairs.
Image Path: ./102flowers/image_06734.jpg
Text Embedding Shape: (512,)
Image Path: ./102flowers/image_06734.jpg
Text Embedding Shape: (512,)
Image Path: ./102flowers/image_06734.jpg
Text Embedding Shape: (512,)


In [8]:
# Feel free to change these parameters according to your system's configuration
# data
dataset_name = "oxford_flowers102"
dataset_repetitions = 1
num_epochs = 50
image_size = 224
suffle_times = 10

# KID = Kernel Inception Distance, see related section
kid_image_size = 75
kid_diffusion_steps = 20
plot_diffusion_steps = 50

# sampling
min_signal_rate = 0.02
max_signal_rate = 0.95

# architecture
embedding_dims = 32
embedding_max_frequency = 1000.0
widths = [32, 64, 96, 128]
block_depth = 2

# optimization
batch_size = 32
ema = 0.999
learning_rate = 1e-3
weight_decay = 1e-4

In [9]:
from sklearn.model_selection import train_test_split

# Train-validation split
train_pairs, val_pairs = train_test_split(image_text_pairs, test_size=0.05, random_state=42)

# Convert to TensorFlow datasets
def create_dataset(pairs):
    return tf.data.Dataset.from_generator(
        lambda: iter(pairs),
        output_signature=(
            tf.TensorSpec(shape=(), dtype=tf.string),  # Image path
            tf.TensorSpec(shape=(512,), dtype=tf.float32)   # Caption
        )
    )

train_dataset = create_dataset(train_pairs)
val_dataset = create_dataset(val_pairs)

# Preprocess function with augmentation for training
def preprocess_image_with_augmentation(image_path, caption):
    # Load and preprocess image
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [224, 224]) / 255.0  # Resize and normalize

    # Data augmentation
    image = tf.image.random_flip_left_right(image)  # Random horizontal flip
    image = tf.image.random_flip_up_down(image)  # Random vertical flip
    image = tf.image.random_brightness(image, max_delta=0.1)  # Random brightness

    return image, caption

# Preprocess function for validation (no augmentation)
def preprocess_image_without_augmentation(image_path, caption):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [224, 224]) / 255.0  # Resize and normalize
    return image, caption

# Apply preprocessing
train_dataset = (
    train_dataset
    .map(preprocess_image_with_augmentation, num_parallel_calls=tf.data.AUTOTUNE)
    .shuffle(buffer_size=128)
    .batch(batch_size, drop_remainder=True)
    .prefetch(tf.data.AUTOTUNE)
)

val_dataset = (
    val_dataset
    .map(preprocess_image_without_augmentation, num_parallel_calls=tf.data.AUTOTUNE)
    .shuffle(buffer_size=128)
    .batch(batch_size, drop_remainder=True)
    .prefetch(tf.data.AUTOTUNE)
)

# Preview the training dataset
for images, captions in train_dataset.take(1):
    print("Image Batch Shape:", images.shape)
    print("Caption Batch:", captions)


Image Batch Shape: (32, 224, 224, 3)
Caption Batch: tf.Tensor(
[[-4.0519927e-02 -3.6175391e-03  4.7331288e-02 ...  3.2521818e-02
   1.0116545e-02  6.4305763e-04]
 [-1.4634279e-02 -1.1754892e-02  2.8505048e-03 ... -6.4958604e-03
   3.9020576e-02  1.7932557e-02]
 [-1.7231159e-02  3.9212320e-02  3.3312727e-02 ...  1.5258783e-03
  -8.7921387e-03  1.6692292e-02]
 ...
 [-2.1876303e-02  7.4624244e-05  2.6791889e-02 ...  9.0900036e-03
   2.6402872e-02 -4.8067416e-03]
 [-4.4363417e-02  3.3815790e-02  3.7432138e-02 ...  3.8585838e-02
   1.0856384e-02 -6.3694427e-03]
 [-3.5952233e-02  1.5792623e-02  2.4134180e-02 ...  2.6423601e-02
   4.5691095e-02  3.5487693e-02]], shape=(32, 512), dtype=float32)


### Toy dataset

In [10]:
# Create a toy training dataset with the first 2 batches
toy_train_dataset = (
    train_dataset.take(2)  # Take the first 2 batches
)

# Preview the toy dataset
for images, captions in toy_train_dataset:
    print("Image Batch Shape:", images.shape)
    print("Caption Batch Shape:", captions.shape)


Image Batch Shape: (32, 224, 224, 3)
Caption Batch Shape: (32, 512)
Image Batch Shape: (32, 224, 224, 3)
Caption Batch Shape: (32, 512)


# Model

In [11]:
class KID(keras.metrics.Metric):
    def __init__(self, name, **kwargs):
        super().__init__(name=name, **kwargs)

        # KID is estimated per batch and is averaged across batches
        self.kid_tracker = keras.metrics.Mean(name="kid_tracker")

        # a pretrained InceptionV3 is used without its classification layer
        # transform the pixel values to the 0-255 range, then use the same
        # preprocessing as during pretraining
        self.encoder = keras.Sequential(
            [
                keras.Input(shape=(image_size, image_size, 3)),
                layers.Rescaling(255.0),
                layers.Resizing(height=kid_image_size, width=kid_image_size),
                layers.Lambda(keras.applications.inception_v3.preprocess_input),
                keras.applications.InceptionV3(
                    include_top=False,
                    input_shape=(kid_image_size, kid_image_size, 3),
                    weights="imagenet",
                ),
                layers.GlobalAveragePooling2D(),
            ],
            name="inception_encoder",
        )

    def polynomial_kernel(self, features_1, features_2):
        # Use TensorFlow functions instead of ops
        feature_dimensions = tf.cast(tf.shape(features_1)[1], dtype="float32")
        return (tf.matmul(features_1, tf.transpose(features_2)) / feature_dimensions + 1.0) ** 3.0

    def update_state(self, real_images, generated_images, sample_weight=None):
        real_features = self.encoder(real_images, training=False)
        generated_features = self.encoder(generated_images, training=False)

        # compute polynomial kernels using the two sets of features
        kernel_real = self.polynomial_kernel(real_features, real_features)
        kernel_generated = self.polynomial_kernel(generated_features, generated_features)
        kernel_cross = self.polynomial_kernel(real_features, generated_features)

        # estimate the squared maximum mean discrepancy using the average kernel values
        batch_size = tf.shape(real_features)[0]
        batch_size_f = tf.cast(batch_size, dtype="float32")
        mean_kernel_real = tf.reduce_sum(kernel_real * (1.0 - tf.eye(batch_size))) / (
            batch_size_f * (batch_size_f - 1.0)
        )
        mean_kernel_generated = tf.reduce_sum(kernel_generated * (1.0 - tf.eye(batch_size))) / (
            batch_size_f * (batch_size_f - 1.0)
        )
        mean_kernel_cross = tf.reduce_mean(kernel_cross)
        kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross

        # update the average KID estimate
        self.kid_tracker.update_state(kid)

    def result(self):
        return self.kid_tracker.result()

    def reset_state(self):
        self.kid_tracker.reset_state()

In [12]:
def sinusoidal_embedding(x, embedding_dims, embedding_max_frequency=1000.0):
    embedding_min_frequency = 1.0
    frequencies = tf.math.exp(
        tf.linspace(
            tf.math.log(embedding_min_frequency),
            tf.math.log(embedding_max_frequency),
            embedding_dims // 2
        )
    )
    angular_speeds = tf.cast(2.0 * math.pi * frequencies, "float32")
    embeddings = tf.concat(
        [tf.sin(angular_speeds * x), tf.cos(angular_speeds * x)], axis=-1
    )
    return embeddings  # Shape: (batch_size, 1, 1, embedding_dims)



def ResidualBlock(width, context=None):
    def apply(x):
        residual = x
        # Adjust residual channels if necessary
        if residual.shape[-1] != width:
            residual = layers.Conv2D(width, kernel_size=1, padding='same')(residual)
        x = layers.BatchNormalization(center=False, scale=False)(x)
        x = layers.Activation('swish')(x)
        x = layers.Conv2D(width, kernel_size=3, padding="same")(x)
        if context is not None:
            x = CrossAttention(embed_dim=width)(x, context)
        x = layers.BatchNormalization(center=False, scale=False)(x)
        x = layers.Activation('swish')(x)
        x = layers.Conv2D(width, kernel_size=3, padding="same")(x)
        x = layers.Add()([x, residual])
        return x

    return apply


def DownBlock(width, block_depth):
    def apply(x):
        x, skips = x
        for _ in range(block_depth):
            x = ResidualBlock(width)(x)
            skips.append(x)
        x = layers.AveragePooling2D(pool_size=2)(x)
        return x

    return apply


def UpBlock(width, block_depth):
    def apply(x):
        x, skips = x
        x = layers.UpSampling2D(size=2, interpolation="bilinear")(x)
        for _ in range(block_depth):
            x = layers.Concatenate()([x, skips.pop()])
            x = ResidualBlock(width)(x)
        return x

    return apply

class CrossAttention(layers.Layer):
    def __init__(self, embed_dim, num_heads=4, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.attention = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim // num_heads
        )
        self.context_proj = layers.Dense(embed_dim)

    def call(self, x, context):
        # x: (batch_size, height, width, channels)
        batch_size = tf.shape(x)[0]
        height = tf.shape(x)[1]
        width = tf.shape(x)[2]
        channels = x.shape[-1]
        x_reshaped = tf.reshape(x, (batch_size, height * width, channels))
        
        # Project context to match embed_dim
        context_proj = self.context_proj(context)  # (batch_size, embed_dim)
        context_proj = tf.expand_dims(context_proj, axis=1)  # (batch_size, 1, embed_dim)
        
        # Apply multi-head attention
        attention_output = self.attention(
            query=x_reshaped, key=context_proj, value=context_proj
        )
        
        # Reshape back to (batch_size, height, width, channels)
        attention_output = tf.reshape(attention_output, (batch_size, height, width, channels))
        return attention_output
    
    def get_config(self):
        """Serialization method to save and reload the layer."""
        config = super().get_config()
        config.update({
            "embed_dim": self.embed_dim,
            "num_heads": self.num_heads,
        })
        return config


def get_conditional_network(image_size, widths, block_depth, text_embedding_dim, embedding_dims=32):
    noisy_images = keras.Input(shape=(image_size, image_size, 3))
    noise_variances = keras.Input(shape=(1, 1, 1))
    text_conditions = keras.Input(shape=(text_embedding_dim,))  # CLIP embeddings

    # Sinusoidal embedding for noise variances
    e = layers.Lambda(lambda nv: sinusoidal_embedding(nv, embedding_dims))(noise_variances)
    # Broadcast to match image dimensions
    e = tf.tile(e, [1, image_size, image_size, 1])

    # Initial processing
    x = layers.Conv2D(widths[0], kernel_size=1)(noisy_images)
    x = layers.Concatenate()([x, e])

    skips = []
    # Encoder
    for width in widths[:-1]:
        for _ in range(block_depth):
            x = ResidualBlock(width, context=text_conditions)(x)
        skips.append(x)
        x = layers.AveragePooling2D()(x)

    # Bottleneck
    for _ in range(block_depth):
        x = ResidualBlock(widths[-1], context=text_conditions)(x)

    # Decoder
    for width in reversed(widths[:-1]):
        x = layers.UpSampling2D()(x)
        x = layers.Concatenate()([x, skips.pop()])
        for _ in range(block_depth):
            x = ResidualBlock(width, context=text_conditions)(x)

    # Output layer
    x = layers.Conv2D(3, kernel_size=1, kernel_initializer="zeros")(x)

    return keras.Model(
        inputs=[noisy_images, noise_variances, text_conditions],
        outputs=x,
        name="conditional_residual_unet",
    )


In [13]:
class ConditionalDiffusionModel(keras.Model):
    def __init__(self, image_size, widths, block_depth, text_embedding_dim):
        super().__init__()
        
        self.text_embedding_dim = text_embedding_dim
        self.normalizer = layers.Normalization(axis=-1)
        self.network = get_conditional_network(image_size, widths, block_depth, text_embedding_dim)
        self.ema_network = keras.models.clone_model(self.network)

    def compile(self, **kwargs):
        super().compile(**kwargs)
        
        self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")
        self.image_loss_tracker = keras.metrics.Mean(name="i_loss")
        self.kid = KID(name="kid")

    @property
    def metrics(self):
        return [self.noise_loss_tracker, self.image_loss_tracker, self.kid]

    def denormalize(self, images):
        images = self.normalizer.mean + images * self.normalizer.variance**0.5
        return tf.clip_by_value(images, 0.0, 1.0)

    def diffusion_schedule(self, diffusion_times):
        start_angle = tf.acos(max_signal_rate)
        end_angle = tf.acos(min_signal_rate)
        diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)
        signal_rates = tf.cos(diffusion_angles) # α開根號
        noise_rates = tf.sin(diffusion_angles) # (1-α)開根號
        return noise_rates, signal_rates

    def denoise(self, noisy_images, noise_rates, signal_rates, conditions, training):
        model = self.network if training else self.ema_network
        # Predict noise using the network with conditions
        pred_noises = model([noisy_images, noise_rates ** 2, conditions], training=training)
        pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates
        return pred_noises, pred_images

    def reverse_diffusion(self, initial_noise, diffusion_steps, conditions):
        num_images = initial_noise.shape[0]
        step_size = 1.0 / diffusion_steps
        next_noisy_images = initial_noise
        # 接著一步一步去噪
        for step in range(diffusion_steps):
            noisy_images = next_noisy_images
            diffusion_times = tf.ones((num_images, 1, 1, 1)) - step * step_size
            noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
            pred_noises, pred_images = self.denoise(noisy_images, noise_rates, signal_rates, conditions, training=False)
            next_diffusion_times = diffusion_times - step_size
            next_noise_rates, next_signal_rates = self.diffusion_schedule(next_diffusion_times)
            next_noisy_images = (next_signal_rates * pred_images + next_noise_rates * pred_noises)
        return pred_images

    def generate(self, num_images, diffusion_steps, conditions):
        initial_noise = tf.random.normal(shape=(num_images, image_size, image_size, 3))
        generated_images = self.reverse_diffusion(initial_noise, diffusion_steps, conditions)
        generated_images = self.denormalize(generated_images)
        return generated_images

    def train_step(self, data):
        images, labels = data
        # print(images[0].shape, labels[0].shape)
        images = self.normalizer(images, training=True)
        noises = tf.random.normal(shape=(batch_size, image_size, image_size, 3))
        
        # Generate class embeddings for conditions
        conditions = labels

        diffusion_times = tf.random.uniform(shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0)
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
        noisy_images = signal_rates * images + noise_rates * noises

        with tf.GradientTape() as tape:
            pred_noises, pred_images = self.denoise(
                noisy_images, noise_rates, signal_rates, conditions, training=True
            )
            noise_loss = self.loss(noises, pred_noises)
            image_loss = self.loss(images, pred_images)

        gradients = tape.gradient(noise_loss, self.network.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))

        self.noise_loss_tracker.update_state(noise_loss)
        self.image_loss_tracker.update_state(image_loss)

        for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
            ema_weight.assign(ema * ema_weight + (1 - ema) * weight)

        return {m.name: m.result() for m in self.metrics[:-1]}

    def test_step(self, data):
        images, labels = data

        images = self.normalizer(images, training=False)

        noises = tf.random.normal(shape=(batch_size, image_size, image_size, 3))
        conditions = labels

        diffusion_times = tf.random.uniform(shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0)
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
        noisy_images = signal_rates * images + noise_rates * noises

        pred_noises, pred_images = self.denoise(
            noisy_images, noise_rates, signal_rates, conditions, training=True
        )
        noise_loss = self.loss(noises, pred_noises)
        image_loss = self.loss(images, pred_images)

        self.image_loss_tracker.update_state(image_loss)
        self.noise_loss_tracker.update_state(noise_loss)

        images = self.denormalize(images)
        generated_images = self.generate(batch_size, kid_diffusion_steps, conditions) 
        self.kid.update_state(images, generated_images)

        return {m.name: m.result() for m in self.metrics}
    
    def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, val_dataset=val_dataset):
        # Randomly select a batch from the validation dataset
        for val_images, val_labels in val_dataset.take(1):
            conditions = val_labels  # Use the labels from the validation set as conditions
            break  # Take only the first batch

        # Select a random subset of conditions
        num_images = num_rows * num_cols
        indices = tf.random.shuffle(tf.range(tf.shape(conditions)[0]))[:num_images]
        random_conditions = tf.gather(conditions, indices)

        # Generate images using the random conditions
        generated_images = self.generate(
            num_images=num_images,
            diffusion_steps=plot_diffusion_steps,
            conditions=random_conditions,
        )


        # Plot the generated images
        plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
        for row in range(num_rows):
            for col in range(num_cols):
                index = row * num_cols + col
                plt.subplot(num_rows, num_cols, index + 1)
                plt.imshow(generated_images[index])
                plt.axis("off")
        plt.tight_layout()
        plt.show()
        plt.close()

In [14]:
model = ConditionalDiffusionModel(image_size, widths, block_depth, 512)

model.compile(
    optimizer=tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    ),
    loss=keras.losses.mean_absolute_error,
)
# pixelwise mean absolute error is used as loss

In [15]:
# save the best model based on the validation KID metric
checkpoint_path = f"checkpoints/DDIM/tf_checkpoint"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    monitor="val_kid",
    mode="min",
    save_best_only=True,
)

In [16]:
# calculate mean and variance of training dataset for normalization
# model.normalizer.adapt(train_dataset)
batched_dataset = train_dataset.map(lambda x, y: x).prefetch(tf.data.AUTOTUNE)

# Adapt the normalization layer using the batched dataset
model.normalizer.adapt(batched_dataset)


# try:
#     model.load_weights(checkpoint_path)
#     print("Checkpoint loaded successfully!")
# except Exception as e:
#     print(f"Failed to load checkpoint: {e}")

In [None]:
def plot_images_every_2_epochs(epoch, logs):
    # Only plot images if the epoch is a multiple of 5
    if (epoch + 1) % 2 == 0:
        model.plot_images(epoch=epoch, logs=logs, val_dataset=val_dataset)

# run training and plot generated images periodically
model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=num_epochs,
    callbacks=[
        keras.callbacks.LambdaCallback(on_epoch_end=plot_images_every_2_epochs),
        # keras.callbacks.LambdaCallback(on_epoch_start=model.plot_images(val_dataset=val_dataset) ),
        checkpoint_callback,
    ],
)

Epoch 1/50


2024-12-05 11:53:30.477167: I tensorflow/stream_executor/cuda/cuda_blas.cc:1614] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2024-12-05 11:53:30.496617: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8907

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.
2024-12-05 11:53:30.538162: W tensorflow/stream_executor/gpu/asm_compiler.cc:230] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 8.9
2024-12-05 11:53:30.538170: W tensorflow/stream_executor/gpu/asm_compiler.cc:233] Used ptxas at ptxas
2024-12-05 11:53:30.538224: W tensorflow/stream_executor/gpu/redzone_allocator.cc:314] UNIMPLEMENTED: ptxas ptxas too old. Falling back to the driver to compile.
Relying on driver to perform ptx compilation. 
Modify $PATH to customize ptxas location.
This message will be only logged once.


In [None]:
for images, labels in val_dataset.take(1):
    conditions = labels  # Use the labels from the validation set as conditions
    model.test_step((images, labels))
    model.plot_images(val_dataset=val_dataset)
    break  # Take only the first+-