In [None]:
# Necessary imports - done
# ------------------------

import os
from PIL import Image
import time
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import random
import copy
import cv2
from os import listdir
from os.path import isfile, join
import h5py
import time
import multiprocessing
from multiprocessing import Pool
from multiprocessing.dummy import Pool as ThreadPool
import math
import sys
from tqdm import tqdm
import pandas as pd
import csv
import json
import imageio

import torch
from torch.autograd import Variable
import torch.autograd
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils
from torchvision.utils import save_image
import torch.optim as optim
from torchvision import datasets, models, transforms


###
###
%matplotlib inline
%env JOBLIB_TEMP_FOLDER=/tmp


# printing platform info
# ----------------------
import platform
print(platform.python_version())

# codes

In [None]:
def cosine_similarity_multi(a,b):
    
    '''
    
    1. a,b are of shape (m,no_latent)
    2. output is of shape (m,1)
    
    '''
    
    # 1. direct steps
    # ---------------
    dot_prod = np.sum(a*b, axis = 1)
    norm_a = np.sqrt(np.sum(np.square(a),axis = 1))
    norm_b = np.sqrt(np.sum(np.square(b),axis = 1))
    out = dot_prod/(norm_a*norm_b)
    
    
    # final return
    # -------------
    return out.reshape(out.shape[0],1)

In [None]:
# simple function
# ----------------
def pre_process_pretrained_model_data(x_in):
    
    
    # inits
    # -----
    assert torch.mean(x_in) > 1,'Error: data must be in [0,255] range.'
    x = copy.deepcopy(x_in)
    
    # we are using a pretrained model that needs to in [-1,1] range
    # -------------------------------------------------------------
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
    std_dev = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)

    # make data 3 channel
    # bring data to [0,1] first then apply normalise
    # ----------------------------------------------
    if x.size()[1] == 1:
        x = torch.cat((x,x,x), 1)
    elif x.size()[1] == 3:
        pass
    else:
        assert 1==2,'Error: number of channels in the input datamust be either 1 or 3.'

    # final ops
    # ---------
    x = x/torch.max(x)
    x = (x - mean)/std_dev
    
    # final return
    # ------------
    return x

In [None]:
# GENERIC - change an torch image to numpy image
# ----------------------------------------------
def to_numpy_image(xin):
    
    try:
        xin = xin.data.numpy()
    except:
        xin = xin.numpy()
    
    xout = np.swapaxes(xin,1,2)
    xout = np.swapaxes(xout,2,3)
    
    # returns axes swapped numpy images
    # ---------------------------------
    return xout       



# GENERIC - converts numpy images to torch tensors for training
# -------------------------------------------------------------
def setup_image_tensor(xin):
    xout = np.swapaxes(xin,1,3)
    xout = np.swapaxes(xout,2,3)
    
    # returns axes swapped torch tensor
    # ---------------------------------
    xout = torch.from_numpy(xout)
    return xout.float()

In [None]:
# a function to load a saved model
# --------------------------------

def load_saved_model_function(path, use_cuda):
    
    
    ''' path = /folder1/folder2/model_ae.tar format'''
    
    # 1. loading full model
    # ---------------------
    model = torch.load(path.replace('.tar','_MODEL.tar'))
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,model.parameters()))
    
    # 2. Applying state dict
    # ----------------------
    if use_cuda == True:
        
        # loads to GPU
        # ------------
        checkpoint = torch.load(path)
        
    else:
        # loads to CPU
        # ------------
        checkpoint = torch.load(path, map_location=lambda storage, loc: storage)
        
        
    # loading checkpoint
    # -------------------
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # loading optimizer
    # -----------------
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if use_cuda == True:
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
            
            
            
    # loading other stuff
    # -------------------
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    loss_mode = checkpoint['loss_mode']
    
    return model, optimizer, epoch, loss, loss_mode
    
    


In [None]:
# simple function to read a single image
# --------------------------------------

