In [2]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Dense, GlobalAveragePooling2D, Lambda, MaxPooling2D
from tensorflow.keras.models import Model
from tensorflow import keras
from tensorflow.keras import layers, models
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import cv2
from PIL import Image
# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

  import scipy


In [3]:
train_imgs_path    = '../datasets/unbraid/img/train'
train_sketchs_path = '../datasets/unbraid/sketch/train'
val_imgs_path      = '../datasets/unbraid/img/test'
val_sketchs_path   = '../datasets/unbraid/sketch/test'

img_height, img_width = 64, 64  # resize shape
channels = 3  # use 3 for RGB, or 1 for grayscale

def get_id(filename):
    base = os.path.splitext(filename)[0]
    return base.split('_')[-1]

def load_data(img_path, sketch_path):
    img_files = [f for f in os.listdir(img_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    sketch_files = [f for f in os.listdir(sketch_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    
    sketch_dict = { get_id(sf): sf for sf in sketch_files }
    
    images, sketches, labels = [], [], []
    
    for img_filename in img_files:
        img_id = get_id(img_filename)
        if img_id in sketch_dict:
            full_img_path = os.path.join(img_path, img_filename)
            full_sketch_path = os.path.join(sketch_path, sketch_dict[img_id])
            
            sketch_pil = Image.open(full_sketch_path).convert('RGB')  # 'L' for grayscale
            sketch_pil = sketch_pil.resize((img_width, img_height))
            sketch_arr = np.array(sketch_pil, dtype=np.float32) / 255.0

            img_pil = Image.open(full_img_path).convert('RGB')
            img_pil = img_pil.resize((img_width, img_height))
            img_arr = np.array(img_pil, dtype=np.float32) / 255.0
            
            images.append(img_arr)
            sketches.append(sketch_arr)
            labels.append(img_id)
    
    return np.array(images, dtype=np.float32), np.array(sketches, dtype=np.float32), np.array(labels)

# Load training and validation data
train_images, train_sketches, train_labels = load_data(train_imgs_path, train_sketchs_path)
val_images, val_sketches, val_labels = load_data(val_imgs_path, val_sketchs_path)
print("Training samples:", train_images.shape)
print("Validation samples:", val_images.shape)

Training samples: (3000, 64, 64, 3)
Validation samples: (466, 64, 64, 3)


In [4]:

def embedding_network(input_shape=(64, 64, 1), embedding_dim=128, model_name="cnn_encoder"):
    inputs = layers.Input(shape=input_shape)
    
    # Initial convolution block
    x = layers.Conv2D(64, (3,3), padding="same", use_bias=False)(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = layers.Conv2D(64, (3,3), padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = layers.MaxPooling2D((2,2))(x)  # Output shape: (32, 32, 64)
    
    # Residual block 1
    shortcut = x
    x = layers.Conv2D(128, (3,3), padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = layers.Conv2D(128, (3,3), padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    # Adjust shortcut to match number of filters (128)
    shortcut = layers.Conv2D(128, (1,1), padding="same", use_bias=False)(shortcut)
    shortcut = layers.BatchNormalization()(shortcut)
    x = layers.Add()([x, shortcut])
    x = layers.Activation("relu")(x)
    x = layers.MaxPooling2D((2,2))(x)  # Output shape: (16, 16, 128)
    
    # Residual block 2
    shortcut = x
    x = layers.Conv2D(256, (3,3), padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = layers.Conv2D(256, (3,3), padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    # Adjust shortcut to match number of filters (256)
    shortcut = layers.Conv2D(256, (1,1), padding="same", use_bias=False)(shortcut)
    shortcut = layers.BatchNormalization()(shortcut)
    x = layers.Add()([x, shortcut])
    x = layers.Activation("relu")(x)
    x = layers.MaxPooling2D((2,2))(x)  # Output shape: (8, 8, 256)
    
    # Global average pooling to flatten spatial dimensions
    x = layers.GlobalAveragePooling2D()(x)
    
    # Final dense layer to produce the embedding vector
    x = layers.Dense(embedding_dim, use_bias=False)(x)
    outputs = tf.math.l2_normalize(x, axis=1)
    
    return models.Model(inputs, outputs, name=model_name)
img_shape = (img_height, img_width, channels)
sketch_shape = (img_height, img_width, channels)


image_encoder = embedding_network(img_shape, model_name='image_encoder')
sketch_encoder = embedding_network(sketch_shape, model_name='sketch_encoder')

In [5]:
from tensorflow.keras.optimizers import Adam

# Siamese-like model
img_input = Input(img_shape)
sketch_input = Input(img_shape)

img_emb = image_encoder(img_input)
sketch_emb = sketch_encoder(sketch_input)

# Cosine similarity-based loss
def cosine_similarity_loss(y_true, y_pred):
    img_emb, sketch_emb = y_pred[:, :128], y_pred[:, 128:]
    cosine_sim = tf.reduce_sum(img_emb * sketch_emb, axis=1)
    return 1 - cosine_sim  # minimize distance => maximize similarity

combined_output = tf.keras.layers.concatenate([img_emb, sketch_emb])

siamese_model = Model([img_input, sketch_input], combined_output)
siamese_model.compile(optimizer=Adam(1e-4), loss=cosine_similarity_loss)

# Dummy labels for compatibility
dummy_labels = np.zeros((train_images.shape[0], 1))

# Train the model
siamese_model.fit(
    [train_images, train_sketches],
    dummy_labels,
    epochs=20,
    batch_size=8,
    validation_data=([val_images, val_sketches], np.zeros((val_images.shape[0], 1)))
)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<tensorflow.python.keras.callbacks.History at 0x2dbb65e9370>

In [6]:
train_image_emb = image_encoder.predict(train_images, batch_size=32)
train_sketch_emb = sketch_encoder.predict(train_sketches, batch_size=32)

val_image_emb = image_encoder.predict(val_images, batch_size=32)
val_sketch_emb = sketch_encoder.predict(val_sketches, batch_size=32)

In [7]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Input

def retrieval_network(input_dim=128, output_dim=128, model_name="retrieval_model"):
    inputs = layers.Input(shape=(input_dim,))
    
    # Initial dense block
    x = layers.Dense(512, use_bias=False)(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = layers.Dropout(0.3)(x)

    # Residual Dense Block 1
    shortcut = layers.Dense(512, use_bias=False)(x)
    shortcut = layers.BatchNormalization()(shortcut)

    x = layers.Dense(512, use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = layers.Dropout(0.3)(x)

    x = layers.Dense(512, use_bias=False)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Add()([x, shortcut])
    x = layers.Activation("relu")(x)

    # Residual Dense Block 2
    shortcut = layers.Dense(256, use_bias=False)(x)
    shortcut = layers.BatchNormalization()(shortcut)

    x = layers.Dense(256, use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = layers.Dropout(0.3)(x)

    x = layers.Dense(256, use_bias=False)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Add()([x, shortcut])
    x = layers.Activation("relu")(x)

    # Final embedding projection
    outputs = layers.Dense(output_dim)(x)  # no activation

    return models.Model(inputs, outputs, name=model_name)
# Example usage:
retrieval_model = retrieval_network(input_dim=train_sketch_emb.shape[1],
                                    output_dim=train_image_emb.shape[1])

retrieval_model.compile(optimizer=tf.keras.optimizers.Adam(1e-4), loss='mse')

In [8]:
retrieval_model.fit(
    train_sketch_emb, train_image_emb,
    epochs=50,
    batch_size=16,
    validation_data=(val_sketch_emb, val_image_emb)
)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


<tensorflow.python.keras.callbacks.History at 0x2dd672a3d60>

In [9]:
# Predict image embeddings from sketch embeddings
predicted_img_emb = retrieval_model.predict(val_sketch_emb)

# Now find closest embeddings in gallery using efficient distance calculation:
from sklearn.metrics.pairwise import euclidean_distances

distances = euclidean_distances(predicted_img_emb, val_image_emb)
top_k = 5  # choose top-5 or any K you prefer

# Get top-K predictions for each sketch
top_k_indices = np.argsort(distances, axis=1)[:, :top_k]

# Calculate recall metrics
correct_matches = np.arange(len(val_sketch_emb)).reshape(-1,1)
recall_at_k = np.mean([correct_matches[i] in top_k_indices[i] for i in range(len(val_sketch_emb))])

print(f"Recall@{top_k}: {recall_at_k:.4f}")

Recall@5: 0.0086


In [None]:
image_encoder.save('models/image_encoder.keras')
sketch_encoder.save('models/sketch_encoder.keras')
retrieval_model.save('models/retrieval_model.keras')



In [12]:
n_samples_fit = val_image_emb.shape[0]
n_neighbors = min(1250, n_samples_fit)

nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(val_image_emb)

count_recall = np.zeros(n_neighbors, dtype=int)
num_queries = val_sketches.shape[0]

# Compute embeddings for sketches (queries)
query_embeddings = sketch_encoder.predict(val_sketches, batch_size=32)

for idx, query_emb in enumerate(query_embeddings):
    distances, indices = nbrs.kneighbors([query_emb])
    correct_idx = idx  # assuming data is matched by index
    
    rank_found = np.where(indices[0] == correct_idx)[0]
    if len(rank_found) > 0:
        rank = rank_found[0]
        count_recall[rank] += 1
        print(f"Query {idx}: Found correct match at rank {rank+1}")
    else:
        print(f"Query {idx}: Not found in top {n_neighbors}.")

# Calculate cumulative recall
cumulative_recall = np.cumsum(count_recall)

# Safely computing recall metrics with checks:
recall_at_1 = cumulative_recall[0] / num_queries if n_neighbors >= 1 else 0
recall_at_5 = cumulative_recall[4] / num_queries if n_neighbors >= 5 else 0
recall_at_10 = cumulative_recall[9] / num_queries if n_neighbors >= 10 else 0

print(f"\nRecall@1: {recall_at_1:.4f}")
print(f"Recall@5: {recall_at_5:.4f}")
print(f"Recall@10: {recall_at_10:.4f}")


Query 0: Found correct match at rank 276
Query 1: Found correct match at rank 439
Query 2: Found correct match at rank 431
Query 3: Found correct match at rank 166
Query 4: Found correct match at rank 234
Query 5: Found correct match at rank 62
Query 6: Found correct match at rank 67
Query 7: Found correct match at rank 291
Query 8: Found correct match at rank 74
Query 9: Found correct match at rank 422
Query 10: Found correct match at rank 317
Query 11: Found correct match at rank 262
Query 12: Found correct match at rank 461
Query 13: Found correct match at rank 161
Query 14: Found correct match at rank 146
Query 15: Found correct match at rank 261
Query 16: Found correct match at rank 341
Query 17: Found correct match at rank 220
Query 18: Found correct match at rank 399
Query 19: Found correct match at rank 348
Query 20: Found correct match at rank 8
Query 21: Found correct match at rank 146
Query 22: Found correct match at rank 363
Query 23: Found correct match at rank 455
Query 2