In [None]:
# Necessary imports - done
# ------------------------

import os
from PIL import Image
import time
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import random
import copy
import cv2
from os import listdir
from os.path import isfile, join
import h5py
import time
import multiprocessing
from multiprocessing import Pool
from multiprocessing.dummy import Pool as ThreadPool
import math
import sys
from tqdm import tqdm
import pandas as pd
import csv
import json
import imageio

import torch
from torch.autograd import Variable
import torch.autograd
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils
from torchvision.utils import save_image
import torch.optim as optim

from __future__ import print_function
from io import BytesIO


# scipy related
# -------------
import scipy
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage import filters
from scipy import misc

# working now
# -----------
#import skimage.io
#from skimage.transform import rotate, AffineTransform, warp
#from skimage.util import random_noise
#from skimage.filters import gaussian
#from skimage import transform as tf


# neccessary imports for imgaug
# ------------------------------
import imgaug as ia
from imgaug import augmenters as iaa
from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage

###
###
%matplotlib inline
%env JOBLIB_TEMP_FOLDER=/tmp


# printing platform info
# ----------------------
import platform
print(platform.python_version())

# codes

### helper functions

In [None]:
# resize pool function
# --------------------

def resize_pool(x,h,w):
    
    # 0. global inits
    # ---------------
    global x_in
    x_in = x
    
    global h_in,w_in
    h_in,w_in = h,w
    
    global x_out
    x_out = np.zeros((x.shape[0],h,w,x.shape[3]))
    
    # 1. calling resize function across multiprocessing pool
    # ------------------------------------------------------
    pool = ThreadPool(5) 
    pool.map(resize_single, list(range(x.shape[0])))
    
    # sanity
    # ------
    print('done with ' + str(len(x)) +' images. access at global x_out.')
    
    # closing pools
    # -------------
    pool.terminate()
    pool.join()
    


# single function
# ----------------
def resize_single(index):
    
    # 0. getting in globls
    # --------------------
    global x_in
    global h_in,w_in
    global x_out
    
    # resize and switch in
    # -------------------
    curr_img = cv2.resize(x_in[index],(w_in,h_in))
    curr_img = curr_img.reshape(h_in,w_in,x_in.shape[3])
    x_out[index] = curr_img
    
    


In [None]:
# simple function to return filtered inds
# ---------------------------------------

def return_filtered_concepts_inds(pred_input_in,pred_db_in):

    
    # 0. initialisations
    # ------------------
    pred_input = pred_input_in.data.numpy()
    pred_input[pred_input < 0.5] = 0
    pred_input[pred_input >= 0.5] = 1
    
    
    pred_db = pred_db_in.data.numpy()
    pred_db[pred_db < 0.5] = 0
    pred_db[pred_db >= 0.5] = 1
    
    
    inds_list = []
    
    # main iter
    # ---------
    for i in range(pred_input.shape[0]):

        # step 1
        # at each input image prediction level
        # just pick examples where sum of positions of zeros should be same
        # ---------------------------------------------------------------
        zero_pos = (pred_input[i]==0).astype(int) * pred_db
        zero_pos_sum = np.sum(zero_pos, axis = 1)
        filtered_inds = np.argwhere(zero_pos_sum == 0)[:,0]
        inds_list.append(filtered_inds)
        
    
    # final return
    # ------------
    return inds_list

In [None]:
# GENERIC function to calculate conv outsize
# -------------------------------------------- 
def outsize_conv(n_H,n_W,f,s,pad):
    
    h = ((n_H - f + (2*pad))/s) + 1
    w = ((n_W - f + (2*pad))/s) + 1
    return h,w
    
    
# GENERIC function to calculate upconv outsize
# --------------------------------------------    
def outsize_upconv(h,w,f,s,p):
    hout = (h-1)*s - 2*p + f
    wout = (w-1)*s - 2*p + f
    return hout, wout



# GENERIC - initialises weights for a NN
# --------------------------------------
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
        
        
# GENERIC - change an torch image to numpy image
# ----------------------------------------------
def to_numpy_image(xin):
    
    try:
        xin = xin.data.numpy()
    except:
        xin = xin.numpy()
    
    xout = np.swapaxes(xin,1,2)
    xout = np.swapaxes(xout,2,3)
    
    # returns axes swapped numpy images
    # ---------------------------------
    return xout       



# GENERIC - converts numpy images to torch tensors for training
# -------------------------------------------------------------
def setup_image_tensor(xin):
    xout = np.swapaxes(xin,1,3)
    xout = np.swapaxes(xout,2,3)
    
    # returns axes swapped torch tensor
    # ---------------------------------
    xout = torch.from_numpy(xout)
    return xout.float()

In [None]:
# a function to load a saved model
# --------------------------------

def load_saved_model_function(path, use_cuda):
    
    
    ''' path = /folder1/folder2/model_ae.tar format'''
    
    # 1. loading full model
    # ---------------------
    model = torch.load(path.replace('.tar','_MODEL.tar'))
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,model.parameters()))
    
    # 2. Applying state dict
    # ----------------------
    if use_cuda == True:
        
        # loads to GPU
        # ------------
        checkpoint = torch.load(path)
        
    else:
        # loads to CPU
        # ------------
        checkpoint = torch.load(path, map_location=lambda storage, loc: storage)
        
        
    # loading checkpoint
    # -------------------
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # loading optimizer
    # -----------------
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if use_cuda == True:
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
            
            
            
    # loading other stuff
    # -------------------
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    loss_mode = checkpoint['loss_mode']
    
    return model, optimizer, epoch, loss, loss_mode
    
    


