In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, Input
from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape, Dense, multiply, Activation, LeakyReLU, Dropout, BatchNormalization, Conv2D, MaxPooling2D, Add # Ensure all layers are imported
import math # Needed for ceiling division if desired, or use // for floor
from tensorflow.keras.models import Model
from tensorflow import keras
from tensorflow.keras import layers, models
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import time
from tensorflow.keras.optimizers import Adam

  import scipy


In [2]:
train_imgs_path    = '../datasets/unbraid/img/train'
train_sketchs_path = '../datasets/unbraid/sketch/train'
val_imgs_path      = '../datasets/unbraid/img/test'
val_sketchs_path   = '../datasets/unbraid/sketch/test'

img_height, img_width = 224, 224  # resize shape
channels = 3  # use 3 for RGB, or 1 for grayscale

def get_id(filename):
    base = os.path.splitext(filename)[0]
    return base.split('_')[-1]

def load_data(img_path, sketch_path):
    img_files = [f for f in os.listdir(img_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    sketch_files = [f for f in os.listdir(sketch_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    
    sketch_dict = { get_id(sf): sf for sf in sketch_files }
    
    images, sketches, labels = [], [], []
    
    for img_filename in img_files:
        img_id = get_id(img_filename)
        if img_id in sketch_dict:
            full_img_path = os.path.join(img_path, img_filename)
            full_sketch_path = os.path.join(sketch_path, sketch_dict[img_id])
            
            sketch_pil = Image.open(full_sketch_path).convert('RGB')  # 'L' for grayscale
            sketch_pil = sketch_pil.resize((img_width, img_height))
            sketch_arr = np.array(sketch_pil, dtype=np.float32) / 255.0

            img_pil = Image.open(full_img_path).convert('RGB')
            img_pil = img_pil.resize((img_width, img_height))
            img_arr = np.array(img_pil, dtype=np.float32) / 255.0
            
            images.append(img_arr)
            sketches.append(sketch_arr)
            labels.append(img_id)
    
    return np.array(images, dtype=np.float32), np.array(sketches, dtype=np.float32), np.array(labels)

# Load training and validation data
train_images, train_sketches, train_labels = load_data(train_imgs_path, train_sketchs_path)
val_images, val_sketches, val_labels = load_data(val_imgs_path, val_sketchs_path)
print("Training samples:", train_images.shape)
print("Validation samples:", val_images.shape)

Training samples: (3000, 224, 224, 3)
Validation samples: (466, 224, 224, 3)


In [3]:
def se_block(input_tensor, ratio=16):
    """Creates a Squeeze-and-Excitation block."""
    channel_axis = -1 # Assuming channels_last format
    filters = input_tensor.shape[channel_axis]
    # Handle cases where filters might be None during model building
    if filters is None:
        # Cannot apply SE block if channel dimension is unknown
        # In a functional model, this usually resolves, but as a fallback:
        return input_tensor
    # Ensure intermediate filters are at least 1
    intermediate_filters = max(1, filters // ratio)
    se_shape = (1, 1, filters)

    se = GlobalAveragePooling2D()(input_tensor)
    se = Reshape(se_shape)(se)
    # Use intermediate_filters calculation
    se = Dense(intermediate_filters, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
    se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)

    # Excitation: Multiply the original input tensor by the learned scaling factors
    x = multiply([input_tensor, se])
    return x

# --- Embedding Network with Filters Scaled by Embedding Dimension ---
def embedding_net_scaled(input_shape=(64, 64, 1), embedding_dim=128, dropout_rate=0.3, min_base_filters=32, model_name="scaled_cnn_encoder"):
    """
    Creates a CNN encoder where the number of filters in Conv layers
    scales based on the embedding_dim.
    """
    inputs = layers.Input(shape=input_shape)

    # --- Calculate Filter Sizes based on embedding_dim ---
    # Set base filters relative to embedding_dim, but ensure a minimum width
    # Example strategy: Base filters = embedding_dim / 4, minimum 32
    base_filters = max(min_base_filters, embedding_dim // 4)

    filters_0 = base_filters       # e.g., if embedding_dim=128, filters_0=32. If embedding_dim=256, filters_0=64.
    filters_1 = base_filters * 2   # e.g., 64 / 128
    filters_2 = base_filters * 4   # e.g., 128 / 256

    print(f"Using filter structure based on embedding_dim={embedding_dim}: {filters_0} -> {filters_1} -> {filters_2}")

    # --- Initial convolution block ---
    x = Conv2D(filters_0, (3,3), padding="same", use_bias=False, kernel_initializer='he_normal')(inputs)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.1)(x)
    x = Conv2D(filters_0, (3,3), padding="same", use_bias=False, kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.1)(x)
    x = MaxPooling2D((2,2))(x)  # Output shape: (32, 32, filters_0)

    # --- Residual block 1 with SE ---
    shortcut = x
    x = Conv2D(filters_1, (3,3), padding="same", use_bias=False, kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.1)(x)
    x = Conv2D(filters_1, (3,3), padding="same", use_bias=False, kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)

    # Add SE block here
    x = se_block(x, ratio=16) # Add Squeeze-and-Excitation

    # Adjust shortcut filters
    shortcut = Conv2D(filters_1, (1,1), padding="same", use_bias=False, kernel_initializer='he_normal')(shortcut)
    shortcut = BatchNormalization()(shortcut)

    x = Add()([x, shortcut])
    x = LeakyReLU(alpha=0.1)(x)
    x = MaxPooling2D((2,2))(x)  # Output shape: (16, 16, filters_1)

    # --- Residual block 2 with SE ---
    shortcut = x
    x = Conv2D(filters_2, (3,3), padding="same", use_bias=False, kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.1)(x)
    x = Conv2D(filters_2, (3,3), padding="same", use_bias=False, kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)

    # Add SE block here
    x = se_block(x, ratio=16) # Add Squeeze-and-Excitation

    # Adjust shortcut filters
    shortcut = Conv2D(filters_2, (1,1), padding="same", use_bias=False, kernel_initializer='he_normal')(shortcut)
    shortcut = BatchNormalization()(shortcut)

    x = Add()([x, shortcut])
    x = LeakyReLU(alpha=0.1)(x)
    x = MaxPooling2D((2,2))(x)  # Output shape: (8, 8, filters_2)

    # --- Pooling and Embedding Head ---
    x = GlobalAveragePooling2D()(x)
    x = Dropout(dropout_rate)(x) # Add Dropout before final dense layer
    # The final dense layer projects features down to the desired embedding_dim
    x = Dense(embedding_dim, use_bias=False, kernel_initializer='he_normal')(x)
    outputs = tf.math.l2_normalize(x, axis=1) # L2 normalize the final embedding

    return models.Model(inputs, outputs, name=model_name)

In [4]:
embedding_dim = 32 # <--- Change this to see filter sizes adapt
dropout_rate_for_encoder = 0.3
img_shape = (img_height, img_width, channels)
sketch_shape = (img_height, img_width, channels)

img_input = Input(img_shape, name="image_input")
sketch_input = Input(sketch_shape, name="sketch_input")

# Instantiate the SCALED encoders
# Pass the desired embedding_dim here
image_encoder = embedding_net_scaled(
    img_shape,
    embedding_dim=embedding_dim, # Pass the dimension
    dropout_rate=dropout_rate_for_encoder,
    model_name='scaled_image_encoder'
)
sketch_encoder = embedding_net_scaled(
    sketch_shape,
    embedding_dim=embedding_dim, # Pass the dimension
    dropout_rate=dropout_rate_for_encoder,
    model_name='scaled_sketch_encoder'
)
#img_emb, sketch_emb = image_encoder(img_input), sketch_encoder(sketch_input)

shared_encoder = embedding_net_scaled( # Or improved_embedding_network, or your original one
    input_shape=img_shape,             # Use the common shape
    embedding_dim=embedding_dim,
    dropout_rate=dropout_rate_for_encoder,
    model_name='shared_encoder'        # Give it a new name
)

img_emb, sketch_emb = shared_encoder(img_input), shared_encoder(sketch_input)


Using filter structure based on embedding_dim=32: 32 -> 64 -> 128
Using filter structure based on embedding_dim=32: 32 -> 64 -> 128
Using filter structure based on embedding_dim=32: 32 -> 64 -> 128


In [5]:
""# Define your cosine similarity loss (as before)
def cosine_similarity_loss(y_true, y_pred):
    img_emb_loss = y_pred[:, :embedding_dim]
    sketch_emb_loss = y_pred[:, embedding_dim:]
    cosine_sim = tf.reduce_sum(img_emb_loss * sketch_emb_loss, axis=1)
    return 1.0 - cosine_sim

combined_output = tf.keras.layers.concatenate([img_emb, sketch_emb], name="combined_embedding")

siamese_model = Model([img_input, sketch_input], combined_output, name="siamese_trainer")

# Consider adjusting learning rate, maybe using a scheduler or different optimizer
siamese_model.compile(optimizer=Adam(1e-4), loss=cosine_similarity_loss)

print("Improved Siamese Training Model Summary:")
siamese_model.summary()
# Dummy labels for compatibility with Keras fit()
dummy_labels_train = np.zeros((train_images.shape[0], 1)) # Correct shape
dummy_labels_val = np.zeros((val_images.shape[0], 1))     # Correct shapez

print("\n--- Starting Training ---")
siamese_model.fit(
    [train_images, train_sketches],
    dummy_labels_train,
    epochs=20, # Or your desired epochs
    batch_size=8, # <--- TRY REDUCING THIS VALUE
    validation_data=([val_images, val_sketches], dummy_labels_val)
)
print("--- Training Finished ---")

Improved Siamese Training Model Summary:
Model: "siamese_trainer"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
image_input (InputLayer)        [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
sketch_input (InputLayer)       [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
shared_encoder (Functional)     (None, 32)           306016      image_input[0][0]                
                                                                 sketch_input[0][0]               
__________________________________________________________________________________________________
combined_embedding (Concatenate (None, 64) 

In [None]:
print("\n--- Building Retrieval Database ---")
# 1. Generate Embeddings for the Image Database (using validation images here)
# Ensure encoders are not in training mode if using layers like BatchNormalization differently
# (though predict handles this automatically)
start_time = time.time()
image_embeddings = shared_encoder.predict(val_images, batch_size=32)
end_time = time.time()
print(f"Generated {len(image_embeddings)} image embeddings in {end_time - start_time:.2f} seconds.")

# 2. Generate Embeddings for the Sketch Queries (using validation sketches)
start_time = time.time()
sketch_embeddings = shared_encoder.predict(val_sketches, batch_size=32)
end_time = time.time()
print(f"Generated {len(sketch_embeddings)} sketch embeddings in {end_time - start_time:.2f} seconds.")

# Ensure embeddings are numpy arrays for easier manipulation if needed
if isinstance(image_embeddings, tf.Tensor):
    image_embeddings = image_embeddings.numpy()
if isinstance(sketch_embeddings, tf.Tensor):
    sketch_embeddings = sketch_embeddings.numpy()

# --- Hit Rate Analysis ---

print("\n--- Performing Retrieval Analysis ---")

def calculate_hit_rate(sketch_embeddings, image_embeddings, k_values):
    """
    Calculates Hit@K for given K values.
    Assumes sketch_embeddings[i] corresponds to image_embeddings[i].
    """
    num_queries = sketch_embeddings.shape[0]
    num_images = image_embeddings.shape[0]
    hits = {k: 0 for k in k_values}
    max_k = max(k_values)

    # Calculate all pairwise cosine similarities (dot product since embeddings are L2 normalized)
    # Shape: (num_queries, num_images)
    start_time = time.time()
    similarity_matrix = np.dot(sketch_embeddings, image_embeddings.T)
    end_time = time.time()
    print(f"Calculated similarity matrix ({num_queries}x{num_images}) in {end_time - start_time:.2f} seconds.")
    partition_indices = np.argpartition(-similarity_matrix, kth=max_k-1, axis=1)[:, :max_k]
    # Now, sort only the top K indices based on their actual similarity values
    top_k_indices = np.array([
        p_indices[np.argsort(-similarity_matrix[i, p_indices])]
        for i, p_indices in enumerate(partition_indices)
    ])


    # Evaluate hits for each query
    for i in range(num_queries):
        # The ground truth image index is 'i'
        ground_truth_index = i
        retrieved_indices = top_k_indices[i]

        # Check if the ground truth is within the top K results for each K
        for k in k_values:
            if ground_truth_index in retrieved_indices[:k]:
                hits[k] += 1

    # Calculate final hit rates
    hit_rates = {k: (hits[k] / num_queries) * 100 for k in k_values}
    return hit_rates

# Define the K values for Hit Rate calculation
k_values_to_check = [1, 5, 10]

# Calculate the hit rates
hit_rates = calculate_hit_rate(sketch_embeddings, image_embeddings, k_values_to_check)

print("\n--- Retrieval Performance ---")
for k, rate in hit_rates.items():
    print(f"Hit@{k}: {rate:.2f}%")

print("\n--- Example Retrieval for First Sketch ---")
# Retrieve top 10 images for the first sketch query
query_index = 0
query_sketch_embedding = sketch_embeddings[query_index:query_index+1] # Keep dimensions

# Calculate similarities for this single query
similarities = np.dot(query_sketch_embedding, image_embeddings.T)[0] # Result shape (1, num_images), take first row

# Get top 10 indices
top_10_indices = np.argsort(-similarities)[:10]

print(f"Query Sketch Index: {query_index}")
print(f"Ground Truth Image Index: {query_index}")
print(f"Top 10 Retrieved Image Indices: {top_10_indices}")
print(f"Similarities of Top 10: {similarities[top_10_indices]}")
# Check if the ground truth was retrieved in the top 10
if query_index in top_10_indices:
    position = np.where(top_10_indices == query_index)[0][0] + 1
    print(f"Correct image found at position {position}.")
else:
    print("Correct image not found in the top 10.")


--- Building Retrieval Database ---
Generated 466 image embeddings in 1.42 seconds.
Generated 466 sketch embeddings in 0.48 seconds.

--- Performing Retrieval Analysis ---
Calculated similarity matrix (466x466) in 0.00 seconds.

--- Retrieval Performance ---
Hit@1: 0.00%
Hit@5: 1.50%
Hit@10: 2.36%

--- Example Retrieval for First Sketch ---
Query Sketch Index: 0
Ground Truth Image Index: 0
Top 10 Retrieved Image Indices: [453 419  81 415 328 436 461 355  12 366]
Similarities of Top 10: [0.99994576 0.9999439  0.99993986 0.99993825 0.9999382  0.9999362
 0.9999359  0.9999355  0.9999341  0.9999323 ]
Correct image not found in the top 10.