def create_dataset_from_folder_all(infolder,resize,gray_mode,n_h,n_w):
    
    # 0. global initialisations
    # -------------------------
    global index_list
    index_list = []
    
    global resize_flag
    resize_flag = resize
    
    global gb_in_folder
    gb_in_folder = infolder

    global counter
    counter = 0
    
    global new_h
    new_h = n_h
    
    global new_w
    new_w = n_w
    
    global image_list
    image_list_jpg = [f for f in listdir(infolder) if isfile(join(infolder, f)) and '.jpg' in f.lower()]
    image_list_png = [f for f in listdir(infolder) if isfile(join(infolder, f)) and '.png' in f.lower()]
    image_list = image_list_jpg + image_list_png

    global x_images_dataset
    x_images_dataset = np.zeros((len(image_list),new_h,new_w,3), dtype='uint8')
    
    global x_images_dataset_gray
    x_images_dataset_gray = np.zeros((len(image_list),new_h,new_w), dtype='uint8')
    
    global x_images_dataset_edge
    x_images_dataset_edge = np.zeros((len(image_list),new_h,new_w,1))
    
    
    # 1.1 sanity assertion
    # -------------------
    assert len(image_list) > 0, 'No images in the folder'
    
    
    # 2. calling resize function across multiprocessing pool
    # ------------------------------------------------------
    pool = ThreadPool(5) 
    pool.map(create_dataset_from_folder_single, list(range(len(image_list))))
    
    # 2.1 sanity assert
    # -----------------
    assert x_images_dataset.shape[0] == x_images_dataset_gray.shape[0], 'RGB and Grayscale images have different numbers of images!'
    
    # 3. filtering out the dataset
    # ----------------------------
    #print('Len at start: ' + str(x_images_dataset.shape))
    x_images_dataset = x_images_dataset[index_list]
    x_images_dataset_gray = x_images_dataset_gray[index_list]
    x_images_dataset_edge = x_images_dataset_edge[index_list]
    
    
    # hard normalising grayscale dataset by individual means
    # ------------------------------------------------------
    if gray_mode == 'bw':
        
        # hard b/w single channel
        # -----------------------        
        mn = np.mean(np.mean(x_images_dataset_gray, axis = 1), axis = 1)
        mn = mn.reshape(mn.shape[0],1,1)
        x_images_dataset_gray[x_images_dataset_gray < mn] = 0
        x_images_dataset_gray[x_images_dataset_gray >= mn] = 255
        x_images_dataset_gray = x_images_dataset_gray.reshape(x_images_dataset_gray.shape[0],new_h,new_w,1)
        
    elif gray_mode == 'gray_3':
        
        # grayscale 3 channel
        # -------------------
        x_images_dataset_gray = x_images_dataset_gray.reshape(x_images_dataset_gray.shape[0],new_h,new_w,1)
        x_images_dataset_gray = np.concatenate((x_images_dataset_gray,x_images_dataset_gray,x_images_dataset_gray), axis= 3)
        
    else:
        
        # grayscale 1 channel
        # -------------------
        x_images_dataset_gray = x_images_dataset_gray.reshape(x_images_dataset_gray.shape[0],new_h,new_w,1)
        

    #print('Len after filtering: ' + str(x_images_dataset.shape))
    print('Done creating dataset of around ' + str(counter) + ' images. Access them at global x_images_dataset, x_images_dataset_gray, x_images_dataset_edge.')
    
    # closing pools
    # -------------
    pool.terminate()
    pool.join()
    

def create_dataset_from_folder_single(i):
    
    # 0. calling global variables
    # ---------------------------
    global gb_in_folder
    global counter
    global new_h
    global new_w
    global x_images_dataset
    global x_images_dataset_gray
    global image_list
    global resize_flag
    global x_images_dataset_edge
    
    
    # 1. ops
    # ------
    try:
        name = image_list[i]
        img_main = cv2.imread(join(gb_in_folder, name))
        img = cv2.cvtColor(copy.deepcopy(img_main), cv2.COLOR_BGR2RGB)
        img_gray = cv2.cvtColor(copy.deepcopy(img_main), cv2.COLOR_BGR2GRAY)
        
        # resizing ops
        # -----------
        if resize_flag == True:
            img = cv2.resize(img, (new_w,new_h))
            img_gray = cv2.resize(img_gray, (new_w,new_h))
            

        # 5. by default building edge images
        # ----------------------------------
        blurred = cv2.GaussianBlur(img_gray.reshape(new_h,new_w).astype('uint8'), (7, 7), 0)
        edged = cv2.Canny(blurred, 50, 150)
        edged = edged.reshape(new_h,new_w,1)
        
        # final assignments
        # ------------------
        x_images_dataset[i] = img
        x_images_dataset_gray[i] = img_gray
        x_images_dataset_edge[i] = edged
        
        counter += 1
        index_list.append(i)
    
    except:
        
        # do nothing
        ##
        pass
    