In [None]:
# main function steps
# -------------------
def score_cam(model_mpn,model,x_in_scam_in,x_in_print,layer,norm_mode):
    
    
    ###### MAIN FUNCTION STEPS #######
    ##################################
    
    # sanity
    # ------
    assert torch.mean(x_in_scam_in) > 1,'Error: images in needs to be in 0-255 range.'
    
    
    # 0. initialisations
    # ------------------
    x_in_scam = x_in_scam_in/torch.max(x_in_scam_in)
    model = model.eval()
    model_mpn = model_mpn.eval()
    x_in_scam_np = to_numpy_image(x_in_scam)
    out_dict = {}
    img_in_h, img_in_w = x_in_scam.size()[2], x_in_scam.size()[3]
    

    # 1. get the feature maps & predictions
    # --------------------------------------
    x_in_masks = model_mpn(x_in_scam)
    x_in_masks[x_in_masks < 0.5] = 0
    x_in_masks[x_in_masks >= 0.5] = 1
    f_maps = model.score_cam_fmaps(x_in_scam*x_in_masks,layer)
    pred = model.eval()(x_in_scam*x_in_masks)
    pred[pred < 0.5] = 0
    pred[pred >= 0.5] = 1


    # 2. normalise each feature map to 0-1
    # unfortunately we have to loop through only
    # also remember feature maps are > 0 since they're after RELU
    # this means we could just to value/max
    # we will include all ops within this for loop :(
    # itering each input image
    # ------------------------
    for i in range(f_maps.size()[0]):

        # 2.1 local inits
        # ---------------
        out_dict[i] = {}
        curr_fmap = f_maps[i]
        curr_class_inds = list(np.argwhere(pred[i].detach().numpy()==1)[:,0])

        # continue only if there's a prediction
        # -------------------------------------
        if len(curr_class_inds) > 0:

            # 2.2 computing max and min values along channel axis
            # ---------------------------------------------------
            mxvals = torch.max(torch.max(curr_fmap,1).values,1).values
            mxvals = mxvals.view(curr_fmap.size()[0],1,1)
            mnvals = torch.min(torch.min(curr_fmap,1).values,1).values
            mnvals = mnvals.view(curr_fmap.size()[0],1,1)

            # 2.3 final normalisation
            # ------------------------
            if norm_mode == 'max':
                curr_fmap_norm = curr_fmap/(mxvals + 0.0001)
            elif norm_mode == 'minmax':
                curr_fmap_norm = (curr_fmap - mnvals)/ (mxvals - mnvals + 0.0001)


            # 2.4 resizing ops
            # ----------------
            curr_fmap_norm = curr_fmap_norm.view(1,curr_fmap_norm.size()[0],curr_fmap_norm.size()[1],curr_fmap_norm.size()[2])
            curr_fmap_norm_np = to_numpy_image(curr_fmap_norm.data)
            curr_fmap_norm_np = cv2.resize(curr_fmap_norm_np[0],(img_in_w,img_in_h))
            #curr_fmap_norm_np[curr_fmap_norm_np < 0.5] = 0
            #curr_fmap_norm_np[curr_fmap_norm_np >= 0.5] = 1
            ## output here is of shape (img_h,img_w,no_channels) i.e., (95,95,256) etc..
            ##


            # 2.5 itering through each channel and masking input image
            # we are concerned with classes that are predicted by the model only
            # we want to filter down activation maps based on classes predicted only
            # ----------------------------------------------------------------------
            for each_ch in range(curr_fmap_norm_np.shape[2]):

                # mask ops
                # --------
                curr_fmap_norm_np_channel = curr_fmap_norm_np[:,:,each_ch]
                curr_fmap_norm_np_channel = curr_fmap_norm_np_channel.reshape(img_in_h,img_in_w,1)
                curr_masked_image = curr_fmap_norm_np_channel * x_in_scam_np[i]
                curr_masked_image = curr_masked_image.reshape(1,img_in_h,img_in_w,1)
                curr_masked_image = curr_masked_image/(np.max(curr_masked_image) + 0.0001)

                # prediction pass
                # ---------------
                curr_masked_image_pred = model(Variable(setup_image_tensor(curr_masked_image)).float())
                curr_masked_image_pred_max_label = torch.argmax(curr_masked_image_pred)
                curr_masked_image_pred_max_weight = torch.max(curr_masked_image_pred).item()

                # adding weight & map to dict
                # ---------------------------
                if curr_masked_image_pred_max_label in curr_class_inds:
                    try:
                        out_dict[i][curr_masked_image_pred_max_label.item()]['act'] += curr_masked_image_pred_max_weight * curr_fmap_norm_np_channel
                    except:
                        out_dict[i][curr_masked_image_pred_max_label.item()] = {}
                        out_dict[i][curr_masked_image_pred_max_label.item()]['act'] = curr_masked_image_pred_max_weight * curr_fmap_norm_np_channel




            # 2.6 we have to iter through dicts
            # ---------------------------------
            for keys_labels in out_dict[i]:

                # dicts are in format
                # d[image_index] has keys ['class_label'] index positions [0,2] etc
                # d[image_index][class_label] has ['act'] which has non-normalised activation map
                # for this class_label
                # -------------------------------------
                #norm_act_map = (out_dict[i][keys_labels]['act'] - np.min(out_dict[i][keys_labels]['act']))/(np.max(out_dict[i][keys_labels]['act']) - np.min(out_dict[i][keys_labels]['act']) + 0.0001)
                norm_act_map = out_dict[i][keys_labels]['act']/(np.max(out_dict[i][keys_labels]['act']) + 0.0001)
                #norm_act_map = out_dict[i][keys_labels]['act']

                # heatmap ops
                # -----------
                heat_map = cv2.applyColorMap(np.uint8(255*(1-norm_act_map)), cv2.COLORMAP_JET)
                heat_map = heat_map/np.max(heat_map)
                curr_visual_image = x_in_print[i]/np.max(x_in_print[i])
                overlayed_img = heat_map * 0.5 + curr_visual_image * 0.5
                
                # adding to dict
                # --------------
                out_dict[i][keys_labels]['heatmap'] = overlayed_img
                
            
            
    # printing results by default
    # ---------------------------
    for i in out_dict.keys():
    
        # sanity
        # ------
        print('image # ' + str(i))
        print('no classes here : ' + str(len(out_dict[i].keys())))

        # iter thru classes
        # -----------------
        for class_inds in out_dict[i]:

            # print class
            # -----------
            print('class: ' + str(class_labels[class_inds]))
            plt.imshow(out_dict[i][class_inds]['heatmap'])
            plt.show()

        # sanity
        # ------
        print('**************')
    
    
    # final return
    # ------------
    return out_dict

In [None]:
# a simple forward pass function
# ------------------------------


def simple_forward_pass_pool(xin,input_is_image,model):
    
    
    # 0. initialisations
    # ------------------
    global model_global
    model_global = model
    
    global x_in_global
    x_in_global = copy.deepcopy(xin)
    
    # 1. setting up y_out size
    # ------------------------
    global y_out_global
    sz = list(model(xin[0:2]).size())
    final_size = sz[1:]
    final_size = tuple([xin.size()[0]] + final_size)
    y_out_global = torch.zeros((final_size))
    
    
    # 1.1 sanity
    # ----------
    if input_is_image == True:
        assert torch.mean(x_in_global) > 1, 'Error: Data already in 0-1 range'
        x_in_global = x_in_global/torch.max(x_in_global)
        
    
    # 2. calling pool function
    # ------------------------
    pool = ThreadPool(10)
    pool.map(simple_forward_pass_single, list(range(xin.size()[0])))
    print('Done with forward pass of around ' + str(xin.size()[0]) + ' examples. Access them at global y_out_global.')
    
    
    
    
def simple_forward_pass_single(i):
    
    # 0. global inits
    # ---------------
    global model_global
    global x_in_global
    global y_out_global
    
    # 1. ops
    # ------
    curr_example = x_in_global[i]
    curr_example = curr_example.view(1,curr_example.size()[0],curr_example.size()[1],curr_example.size()[2])
    curr_out = model_global(curr_example)
    y_out_global[i] = curr_out[0]
    
    

In [None]:
# simple function to read a single image
# --------------------------------------

def create_dataset_from_folder_all(infolder,resize,gray_mode,n_h,n_w):
    
    # 0. global initialisations
    # -------------------------
    global index_list
    index_list = []
    
    global resize_flag
    resize_flag = resize
    
    global gb_in_folder
    gb_in_folder = infolder

    global counter
    counter = 0
    
    global new_h
    new_h = n_h
    
    global new_w
    new_w = n_w
    
    global image_list
    image_list_jpg = [f for f in listdir(infolder) if isfile(join(infolder, f)) and '.jpg' in f.lower()]
    image_list_png = [f for f in listdir(infolder) if isfile(join(infolder, f)) and '.png' in f.lower()]
    image_list = image_list_jpg + image_list_png

    global x_images_dataset
    x_images_dataset = np.zeros((len(image_list),new_h,new_w,3), dtype='uint8')
    
    global x_images_dataset_gray
    x_images_dataset_gray = np.zeros((len(image_list),new_h,new_w), dtype='uint8')
    
    global x_images_dataset_edge
    x_images_dataset_edge = np.zeros((len(image_list),new_h,new_w,1))
    
    
    # 1.1 sanity assertion
    # -------------------
    assert len(image_list) > 0, 'No images in the folder'
    
    
    # 2. calling resize function across multiprocessing pool
    # ------------------------------------------------------
    pool = ThreadPool(5) 
    pool.map(create_dataset_from_folder_single, list(range(len(image_list))))
    
    # 2.1 sanity assert
    # -----------------
    assert x_images_dataset.shape[0] == x_images_dataset_gray.shape[0], 'RGB and Grayscale images have different numbers of images!'
    
    # 3. filtering out the dataset
    # ----------------------------
    print('Len at start: ' + str(x_images_dataset.shape))
    x_images_dataset = x_images_dataset[index_list]
    x_images_dataset_gray = x_images_dataset_gray[index_list]
    x_images_dataset_edge = x_images_dataset_edge[index_list]
    
    
    # hard normalising grayscale dataset by individual means
    # ------------------------------------------------------
    if gray_mode == 'bw':
        
        # hard b/w single channel
        # -----------------------        
        mn = np.mean(np.mean(x_images_dataset_gray, axis = 1), axis = 1)
        mn = mn.reshape(mn.shape[0],1,1)
        x_images_dataset_gray[x_images_dataset_gray < mn] = 0
        x_images_dataset_gray[x_images_dataset_gray >= mn] = 255
        x_images_dataset_gray = x_images_dataset_gray.reshape(x_images_dataset_gray.shape[0],new_h,new_w,1)
        
    elif gray_mode == 'gray_3':
        
        # grayscale 3 channel
        # -------------------
        x_images_dataset_gray = x_images_dataset_gray.reshape(x_images_dataset_gray.shape[0],new_h,new_w,1)
        x_images_dataset_gray = np.concatenate((x_images_dataset_gray,x_images_dataset_gray,x_images_dataset_gray), axis= 3)
        
    else:
        
        # grayscale 1 channel
        # -------------------
        x_images_dataset_gray = x_images_dataset_gray.reshape(x_images_dataset_gray.shape[0],new_h,new_w,1)
        

    print('Len after filtering: ' + str(x_images_dataset.shape))
    print('Done creating dataset of around ' + str(counter) + ' images. Access them at global x_images_dataset, x_images_dataset_gray, x_images_dataset_edge.')
    
    # closing pools
    # -------------
    pool.terminate()
    pool.join()
    

