# Image Retreival using pre-trained Models 

## Dataset: cifar 10
## Model: ResNet v2, cifar10-vit

## Experiments
- 기존 cifar10 dataset으로 학습한 image classification model load 
- feature 추출 모델 구성
- feature extraction
- Similarity 구하기
- retrieval test 


In [None]:
import tensorflow as tf
from tensorflow import keras

from keras.models import Model
from keras.datasets import cifar10

import numpy as np
import os
import math
from datetime import datetime 

In [None]:
# for display image and plot
from PIL import Image
import matplotlib.pyplot as plt

## Dataset Preparation
- Load cifar10 dataset
- Normalize input(x) data
- Output encoding to one-hot vector

In [None]:
# Load the CIFAR10 data.
(x_train_data, y_train_data), (x_test_data, y_test_data) = cifar10.load_data()

# Input image dimensions.
input_shape = x_train_data.shape[1:]

# Normalize data.
x_train = x_train_data.astype('float32') / 255
x_test = x_test_data.astype('float32') / 255

In [None]:
num_classes = 10 # cifar10 classes : fixed

# Convert class vectors to binary class matrices.
y_train = keras.utils.to_categorical(y_train_data, num_classes)
y_test = keras.utils.to_categorical(y_test_data, num_classes)

print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
print('y_train shape:', y_train.shape)

In [None]:
!ls cifar10_*

## Hyperparameters

In [None]:
def get_config(model_type, train_dataset):
    
    if (model_type == 'resnet'): 
        # 1. ResNetv2 model 
        depth = 20
        version=2
        # Model name, depth and version
        model_type = 'ResNet%dv%d' % (depth, version)
        model_file = 'cifar10_%s' % model_type
        output_feature_layer_name = 'feature' # 'flatten'

        # Subtracting pixel mean improves accuracy
        subtract_pixel_mean = True
    elif (model_type == 'vit'):
        # 2. vit model
        model_type='vit_b16'
        model_file = 'cifar10_%s' % model_type
        subtract_pixel_mean = False
        output_feature_layer_name = 'feature' # 'dense' or ('dense_#' like 'dense_3') 
                                              # check the layer name of the model
    else:
        # 2. vit model
        model_type='vit_b16'
        model_file = 'cifar10_%s' % model_type
        subtract_pixel_mean = False
        output_feature_layer_name = 'feature' # 'dense'

    config = {'model_type': model_type,
              'model_file': model_file,
              'output_feature_layer_name': output_feature_layer_name,
              'subtract_pixel_mean': subtract_pixel_mean
             }
              
    if config['subtract_pixel_mean']:
        config['x_train_mean'] = np.mean(train_dataset, axis=0)
    
    return config

In [None]:
config_resnet = get_config('resnet', x_train)
config_vit = get_config('vit', x_train)
input_shape = (32, 32, 3) #cifar10 image size
image_size = 224 # 224 # 256 #size after resizing image

In [None]:
def get_model(config):
    base_model = keras.models.load_model(config['model_file'])
    feature_extraction_model = Model(inputs=base_model.input, 
                                     outputs=base_model.get_layer(
                                         config['output_feature_layer_name']).output
                                    )
    return feature_extraction_model, base_model

In [None]:
# resnet과 vit 모델을 load하고, feature model을 만든다.
resnet_feature_model, resnet_base_model = get_model(config_resnet)

In [None]:
vit_feature_model, vit_base_model = get_model(config_vit)

In [None]:
# 로드한 base model architecture를 확인한다.
resnet_base_model.summary()
model_arch_png = 'model_cifar10_%s_loaded.png' % config_resnet['model_type']
keras.utils.plot_model(resnet_base_model, to_file=model_arch_png, show_shapes=True )

In [None]:
# 로드한 feature model architecture를 확인한다.
resnet_feature_model.summary()
model_arch_png = 'model_cifar10_%s_feature.png' % config_resnet['model_type']
keras.utils.plot_model(resnet_feature_model, to_file=model_arch_png, show_shapes=True )

In [None]:
# 로드한 base model architecture를 확인한다.
vit_base_model.summary()
model_arch_png = 'model_cifar10_%s_loaded.png' % config_vit['model_type']
keras.utils.plot_model(vit_base_model, to_file=model_arch_png, show_shapes=True )

In [None]:
# 로드한 feature model architecture를 확인한다.
vit_feature_model.summary()
model_arch_png = 'model_cifar10_%s_feature.png' % config_vit['model_type']
keras.utils.plot_model(vit_feature_model, to_file=model_arch_png, show_shapes=True )

## cifar10 test dataset에대한 features 분석

