In [1]:
!pip install faiss-cpu

Collecting faiss-cpu
  Downloading faiss_cpu-1.10.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (4.4 kB)
Downloading faiss_cpu-1.10.0-cp310-cp310-manylinux_2_28_x86_64.whl (30.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.7/30.7 MB[0m [31m54.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.10.0


In [2]:
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB3
from tensorflow.keras.applications.efficientnet import preprocess_input
import faiss
import joblib
from tensorflow.keras.datasets import cifar10
import numpy as np
import cv2
import matplotlib.pyplot as plt
import time

N_CLUSTERS = 40
PCA_COMPONENTS = 256
BATCH_SIZE = 64
SEED = 149
batch_size  = BATCH_SIZE
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck']
subset_size = 1000

rng = np.random.default_rng(seed=42) 
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [None]:
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB3
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')

# Hyperparameters
IMG_SIZE = 224
BATCH_SIZE = 64
EPOCHS_HEAD = 3
EPOCHS_FINE_TUNE = 5

# Load CIFAR-10
print("Loading CIFAR-10 dataset...")
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# Resize and preprocess in a tf.data pipeline
def preprocess(image, label):
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    image = preprocess_input(image)
    return image, label

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(10000).map(preprocess).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = test_ds.map(preprocess).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# Load EfficientNetB3 base model
print("Loading EfficientNetB3 model...")
base_model = EfficientNetB3(weights='imagenet', include_top=False, input_tensor=Input(shape=(IMG_SIZE, IMG_SIZE, 3)))
base_model.trainable = False  # Freeze initially

# Classification head
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
outputs = Dense(10, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=outputs)

# Compile and train head
model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy'])
print("Training classification head...")
model.fit(train_ds, epochs=EPOCHS_HEAD, validation_data=test_ds)

# Fine-tune full model
base_model.trainable = True
model.compile(optimizer=Adam(learning_rate=1e-5), loss='categorical_crossentropy', metrics=['accuracy'])
print("Fine-tuning full EfficientNetB3 model...")
model.fit(train_ds, epochs=EPOCHS_FINE_TUNE, validation_data=test_ds)

# Evaluate
loss, acc = model.evaluate(test_ds)
print(f"Test Accuracy: {acc:.4f}")


Loading CIFAR-10 dataset...
Loading EfficientNetB3 model...
Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb3_notop.h5
[1m43941136/43941136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Training classification head...
Epoch 1/3
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m137s[0m 126ms/step - accuracy: 0.8196 - loss: 0.5556 - val_accuracy: 0.8775 - val_loss: 0.3489
Epoch 2/3
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 90ms/step - accuracy: 0.8869 - loss: 0.3252 - val_accuracy: 0.8882 - val_loss: 0.3297
Epoch 3/3
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 89ms/step - accuracy: 0.9011 - loss: 0.2815 - val_accuracy: 0.8935 - val_loss: 0.3202
Fine-tuning full EfficientNetB3 model...
Epoch 1/5
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m497s[0m 441ms/step - accuracy: 0.7158 - loss: 0.8605 - val_accuracy: 0.8891 - val_loss: 0.3427
Epoch 2/5
[1m118/782[0m [32

In [5]:
model.save("efficientnetb3_finetuned.keras") 
# from tensorflow.keras.models import load_model
# print("Loading fine-tuned EfficientNetB3 base model...")
# base_model = load_model('efficientnetb3_finetuned_base')

In [None]:
from tensorflow.keras.models import load_model
# Load the full model
model = load_model("efficientnetb3_finetuned.keras", compile=False)
model = Model(inputs=model.input, outputs=model.layers[-3].output)

In [3]:
# import tensorflow as tf
# from tensorflow.keras.models import load_model

# # Define a proper custom layer for casting operations
# class CastLayer(tf.keras.layers.Layer):
#     def __init__(self, **kwargs):
#         super(CastLayer, self).__init__(**kwargs)
    
#     def call(self, inputs):
#         return tf.cast(inputs, tf.float32)
    
#     def get_config(self):
#         return super().get_config()

# # Load the model with the custom layer
# custom_objects = {'Cast': CastLayer}
# base_model = load_model('efficientnetb3_finetuned_base.h5', custom_objects=custom_objects)


In [5]:
# # Save the model in SavedModel format
# base_model.save('efficientnetb3_model.keras')
# # Later, you can load it without custom objects
# # loaded_model = tf.keras.models.load_model('efficientnetb3_model.keras')


In [6]:
### print("Loading CIFAR-10 dataset...")
from tensorflow.keras.models import load_model, Model

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
y_train = y_train.flatten()
y_test = y_test.flatten()

# Load EfficientNetB0 model
print("Loading EfficientNetB0 model...")
model = model #EfficientNetB3(weights='imagenet', include_top=False, pooling='avg')

def extract_cnn_features(images, batch_size=BATCH_SIZE):
    """Extract CNN features from images using EfficientNetB0"""
    features = []
    
    for i in range(0, len(images), batch_size):
        batch = images[i:i+batch_size]
        # Resize images to 224x224 as required by EfficientNetB0
        batch_resized = np.array([cv2.resize(img, (224, 224)) for img in batch])
        # Preprocess images
        batch_preprocessed = preprocess_input(batch_resized)
        # Extract features
        batch_features = model.predict(batch_preprocessed, verbose=0)
        features.append(batch_features)
    
    return np.vstack(features)


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 0us/step
Loading EfficientNetB0 model...


In [7]:

class ImageRetriever:
    def __init__(self, n_clusters=N_CLUSTERS, pca_components=PCA_COMPONENTS):
        self.n_clusters = n_clusters
        self.pca_components = pca_components
        self.kmeans = None
        self.pca = None
        self.faiss_index = None
        self.features = None
        self.image_ids = None
        self.labels = None
        
    def fit(self, images, labels=None):
        """Build the retrieval model with KMeans clustering on CNN features"""
        # Extract CNN features
        print("Extracting CNN features...")

        features = []
        for i in range(0, len(images), batch_size):
            batch_images = images[i:i+batch_size]
            batch_descriptors = extract_cnn_features(batch_images, batch_size)
            features.extend(batch_descriptors)
        features = np.array(features)
        from tensorflow.keras import backend as K
        K.clear_session()
        # Apply PCA for dimensionality reduction
        print(f"Applying PCA with {self.pca_components} components...")
        self.pca = PCA(n_components=self.pca_components)
        reduced_features = self.pca.fit_transform(features)
        
        # Apply KMeans clustering
        print(f"Applying KMeans with {self.n_clusters} clusters...")
        self.kmeans = KMeans(n_clusters=self.n_clusters, random_state=SEED)
        clusters = self.kmeans.fit_predict(reduced_features)
        
        # Create histogram features
        print("Creating histogram features...")
        self.features = np.zeros((len(images), self.n_clusters))
        for i, cluster in enumerate(clusters):
            self.features[i, cluster] += 1
        
        # Normalize histograms
        row_sums = self.features.sum(axis=1)
        self.features = self.features / row_sums[:, np.newaxis]
        
        # Build FAISS index for fast similarity search
        print("Building FAISS index...")
        d = self.features.shape[1]  # Dimension of feature vectors
        self.faiss_index = faiss.IndexFlatL2(d)
        self.faiss_index.add(self.features.astype('float32'))
        
        # Store image IDs and labels
        self.image_ids = np.arange(len(images))
        self.labels = labels
        
        return self
    
    def add_to_index(self, new_images, new_labels=None):
        """Add new images to the existing index"""
        # Extract and process features
        new_features = extract_cnn_features(new_images)
        new_reduced = self.pca.transform(new_features)
        new_clusters = self.kmeans.predict(new_reduced)
        
        new_histograms = np.zeros((len(new_images), self.n_clusters))
        for i, cluster in enumerate(new_clusters):
            new_histograms[i, cluster] += 1
        
        # Normalize
        row_sums = new_histograms.sum(axis=1)
        new_histograms = new_histograms / row_sums[:, np.newaxis]
        self.faiss_index.add(new_histograms.astype('float32'))
        start_id = len(self.image_ids)
        new_ids = np.arange(start_id, start_id + len(new_images))
        
        self.image_ids = np.append(self.image_ids, new_ids)
        self.features = np.vstack([self.features, new_histograms])
        
        if new_labels is not None and self.labels is not None:
            self.labels = np.append(self.labels, new_labels)
        
        print(f"Added {len(new_images)} images to index. Total index size: {len(self.image_ids)}")
        return self
    
    def process_query(self, query_image):
        """Process a query image to get its feature histogram"""
        # Handle both single image and batch
        is_batch = len(query_image.shape) == 4
        query_images = query_image if is_batch else np.expand_dims(query_image, axis=0)
        
        # Extract features
        query_features = extract_cnn_features(query_images)
        query_reduced = self.pca.transform(query_features)
        query_clusters = self.kmeans.predict(query_reduced)
        
        # Create histogram
        query_hist = np.zeros((len(query_images), self.n_clusters))
        for i, cluster in enumerate(query_clusters):
            query_hist[i, cluster] += 1
        
        # Normalize
        row_sums = query_hist.sum(axis=1)
        query_hist = query_hist / row_sums[:, np.newaxis]
        
        return query_hist
    
    def query(self, query_image, top_k=5):
        """Query the index with an image and return top_k matches"""
        query_hist = self.process_query(query_image)
        
        # Search using FAISS
        distances, indices = self.faiss_index.search(
            query_hist.astype('float32'), top_k
        )
        
        # Map indices to original image IDs
        result_ids = [[int(self.image_ids[idx]) for idx in row] for row in indices]
        
        return result_ids, distances
    
    def plot_results(self, query_image, retrieved_ids, distances, all_images):
        """Plot query image and retrieval results"""
        top_k = len(retrieved_ids[0])
        fig, axes = plt.subplots(1, top_k + 1, figsize=(3 * (top_k + 1), 3))
        
        # Plot query image
        axes[0].imshow(query_image)
        axes[0].set_title("Query Image")
        axes[0].axis('off')
        
        # Plot retrieved images
        for i, (idx, dist) in enumerate(zip(retrieved_ids[0], distances[0])):
            img = all_images[idx]
            axes[i+1].imshow(img)
            
            title = f"Rank {i+1}\nDist: {dist:.4f}"
            if self.labels is not None:
                title += f"\nLabel: {self.labels[idx]}"
            
            axes[i+1].set_title(title)
            axes[i+1].axis('off')
        
        plt.tight_layout()
        plt.show()
    
    def save(self, filepath):
        """Save model using joblib"""
        # Serialize FAISS index
        faiss_bytes = faiss.serialize_index(self.faiss_index)
        
        # Prepare data to save
        data = {
            'n_clusters': self.n_clusters,
            'pca_components': self.pca_components,
            'pca': self.pca,
            'kmeans': self.kmeans,
            'features': self.features,
            'image_ids': self.image_ids,
            'labels': self.labels,
            'faiss_bytes': faiss_bytes
        }
        
        joblib.dump(data, filepath)
        print(f"Model saved to {filepath}")
    
    @classmethod
    def load(cls, filepath):
        """Load model from joblib file"""
        data = joblib.load(filepath)
        
        # Create instance
        instance = cls(n_clusters=data['n_clusters'], pca_components=data['pca_components'])
        
        # Load components
        instance.pca = data['pca']
        instance.kmeans = data['kmeans']
        instance.features = data['features']
        instance.image_ids = data['image_ids']
        instance.labels = data['labels']
        
        # Deserialize FAISS index
        instance.faiss_index = faiss.deserialize_index(data['faiss_bytes'])
        
        print(f"Model loaded from {filepath}")
        return instance

def analyze_cluster_distribution(retriever, images, labels, class_names):
    """Analyze cluster distributions and plot histograms with sample images"""
    # Extract features
    cnn_features = extract_cnn_features(images)
    
    # Apply PCA
    pca_features = retriever.pca.transform(cnn_features)
    
    # Get cluster assignments
    assignments = retriever.kmeans.predict(pca_features)
    
    unique_labels = np.unique(labels)
    cluster_frequency = {label: np.zeros(retriever.n_clusters) for label in unique_labels}
    
    class_images = {label: [] for label in unique_labels}
    class_assignments = {label: [] for label in unique_labels}
    
    for i, (image, label, assignment) in enumerate(zip(images, labels, assignments)):
        # Count frequencies
        cluster_frequency[label][assignment] += 1
        
        if len(class_images[label]) < 3:
            class_images[label].append(image)
            class_assignments[label].append(assignment)
    
    for label in unique_labels:
        plt.figure(figsize=(15, 8))
        
        # Plot histogram
        plt.subplot(2, 1, 1)
        plt.bar(range(retriever.n_clusters), cluster_frequency[label])
        plt.title(f"Cluster Frequency for Class: {class_names[label]}")
        plt.xlabel("Cluster Index")
        plt.ylabel("Frequency")
        plt.grid(True, alpha=0.3)
        
        # Plot sample images
        for i, (img, cluster_id) in enumerate(zip(class_images[label], class_assignments[label])):
            plt.subplot(2, 3, i+4)
            plt.imshow(img.astype(np.uint8))
            plt.title(f"Cluster {cluster_id}")
            plt.axis('off')
        
        plt.tight_layout()
        plt.savefig(f"cluster_histogram_{class_names[label]}.png")
        plt.show()

def inference_pipeline(query_image, model_path=None, retriever=None, all_images=None, top_k=5):
    if retriever is None and model_path is not None:
        retriever = ImageRetriever.load(model_path)
    
    if retriever is None:
        raise ValueError("Either retriever or model_path must be provided")
    
    result_ids, distances = retriever.query(query_image, top_k=top_k)
    if all_images is not None:
        retriever.plot_results(query_image, result_ids, distances, all_images)
    
    return result_ids, distances


In [8]:
subset_indices = rng.choice(len(x_train), 1000, replace=False)

In [9]:
retriever = ImageRetriever(n_clusters=N_CLUSTERS, pca_components=PCA_COMPONENTS)
retriever.fit(x_train[subset_indices], y_train[subset_indices])

retriever.save('image_retriever.joblib')

# Analyze cluster distribution
print("Analyzing cluster distribution...")
analyze_cluster_distribution(
    retriever, 
    x_train[:1000],
    y_train[:1000],
    cifar10_classes
)

# Query with a test image
query_idx = np.random.randint(0, len(x_test))
query_image = x_test[query_idx]

# Use inference pipeline to get and display results
all_images = np.vstack([x_train, x_test])
all_labels = np.hstack([y_train, y_test])
result_ids, distances = inference_pipeline(
    query_image,
    retriever=retriever,
    all_images=all_images
)

print(f"Query image label: {cifar10_classes[y_test[query_idx]]}")
print("Retrieved image labels:", [cifar10_classes[all_labels[idx]] for idx in result_ids[0]])

# Example: Add more images to the index
print(f"Original index size: {len(retriever.image_ids)}")
new_images = x_test[:100]  # Add first 100 test images
new_labels = y_test[:100]
retriever.add_to_index(new_images, new_labels)

Extracting CNN features...
Applying PCA with 256 components...


ValueError: Found array with dim 4. PCA expected <= 2.

In [None]:
def plot_retrieved_images(query_image, retrieved_images, query_label=None, retrieved_labels=None, distances=None):
    n_retrieved = len(retrieved_images)
    fig, axes = plt.subplots(1, n_retrieved + 1, figsize=(3 * (n_retrieved + 1), 3))
    
    # Plot query image
    axes[0].imshow(query_image.astype(np.uint8))
    title = "Query Image"
    if query_label is not None:
        title += f"\nLabel: {query_label}"
    axes[0].set_title(title)
    axes[0].axis('off')
    
    # Plot retrieved images
    for i in range(n_retrieved):
        axes[i+1].imshow(retrieved_images[i].astype(np.uint8))
        
        title = f"Rank {i+1}"
        
        # Add distance if available
        if distances is not None and i < len(distances):
            title += f"\nDist: {distances[i]:.4f}"
        
        # Add label and correctness if available
        if retrieved_labels is not None and i < len(retrieved_labels):
            title += f"\nLabel: {retrieved_labels[i]}"
            
            # Add correctness indicator
            if query_label is not None:
                is_correct = retrieved_labels[i] == query_label
                title += f"\nCorrect: {'✓' if is_correct else '✗'}"
        
        axes[i+1].set_title(title)
        axes[i+1].axis('off')
    
    plt.tight_layout()
    plt.savefig(f'retrieval_results_{int(time.time())}.png')
    plt.show()

def evaluate_precision(retriever, query_images, query_labels, k_values=[1, 5, 10], batch_size=32):
    precision_at_k = {k: [] for k in k_values}
    average_precision = []
    
    max_k = max(k_values)
    
    for i in range(len(query_images)):
        query_img = query_images[i]
        query_label = query_labels[i]
        
        # Get retrieval results
        result_ids, distances = retriever.query(query_img, top_k=max_k)
        result_ids = result_ids[0]  # First (and only) query in batch
        distances = distances[0]
        
        # Get labels of retrieved images
        retrieved_labels = [retriever.labels[id] for id in result_ids]
        
        # Calculate relevance (1 if match, 0 if not)
        relevance = [1 if label == query_label else 0 for label in retrieved_labels]
        
        # Calculate precision@k
        for k in k_values:
            if k <= len(relevance):
                precision_k = sum(relevance[:k]) / k
                precision_at_k[k].append(precision_k)
        
        # Calculate average precision
        if sum(relevance) > 0:
            ap = 0.0
            running_sum = 0
            for j in range(len(relevance)):
                if relevance[j] == 1:
                    running_sum += sum(relevance[:j+1]) / (j+1)
            ap = running_sum / sum(relevance)
            average_precision.append(ap)
    
    # Compile results
    results = {
        'mean_average_precision': np.mean(average_precision) if average_precision else 0
    }
    
    # Add precision@k
    for k in k_values:
        if precision_at_k[k]:
            results[f'precision@{k}'] = np.mean(precision_at_k[k])
    
    return results

def retrieve_and_visualize(retriever, query_image, all_images, all_labels=None, class_names=None, top_k=5):
    result_ids, distances = retriever.query(query_image, top_k=top_k)
    result_ids = result_ids[0] 
    distances = distances[0]
    
    # Get retrieved images
    retrieved_images = [all_images[id] for id in result_ids]
    
    # Get labels if available
    query_label = None
    retrieved_labels = None
    precision = None
    
    if all_labels is not None:
        # Find the query image label (assuming it's part of the test set)
        # This would typically be passed in directly, but we're estimating it here
        query_label = None
        if hasattr(query_image, 'shape') and len(query_image.shape) == 3:
            # Find the most similar image in all_images to identify the label
            for i, img in enumerate(all_images):
                if np.array_equal(query_image, img):
                    query_label = all_labels[i]
                    break

        retrieved_labels = [all_labels[id] for id in result_ids]
        if class_names is not None:
            if query_label is not None:
                query_label = class_names[query_label]
            retrieved_labels = [class_names[label] for label in retrieved_labels]

        if query_label is not None:
            relevant = sum(1 for label in retrieved_labels if label == query_label)
            precision = relevant / len(retrieved_labels)
    
    plot_retrieved_images(
        query_image,
        retrieved_images,
        query_label,
        retrieved_labels,
        distances
    )
    
    return {
        'result_ids': result_ids,
        'retrieved_labels': retrieved_labels,
        'distances': distances,
        'precision': precision
    }


In [None]:
print("Evaluating on test set...")
test_size = min(1000, len(x_test))  # Use a subset for faster evaluation
evaluation_results = evaluate_precision(
    retriever,
    x_test[:test_size],
    y_test[:test_size],
    k_values=[1, 5, 10]
)

print("\nEvaluation Results:")
for metric, value in evaluation_results.items():
    print(f"{metric}: {value:.4f}")

# Visualize results for a few random examples
print("\nVisualizing retrieval examples...")
all_images = np.vstack([x_train, x_test])
all_labels = np.hstack([y_train, y_test])

for i in range(3):  # Show 3 examples
    query_idx = np.random.randint(0, len(x_test))
    query_image = x_test[query_idx]
    query_label = y_test[query_idx]
    
    print(f"\nExample {i+1}:")
    print(f"Query class: {cifar10_classes[query_label]}")
    
    results = retrieve_and_visualize(
        retriever,
        query_image,
        all_images,
        all_labels,
        cifar10_classes,
        top_k=5
    )
    
    if results['precision'] is not None:
        print(f"Precision for this query: {results['precision']:.4f}")


In [None]:
analyze_cluster_distribution(
    retriever, 
    x_train[:1000],
    y_train[:1000],
    cifar10_classes
)

# inference

In [None]:
from tensorflow.keras.models import load_model
model = load_model("efficientnetb3_finetuned.keras", compile=False)
model = Model(inputs=model.input, outputs=model.layers[-3].output)

In [None]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
y_train = y_train.flatten()
y_test = y_test.flatten()

In [None]:
r1 = ImageRetriever.load("/kaggle/working/image_retriever.joblib")
inference_pipeline(x_test[10],retriever=r1)

In [None]:
test_size = min(1000, len(x_test))  # Use a subset for faster evaluation
evaluation_results = evaluate_precision(
    r1,
    x_test[:test_size],
    y_test[:test_size],
    k_values=[1, 5, 10]
)
print("\nEvaluation Results:")
for metric, value in evaluation_results.items():
    print(f"{metric}: {value:.4f}")