def create_dataset_from_folder_single(i):
    
    # 0. calling global variables
    # ---------------------------
    global gb_in_folder
    global counter
    global new_h
    global new_w
    global x_images_dataset
    global x_images_dataset_gray
    global image_list
    global resize_flag
    global x_images_dataset_edge
    
    
    # 1. ops
    # ------
    try:
        name = image_list[i]
        img_main = cv2.imread(join(gb_in_folder, name))
        img = cv2.cvtColor(copy.deepcopy(img_main), cv2.COLOR_BGR2RGB)
        img_gray = cv2.cvtColor(copy.deepcopy(img_main), cv2.COLOR_BGR2GRAY)
        
        # resizing ops
        # -----------
        if resize_flag == True:
            img = cv2.resize(img, (new_w,new_h))
            img_gray = cv2.resize(img_gray, (new_w,new_h))
            

        # 5. by default building edge images
        # ----------------------------------
        blurred = cv2.GaussianBlur(img_gray.reshape(new_h,new_w).astype('uint8'), (7, 7), 0)
        edged = cv2.Canny(blurred, 50, 150)
        edged = edged.reshape(new_h,new_w,1)
        
        # final assignments
        # ------------------
        x_images_dataset[i] = img
        x_images_dataset_gray[i] = img_gray
        x_images_dataset_edge[i] = edged
        
        counter += 1
        index_list.append(i)
    
    except:
        
        # do nothing
        ##
        pass
    


In [None]:
# function that will return final latents for similarity function 
# ---------------------------------------------------------------

def final_latents(f_maps,kernel_stride_dims,pool_mode,aggregate_pool_maps):
    
    
    '''
    
    1. input is a dict with keys - 
    
    a. f_maps - list of feature maps (m,c,h,w) on which pooling functions can be run for latent computation
    b. kernel_dims - list of kernel dimensions that will be used for pooling on feature maps
    c. pool_mode - 'max', 'avg' or 'both'
    d. aggregate_pool_maps - either sum up pool values or not
    
    2. output will be a dict with final latents to be input to similarity function
    3. output latents will be L2 normalized
    
    
    '''

    
    # 0. initialisations
    # ------------------
    latents_out = []
    
    
    # 1. computing rmac latents
    # -------------------------
    for each_fmap in f_maps:
        for each_kernel_dim in kernel_stride_dims:
            curr_latent = return_ms_rmac([each_fmap],pool_mode,each_kernel_dim,aggregate_pool_maps).data.numpy()
            
            # L2 normalization
            # ----------------
            #curr_latent = curr_latent / np.linalg.norm(curr_latent)
            
            # appending
            # ----------
            latents_out.append(curr_latent)
            print('done at kernel ' + str(each_kernel_dim) + '. latent size: ' + str(curr_latent.shape))
    
    
    
    # final return
    # ------------
    return latents_out



In [None]:
# main function
# -------------

