In [62]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, Input
from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape, Dense, multiply, Activation, LeakyReLU, Dropout, BatchNormalization, Conv2D, MaxPooling2D, Add
import math
from tensorflow.keras.models import Model
from tensorflow import keras
from tensorflow.keras import layers, models
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import time
from tensorflow.keras.optimizers import Adam
import random
from scipy.spatial.distance import cdist

In [63]:
train_imgs_path    = '../datasets/unbraid/img/train'
train_sketchs_path = '../datasets/unbraid/sketch/train'
val_imgs_path      = '../datasets/unbraid/img/test'
val_sketchs_path   = '../datasets/unbraid/sketch/test'

img_height = 128
img_width = 128
channels = 3
batch_size = 16 # Adjust as needed
input_shape = (img_height, img_width, channels)
latent_dim = 32  # Dimension for z_inv, mu, log_var each [cite: 147]

In [64]:
def get_id(filename):
    base = os.path.splitext(filename)[0]
    # Handle potential multiple underscores if necessary
    parts = base.split('_')
    if len(parts) > 1:
        return parts[-1]
    return base # Fallback if no underscore

# --- Modified load_data Function ---
def load_data(img_path, sketch_path, img_width, img_height): # Added img_width, img_height as args
    """
    Loads image-sketch pairs, normalizes them to [0, 1],
    and checks for/handles NaN/inf values.

    Args:
        img_path (str): Path to the image directory.
        sketch_path (str): Path to the sketch directory.
        img_width (int): Target width for resizing.
        img_height (int): Target height for resizing.

    Returns:
        tuple: (np.array(images), np.array(sketches), np.array(labels))
               Images/sketches are float32 arrays normalized to [0, 1].
    """
    img_files = [f for f in os.listdir(img_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    sketch_files = [f for f in os.listdir(sketch_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

    sketch_dict = { get_id(sf): sf for sf in sketch_files }

    images, sketches, labels = [], [], []
    skipped_count = 0

    print(f"Loading data from {img_path} and {sketch_path}...")
    for i, img_filename in enumerate(img_files):
        if i % 500 == 0 and i > 0: # Print progress occasionally
             print(f" Processing image {i}/{len(img_files)}")

        img_id = get_id(img_filename)
        if img_id in sketch_dict:
            full_img_path = os.path.join(img_path, img_filename)
            full_sketch_path = os.path.join(sketch_path, sketch_dict[img_id])

            try:
                # Load and process sketch
                sketch_pil = Image.open(full_sketch_path).convert('RGB')
                sketch_pil = sketch_pil.resize((img_width, img_height))
                # Normalize to [0, 1]
                sketch_arr = np.array(sketch_pil, dtype=np.float32) / 255.0

                # Load and process image
                img_pil = Image.open(full_img_path).convert('RGB')
                img_pil = img_pil.resize((img_width, img_height))
                # Normalize to [0, 1]
                img_arr = np.array(img_pil, dtype=np.float32) / 255.0

                # --- Check for NaN/inf and handle ---
                if np.isnan(sketch_arr).any() or np.isinf(sketch_arr).any():
                    warnings.warn(f"NaN or Inf found in sketch: {full_sketch_path}. Replacing with 0.")
                    sketch_arr = np.nan_to_num(sketch_arr, nan=0.0, posinf=0.0, neginf=0.0) # Replace NaN/inf with 0

                if np.isnan(img_arr).any() or np.isinf(img_arr).any():
                    warnings.warn(f"NaN or Inf found in image: {full_img_path}. Replacing with 0.")
                    img_arr = np.nan_to_num(img_arr, nan=0.0, posinf=0.0, neginf=0.0) # Replace NaN/inf with 0
                # --- End Check ---

                images.append(img_arr)
                sketches.append(sketch_arr)
                labels.append(img_id)

            except Exception as e:
                # Catch potential errors during image opening/processing
                print(f"Error processing pair: img={img_filename}, sketch={sketch_dict[img_id]}. Error: {e}")
                skipped_count += 1

    print(f"Data loading complete. Processed {len(labels)} pairs. Skipped {skipped_count} due to errors.")

    if not images: # Handle case where no valid pairs were found
        warnings.warn("No valid image-sketch pairs found.")
        return np.array([]), np.array([]), np.array([])

    return np.array(images, dtype=np.float32), np.array(sketches, dtype=np.float32), np.array(labels)


# Load training and validation data
train_images, train_sketches, train_labels = load_data(train_imgs_path, train_sketchs_path, img_width, img_height)
val_images, val_sketches, val_labels = load_data(val_imgs_path, val_sketchs_path, img_width, img_height)
print("Training samples:", train_images.shape)
print("Validation samples:", val_images.shape)

Loading data from ../datasets/unbraid/img/train and ../datasets/unbraid/sketch/train...
 Processing image 500/3000
 Processing image 1000/3000
 Processing image 1500/3000
 Processing image 2000/3000
 Processing image 2500/3000
Data loading complete. Processed 3000 pairs. Skipped 0 due to errors.
Loading data from ../datasets/unbraid/img/test and ../datasets/unbraid/sketch/test...
Data loading complete. Processed 466 pairs. Skipped 0 due to errors.
Training samples: (3000, 128, 128, 3)
Validation samples: (466, 128, 128, 3)


In [65]:
def normalize_to_tanh(image):
    return (image * 2.0) - 1.0

# --- Data Generator Function ---
def triplet_generator(sketches, images, labels):
    """
    Generator function to yield batches of (anchor_sketch, positive_photo, negative_photo)
    along with a target image for reconstruction.

    Args:
        sketches (np.array): Array of sketch images (normalized 0-1).
        images (np.array): Array of photo images (normalized 0-1).
        labels (np.array): Array of corresponding labels (IDs).
    """
    num_samples = sketches.shape[0]
    indices = np.arange(num_samples)

    # Create a mapping from label to list of indices for faster lookup (optional but good for large datasets)
    label_to_indices = {}
    for idx, label in enumerate(labels):
        if label not in label_to_indices:
            label_to_indices[label] = []
        label_to_indices[label].append(idx)

    while True: # Generators used with from_generator should loop indefinitely
        # Shuffle indices each epoch (optional, good practice)
        np.random.shuffle(indices)

        for i in indices:
            # --- Anchor and Positive ---
            anchor_sketch = sketches[i]
            positive_photo = images[i] # Direct correspondence from load_data
            anchor_label = labels[i]

            # --- Negative Sampling ---
            # Keep sampling until we find an image with a different label
            while True:
                # USE random.choice HERE
                negative_idx = random.choice(indices) # Now 'random' is defined
                if labels[negative_idx] != anchor_label:
                    negative_photo = images[negative_idx]
                    break

            # --- Normalization ---
            anchor_sketch = normalize_to_tanh(anchor_sketch)
            positive_photo = normalize_to_tanh(positive_photo)
            negative_photo = normalize_to_tanh(negative_photo)

            # Target for reconstruction is the anchor sketch itself
            target_image = anchor_sketch

            yield (anchor_sketch, positive_photo, negative_photo), target_image


# --- Create tf.data Datasets ---

# Assuming train_images, train_sketches, train_labels are loaded
# Assuming val_images, val_sketches, val_labels are loaded

print("Creating tf.data Datasets...")

# Training Dataset
train_dataset = tf.data.Dataset.from_generator(
    lambda: triplet_generator(train_sketches, train_images, train_labels),
    output_signature=(
        (tf.TensorSpec(shape=(img_height, img_width, channels), dtype=tf.float32), # Anchor Sketch
         tf.TensorSpec(shape=(img_height, img_width, channels), dtype=tf.float32), # Positive Photo
         tf.TensorSpec(shape=(img_height, img_width, channels), dtype=tf.float32)), # Negative Photo
        tf.TensorSpec(shape=(img_height, img_width, channels), dtype=tf.float32)  # Target Image
    )
)
# Apply batching and prefetching
train_dataset = train_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

# Validation Dataset (using the same generator structure)
val_dataset = tf.data.Dataset.from_generator(
    lambda: triplet_generator(val_sketches, val_images, val_labels),
     output_signature=(
        (tf.TensorSpec(shape=(img_height, img_width, channels), dtype=tf.float32), # Anchor Sketch
         tf.TensorSpec(shape=(img_height, img_width, channels), dtype=tf.float32), # Positive Photo
         tf.TensorSpec(shape=(img_height, img_width, channels), dtype=tf.float32)), # Negative Photo
        tf.TensorSpec(shape=(img_height, img_width, channels), dtype=tf.float32)  # Target Image
    )
)
# Apply batching and prefetching
val_dataset = val_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

print("Training and validation datasets created.")

Creating tf.data Datasets...
Training and validation datasets created.


In [66]:


# --- Encoder ---
# Using InceptionV3 as the base, as mentioned in the paper [cite: 145]
# We'll remove the top classification layer
base_encoder = keras.applications.InceptionV3(
    include_top=False,
    weights='imagenet', # Or None if you want to train from scratch
    input_shape=input_shape,
    pooling='avg' # Global average pooling to get a feature vector
)

# Freeze base encoder layers if using pre-trained weights initially
# for layer in base_encoder.layers:
#     layer.trainable = False

encoder_input = keras.Input(shape=input_shape, name="encoder_input")
x = base_encoder(encoder_input, training=False) # Set training=False if layers are frozen

# Projection layers to get invariant part and parameters for variable part
z_inv = layers.Dense(latent_dim, name="z_invariant")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)

# Instantiate the encoder model
encoder = keras.Model(encoder_input, [z_inv, z_mean, z_log_var], name="encoder")
print("--- Encoder Summary ---")
encoder.summary()

# --- Reparameterization Trick (Sampling Layer) ---
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding the variable part."""
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        # Use exp(0.5 * log_var) for stddev
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


--- Encoder Summary ---
Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      [(None, 128, 128, 3) 0                                            
__________________________________________________________________________________________________
inception_v3 (Functional)       (None, 2048)         21802784    encoder_input[0][0]              
__________________________________________________________________________________________________
z_invariant (Dense)             (None, 32)           65568       inception_v3[0][0]               
__________________________________________________________________________________________________
z_mean (Dense)                  (None, 32)           65568       inception_v3[0][0]               
____________________________________________________________________

In [67]:
latent_inputs = keras.Input(shape=(latent_dim,), name="latent_input")

# 2. Define Starting Spatial Dimension and Initial Channels
start_dim = 8  # Starting spatial dimension (e.g., 8x8)
# Ensure target dimensions are powers of 2 and larger than start_dim
if not (img_height == img_width and math.log2(img_height).is_integer() and img_height >= start_dim):
    raise ValueError(f"img_height ({img_height}) and img_width ({img_width}) must be equal,"
                     f" powers of 2, and >= start_dim ({start_dim}) for this structure.")

initial_channels = 128 # Number of channels after the initial dense layer

# 3. Initial Dense and Reshape layers
initial_dense_units = start_dim * start_dim * initial_channels
x = layers.Dense(initial_dense_units, activation="relu")(latent_inputs)
x = layers.Reshape((start_dim, start_dim, initial_channels))(x)

# 4. Calculate Number of Upsampling Layers needed
num_upsample_layers = int(np.log2(img_height / start_dim))
print(f"Target size: {img_height}x{img_height}, Start dim: {start_dim}x{start_dim}")
print(f"Calculated number of Conv2DTranspose layers: {num_upsample_layers}")

# 5. Build Upsampling Layers Dynamically
decoder_filters = [128, 64, 64, 32] # Example filter progression for 128x128 target

if len(decoder_filters) != num_upsample_layers:
     raise ValueError(f"Length of decoder_filters ({len(decoder_filters)}) "
                      f"must match num_upsample_layers ({num_upsample_layers})")

current_channels = initial_channels
for i in range(num_upsample_layers):
    layer_filters = decoder_filters[i]
    print(f" Adding Upsample Layer {i+1}/{num_upsample_layers} with {layer_filters} filters")
    x = layers.Conv2DTranspose(layer_filters, 3, strides=2, padding="same")(x)
    # Using default epsilon for BatchNorm here
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

# 6. Final layer
final_activation = "tanh"
decoder_outputs = layers.Conv2DTranspose(
    channels, 3, activation=final_activation, padding="same"
)(x)

# 7. Instantiate the decoder model
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
print(f"\n--- Decoder Summary (Input: {latent_dim}, Output: {img_height}x{img_width}x{channels}) ---")
decoder.summary()


Target size: 128x128, Start dim: 8x8
Calculated number of Conv2DTranspose layers: 4
 Adding Upsample Layer 1/4 with 128 filters
 Adding Upsample Layer 2/4 with 64 filters
 Adding Upsample Layer 3/4 with 64 filters
 Adding Upsample Layer 4/4 with 32 filters

--- Decoder Summary (Input: 32, Output: 128x128x3) ---
Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
latent_input (InputLayer)    [(None, 32)]              0         
_________________________________________________________________
dense_5 (Dense)              (None, 8192)              270336    
_________________________________________________________________
reshape_5 (Reshape)          (None, 8, 8, 128)         0         
_________________________________________________________________
conv2d_transpose_25 (Conv2DT (None, 16, 16, 128)       147584    
_________________________________________________________________
batch_norm

In [68]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.sampling = Sampling()
        # Use lists to store metrics for train and val separation if needed,
        # or rely on Keras naming convention (val_...)
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
        self.triplet_loss_tracker = keras.metrics.Mean(name="triplet_loss")

    @property
    def metrics(self):
        # These are the metrics reset at the start of each epoch
        # Keras automatically adds 'val_' prefix during validation
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
            self.triplet_loss_tracker,
        ]

    def call(self, inputs):
        """ Forward pass for inference """
        z_inv, z_mean, z_log_var = self.encoder(inputs)
        z_var = self.sampling([z_mean, z_log_var])
        # Element-wise sum as per paper [cite: 97]
        z_combined = z_inv + z_var
        reconstruction = self.decoder(z_combined)
        return reconstruction
      
    def train_step(self, data):
        # Assumes data is ((anchor_sketches, positive_photos, negative_photos), target_images)
        (anchor_sketches, positive_photos, negative_photos), target_images = data

        with tf.GradientTape() as tape:
            # --- Encoder ---
            z_inv_anchor, z_mean_anchor, z_log_var_anchor = self.encoder(anchor_sketches)
            z_var_anchor = self.sampling([z_mean_anchor, z_log_var_anchor])
            z_combined_anchor = z_inv_anchor + z_var_anchor

            # --- Decoder ---
            reconstruction = self.decoder(z_combined_anchor)

            # --- Reconstruction Loss (Mean) ---
            reconstruction_loss = tf.reduce_mean(
                 keras.losses.mean_squared_error(target_images, reconstruction)
            )

            # --- KL Loss ---
            exp_log_var = tf.exp(z_log_var_anchor)
            kl_loss = -0.5 * (1 + z_log_var_anchor - tf.square(z_mean_anchor) - exp_log_var)
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))

            # --- Triplet Loss ---
            z_inv_pos, z_mean_pos, z_log_var_pos = self.encoder(positive_photos)
            z_inv_neg, z_mean_neg, z_log_var_neg = self.encoder(negative_photos)
            z_var_pos = self.sampling([z_mean_pos, z_log_var_pos])
            z_combined_pos = z_inv_pos + z_var_pos
            z_var_neg = self.sampling([z_mean_neg, z_log_var_neg])
            z_combined_neg = z_inv_neg + z_var_neg
            triplet_loss = calculate_triplet_loss(
                z_inv_anchor, z_inv_pos, z_inv_neg,
                z_combined_anchor, z_combined_pos, z_combined_neg
            )

            # --- Total Loss ---
            lambda_1 = 0.001 # KL weight
            lambda_2 = 1.0   # Triplet weight
            # Ensure losses are finite before combining (optional safeguard)
            total_loss = reconstruction_loss + lambda_1 * kl_loss + lambda_2 * triplet_loss

        # --- Gradients & Updates ---
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # --- Update Metrics ---
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        self.triplet_loss_tracker.update_state(triplet_loss)

        return {m.name: m.result() for m in self.metrics}

    # --- ADD THIS METHOD ---
    def test_step(self, data):
        # Handles data from validation_data generator
        (anchor_sketches, positive_photos, negative_photos), target_images = data

        # --- Forward pass for evaluation ---
        z_inv_anchor, z_mean_anchor, z_log_var_anchor = self.encoder(anchor_sketches, training=False)
        z_var_anchor = self.sampling([z_mean_anchor, z_log_var_anchor])
        z_combined_anchor = z_inv_anchor + z_var_anchor
        reconstruction = self.decoder(z_combined_anchor, training=False)

        # --- Calculate validation losses ---
        reconstruction_loss_val = tf.reduce_mean(
            keras.losses.mean_squared_error(target_images, reconstruction)
        )

        exp_log_var_val = tf.exp(z_log_var_anchor)
        kl_loss_val = -0.5 * (1 + z_log_var_anchor - tf.square(z_mean_anchor) - exp_log_var_val)
        kl_loss_val = tf.reduce_mean(tf.reduce_sum(kl_loss_val, axis=1))

        # --- Calculate Triplet Loss for Validation ---
        # Encode positive and negative photos
        z_inv_pos, z_mean_pos, z_log_var_pos = self.encoder(positive_photos, training=False) # Use training=False
        z_inv_neg, z_mean_neg, z_log_var_neg = self.encoder(negative_photos, training=False) # Use training=False
        # Sample and combine for z_f
        z_var_pos = self.sampling([z_mean_pos, z_log_var_pos])
        z_combined_pos = z_inv_pos + z_var_pos
        z_var_neg = self.sampling([z_mean_neg, z_log_var_neg])
        z_combined_neg = z_inv_neg + z_var_neg
        # Calculate actual triplet loss
        triplet_loss_val = calculate_triplet_loss(
            z_inv_anchor, z_inv_pos, z_inv_neg,
            z_combined_anchor, z_combined_pos, z_combined_neg
            # Ensure margins here match training margins if desired for comparison
        )
        # --- End Triplet Calculation ---

        # Total validation loss
        lambda_1 = 0.001 # Use same weights as training
        lambda_2 = 1.0
        total_loss_val = reconstruction_loss_val + lambda_1 * kl_loss_val + lambda_2 * triplet_loss_val

        # Update Metrics
        self.total_loss_tracker.update_state(total_loss_val)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss_val)
        self.kl_loss_tracker.update_state(kl_loss_val)
        self.triplet_loss_tracker.update_state(triplet_loss_val) # Now tracks actual val triplet loss

        return {m.name: m.result() for m in self.metrics}

In [69]:
vae = VAE(encoder, decoder)

vae.compile(optimizer=keras.optimizers.Adam(learning_rate=0.0005)) # LR from paper [cite: 154]


In [70]:
# --- Instantiate the VAE ---

def calculate_triplet_loss(anchor_inv, positive_inv, negative_inv,
                           anchor_f, positive_f, negative_f,
                           margin_inv=0.5, margin_f=0.3):
    """
    Calculates the triplet loss for both invariant (z_inv) and
    combined (z_f) features.

    Args:
        anchor_inv, positive_inv, negative_inv: Batch of invariant features.
        anchor_f, positive_f, negative_f: Batch of combined features.
        margin_inv: Margin for the invariant feature triplet loss.
        margin_f: Margin for the combined feature triplet loss.

    Returns:
        Total triplet loss for the batch.
    """
    # Ensure inputs are tensors
    anchor_inv = tf.convert_to_tensor(anchor_inv, dtype=tf.float32)
    positive_inv = tf.convert_to_tensor(positive_inv, dtype=tf.float32)
    negative_inv = tf.convert_to_tensor(negative_inv, dtype=tf.float32)
    anchor_f = tf.convert_to_tensor(anchor_f, dtype=tf.float32)
    positive_f = tf.convert_to_tensor(positive_f, dtype=tf.float32)
    negative_f = tf.convert_to_tensor(negative_f, dtype=tf.float32)

    # Calculate squared Euclidean distances for invariant features
    dist_ap_inv = tf.reduce_sum(tf.square(anchor_inv - positive_inv), axis=-1)
    dist_an_inv = tf.reduce_sum(tf.square(anchor_inv - negative_inv), axis=-1)

    # Calculate triplet loss for invariant features
    loss_inv = tf.maximum(0.0, margin_inv + dist_ap_inv - dist_an_inv)
    loss_inv = tf.reduce_mean(loss_inv) # Average over the batch

    # Calculate squared Euclidean distances for combined features
    dist_ap_f = tf.reduce_sum(tf.square(anchor_f - positive_f), axis=-1)
    dist_an_f = tf.reduce_sum(tf.square(anchor_f - negative_f), axis=-1)

    # Calculate triplet loss for combined features
    loss_f = tf.maximum(0.0, margin_f + dist_ap_f - dist_an_f)
    loss_f = tf.reduce_mean(loss_f) # Average over the batch

    # Total triplet loss
    total_triplet_loss = loss_inv + loss_f
    return total_triplet_loss

print("\n VAE model created and compiled.")
print(" Note: Triplet loss calculation needs to be implemented and integrated into train_step.")
print("       Meta-learning components (FT Layers, Regulariser) are not included here.")



 VAE model created and compiled.
 Note: Triplet loss calculation needs to be implemented and integrated into train_step.
       Meta-learning components (FT Layers, Regulariser) are not included here.


In [None]:
try:
    # Use math.ceil to include the last partial batch
    steps_per_epoch = math.ceil(len(train_sketches) / batch_size)
    validation_steps = math.ceil(len(val_sketches) / batch_size)
    print(f"  Steps per epoch: {steps_per_epoch}")
    print(f"  Validation steps: {validation_steps}")

    print("\nStarting training...")
    epochs = 20 # Adjust as needed, maybe start lower (e.g., 5) to check stability first
    history = vae.fit(
        train_dataset,
        epochs=epochs,
        validation_data=val_dataset,
        steps_per_epoch=steps_per_epoch,     # Use calculated value
        validation_steps=validation_steps    # Use calculated value
        # Add callbacks here if needed (e.g., ModelCheckpoint, EarlyStopping)
        # callbacks=[...]
    )
    print("\nTraining complete.")

except NameError as e:
    print(f"Error: A required variable is not defined before calculating steps or starting training.")
    print(f"       Please ensure train_sketches, val_sketches, and batch_size are defined.")
    print(f"       Original error: {e}")
except Exception as e:
    print(f"An error occurred during training: {e}")

# You can now analyze the 'history' object
# print(history.history)

  Steps per epoch: 188
  Validation steps: 30

Starting training...
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20

In [None]:
def normalize_to_tanh(image):
    # Assumes input image is in [0, 1] range from load_data
    return (image * 2.0) - 1.0

# --- Revised Evaluation Function (No tqdm) ---
def evaluate_hit_at_k(encoder_model,
                       test_sketches_arr,   # NumPy array of test sketches
                       test_labels_sketch,  # NumPy array of test sketch labels
                       gallery_photos_arr,  # NumPy array of gallery photos
                       gallery_labels_photo,# NumPy array of gallery photo labels
                       k_values=[1, 5, 10],
                       batch_size_eval=32): # Optional batching for prediction
    print("Normalizing gallery photos for evaluation...")
    gallery_photos_norm = np.array([normalize_to_tanh(img) for img in gallery_photos_arr])

    print(f"Extracting features for {len(gallery_photos_norm)} gallery photos...")
    gallery_features = encoder_model.predict(gallery_photos_norm,
                                             batch_size=batch_size_eval,
                                             verbose=0)[0] # Index 0 corresponds to z_inv


    print("Normalizing test sketches for evaluation...")
    test_sketches_norm = np.array([normalize_to_tanh(img) for img in test_sketches_arr])

    print(f"Extracting features for {len(test_sketches_norm)} test sketches...")
    query_features_inv = encoder_model.predict(test_sketches_norm,
                                               batch_size=batch_size_eval,
                                               verbose=0)[0] # Index 0 corresponds to z_inv

    print("Calculating distances...")
    all_distances = cdist(query_features_inv, gallery_features, metric='sqeuclidean')

    print("Ranking and calculating Hit@k...")
    hits = {k: 0 for k in k_values}
    num_queries = len(test_sketches_norm)

    for i in range(num_queries):
        # Optional: Print progress manually if desired
        if (i + 1) % 100 == 0: # Print every 100 queries
            print(f"  Processing query {i+1}/{num_queries}")

        true_label = test_labels_sketch[i]
        distances = all_distances[i]
        ranked_indices = np.argsort(distances)

        for k in k_values:
            current_k = min(k, len(gallery_labels_photo))
            top_k_indices = ranked_indices[:current_k]
            top_k_labels = gallery_labels_photo[top_k_indices]
            if true_label in top_k_labels:
                hits[k] += 1
    # --- End change ---

    print("Calculation complete.") # Added completion message

    hit_results = {f"Hit@{k}": (hits[k] / num_queries) * 100.0 for k in k_values}
    return hit_results

# --- Example Usage (remains the same) ---

print("\nRunning evaluation...")
results = evaluate_hit_at_k(
    encoder,
    val_sketches, # Use loaded NumPy arrays
    val_labels,
    val_images,
    val_labels,
    k_values=[1, 5, 10],
    batch_size_eval=64
 )

print("\n--- Evaluation Results ---")
if results:
    for metric, value in results.items():
        print(f"{metric}: {value:.2f}%")
else:
    print("Evaluation did not produce results.")


Running evaluation...
Normalizing gallery photos for evaluation...
Extracting features for 466 gallery photos...
Normalizing test sketches for evaluation...
Extracting features for 466 test sketches...
Calculating distances...
Ranking and calculating Hit@k...
  Processing query 100/466
  Processing query 200/466
  Processing query 300/466
  Processing query 400/466
Calculation complete.

--- Evaluation Results ---
Hit@1: 1.50%
Hit@5: 4.94%
Hit@10: 8.37%
