In [None]:
timm_path = "../input/timm-pytorch-image-models/pytorch-image-models-master"
import sys
sys.path.append(timm_path)
import timm
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sklearn
import os
from tqdm.notebook import tqdm

import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
from torch import optim

import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
test_df = pd.read_csv("/kaggle/input/shopee-product-matching/test.csv")
test_dataset_dir = "../input/shopee-product-matching/test_images/"

In [None]:
dataset_dir = "../input/shopee-product-matching/train_images/"


class Shopee:
    def __init__(self):

        image_size = 128
        num_embeddings = 512
        weights_path = "../input/shopee-embedding-df/NFNet_f0_10.pth"
        embedding_path = '../input/shopee-embedding-df/train_df_embeddings.csv'
        
        self.df = pd.read_csv(embedding_path)
        self.embeddings = []
        for emb_str in self.df['embedding']:
            emb_str = emb_str[1:-1]
            emb = [float(_) for _ in emb_str.split(', ')]
            self.embeddings.append(emb)
        self.embeddings = np.array(self.embeddings)
        
        self.model = timm.create_model('dm_nfnet_f0', pretrained=False)
        num_features = self.model.head.fc.in_features
        self.model.head.fc = nn.Linear(num_features, num_embeddings)
        update_param_names = ['head.fc.weight', 'head.fc.bias']
        load_weghts = torch.load(weights_path)
        self.model.load_state_dict(load_weghts)
        self.model.eval()
        _ = self.model.to(device)
        
        
        self.valid_aug = A.Compose([
                            A.LongestMaxSize(max_size=image_size, p=1.0),
                            A.PadIfNeeded(min_height=image_size, min_width=image_size, border_mode=0, p=1.0),
                            A.Normalize(p=1.0),
                            ToTensorV2(p=1.0)
                            ])
        
    def display_image(self, image):
        plt.xticks([])
        plt.yticks([])
        plt.imshow(image)
        plt.show()
    
    
    def find_similar(self, image_path, top_n=100,threshold=0.95):
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.uint8)
        print('query image')
        self.display_image(image)
        
        transformed = self.valid_aug(image=image)
        aug_image = transformed['image']
        
        query_embedding = self.model(aug_image.unsqueeze(0).to(device)).cpu().detach().numpy()
        numerator = np.sum(self.embeddings*query_embedding, axis=1)
        denominator = np.sqrt(np.sum(self.embeddings**2, axis=1))*np.sqrt(np.sum(query_embedding**2))
        cos_sims = numerator/denominator
        
        plot_list = []
        for top_i, index in enumerate(np.argsort(-cos_sims)[0:top_n]):
            if cos_sims[index] < threshold:
                    break
            plot_list.append(dataset_dir + self.df.iloc[index].image)

        images_number = len(plot_list)
        size = np.sqrt(images_number)
        if int(size)*int(size) < images_number:
            size = int(size) + 1
        plt.figure(figsize=(20, 20))

        ind=0
        for image_id in plot_list:
            plt.subplot(size, size, ind + 1)
            image = cv2.imread(image_id)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            plt.imshow(image)
            plt.title(image_id, fontsize=6)
            plt.axis("off")
            ind+=1
        plt.show()
                
finder = Shopee()

In [None]:
%time finder.find_similar(os.path.join(dataset_dir , '027478fc15b3caf7d9be5465ad7bdf5c.jpg'), top_n=100,threshold=0.5)

In [None]:
%time finder.find_similar(os.path.join(test_dataset_dir , test_df.loc[0,"image"]), top_n=60,threshold=0.5)

In [None]:
%time finder.find_similar(os.path.join(test_dataset_dir , test_df.loc[1,"image"]), top_n=60,threshold=0.5)

In [None]:
%time finder.find_similar(os.path.join(test_dataset_dir , test_df.loc[2,"image"]), top_n=60,threshold=0.5)