In [None]:
# a simple forward pass function
# ------------------------------


def simple_forward_pass_pool(use_pretrained_in,xin,input_is_image,model):
    
    # 0. initialisations
    # ------------------
    global use_pretrained_in_func
    use_pretrained_in_func = use_pretrained_in
    
    global model_global
    model_global = model
    
    global x_in_global
    x_in_global = copy.deepcopy(xin)
    
    # 1. setting up y_out size
    # ------------------------
    global y_out_global
    if use_pretrained_in == True:
        sz = list(model(pre_process_pretrained_model_data(xin[0:2])).size())
    else:
        sz = list(model(xin[0:2]).size())
    final_size = sz[1:]
    final_size = tuple([xin.size()[0]] + final_size)
    y_out_global = torch.zeros((final_size))
    
    
    # 1.1 sanity
    # ----------
    if input_is_image == True:
        assert torch.mean(x_in_global) > 1, 'Error: Data already in 0-1 range'
        if use_pretrained_in == False:
            x_in_global = x_in_global/torch.max(x_in_global)
        
    
    # 2. calling pool function
    # ------------------------
    pool = ThreadPool(10)
    pool.map(simple_forward_pass_single, list(range(xin.size()[0])))
    #print('Done with forward pass of around ' + str(xin.size()[0]) + ' examples. Access them at global y_out_global.')
    
    
    
    
def simple_forward_pass_single(i):
    
    # 0. global inits
    # ---------------
    global model_global
    global x_in_global
    global y_out_global
    global use_pretrained_in_func
    
    # 1. ops
    # ------
    curr_example = x_in_global[i]
    curr_example = curr_example.view(1,curr_example.size()[0],curr_example.size()[1],curr_example.size()[2])
    if use_pretrained_in_func == True:
        curr_out = model_global(pre_process_pretrained_model_data(curr_example))
    else:
        curr_out = model_global(curr_example)
    y_out_global[i] = curr_out[0]
    
    

In [None]:
# function to simply return similar images based on whole latents
# ---------------------------------------------------------------