In [None]:
# feature extraction 
# input: normalized input dataset(x_dataset [0,1])
# return :predictions (feature vectors), 
#         out_x_dataset(preprocessed dataset), out_y_labels
def extract_features(config, model, x_dataset, y_label, start, num):
    if start < 0: 
        start = 0
    end = -1    
    if start + num > len(x_dataset)-1:
        end = len(x_dataset) - start-1 # -1
    else:
        end = start + num
    # Retrieve a number of images from the dataset.
    data_batch = x_dataset[start:end]

    # If subtract pixel mean is enabled (ResNet case)
    if config['subtract_pixel_mean']:
        data_batch = np.copy(x_dataset[start:end])
        data_batch -= config['x_train_mean']
        
    # Get predictions from model.  
    predictions = model.predict(data_batch) # features

    out_x_dataset = data_batch
    out_y_labels = y_label[start:end]
    
    return predictions, out_x_dataset, out_y_labels

In [None]:
resnet_test_features, resnet_x_data, resnet_y_labels = \
    extract_features (config_resnet, resnet_feature_model, x_test, y_test, 0, 1000)

In [None]:
vit_test_features, vit_x_data, vit_y_labels = \
    extract_features (config_vit, vit_feature_model, x_test, y_test, 0, 1000)

In [None]:
resnet_test_features.shape, vit_test_features.shape

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
resnet_sim = cosine_similarity(resnet_test_features)

In [None]:
vit_sim = cosine_similarity(vit_test_features)

In [None]:
resnet_sim.shape, resnet_sim[0]

In [None]:
# check similarity values
idx1=0
idx2=17
resnet_sim[idx1][idx2], \
np.dot(resnet_test_features[idx1], resnet_test_features[idx2])/ \
      (np.linalg.norm(resnet_test_features[idx1])\
       *np.linalg.norm(resnet_test_features[idx2]))

In [None]:
# sort : ascending(default) or descending
resnet_score_ind = np.argsort(resnet_sim)# [::,-1] 

In [None]:
vit_score_ind = np.argsort(vit_sim)# [::,-1] 

In [None]:
# given similarity with the corresponding dataset & labels, 
# when querying an image of a query index, show the most similar top-k images 
# - the most similar top-k images
# cls_model is used to provide prediction results of the base model
def show_similar_images(config, similarity, query_index, top_k, 
                        x_dataset, y_label, 
                        cls_model, do_preprocess=False):

    class_names = ['airplane',
                   'automobile',
                   'bird',
                   'cat',
                   'deer',
                   'dog',
                   'frog',
                   'horse',
                   'ship',
                   'truck' ]
    # num_rows = 3
    num_cols = 7
    num_rows = (top_k // num_cols) + 2   
    #print(f'rows={num_rows}')
    if top_k > num_rows * num_cols:
        top_k = num_rows * num_cols
        
    # Retrieve a number of images from the dataset.
    data_batch = x_dataset

    # Get predictions from model.  
    if (do_preprocess):
        if config['subtract_pixel_mean']:
            # data_batch = np.copy(x_dataset)
            data_batch -= config['x_train_mean']
    
    predictions = cls_model.predict(data_batch)

    plt.figure(figsize=(20, 2*num_rows))
    num_matches = 0
        
    if config['subtract_pixel_mean']:
        # add to range [0,1] for display , x_train_mean = np.mean(x_train, axis=0)
        data_batch = data_batch + config['x_train_mean'] 
    else:
        data_batch = data_batch

    score_idx = np.argsort(similarity[query_index])[::-1] # to change descending order
    score_idx = score_idx[:top_k] # select top-k
    
    # display query image in the top row
    ax = plt.subplot(num_rows, num_cols, 1)
    plt.axis("off")
    plt.imshow(data_batch[query_index])
    truth_idx = np.nonzero(y_label[query_index])
    title = f"{query_index}: {class_names[truth_idx[0][0]]}, {similarity[query_index][query_index]:.2f}"
    title_obj = plt.title(title, fontdict={'fontsize':13})
    plt.setp(title_obj, color='g')
        
    # display similar images
    # displayed text format: image number, ground truth class name: prected class name, similarity
    for i, idx in enumerate(score_idx):
        ax = plt.subplot(num_rows, num_cols, i + num_cols+1)
        plt.axis("off")
        plt.imshow(data_batch[idx])
        
        pred_idx = tf.argmax(predictions[idx]).numpy()
        truth_idx = np.nonzero(y_label[idx])
            
        title = f"{idx}: {class_names[truth_idx[0][0]]}:{class_names[pred_idx]}, {similarity[query_index][idx]:.2f}"
        title_obj = plt.title(title, fontdict={'fontsize':13})
            
        if pred_idx == truth_idx:
            num_matches += 1
            plt.setp(title_obj, color='g')
        else:
            plt.setp(title_obj, color='r')
                
        acc = num_matches/len(score_idx)
    print("Prediction accuracy: ", int(100*acc)/100)
    
    return


In [None]:
query_index = 10 # 10
top_k = 64

In [None]:
show_similar_images(config_resnet, resnet_sim, query_index, top_k, 
                    resnet_x_data, resnet_y_labels, resnet_base_model)

In [None]:
show_similar_images(config_vit, vit_sim, query_index, top_k, 
                    vit_x_data, vit_y_labels, vit_base_model)

In [None]:
!nvidia-smi