def return_ms_rmac(fmaps_list,pool_mode,kernel_dims,aggregate_pool_maps):
    
    '''
    ref: https://www.researchgate.net/publication/313465134_MS-RMAC_Multiscale_Regional_Maximum_Activation_of_Convolutions_for_Image_Retrieval
    
    1. takes as input fmaps tuple with feature masps of size (m,k,h,w) each where k = no_channels at each layer
    -- input is ((m1,k1,h1,w1), (m2,k2,h2,w2),...)
    -- pool_mode = 'max','avg','both'
    -- kernel_dims = None or (f,s)
    -- aggregate_pool_maps - if true, we will sum across channels, else leave as it is
    
    2. works out 3 scales of MAC kernel sizes for each feature map in tuple
    3. computes MAC i.e., maximum activations convolutions & aggregates them
    4. returns concatenated (m,K) vector where K = k1 + k2 + k3.. i.e., no_channels at each layer in fmaps in tuple
    
    '''
    
    # 0. initialisations
    # -------------------
    assert type(fmaps_list) == list, 'Type error: input feature maps must be a list'
    
    
    # 1. looping through fmap tuple
    # -----------------------------
    for fmap in fmaps_list:
        
        # 1.1 checking kernel initialisations
        # -----------------------------------
        if kernel_dims == None:
            
            # 1.1.1 we will be using preset scaled regions - computing f,s
            # ------------------------------------------------------------
            h,w = fmap.size()[2],fmap.size()[3]

            # scale 1 : w is same, h = h/2, stride = curr_h/3
            ##
            l1_h = int(h/2)
            l1_w = w
            l1_s = int(l1_h/2)

            # scale 2 : h = h/2, w = 2w/3, stride = curr_w/2
            ##
            l2_h = int(h/2)
            l2_w = int(2*w/3)
            l2_s = int(l2_w/2)

            # scale 3 : h = 2h/5, w = w/2, stride = curr_w/2
            ##
            l3_h = int(2*h/5)
            l3_w = int(w/2)
            l3_s = int(l3_w/2)
            
            
            # 1.1.1.2 forward pass to get pool maps
            # --------------------------------------
            if pool_mode == 'max':
                
                # computing mx pool @ scale 1
                # ---------------------------
                pl_l1 =  nn.Sequential(*[nn.MaxPool2d((l1_h,l1_w), stride=l1_s)])
                l1_pool_map = pl_l1(fmap)
                l1_pool_map = torch.sum(torch.sum(l1_pool_map, 2),2)
                
                # computing mx pool @ scale 2
                # ---------------------------
                pl_l2 =  nn.Sequential(*[nn.MaxPool2d((l2_h,l2_w), stride=l2_s)])
                l2_pool_map = pl_l2(fmap)
                l2_pool_map = torch.sum(torch.sum(l2_pool_map, 2),2)
                
                # computing mx pool @ scale 3
                # ---------------------------
                pl_l3 =  nn.Sequential(*[nn.MaxPool2d((l3_h,l3_w), stride=l3_s)])
                l3_pool_map = pl_l3(fmap)
                l3_pool_map = torch.sum(torch.sum(l3_pool_map, 2),2)
                
                
                # NOT concatenating -- summing
                # -----------------------------
                combined_pool_map = l1_pool_map + l2_pool_map + l3_pool_map
                
                
            
            elif pool_mode == 'avg':
                

                # computing avg pool @ scale 1
                # ---------------------------
                pl_l1 =  nn.Sequential(*[nn.AvgPool2d((l1_h,l1_w), stride=l1_s)])
                l1_pool_map = pl_l1(fmap)
                l1_pool_map = torch.sum(torch.sum(l1_pool_map, 2),2)
                
                # computing avg pool @ scale 2
                # ---------------------------
                pl_l2 =  nn.Sequential(*[nn.AvgPool2d((l2_h,l2_w), stride=l2_s)])
                l2_pool_map = pl_l2(fmap)
                l2_pool_map = torch.sum(torch.sum(l2_pool_map, 2),2)
                
                # computing avg pool @ scale 3
                # ---------------------------
                pl_l3 =  nn.Sequential(*[nn.AvgPool2d((l3_h,l3_w), stride=l3_s)])
                l3_pool_map = pl_l3(fmap)
                l3_pool_map = torch.sum(torch.sum(l3_pool_map, 2),2)
                
                # NOT concatenating -- summing
                # ----------------------------
                combined_pool_map = l1_pool_map + l2_pool_map + l3_pool_map
                
                
            elif pool_mode == 'both':
                
                # computing mx pool @ scale 1
                # ---------------------------
                pl_l1 =  nn.Sequential(*[nn.MaxPool2d((l1_h,l1_w), stride=l1_s)])
                l1_pool_map = pl_l1(fmap)
                l1_pool_map = torch.sum(torch.sum(l1_pool_map, 2),2)
                
                # computing mx pool @ scale 2
                # ---------------------------
                pl_l2 =  nn.Sequential(*[nn.MaxPool2d((l2_h,l2_w), stride=l2_s)])
                l2_pool_map = pl_l2(fmap)
                l2_pool_map = torch.sum(torch.sum(l2_pool_map, 2),2)
                
                # computing mx pool @ scale 3
                # ---------------------------
                pl_l3 =  nn.Sequential(*[nn.MaxPool2d((l3_h,l3_w), stride=l3_s)])
                l3_pool_map = pl_l3(fmap)
                l3_pool_map = torch.sum(torch.sum(l3_pool_map, 2),2)
                
                
                # computing avg pool @ scale 1
                # ---------------------------
                avg_pl_l1 =  nn.Sequential(*[nn.AvgPool2d((l1_h,l1_w), stride=l1_s)])
                avg_l1_pool_map = avg_pl_l1(fmap)
                avg_l1_pool_map = torch.sum(torch.sum(avg_l1_pool_map, 2),2)
                
                # computing avg pool @ scale 2
                # ---------------------------
                avg_pl_l2 =  nn.Sequential(*[nn.AvgPool2d((l2_h,l2_w), stride=l2_s)])
                avg_l2_pool_map = avg_pl_l2(fmap)
                avg_l2_pool_map = torch.sum(torch.sum(avg_l2_pool_map, 2),2)
                
                # computing avg pool @ scale 3
                # ---------------------------
                avg_pl_l3 =  nn.Sequential(*[nn.AvgPool2d((l3_h,l3_w), stride=l3_s)])
                avg_l3_pool_map = avg_pl_l3(fmap)
                avg_l3_pool_map = torch.sum(torch.sum(avg_l3_pool_map, 2),2)
                
                
                # NOT concatenating -- summing
                # ----------------------------
                combined_pool_map = l1_pool_map + l2_pool_map + l3_pool_map + avg_l1_pool_map + avg_l2_pool_map + avg_l3_pool_map

            else:
                
                assert 1 == 2,'Error: invalid pool mode'
                
            
            # 1.1.1.3 finally concatenating pool maps across all input feature maps
            # ---------------------------------------------------------------------
            try:
                all_layers_pool_map = torch.cat((all_layers_pool_map,combined_pool_map), 1)
            except:
                all_layers_pool_map = combined_pool_map
            
            
            
            
        # if the input dims are given
        # ---------------------------
        else:
            
            
            # 1.1.2 dims is given as (f,s)
            # ----------------------------
            curr_f = kernel_dims[0]
            curr_s = kernel_dims[1]
            
            # 1.1.2.1 setting up sequentials
            # ------------------------------
            mxpl =  nn.Sequential(*[nn.MaxPool2d((curr_f,curr_f), stride=curr_s)])
            avgpl =  nn.Sequential(*[nn.AvgPool2d((curr_f,curr_f), stride=curr_s)])
            
            
            # 1.1.2.2 forward pass to get pool maps
            # --------------------------------------
            if pool_mode == 'max':
                
                # computing max pool
                # ------------------
                pool_map = mxpl(fmap)

            
            elif pool_mode == 'avg':
                
                # computing avg pool
                # ------------------
                pool_map = avgpl(fmap)
            
            elif pool_mode == 'both':
                
                # computing both
                # --------------
                mx_pool_map = mxpl(fmap)
                avg_pool_map = avgpl(fmap)
                
                # concatenating
                # -------------
                #pool_map = torch.cat((mx_pool_map,avg_pool_map), 1)
                pool_map = mx_pool_map + avg_pool_map
                
            else:
                
                assert 1 == 2,'Error: invalid pool mode'
            
            
            
            # 1.1.2.3 aggregating
            # -------------------
            if aggregate_pool_maps == True:
                
                # summing along h,w axises - pool_map will be of shape (m,no_channels)
                # --------------------------------------------------------------------
                pool_map = torch.sum(torch.sum(pool_map, 2),2)
                
            
            # final resize before concat
            # --------------------------
            pool_map = pool_map.view(pool_map.size()[0],-1)
            
            
            # final concat along channel axis
            # --------------------------------
            try:
                all_layers_pool_map = torch.cat((all_layers_pool_map,pool_map), 1)
            except:
                all_layers_pool_map = pool_map
                

    # final return
    # ------------
    return all_layers_pool_map
            


In [None]:
# function to simply return similar images based on whole latents
# ---------------------------------------------------------------