def similarity(input_latents_list,db_latents_list,xin,xdb,sim_weights,no_suggestions,similarity_check_mode,print_result):

    
    '''
    
    1. input includes
    
    a. input images & db images latent lists
    b. settings such as weights, mode etc
    c. will return indices of db images sorted in descending order of ranking similarity
    
    
    '''
    
    # 1. initialisations
    # ------------------
    interim_sim_array = {}
    final_sim_array = {}
    final_all_indices = []
    similar_products = []
    all_similarity_values = []
    epsilon = 0.0001
    
    # 2. sanity assertions
    # --------------------
    assert len(input_latents_list) == len(db_latents_list), 'Error: length of input & db lists dont match.'
    if sim_weights == 'equal':
        
        # setting sim weight values to 1
        # ------------------------------
        sim_weights = []
        for _ in range(len(input_latents_list)):
            sim_weights.append(1.0)
    else:
        assert len(sim_weights) == len(input_latents_list), 'Error: number of weights & length of latent lists dont match.'
        #assert round(sum(sim_weights),2) == 1.0, 'Error: the weights dont sum to 1.0'
    
    # 3. looping through the INPUT latents list
    # -----------------------------------------
    print('1. looping through latent list..')
    for l_counter in range(len(input_latents_list)):
        
        # 3.0 local initialisations
        # -------------------------
        curr_input_latent = input_latents_list[l_counter]
        db_input_latent = db_latents_list[l_counter]
        
        
        # 3.1 itering through every image in curr_input_latent
        # ----------------------------------------------------
        for i in range(curr_input_latent.shape[0]):
            
            # Finding similarity per example
            # ------------------------------
            curr_input_latent_example = curr_input_latent[i]
            
            
            # checking using similarity
            # -------------------------
            if similarity_check_mode == 'l2':
                
                # using L2
                # --------
                similarity_array = np.sqrt(np.sum((curr_input_latent_example-db_input_latent)**2, axis = 1))
                
                
            elif similarity_check_mode == 'cosine_ratio':
                
                # using cosine_ratio
                # ------------------
                curr_input_latent_example = curr_input_latent_example.reshape(1,curr_input_latent_example.shape[0])
                similarity_array = cosine_similarity_multi(curr_input_latent_example,db_input_latent)
                similarity_array = similarity_array.reshape(similarity_array.shape[0],)
                similarity_array += np.average((np.minimum(curr_input_latent_example,db_input_latent)/(np.maximum(curr_input_latent_example,db_input_latent) + epsilon)), axis = 1)
            
            
            elif similarity_check_mode == 'cosine':
                
                # using cosine_ratio
                # ------------------
                curr_input_latent_example = curr_input_latent_example.reshape(1,curr_input_latent_example.shape[0])
                similarity_array = cosine_similarity_multi(curr_input_latent_example,db_input_latent)
                similarity_array = similarity_array.reshape(similarity_array.shape[0],)
                
            
            
            elif similarity_check_mode == 'ratio':
                
                # ratio only
                # ----------
                similarity_array = np.average((np.minimum(curr_input_latent_example,db_input_latent)/(np.maximum(curr_input_latent_example,db_input_latent) + epsilon)), axis = 1)
                
            else:
                
                # invalid sim check mode
                # ----------------------
                assert 1 == 2, 'Error: invalid similarity check mode. This can be either "cosine_ratio" or "ratio" only.'
            
            
            # appending to interim sim array for final weights application
            # ------------------------------------------------------------
            
            # this dict will store all similarity values per example
            # that is - 
            # interim_sim_array[0] = [sim_array_based_on_latent_list_0, sim_array_based_on_latent_list_1,...]
            
            try:
                interim_sim_array[i].append(similarity_array)
            except:
                interim_sim_array[i] = []
                interim_sim_array[i].append(similarity_array)
    
    
    # 4. itering thru dict & applying weights
    # ---------------------------------------
    print('2. applying weights & appending final results..')
    for keys in interim_sim_array:
        
        # deleting old value
        # ------------------
        try:
            del similarity_array_final
        except:
            pass
        
        # itering thru list
        # -----------------
        for each_sim_array_index in range(len(interim_sim_array[keys])):
            curr_weighted_array = interim_sim_array[keys][each_sim_array_index] * sim_weights[each_sim_array_index]
            try:
                similarity_array_final += curr_weighted_array
            except:
                similarity_array_final = curr_weighted_array
                
                
        # similarity_array_final must be of shape (m,)
        ###
        
        # continuation - no change here on
        # --------------------------------
        if similarity_check_mode == 'l2':
            
            # since in l2 distances are calculated and smaller dis = higher sim
            # ------------------------------------------------------------------
            sorted_indices = list(np.argsort(similarity_array_final))
            
        else:
            sorted_indices = list(np.argsort(-1*similarity_array_final))
        final_indices = sorted_indices[0:no_suggestions]
        final_indices = [int(fi) for fi in final_indices]
        final_all_indices.append(final_indices)
        all_similarity_values.append(similarity_array[final_indices])
        similar_products.append(xdb[final_indices])
        
    

    
    # 3. showing if required
    # ----------------------
    if print_result == True:
        
        # itering
        # -------
        for i in range(xin.shape[0]):
            
            print('>> At image ' + str(i) + ' of around ' + str(xin.shape[0]) + '..')
            print('** Input Image - ')
            plt.figure(figsize=(2,2))
            plt.imshow(xin[i])
            plt.show()
           
            print('** Showing result images..')
            fig=plt.figure(figsize=(25, 25))
            columns = 5
            rows = 10
            
            
            for i_1 in range(similar_products[i].shape[0]):
                
                #print('** At image ' + str(i) + ' showing option number ' + str(i_1) + '**')
                #print('Image index number: ' + str(final_all_indices[i][i_1]))
                #print('Similarity value: ' + str(all_similarity_values[i][i_1]))
                
                if xdb.shape[3] > 1:
                    img = similar_products[i][i_1]
                    fig.add_subplot(rows, columns, i_1+1)
                    plt.imshow(img)
                    #plt.show()
                else:
                    img = similar_products[i][i_1,:,:,0]
                    fig.add_subplot(rows, columns, i_1+1)
                    plt.imshow(img, cmap='gray')
            
            plt.show()
            
            print('\n#########################################\n')
                
            
    
    # 3. final return
    # ---------------
    return similar_products,all_similarity_values,final_all_indices

