# Image Retreival using pre-trained Models 

## Dataset: cifar 10
## Model: ResNet v2, cifar10-vit

## Experiments
- 기존 cifar10 dataset으로 학습한 image classification model load 
- feature 추출 모델 구성
- feature extraction
- Similarity 구하기
- retrieval test 


In [None]:
import tensorflow as tf
from tensorflow import keras

from keras.models import Model
from keras.datasets import cifar10

import numpy as np
import os
import math
from datetime import datetime 

In [None]:
# for display image and plot
from PIL import Image
import matplotlib.pyplot as plt

## Hyperparameters

In [None]:
num_classes = 10 # cifar10 classes : fixed

In [None]:
# 1. ResNetv2 model 
depth = 20
version=2
# Model name, depth and version
model_type = 'ResNet%dv%d' % (depth, version)
model_file = 'cifar10_%s' % model_type
output_feature_layer_name = 'flatten'

# Subtracting pixel mean improves accuracy
subtract_pixel_mean = True

In [None]:
# 2. vit model
model_type='vit_b16'
model_file = 'cifar10_%s' % model_type
input_shape = (32, 32, 3) #cifar10 image size
image_size = 224 # 224 # 256 #size after resizing image

subtract_pixel_mean = False
output_feature_layer_name = 'feature' # dense

## Dataset Preparation
- Load cifar10 dataset
- Normalize input(x) data
- Output encoding to one-hot vector

In [None]:
# Load the CIFAR10 data.
(x_train_data, y_train_data), (x_test_data, y_test_data) = cifar10.load_data()

# Input image dimensions.
input_shape = x_train_data.shape[1:]

# Normalize data.
x_train = x_train_data.astype('float32') / 255
x_test = x_test_data.astype('float32') / 255

In [None]:
# If subtract pixel mean is enabled (ResNet case)
if subtract_pixel_mean:
    x_train_mean = np.mean(x_train, axis=0)
    x_train -= x_train_mean
    x_test -= x_train_mean

# Convert class vectors to binary class matrices.
y_train = keras.utils.to_categorical(y_train_data, num_classes)
y_test = keras.utils.to_categorical(y_test_data, num_classes)

print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
print('y_train shape:', y_train.shape)

In [None]:
!ls cifar10*

In [None]:
model_file

In [None]:
# Note: kill other kernels(jobs)
reloaded_model = keras.models.load_model(model_file)
reloaded_model.summary()

In [None]:
model_arch_png = 'model_cifar10_%s_loaded.png' % model_type
keras.utils.plot_model(reloaded_model, to_file=model_arch_png, show_shapes=True )


In [None]:
# retrieval model
base_model = reloaded_model
r_model = Model(inputs=base_model.input, outputs=base_model.get_layer(output_feature_layer_name).output)

r_model.summary()

In [None]:
model_arch_png = 'model_cifar10_%s_feature.png' % model_type
keras.utils.plot_model(r_model, to_file=model_arch_png, show_shapes=True )

## cifar10 test dataset에대한 features 분석

In [None]:
# feature 추출하는 model을 이용해서 입력 data (x_dataset)에 대한 feature vector 를 추출한다.
# 입력 data: x_dataset에서 start index에서 num 개의 data 
# return: predictions(feature vectors), out_x_dataset, out_y_labels
def extract_features(model, x_dataset, y_label, start, num):
    if start < 0: 
        start = 0
    end = -1    
    if start + num > len(x_dataset)-1:
        end = len(x_dataset) - start-1 # -1
    else:
        end = start + num
    # Retrieve a number of images from the dataset.
    data_batch = x_dataset[start:end]

    # Get predictions from model.  
    predictions = model.predict(data_batch) # features

    out_x_dataset = data_batch
    out_y_labels = y_label[start:end]
    return predictions, out_x_dataset, out_y_labels

In [None]:
cifar10_test_features, feature_x_data, feature_y_labels = extract_features(r_model, x_test, y_test, 0, 1000)

In [None]:
cifar10_test_features.shape

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
sim = cosine_similarity(cifar10_test_features)

In [None]:
sim.shape, sim[0]

In [None]:
# check similarity values
idx1=0
idx2=17
sim[idx1][idx2], np.dot(cifar10_test_features[idx1], cifar10_test_features[idx2])/(np.linalg.norm(cifar10_test_features[idx1])*np.linalg.norm(cifar10_test_features[idx2]))

In [None]:
# sort : ascending(default) or descending
score_ind = np.argsort(sim)# [::,-1] 

In [None]:
# given similarity with the corresponding dataset & labels, when querying an image of a query index, show the most similar top-k images 
# - the most similar top-k images
# cls_model is used to provide prediction results of the base model
def show_similar_images(similarity, query_index, top_k, x_dataset, y_label, cls_model):

    class_names = ['airplane',
                   'automobile',
                   'bird',
                   'cat',
                   'deer',
                   'dog',
                   'frog',
                   'horse',
                   'ship',
                   'truck' ]
    # num_rows = 3
    num_cols = 7
    num_rows = (top_k // num_cols) + 2   
    #print(f'rows={num_rows}')
    if top_k > num_rows * num_cols:
        top_k = num_rows * num_cols
        
    # Retrieve a number of images from the dataset.
    data_batch = x_dataset

    # Get predictions from model.  
    predictions = cls_model.predict(data_batch)

    plt.figure(figsize=(20, 2*num_rows))
    num_matches = 0
        
    if subtract_pixel_mean:
        data_batch = x_dataset + x_train_mean  # add to range [0,1] for display , x_train_mean = np.mean(x_train, axis=0)
    else:
        data_batch = x_dataset

    score_idx = np.argsort(similarity[query_index])[::-1] # to change descending order
    score_idx = score_idx[:top_k] # select top-k
    
    # display query image in the top row
    ax = plt.subplot(num_rows, num_cols, 1)
    plt.axis("off")
    plt.imshow(data_batch[query_index])
    truth_idx = np.nonzero(y_label[query_index])
    title = f"{query_index}: {class_names[truth_idx[0][0]]}, {similarity[query_index][query_index]:.2f}"
    title_obj = plt.title(title, fontdict={'fontsize':13})
    plt.setp(title_obj, color='g')
        
    # display similar images
    # displayed text format: image number, ground truth class name: prected class name, similarity
    for i, idx in enumerate(score_idx):
        ax = plt.subplot(num_rows, num_cols, i + num_cols+1)
        plt.axis("off")
        plt.imshow(data_batch[idx])
        
        pred_idx = tf.argmax(predictions[idx]).numpy()
        truth_idx = np.nonzero(y_label[idx])
            
        title = f"{idx}: {class_names[truth_idx[0][0]]}:{class_names[pred_idx]}, {similarity[query_index][idx]:.2f}"
        title_obj = plt.title(title, fontdict={'fontsize':13})
            
        if pred_idx == truth_idx:
            num_matches += 1
            plt.setp(title_obj, color='g')
        else:
            plt.setp(title_obj, color='r')
                
        acc = num_matches/len(score_idx)
    print("Prediction accuracy: ", int(100*acc)/100)
    
    return


In [None]:
query_index = 10
top_k = 64
show_similar_images(sim, query_index, top_k, feature_x_data, feature_y_labels, base_model)