def similarity(input_latents_list,db_latents_list,xin,xdb,sim_weights,no_suggestions,similarity_check_mode,print_result):

    
    '''
    
    1. input includes
    
    a. input images & db images latent lists
    b. settings such as weights, mode etc
    c. will return indices of db images sorted in descending order of ranking similarity
    
    
    '''
    
    # 1. initialisations
    # ------------------
    interim_sim_array = {}
    final_sim_array = {}
    final_all_indices = []
    similar_products = []
    all_similarity_values = []
    epsilon = 0.0001
    
    # 2. sanity assertions
    # --------------------
    assert len(input_latents_list) == len(db_latents_list), 'Error: length of input & db lists dont match.'
    if sim_weights == 'equal':
        
        # setting sim weight values to 1
        # ------------------------------
        sim_weights = []
        for _ in range(len(input_latents_list)):
            sim_weights.append(1.0)
    else:
        assert len(sim_weights) == len(input_latents_list), 'Error: number of weights & length of latent lists dont match.'
        #assert round(sum(sim_weights),2) == 1.0, 'Error: the weights dont sum to 1.0'
    
    # 3. looping through the INPUT latents list
    # -----------------------------------------
    print('1. looping through latent list..')
    for l_counter in range(len(input_latents_list)):
        
        # 3.0 local initialisations
        # -------------------------
        curr_input_latent = input_latents_list[l_counter]
        db_input_latent = db_latents_list[l_counter]
        
        
        # 3.1 itering through every image in curr_input_latent
        # ----------------------------------------------------
        for i in range(curr_input_latent.shape[0]):
            
            # Finding similarity per example
            # ------------------------------
            curr_input_latent_example = curr_input_latent[i]
            
            
            # checking using similarity
            # -------------------------
            if similarity_check_mode == 'l2':
                
                # using L2
                # --------
                similarity_array = np.sqrt(np.sum((curr_input_latent_example-db_input_latent)**2, axis = 1))
                
                
            elif similarity_check_mode == 'cosine_ratio':
                
                # using cosine_ratio
                # ------------------
                curr_input_latent_example = curr_input_latent_example.reshape(1,curr_input_latent_example.shape[0])
                similarity_array = cosine_similarity_multi(curr_input_latent_example,db_input_latent)
                similarity_array = similarity_array.reshape(similarity_array.shape[0],)
                similarity_array += np.average((np.minimum(curr_input_latent_example,db_input_latent)/(np.maximum(curr_input_latent_example,db_input_latent) + epsilon)), axis = 1)
            
            
            elif similarity_check_mode == 'cosine':
                
                # using cosine_ratio
                # ------------------
                curr_input_latent_example = curr_input_latent_example.reshape(1,curr_input_latent_example.shape[0])
                similarity_array = cosine_similarity_multi(curr_input_latent_example,db_input_latent)
                similarity_array = similarity_array.reshape(similarity_array.shape[0],)
                
            
            
            elif similarity_check_mode == 'ratio':
                
                # ratio only
                # ----------
                similarity_array = np.average((np.minimum(curr_input_latent_example,db_input_latent)/(np.maximum(curr_input_latent_example,db_input_latent) + epsilon)), axis = 1)
                
            else:
                
                # invalid sim check mode
                # ----------------------
                assert 1 == 2, 'Error: invalid similarity check mode. This can be either "cosine_ratio" or "ratio" only.'
            
            
            # appending to interim sim array for final weights application
            # ------------------------------------------------------------
            
            # this dict will store all similarity values per example
            # that is - 
            # interim_sim_array[0] = [sim_array_based_on_latent_list_0, sim_array_based_on_latent_list_1,...]
            
            try:
                interim_sim_array[i].append(similarity_array)
            except:
                interim_sim_array[i] = []
                interim_sim_array[i].append(similarity_array)
    
    
    # 4. itering thru dict & applying weights
    # ---------------------------------------
    print('2. applying weights & appending final results..')
    for keys in interim_sim_array:
        
        # deleting old value
        # ------------------
        try:
            del similarity_array_final
        except:
            pass
        
        # itering thru list
        # -----------------
        for each_sim_array_index in range(len(interim_sim_array[keys])):
            curr_weighted_array = interim_sim_array[keys][each_sim_array_index] * sim_weights[each_sim_array_index]
            try:
                similarity_array_final += curr_weighted_array
            except:
                similarity_array_final = curr_weighted_array
                
                
        # similarity_array_final must be of shape (m,)
        ###
        
        # continuation - no change here on
        # --------------------------------
        if similarity_check_mode == 'l2':
            
            # since in l2 distances are calculated and smaller dis = higher sim
            # ------------------------------------------------------------------
            sorted_indices = list(np.argsort(similarity_array_final))
            
        else:
            sorted_indices = list(np.argsort(-1*similarity_array_final))
        final_indices = sorted_indices[0:no_suggestions]
        final_indices = [int(fi) for fi in final_indices]
        final_all_indices.append(final_indices)
        all_similarity_values.append(similarity_array[final_indices])
        similar_products.append(xdb[final_indices])
        
    

    
    # 3. showing if required
    # ----------------------
    if print_result == True:
        
        # itering
        # -------
        for i in range(xin.shape[0]):
            
            print('>> At image ' + str(i) + ' of around ' + str(xin.shape[0]) + '..')
            print('** Input Image - ')
            plt.figure(figsize=(2,2))
            plt.imshow(xin[i])
            plt.show()
           
            print('** Showing result images..')
            fig=plt.figure(figsize=(25, 25))
            columns = 5
            rows = 10
            
            
            for i_1 in range(similar_products[i].shape[0]):
                
                #print('** At image ' + str(i) + ' showing option number ' + str(i_1) + '**')
                #print('Image index number: ' + str(final_all_indices[i][i_1]))
                #print('Similarity value: ' + str(all_similarity_values[i][i_1]))
                
                if xdb.shape[3] > 1:
                    img = similar_products[i][i_1]
                    fig.add_subplot(rows, columns, i_1+1)
                    plt.imshow(img)
                    #plt.show()
                else:
                    img = similar_products[i][i_1,:,:,0]
                    fig.add_subplot(rows, columns, i_1+1)
                    plt.imshow(img, cmap='gray')
            
            plt.show()
            
            print('\n#########################################\n')
                
            
    
    # 3. final return
    # ---------------
    return similar_products,all_similarity_values,final_all_indices

In [None]:
# a wrapper around similarity function since each and every image has different number of filtered results
# --------------------------------------------------------------------------------------------------------

def similarity_wrapper(input_latents_list,db_latents_list,xin,xdb,similarity_check_mode,concept_inds_list):
    
    
    # 0. initialistions
    # -----------------
    _,h,w,c = xin.shape
    
    
    # 1. itering through every single image
    # -------------------------------------
    for i in range(xin.shape[0]):
        
        # filter ops
        # itering and appenidng filtered fmaps to curr list
        # ----------
        curr_db_latents = [each[concept_inds_list[i]] for each in db_latents_list]
        
        # input latents ops
        # -----------------
        curr_input_latents = [each[i].reshape(1,each.shape[1]) for each in input_latents_list]
        
    
        # running sim functions
        # ---------------------
        _,_,_ = similarity(curr_input_latents,curr_db_latents,xin[i].reshape(1,h,w,c),xdb[concept_inds_list[i]],'equal',50,similarity_check_mode,True)
        
    

### models

In [None]:
# FCN class copied from image search notebook which worked
# --------------------------------------------------------