# models

In [None]:
# none for now

# execution

## 1. setups

In [None]:
# set up all the variables here
# -----------------------------
img_h, img_w = 224,224
use_cuda = False
use_pretrained = True
model_path = '/Users/venkateshmadhava/Documents/pmate2/pmate2_env/models/resnet50_pretrained_poc_5_class_GRAY_bird_apple_fish_skull_star_imgaug_and_text_aug.tar'
db_folder = '/Users/venkateshmadhava/Desktop/temp_db'
input_folder = '/Users/venkateshmadhava/Desktop/temp_input'

# global variable class labels -- DO NOT CHANGE ORDER AS MODEL PREDICTIONS ARE DEPENDENT
# --------------------------------------------------------------------------------------
global class_labels
class_labels = ['Bird','Apples','Fish','Skulls','Star']

In [None]:
# load the model
# --------------
model_mimgcls,_,_,_,_ = load_saved_model_function(model_path, use_cuda)
model_mimgcls = model_mimgcls.eval()

# setting up latent if using pretrained
# -------------------------------------
if use_pretrained == True:
    
    # slicing FC to get latent
    # ------------------------
    model_mimgcls_for_latent = copy.deepcopy(model_mimgcls)
    model_mimgcls_for_latent.fc = model_mimgcls_for_latent.fc[0:3]
else:
    
    # incase we use a custom model, latent will be defined in it
    # ----------------------------------------------------------
    model_mimgcls_for_latent = model_mimgcls.latent
    

## 2. loading data & extracting feature vectors

### 2.1 db data

In [None]:
# loading DB data
# ---------------

create_dataset_from_folder_all(db_folder,True,None,img_h,img_w)
global x_images_dataset_gray
x_db_cls_trn = Variable(setup_image_tensor(x_images_dataset_gray)).float()
x_db_rgb = x_images_dataset
print(x_db_cls_trn.size())

In [None]:
# extraction
# ----------

# 1. getting prediction
# ---------------------   
simple_forward_pass_pool(use_pretrained,x_db_cls_trn,True,model_mimgcls)
global y_out_global
db_pred = y_out_global
del y_out_global

# sanity
# ------
print(db_pred.size())

# 2. getting latent
# -----------------
simple_forward_pass_pool(use_pretrained,x_db_cls_trn,True,model_mimgcls_for_latent)
global y_out_global
db_pred_latent = y_out_global
del y_out_global

# sanity
# ------
print(db_pred_latent.size())


### 2.2 input queries data

In [None]:
# loading input queries data
# ---------------------------

create_dataset_from_folder_all(input_folder,True,None,img_h,img_w)
global x_images_dataset_gray
x_input_cls_trn = Variable(setup_image_tensor(x_images_dataset_gray)).float()
x_input_rgb = x_images_dataset
print(x_input_cls_trn.size())

In [None]:
# extraction
# ----------

# 1. getting prediction
# ---------------------   
simple_forward_pass_pool(use_pretrained,x_input_cls_trn,True,model_mimgcls)
global y_out_global
input_pred = y_out_global
del y_out_global

# sanity
# ------
print(input_pred.size())

# 2. getting latent
# -----------------
simple_forward_pass_pool(use_pretrained,x_input_cls_trn,True,model_mimgcls_for_latent)
global y_out_global
input_pred_latent = y_out_global
del y_out_global

# sanity
# ------
print(input_pred_latent.size())


### 2.3 setting up latent list

In [None]:
# final inputs to similarity function
# -----------------------------------

db_latents_list = [db_pred.detach().data.numpy(), db_pred_latent.detach().data.numpy()]
input_latents_list = [input_pred.detach().data.numpy(),input_pred_latent.detach().data.numpy()]

## 3. similarity

In [None]:
# set up weightages of input latents lists
# [1,1] means equal weightage for similarity between all vectors in the list
# ---------------------------------------------------------------------------
sim_w = [1,1]

# actual function
# ---------------
_ = similarity(input_latents_list,db_latents_list,x_input_rgb,x_db_rgb,sim_w,50,'cosine',True)

# rough