class fcn_mask_proposal_UNET(nn.Module):
    def __init__(self, in_channels, latent_softmax):
        super().__init__()
        
        # This is WNET model
        # ------------------
        
        # Showing conv up sizes - 
        # --------------------------
        # (95,95) -- Insize
        
        # @conv1 - (47, 47)
        # @conv2 - (23, 23)
        # @conv3 - (11, 11)
        # @conv4 - (5, 5)
        # @conv5 - (2, 2)
    
        # Followed by a an avg pool (2,2) to make this 1,1
        
        
        
        # Initialising N/W here
        # ---------------------
        nw_activation_conv = nn.ReLU() #nn.LeakyReLU(0.2) # nn.Tanh() nn.Softmax2d()
        f = 3
        s = 2
        dropout_prob = 0.2
        dropout_node = nn.Dropout2d(p=dropout_prob)
        
        # CONV Down layers
        # ----------------
        
        # Conv 1
        ###
        conv1 = 16
        ct1 = nn.Conv2d(in_channels,conv1,f,stride = s)
        cb1 = nn.BatchNorm2d(conv1)
        ca1 = nw_activation_conv
        cl1 = [ct1,cb1,ca1,dropout_node]
        self.convl1 = nn.Sequential(*cl1) # 47x47
        
        # Conv 2
        ###
        conv2 = 32
        ct2 = nn.Conv2d(conv1,conv2,f,stride = s)
        cb2 = nn.BatchNorm2d(conv2)
        ca2 = nw_activation_conv
        cl2 = [ct2,cb2,ca2,dropout_node]
        self.convl2 = nn.Sequential(*cl2) # 23x23
        
        # Conv 3
        ###
        conv3 = 64
        ct3 = nn.Conv2d(conv2,conv3,f,stride = s)
        cb3 = nn.BatchNorm2d(conv3)
        ca3 = nw_activation_conv
        cl3 = [ct3,cb3,ca3,dropout_node]
        self.convl3 = nn.Sequential(*cl3) # 11x11
        
        # Conv 4
        ###
        conv4 = 128
        ct4 = nn.Conv2d(conv3,conv4,f,stride = s)
        cb4 = nn.BatchNorm2d(conv4)
        ca4 = nw_activation_conv
        cl4 = [ct4,cb4,ca4,dropout_node]
        self.convl4 = nn.Sequential(*cl4) # 5x5
        
        # Conv 5
        ###
        conv5 = 256
        ct5 = nn.Conv2d(conv4,conv5,f,stride = s)
        cb5 = nn.BatchNorm2d(conv5)
        ca5 = nw_activation_conv
        cl5 = [ct5,cb5,ca5,dropout_node]
        self.convl5 = nn.Sequential(*cl5) # 2x2
        

        # Pooling layer + softmax activation
        # ----------------------------------
        if latent_softmax == True:
            avpl =  [nn.AvgPool2d((2,2), stride=1), nn.Softmax2d()]
        else:
            avpl =  [nn.AvgPool2d((2,2), stride=1)]
        self.pool_net = nn.Sequential(*avpl)
        
      
        # Transconv layers
        # ----------------
        # Showing conv up sizes - 
        # --------------------------
        # Incoming input is 1 x 1 x C
        # (5, 5)
        # (11, 11)
        # (23, 23)
        # (47, 47)
        # (95, 95)
       
        
        # Upconv layer 0
        ###
        #up_conv0 = conv5
        t0 = nn.ConvTranspose2d(conv5,conv5,2,stride = 1)
        b0 = nn.BatchNorm2d(conv5)
        a0 = nw_activation_conv
        l0 = [t0,b0,a0,dropout_node]
        self.upcl0 = nn.Sequential(*l0) # 2x2
        
        # Upconv layer 1
        # concat layer
        ###
        #up_conv1 = conv4
        t1 = nn.ConvTranspose2d(conv5,conv4,f,stride = s)
        b1 = nn.BatchNorm2d(conv4)
        a1 = nw_activation_conv
        l1 = [t1,b1,a1,dropout_node]
        self.upcl1 = nn.Sequential(*l1) # 5x5
        
        # Upconv layer 2
        # concat layer
        ###
        #up_conv2 = conv3
        t2 = nn.ConvTranspose2d(conv4,conv3,f,stride = s)
        b2 = nn.BatchNorm2d(conv3)
        a2 = nw_activation_conv
        l2 = [t2,b2,a2,dropout_node]
        self.upcl2 = nn.Sequential(*l2) # 11x11
        
        # Upconv layer 3
        # concat layer
        ###
        #up_conv3 = conv2
        t3 = nn.ConvTranspose2d(conv3,conv2,f,stride = s)
        b3 = nn.BatchNorm2d(conv2)
        a3 = nw_activation_conv
        l3 = [t3,b3,a3,dropout_node]
        self.upcl3 = nn.Sequential(*l3) # 23x23
        
        # Upconv layer 4
        # concat layer
        ###
        #up_conv4 = conv1
        t4 = nn.ConvTranspose2d(conv2,conv1,f,stride = s)
        b4 = nn.BatchNorm2d(conv1)
        a4 = nw_activation_conv
        l4 = [t4,b4,a4,dropout_node]
        self.upcl4 = nn.Sequential(*l4) # 47x47
        
        # Upconv layer 5 -- FINAL LAYER
        # concat layer
        ###
        #up_conv5 = 1
        t5 = nn.ConvTranspose2d(conv1,1,f,stride = s)
        a5 = nn.Sigmoid()
        l5 = [t5,a5]
        self.upcl5 = nn.Sequential(*l5)
    
   
        

    def forward(self, x):
        
        # Conv pass
        # ---------
        c1_out = self.convl1(x)
        c2_out = self.convl2(c1_out)
        c3_out = self.convl3(c2_out)
        c4_out = self.convl4(c3_out)
        c5_out = self.convl5(c4_out)
        
        # pooling
        # -------
        latent_out = self.pool_net(c5_out)
        
        # Transconv pass
        # --------------
        f1_out = self.upcl0(latent_out)
        f2_out = self.upcl1(f1_out + c5_out)
        f3_out = self.upcl2(f2_out + c4_out)
        f4_out = self.upcl3(f3_out + c3_out)
        f5_out = self.upcl4(f4_out + c2_out)
        f6_out = self.upcl5(f5_out + c1_out)
        
        return f6_out



In [None]:
# steps to create a UNET AE to get embeddings
# -------------------------------------------

# FCN class copied from image search notebook which worked
# --------------------------------------------------------

class conv_classifier(nn.Module):
    def __init__(self, in_channels, out_size):
        super().__init__()

        # Initialising N/W here
        # ---------------------
        nw_activation_conv = nn.ReLU() #nn.LeakyReLU(0.2) # nn.Tanh() nn.Softmax2d()
        f = 3
        s = 2
        dropout_prob = 0.10
        dropout_node = nn.Dropout2d(p=dropout_prob)
        
        # CONV Down layers
        # ----------------
        
        # Conv 1
        ###
        conv1 = 16
        ct1 = nn.Conv2d(in_channels,conv1,f,stride = s)
        cb1 = nn.BatchNorm2d(conv1)
        ca1 = nw_activation_conv
        cl1 = [ct1,cb1,ca1,dropout_node]
        self.convl1 = nn.Sequential(*cl1)
        
        # Conv 2
        ###
        conv2 = 32
        ct2 = nn.Conv2d(conv1,conv2,f,stride = s)
        cb2 = nn.BatchNorm2d(conv2)
        ca2 = nw_activation_conv
        cl2 = [ct2,cb2,ca2,dropout_node]
        self.convl2 = nn.Sequential(*cl2)
        
        # Conv 3
        ###
        conv3 = 64
        ct3 = nn.Conv2d(conv2,conv3,f,stride = s)
        cb3 = nn.BatchNorm2d(conv3)
        ca3 = nw_activation_conv
        cl3 = [ct3,cb3,ca3,dropout_node]
        self.convl3 = nn.Sequential(*cl3) 
        
        # Conv 4
        ###
        conv4 = 128
        ct4 = nn.Conv2d(conv3,conv4,f,stride = s)
        cb4 = nn.BatchNorm2d(conv4)
        ca4 = nw_activation_conv
        cl4 = [ct4,cb4,ca4,dropout_node]
        self.convl4 = nn.Sequential(*cl4) # 5x5
        
        # Conv 5
        ###
        conv5 = 256
        ct5 = nn.Conv2d(conv4,conv5,f,stride = s)
        cb5 = nn.BatchNorm2d(conv5)
        ca5 = nw_activation_conv
        cl5 = [ct5,cb5,ca5,dropout_node]
        self.convl5 = nn.Sequential(*cl5) # 2x2 here
        
        # pooling layer
        # -------------
        mx_avg_pl =  [nn.AvgPool2d((2,2), stride=1)]
        self.pool_net = nn.Sequential(*mx_avg_pl) # 1x1 here
        

        # final out
        # ---------
        lnt1 = nn.Linear(conv5*1*1,out_size)
        ln1 = [lnt1, nn.Sigmoid()]
        self.linear1 = nn.Sequential(*ln1)
    
    
    
    def forward(self, x):
        
        # 1. Conv pass down
        # -----------------
        c1_out = self.convl1(x)
        c2_out = self.convl2(c1_out)
        c3_out = self.convl3(c2_out)
        c4_out = self.convl4(c3_out)
        c5_out = self.convl5(c4_out)
        
        # 2. GAP layer
        # ------------
        gap_pool_out = self.pool_net(c5_out)
        
        # 3. pred out
        # -----------
        pred_out = self.linear1(gap_pool_out.view(gap_pool_out.size()[0],-1))
        
        return pred_out
    
    
    def score_cam_fmaps(self, x, layer):
        
        
        # sanity
        # ------
        assert layer >=2 and layer <= 5,'Error: please choose layer between 2 and 5 only.'
        
        # 1. Conv pass down
        # -----------------
        c1_out = self.convl1(x)
        c2_out = self.convl2(c1_out)
        c3_out = self.convl3(c2_out)
        c4_out = self.convl4(c3_out)
        c5_out = self.convl5(c4_out)
        
        # simple if else
        # --------------
        if layer == 2:
            return c2_out
        elif layer == 3:
            return c3_out
        elif layer == 4:
            return c4_out
        elif layer == 5:
            return c5_out
        else:
            assert 1==2,'Error: something wrong. layer not in range 2-5.'
        
        
        
  

In [None]:
# FCN class copied from image search notebook which worked
# --------------------------------------------------------

class fcn_ae_4_layer_std(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        
        # Initialising N/W here
        # ---------------------
        nw_activation_conv = nn.ReLU() #nn.LeakyReLU(0.2) # nn.Tanh() nn.Softmax2d()
        f = 3
        s = 2
        dropout_prob = 0.1
        dropout_node = nn.Dropout2d(p=dropout_prob)
        
        #  what worked - 3,64,128,256,128,64,3
        
        # Conv 1
        ###
        conv1 = 64
        ct1 = nn.Conv2d(in_channels,conv1,f,stride = s)
        cb1 = nn.BatchNorm2d(conv1)
        ca1 = nw_activation_conv
        cl1 = [ct1,cb1,ca1,dropout_node]
        self.convl1 = nn.Sequential(*cl1)
        
        # Conv 2
        ###
        conv2 = 128
        ct2 = nn.Conv2d(conv1,conv2,f,stride = s)
        cb2 = nn.BatchNorm2d(conv2)
        ca2 = nw_activation_conv
        cl2 = [ct2,cb2,ca2,dropout_node]
        self.convl2 = nn.Sequential(*cl2)
        
        # Conv 3
        ###
        conv3 = 256
        ct3 = nn.Conv2d(conv2,conv3,f,stride = s)
        cb3 = nn.BatchNorm2d(conv3)
        ca3 = nw_activation_conv #nw_activation_conv
        cl3 = [ct3,cb3,ca3,dropout_node]
        self.convl3 = nn.Sequential(*cl3)
        
        # Conv 4
        ###
        conv4 = 512
        ct4 = nn.Conv2d(conv3,conv4,f,stride = s)
        cb4 = nn.BatchNorm2d(conv4)
        ca4 = nn.Softmax2d()#nw_activation_conv #nn.Softmax2d() #nw_activation_conv
        cl4 = [ct4,ca4,dropout_node]
        self.convl4 = nn.Sequential(*cl4) # size 6 x 4
        
        
        # Pooling layer
        #mxpl =  [nn.MaxPool2d((2,2), stride=2)]
        #avpl =  [nn.AvgPool2d((6,4), stride=1)]
        #self.pool_net = nn.Sequential(*mxpl)
        
        # Adding a fully connected linear layer
        #
        
        # Upconv layer 0
        ###
        #t0 = nn.ConvTranspose2d(conv4,conv4,2,stride = 2)
        #b0 = nn.BatchNorm2d(conv4)
        #a0 = nw_activation_conv
        #l0 = [t0,b0,a0]
        #self.upcl0 = nn.Sequential(*l0)
        
        # Upconv layer 1
        ###
        up_conv1 = 256
        t1 = nn.ConvTranspose2d(conv4,up_conv1,f,stride = s)
        b1 = nn.BatchNorm2d(up_conv1)
        a1 = nw_activation_conv
        l1 = [t1,b1,a1,dropout_node]
        self.upcl1 = nn.Sequential(*l1)
        
        # Upconv layer 2
        ###
        up_conv2 = 128
        t2 = nn.ConvTranspose2d(up_conv1,up_conv2,f,stride = s)
        b2 = nn.BatchNorm2d(up_conv2)
        a2 = nw_activation_conv
        l2 = [t2,b2,a2,dropout_node]
        self.upcl2 = nn.Sequential(*l2)
        
        # Upconv layer 3
        ###
        up_conv3 = 64
        t3 = nn.ConvTranspose2d(up_conv2,up_conv3,f,stride = s)
        b3 = nn.BatchNorm2d(up_conv3)
        a3 = nw_activation_conv
        l3 = [t3,b3,a3,dropout_node]
        self.upcl3 = nn.Sequential(*l3)
        
        # Upconv layer 4
        ###
        t4 = nn.ConvTranspose2d(up_conv3,in_channels,f,stride = s)
        a4 = nn.Sigmoid()
        l4 = [t4,a4]
        self.upcl4 = nn.Sequential(*l4)
        

    def forward(self, x):
        
        # Generation
        # ----------
        c1_out = self.convl1(x)
        c2_out = self.convl2(c1_out)
        c3_out = self.convl3(c2_out)
        c4_out = self.convl4(c3_out)
        #c5_out = self.pool_net(c3_out)
        
        #f1_out = self.upcl0(c5_out)
        f2_out = self.upcl1(c4_out)
        f3_out = self.upcl2(f2_out)
        f4_out = self.upcl3(f3_out)
        f5_out = self.upcl4(f4_out)
        
        return f5_out

    
    def latent(self, x):
        
        
        # Generation
        # ----------
        c1_out = self.convl1(x)
        c2_out = self.convl2(c1_out)
        c3_out = self.convl3(c2_out)
        c4_out = self.convl4(c3_out)
            
        
        return c4_out
        



In [None]:
# END OF CODES
##

# execution

### sanity

In [None]:
# URL set ups for execution
# -------------------------

## SET UP CUDA OR NOT HERE + OTHER SET UPS
##########################################

dev_env = 'local' # 'gpu' or 'local'

##########################################
##########################################

# Setting CUDA
# ------------
if dev_env == 'gpu':
    use_cuda = True
else:
    use_cuda = False
if use_cuda == True:
    torch.cuda.empty_cache()
    

# SET FILE SPECIFIC NAMES HERE
# ----------------------------
if dev_env == 'gpu':
    save_path = '/home/venkateshmadhava/codes/pmate2_localgpuenv/models/'
    parent_url = '/home/venkateshmadhava/datasets/images'
else:
    save_path = '/Users/venkateshmadhava/Documents/pmate2/pmate2_env/models/'
    parent_url = '/Users/venkateshmadhava/Documents/pmate2/local_working_files'


# displaying save path
# --------------------
print(save_path)

# 1. loading models

In [None]:
## 1. mask proposal nw
######################

cn_file_name = 'fcn_unet_mpn_GRAY_nonsoftmax_256_5classpoc_bird_apple_fish_skull_star.tar'
cn_save_path = save_path + cn_file_name
model_mpn,_,epoch,loss,_ = load_saved_model_function(cn_save_path, use_cuda)
print('epoch: ' + str(epoch))
print('loss: ' + str(loss))
model_mpn = model_mpn.eval()
model_mpn

In [None]:
## 2. masked image classifier
##############################

cn_file_name = 'conv_masked_image_classifier_GRAY_5classpoc_bird_apple_fish_skull_star.tar'
cn_save_path = save_path + cn_file_name
model_mimgcls,_,epoch,loss,_ = load_saved_model_function(cn_save_path, use_cuda)
print('epoch: ' + str(epoch))
print('loss: ' + str(loss))
model_mimgcls = model_mimgcls.eval()
model_mimgcls

In [None]:
## 3. AE
########

cn_file_name = 'ae_model_tboard_4_layer_512_softmax_RGB.tar'
cn_save_path = save_path + cn_file_name
model_ae_rgb,_,_,_,_ = load_saved_model_function(cn_save_path, use_cuda)
model_ae_rgb = model_ae_rgb.eval()
model_ae_rgb

# 2. loading base dataset

In [None]:
# main set ups
# ------------
img_cls_h, img_cls_w = 95,95
img_ae_h, img_ae_h = 175,175
input_image_mode = 'gray'

# setting up class labels first up
# this needs to match with exact 'classTitle' in final_dict
# ---------------------------------------------------------

global class_labels
class_labels = ['Bird','Apples','Fish','Skulls','Star']

In [None]:
# reading data
# ------------
base_folder = '/Users/venkateshmadhava/Desktop/db'


# main op -- reading data for ae
################################
create_dataset_from_folder_all(base_folder,True,None,img_ae_h, img_ae_h)
global x_images_dataset, x_images_dataset_gray, x_images_dataset_edge

# setting up image
# ---------------
if input_image_mode == 'rgb':
    x_db_ae = x_images_dataset
elif input_image_mode == 'gray':
    x_db_ae = x_images_dataset_gray
elif input_image_mode == 'edge':
    x_db_ae = x_images_dataset_edge
else:
    assert 1==2,'Error: invalid mode'
    
# sanity
# ------
x_db_rgb_ae = x_images_dataset
print(x_db_ae.shape)


# main op -- resizing data for classification
############################################
resize_pool(x_db_ae,img_cls_h,img_cls_w)
global x_out
x_db_cls = x_out
print(x_db_cls.shape)

In [None]:
# sanity viewing to ensure both datasets correspond well
# ------------------------------------------------------
randrange = random.sample(list(range(x_db_ae.shape[0])), 2)

# showing
# -------
for i in randrange:
    
    print('95,95 image -- ')
    plt.imshow(x_db_cls[i,:,:,0], cmap='gray')
    plt.show()
    print('175,175 image -- ')
    plt.imshow(x_db_ae[i,:,:,0], cmap='gray')
    plt.show()
    print('********')

### 2.1 db - extracting fmaps & predictions

In [None]:
# setting up tensors
# ------------------
x_db_ae_trn = Variable(setup_image_tensor(x_db_ae)).float()
x_db_cls_trn = Variable(setup_image_tensor(x_db_cls)).float()

print(x_db_ae_trn.size())
print(x_db_cls_trn.size())

In [None]:
# extracting classification predictions
# -------------------------------------

# 1. getting mask
# ---------------
simple_forward_pass_pool(x_db_cls_trn,True,model_mpn)
global y_out_global
db_mask = y_out_global
del y_out_global

# 2. getting prediction
# ---------------------
simple_forward_pass_pool(x_db_cls_trn*db_mask.detach(),True,model_mimgcls)
global y_out_global
db_pred = y_out_global
del y_out_global

# sanity
# ------
print(db_pred.size())

In [None]:
# extracting ae latents
# ---------------------
simple_forward_pass_pool(torch.cat((x_db_ae_trn,x_db_ae_trn,x_db_ae_trn), 1),True,model_ae_rgb.latent)
global y_out_global
db_ae_latent = y_out_global
del y_out_global

# sanity
# ------
print(db_ae_latent.size())

In [None]:
# score cam viz on db set
# --------------------------
randrange = random.sample(list(range(x_db_cls_trn.size()[0])), 3)

# setting up inputs
# -----------------
x_scam_in = x_db_cls_trn[randrange]
x_scam_print = to_numpy_image(x_scam_in)

# final score cam
# ---------------
_ = score_cam(model_mpn,model_mimgcls,x_scam_in,x_scam_print,3,'minmax')

# 3. loading input folder dataset

In [None]:
# reading data
# ------------
input_folder = '/Users/venkateshmadhava/Desktop/input'


# main op -- reading data for ae
################################
create_dataset_from_folder_all(input_folder,True,None,img_ae_h, img_ae_h)
global x_images_dataset, x_images_dataset_gray, x_images_dataset_edge

# setting up image
# ---------------
if input_image_mode == 'rgb':
    x_input_ae = x_images_dataset
elif input_image_mode == 'gray':
    x_input_ae = x_images_dataset_gray
elif input_image_mode == 'edge':
    x_input_ae = x_images_dataset_edge
else:
    assert 1==2,'Error: invalid mode'
    
# sanity
# ------
x_input_rgb_ae = x_images_dataset
print(x_input_ae.shape)


# main op -- resizing data for classification
############################################
resize_pool(x_input_ae,img_cls_h,img_cls_w)
global x_out
x_input_cls = x_out
print(x_input_cls.shape)

In [None]:
# sanity viewing to ensure both datasets correspond well
# ------------------------------------------------------
randrange = random.sample(list(range(x_input_cls.shape[0])), 2)

# showing
# -------
for i in randrange:
    
    print('95,95 image -- ')
    plt.imshow(x_input_cls[i,:,:,0], cmap='gray')
    plt.show()
    print('175,175 image -- ')
    plt.imshow(x_input_ae[i,:,:,0], cmap='gray')
    plt.show()
    print('********')

### 3.1 input - extracting fmaps & predictions


In [None]:
# setting up tensors
# ------------------
x_input_ae_trn = Variable(setup_image_tensor(x_input_ae)).float()
x_input_cls_trn = Variable(setup_image_tensor(x_input_cls)).float()

print(x_input_ae_trn.size())
print(x_input_cls_trn.size())

In [None]:
# extracting classification predictions
# -------------------------------------

# 1. getting mask
# ---------------
simple_forward_pass_pool(x_input_cls_trn,True,model_mpn)
global y_out_global
input_mask = y_out_global
del y_out_global

# 2. getting prediction
# ---------------------
simple_forward_pass_pool(x_input_cls_trn*input_mask.detach(),True,model_mimgcls)
global y_out_global
input_pred = y_out_global
del y_out_global

# sanity
# ------
print(input_pred.size())

In [None]:
# extracting ae latents
# ---------------------
simple_forward_pass_pool(torch.cat((x_input_ae_trn,x_input_ae_trn,x_input_ae_trn), 1),True,model_ae_rgb.latent)
global y_out_global
input_ae_latent = y_out_global
del y_out_global

# sanity
# ------
print(input_ae_latent.size())

In [None]:
# score cam viz on input set
# --------------------------
randrange = random.sample(list(range(x_input_cls_trn.size()[0])), 3)

# setting up inputs
# -----------------
x_scam_in = x_input_cls_trn[randrange]
x_scam_print = to_numpy_image(x_scam_in)

# final score cam
# ---------------
_ = score_cam(model_mpn,model_mimgcls,x_scam_in,x_scam_print,3,'minmax')

In [None]:
# END OF LOADING AND EXTRACTING PREDICTIONS & FMAPS
##

# 4. search related codes begin

### 4.1 getting concepts filtered

In [None]:
# getting inds of db set that have ONLY one or more concepts from input image
# ---------------------------------------------------------------------------

concept_inds_list =  return_filtered_concepts_inds(input_pred,db_pred)

### 4.2 building rmac latents

In [None]:
################# LATENT SETTINGS #################
###################################################

# what works RGB -- RGB local - (1/3*h,1/3*w) i.e if h,w of fmap is 10,10 then local pool dim is 3,3

# latents settings
# ----------------
kernel_stride_dims = [(2,2),[5,2]]
pool_mode = 'both'
aggregate_pool_maps = False

In [None]:
# computing db latents from fmaps
# -------------------------------

db_latents_list = final_latents([db_ae_latent],kernel_stride_dims,pool_mode,aggregate_pool_maps)

In [None]:
# computing input latents from fmaps
# ----------------------------------

input_latents_list = final_latents([input_ae_latent],kernel_stride_dims,pool_mode,aggregate_pool_maps)

# 5. similarity

In [None]:
# running simple wrapper
# ----------------------

similarity_wrapper(input_latents_list,db_latents_list,x_input_rgb_ae,x_db_rgb_ae,'ratio',concept_inds_list)

# rough