In [None]:
%%capture
!pip install ../input/pytorchimagemodels

In [None]:
import sys
sys.path.append('../input/facebookdeit')

In [None]:
import os
import gc
import sys

import math
import random

from tqdm.notebook import tqdm

import re

import cv2
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset ,RandomSampler

import timm

import albumentations

from swav_resnet import resnet50w2

import transformers
from transformers import AutoModel, AutoTokenizer

from models import deit_base_distilled_patch16_224



import cudf
import cuml
import cupy

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(0)

In [None]:
text_models=[
    {
        'transformer_model': '../input/shopee-paraphrase-xlm-r-multilingual/paraphrase-xlm-r-multilingual-v1_pp_7_14',
        'model_path' : '../input/shopee-paraphrase-xlm-r-multilingual/paraphrase-xlm-r-multilingual-v1_pp_7_14/ckpt.pt',
        'MAX_LEN' : 128,
        'params' : {
            'embed_dim' :1024,
            'out_dim' : 11014
        }
        
    },
    {
            'transformer_model': '../input/shopee-bert-base-indonesian/bert-base-indonesian_pp_7_8',
            'model_path' : '../input/shopee-bert-base-indonesian/bert-base-indonesian_pp_7_8/ckpt.pt',
            'MAX_LEN' : 128,
            'params' : {
                'embed_dim' :1024,
                'out_dim' : 11014
            }

        }
]

TRANSFORMER_EMBED_DIM = 768

In [None]:
joint_models=[
    {
        'transformer_model': '../input/shopee-paraphrase-xlm-r-multilingual/paraphrase-xlm-r-multilingual-v1_pp_7_14',
        'model_path' : '../input/deit-xlm-joint-weights/finetuned_deit_xlm_joint_7_11.pt',
        'MAX_LEN' : 128,
        'IMAGE_SIZE' : 224,
        'params' : {
                'embed_dim' : 2048,
                'out_dim' : 11014
            }
        
    }
]

In [None]:
image_models=[
    {
        'vision_model': '../input/shopee-swav-finetuned-resnet50w2/swav_resnet50w2_224_2048_5_epochs_loss_12_beef.pt',
        'IMAGE_SIZE' : 224,
        'params' : {
            'embed_dim' :2048,
            'out_dim' : 11014
        }
    },
    {
        'vision_model': '../input/effnet-weights/finetuned_tf_efficientnet_b3_ns_300_2048_7_9.pt',
        'IMAGE_SIZE' : 300,
        'params' : {
            'backbone_name' : 'tf_efficientnet_b3_ns',
            'embed_dim' :2048,
            'out_dim' : 11014
        }
    },
    
    {
        'vision_model': '../input/deit-base-distilled-patch16-224/finetuned_deit_base_distilled_patch16_224_4_18.pt',
        'IMAGE_SIZE' : 224,
        'params' : {
            'embed_dim' :2048,
            'out_dim' : 11014
        }
    },
]

In [None]:
NUM_WORKERS = 4
BATCH_SIZE = 128

In [None]:
CHECK_SUB = False
GET_CV = True

In [None]:
test = pd.read_csv('../input/shopee-product-matching/test.csv')
if len(test)>3: GET_CV = False
else: print('this submission notebook will compute CV score, but commit notebook will not')


In [None]:
def read_dataset():
    if GET_CV:
        
        df = pd.read_csv('../input/shopee-product-matching/train.csv')
        tmp = df.groupby(['label_group'])['posting_id'].unique().to_dict()
        df['matches'] = df['label_group'].map(tmp)
        df['matches'] = df['matches'].apply(lambda x: ' '.join(x))
        
        df['filepath'] = df['image'].apply(lambda x: os.path.join('../input/shopee-product-matching/train_images', x))
        
        if CHECK_SUB: 
            df = pd.concat([df, df], axis = 0)
            df.reset_index(drop = True, inplace = True)
                    
    else:
        df = pd.read_csv('../input/shopee-product-matching/test.csv')
        
        df['filepath'] = df['image'].apply(lambda x: os.path.join('../input/shopee-product-matching/test_images', x))
                
    return df 

In [None]:
def f1_score(y_true, y_pred):
    y_true = y_true.apply(lambda x: set(x.split()))
    y_pred = y_pred.apply(lambda x: set(x.split()))
    intersection = np.array([len(x[0] & x[1]) for x in zip(y_true, y_pred)])
    len_y_pred = y_pred.apply(lambda x: len(x)).values
    len_y_true = y_true.apply(lambda x: len(x)).values
    f1 = 2 * intersection / (len_y_pred + len_y_true)
    return f1

## ArcFace utils

In [None]:
class DenseCrossEntropy(nn.Module):
    def forward(self, x, target):
        x = x.float()
        target = target.float()
        logprobs = torch.nn.functional.log_softmax(x, dim=-1)

        loss = -logprobs * target
        loss = loss.sum(-1)
        return loss.mean()


class ArcMarginProduct_subcenter(nn.Module):
    def __init__(self, in_features, out_features, k=3):
        super().__init__()
        self.weight = nn.Parameter(torch.FloatTensor(out_features*k, in_features))
        self.reset_parameters()
        self.k = k
        self.out_features = out_features
        
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        
    def forward(self, features):
        cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
        cosine_all = cosine_all.view(-1, self.out_features, self.k)
        cosine, _ = torch.max(cosine_all, dim=2)
        return cosine   


class ArcFaceLossAdaptiveMargin(nn.modules.Module):
    def __init__(self, margins, s=30.0):
        super().__init__()
        self.crit = DenseCrossEntropy()
        self.s = s
        self.margins = margins
            
    def forward(self, logits, labels, out_dim):
        ms = []
        ms = self.margins[labels.cpu().numpy()]
        cos_m = torch.from_numpy(np.cos(ms)).float().cuda()
        sin_m = torch.from_numpy(np.sin(ms)).float().cuda()
        th = torch.from_numpy(np.cos(math.pi - ms)).float().cuda()
        mm = torch.from_numpy(np.sin(math.pi - ms) * ms).float().cuda()
        labels = F.one_hot(labels, out_dim).float()
        logits = logits.float()
        cosine = logits
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * cos_m.view(-1,1) - sine * sin_m.view(-1,1)
        phi = torch.where(cosine > th.view(-1,1), phi, cosine - mm.view(-1,1))
        output = (labels * phi) + ((1.0 - labels) * cosine)
        output *= self.s
        loss = self.crit(output, labels)
        return loss     



## Common Projection Util

In [None]:
class Projection(nn.Module):
    def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None:
        super().__init__()
        self.linear1 = nn.Linear(d_in, d_out, bias=False)
        self.linear2 = nn.Linear(d_out, d_out, bias=False)
        self.layer_norm = nn.LayerNorm(d_out)
        self.drop = nn.Dropout(p)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        embed1 = self.linear1(x)
        embed2 = self.drop(self.linear2(F.gelu(embed1)))
        embeds = self.layer_norm(embed1 + embed2)
        return embeds

# **Vision Model**

In [None]:
class VisionEncoderResnet(nn.Module):
    def __init__(self, embed_dim , out_dim):
        super().__init__()
        base = resnet50w2() #torch.hub.load('facebookresearch/swav', 'resnet50w2')

        d_in = 4096
        self.base = base
        self.projection = Projection(d_in, embed_dim)
        self.metric_classify = ArcMarginProduct_subcenter(embed_dim, out_dim)

        for p in self.base.parameters():
            p.requires_grad = True

    def forward(self, x):
        projected_vec = self.projection(self.base(x))
        cosine_feat= self.metric_classify(projected_vec)

        return projected_vec ,cosine_feat

In [None]:
class VisionEncoderEffnet(nn.Module):
    def __init__(self,backbone_name, embed_dim , out_dim):
        super().__init__()
        base = timm.create_model(backbone_name, pretrained=False)
        d_in = base.classifier.in_features
        base.classifier = nn.Identity()
        
        self.base = base
        self.projection = Projection(d_in, embed_dim)
        self.metric_classify = ArcMarginProduct_subcenter(embed_dim, out_dim)

        for p in self.base.parameters():
            p.requires_grad = True

    def forward(self, x):
        projected_vec = self.projection(self.base(x))
        cosine_feat= self.metric_classify(projected_vec)

        return projected_vec ,cosine_feat

In [None]:
class DeitWrapperVision(nn.Module):

    def __init__(self):
        super().__init__()
        self.backbone = deit_base_distilled_patch16_224(pretrained=False)
        self.backbone.head= nn.Identity()
        self.backbone.head_dist = nn.Identity()

    def forward_features(self, x):
        # taken from https://github.com/facebookresearch/deit/blob/main/models.py

        B = x.shape[0]
        x = self.backbone.patch_embed(x)

        cls_tokens = self.backbone.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        dist_token = self.backbone.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)

        x = x + self.backbone.pos_embed
        x = self.backbone.pos_drop(x)

        for blk in self.backbone.blocks:
            x = blk(x)

        x = self.backbone.norm(x)
        return x
    
    def forward(self, x):
        hidden_states = self.forward_features(x)
        _cls, _cls_dist = hidden_states[:, 0], hidden_states[:, 1]
        return (_cls + _cls_dist) / 2   
    
    
class VisionEncoderDeit(nn.Module):
    def __init__(self, embed_dim , out_dim):
        super().__init__()
        base = DeitWrapperVision()

        d_in = 768
        self.base = base
        self.projection = Projection(d_in, embed_dim)
        self.metric_classify = ArcMarginProduct_subcenter(embed_dim, out_dim)

        for p in self.base.parameters():
            p.requires_grad = True

    def forward(self, x):
        projected_vec = self.projection(self.base(x))
        cosine_feat= self.metric_classify(projected_vec)

        return projected_vec ,cosine_feat

In [None]:
class ShopeeDatasetVision(Dataset):
    def __init__(self, csv, mode, transform):

        self.csv = csv.reset_index()
        self.mode= mode
        self.transform = transform


    def __len__(self):
        return self.csv.shape[0]

    def __getitem__(self, index):
        row = self.csv.iloc[index]

        image = cv2.imread(row.filepath)[:,:,::-1]

        if self.transform is not None:
            res = self.transform(image=image)
            image = res['image'].astype(np.float32)
        else:
            image = image.astype(np.float32)

        image = image.transpose(2, 0, 1)

        if self.mode == 'test':
            return torch.tensor(image)
        else:
            return torch.tensor(image), torch.tensor(row.label)

In [None]:
def get_transforms(image_size=512):

    transforms_train = albumentations.Compose([
        albumentations.HorizontalFlip(p=0.5),
        albumentations.JpegCompression(quality_lower=99, quality_upper=100),
        albumentations.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=10, border_mode=0, p=0.7),
        albumentations.Resize(image_size, image_size),
        albumentations.Cutout(max_h_size=int(image_size * 0.4), max_w_size=int(image_size * 0.4), num_holes=1, p=0.5),
        albumentations.Normalize()
    ])

    transforms_val = albumentations.Compose([
        albumentations.Resize(image_size, image_size),
        albumentations.Normalize()
    ])

    return transforms_train, transforms_val

In [None]:
def generate_image_embeddings(df, config, VisionEncoder):
    print(config)
    
    vision_encoder_path = config['vision_model']
    vision_encoder_params= config['params']
    img_size = config['IMAGE_SIZE']

    ckpt= torch.load(vision_encoder_path, map_location='cpu')

    model= VisionEncoder(**vision_encoder_params)

    model.load_state_dict(ckpt)
    model = model.cuda()


    ##########################################################################################################

    _, transforms_val = get_transforms(img_size)

    dataset_test = ShopeeDatasetVision(df, 'test', transform=transforms_val)
    test_loader = torch.utils.data.DataLoader(dataset_test,  batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

    ##########################################################################################################

    model.eval()

    image_embeddings = []

    with torch.no_grad():
        for batch in tqdm(test_loader):
            images = batch

            images = images.cuda()
            image_embed , _= model(images)

            image_embeddings.append(image_embed.detach().cpu())

    image_embeddings = torch.cat(image_embeddings, dim=0)
    #image_embeddings= F.normalize(image_embeddings)

    print(f'image_embeddings shape : {image_embeddings.shape}')
    
    del ckpt
    del model
    gc.collect()
    torch.cuda.empty_cache()
    
    return image_embeddings

# **Text Model**

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, embed_dim , out_dim, transformer_model):
        super().__init__()
        
        self.base = AutoModel.from_pretrained(transformer_model)
        
        self.projection = Projection(TRANSFORMER_EMBED_DIM, embed_dim)
        self.metric_classify = ArcMarginProduct_subcenter(embed_dim, out_dim)
        
        for p in self.base.parameters():
            p.requires_grad = True

    @staticmethod
    def mean_pooling(model_output, attention_mask):
        token_embeddings = model_output[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

    def forward(self, input_ids,attention_mask):
        out = self.base(input_ids=input_ids, attention_mask=attention_mask)
        out = TextEncoder.mean_pooling(out, attention_mask)
        
        projected_vec = self.projection(out)
        cosine_feat= self.metric_classify(projected_vec)
        
        return projected_vec ,cosine_feat


In [None]:
class Tokenizer:
    def __init__(self,max_length, transformer_model):
        self.tokenizer = AutoTokenizer.from_pretrained(transformer_model)
        self.max_length = max_length

    def __call__(self, x) :
        return self.tokenizer(
            x, max_length=self.max_length, truncation=True, padding='max_length', return_tensors="pt"
        )

In [None]:
class ShopeeDatasetText(Dataset):
    def __init__(self, csv, mode, tokenizer):

        self.csv = csv.reset_index()
        self.mode= mode
        self.tokenizer = tokenizer


    def __len__(self):
        return self.csv.shape[0]
    
    @staticmethod
    def string_escape(s, encoding='utf-8'):
        return (
            s.encode('latin1')  # To bytes, required by 'unicode-escape'
            .decode('unicode-escape')  # Perform the actual octal-escaping decode
            .encode('latin1')  # 1:1 mapping back to bytes
            .decode(encoding)
        )  # Decode original encoding

    @staticmethod
    def preprocess_text(x):
        x = ShopeeDatasetText.string_escape(x)
        x = re.sub(r'[^\w\s]',' ', x)
        return x

    def __getitem__(self, index):
        row = self.csv.iloc[index]
        text = row.title
        
        text= ShopeeDatasetText.preprocess_text(text)

        encoded_text = self.tokenizer(text)
        
        input_ids = encoded_text['input_ids'][0]
        attention_mask = encoded_text['attention_mask'][0]

        if self.mode == 'test':
            return input_ids, attention_mask
        else:
            return input_ids, attention_mask, torch.tensor(row.label)

In [None]:
def generate_text_embeddings(df, config):
    print(config)
    
    transformer_model= config['transformer_model']
    max_len= config['MAX_LEN']


    ckpt= torch.load(config['model_path'], map_location='cpu')

    model= TextEncoder(**config['params'], transformer_model=transformer_model)

    model.load_state_dict(ckpt)
    model = model.cuda()


    ##########################################################################################################

    dataset_test = ShopeeDatasetText(df, 'test', Tokenizer(max_len, transformer_model))
    test_loader = torch.utils.data.DataLoader(dataset_test,  batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

    ##########################################################################################################

    model.eval()

    text_embeddings = []

    with torch.no_grad():
        for batch in tqdm(test_loader):
            
            input_ids, attention_mask = batch

            input_ids, attention_mask = input_ids.cuda(), attention_mask.cuda()
            
            text_embed , _= model(input_ids, attention_mask)

            text_embeddings.append(text_embed.detach().cpu())

    text_embeddings = torch.cat(text_embeddings, dim=0)
    #text_embeddings= F.normalize(text_embeddings)

    print(f'text_embeddings shape : {text_embeddings.shape}')
    
    del ckpt
    del model
    gc.collect()
    torch.cuda.empty_cache()
    
    return text_embeddings

# **Joint Transformer**

In [None]:
class ShopeeDatasetJoint(Dataset):
    def __init__(self, csv, mode, tokenizer, transform):

        self.csv = csv.reset_index()
        self.mode= mode
        self.tokenizer = tokenizer
        self.transform = transform


    def __len__(self):
        return self.csv.shape[0]
    
    @staticmethod
    def string_escape(s, encoding='utf-8'):
        return (
            s.encode('latin1')  # To bytes, required by 'unicode-escape'
            .decode('unicode-escape')  # Perform the actual octal-escaping decode
            .encode('latin1')  # 1:1 mapping back to bytes
            .decode(encoding)
        )  # Decode original encoding

    @staticmethod
    def preprocess_text(x):
        x = ShopeeDatasetJoint.string_escape(x)
        x = re.sub(r'[^\w\s]',' ', x)
        return x

    def __getitem__(self, index):
        row = self.csv.iloc[index]
        text = row.title
        
        text= ShopeeDatasetJoint.preprocess_text(text)


        ##################################################################
        
        encoded_text = self.tokenizer(text)
        input_ids = encoded_text['input_ids'][0]
        attention_mask = encoded_text['attention_mask'][0]
        ##################################################################
        
        image = cv2.imread(row.filepath)[:,:,::-1]

        if self.transform is not None:
            res = self.transform(image=image)
            image = res['image'].astype(np.float32)
        else:
            image = image.astype(np.float32)

        image = image.transpose(2, 0, 1)
        ##################################################################
        

        if self.mode == 'test':
            return image, input_ids, attention_mask
        else:
            return image, input_ids, attention_mask, torch.tensor(row.label)

In [None]:
class DeitWrapperJoint(nn.Module):

    def __init__(self):
        super().__init__()
        self.backbone = deit_base_distilled_patch16_224(pretrained=False)
        self.backbone.head= nn.Identity()
        self.backbone.head_dist = nn.Identity()

    def forward_features(self, x):
        # taken from https://github.com/facebookresearch/deit/blob/main/models.py

        B = x.shape[0]
        x = self.backbone.patch_embed(x)

        cls_tokens = self.backbone.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        dist_token = self.backbone.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)

        x = x + self.backbone.pos_embed
        x = self.backbone.pos_drop(x)

        for blk in self.backbone.blocks:
            x = blk(x)

        x = self.backbone.norm(x)
        return x

    def forward(self, x):
        hidden_states = self.forward_features(x)
        _cls, _cls_dist = hidden_states[:, 0], hidden_states[:, 1]
        merge_cls= (_cls + _cls_dist) / 2
        merge_cls = merge_cls.unsqueeze(1)
        
        new_hidden_states = torch.cat([merge_cls, hidden_states[:, 1:-1,:]], dim=1)
        return new_hidden_states

In [None]:
def get_text_encoder(transformer_model):
    text_encoder = AutoModel.from_pretrained(transformer_model)
    
    for p in text_encoder.parameters():
        p.requires_grad = False
    
    return text_encoder

In [None]:
class JointTransformer(nn.Module):
    
    def __init__(self, embed_dim, out_dim, transformer_model):
        super().__init__()
        
        
        self.vision_encoder= DeitWrapperJoint()
        self.text_encoder= get_text_encoder(transformer_model)

        
        d_in=768
        encoder_layer = nn.TransformerEncoderLayer(d_model =768, nhead=8, dim_feedforward=2048, dropout=0.1, activation='gelu')
        encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
        
        self.transformer_encoder = encoder
        
        self.projection = Projection(d_in, embed_dim)
        self.metric_classify = ArcMarginProduct_subcenter(embed_dim, out_dim)
    
    @staticmethod 
    def merge_hidden_states(vision_output, text_output):
        cls_img= vision_output[:,0,: ]
        cls_text= text_output[:,0,:]
        cls_joint= (cls_img+cls_text)/2
        cls_joint= cls_joint.unsqueeze(1)

        merge_embeddings = torch.cat([cls_joint , vision_output[:,1: ,:] ,  text_output[:,1: ,:]], dim=1)
        return merge_embeddings
    
    @staticmethod 
    def get_pad_mask(image_embed_len , text_attention_mask):

        _batch_size= text_attention_mask.shape[0]

        imge_attention_mask = torch.ones((_batch_size, image_embed_len), dtype=torch.float32).cuda()
        merge_attention_mask= torch.cat([imge_attention_mask, text_attention_mask], dim=1)
        merge_attention_mask= (1-merge_attention_mask).bool()
        return merge_attention_mask
    
    def forward_features(self, image, input_ids, attention_mask):
        
        vision_output = self.vision_encoder(image)
        text_output = self.text_encoder(input_ids, attention_mask)[0]
        
        text_attention_mask = attention_mask

        image_embed_len = vision_output.shape[1]-1

        joint_embeddings=  JointTransformer.merge_hidden_states(vision_output, text_output)
        joint_embeddings= joint_embeddings.permute(1,0,2)

        joint_pad_mask = JointTransformer.get_pad_mask(image_embed_len , text_attention_mask)
        
        out= self.transformer_encoder(joint_embeddings, src_key_padding_mask = joint_pad_mask)
        out= out.permute(1,0,2)
        
        return out
    
    def forward(self, image, input_ids, attention_mask):
        x= self.forward_features(image, input_ids, attention_mask)
        x= x[:,0,:] #cls token
        
        projected_vec = self.projection(x)
        cosine_feat= self.metric_classify(projected_vec)
        return projected_vec ,cosine_feat

In [None]:
def generate_joint_embeddings(df, config):
    
    print(config)
    
    transformer_model= config['transformer_model']
    max_len= config['MAX_LEN']
    model_path= config['model_path']
    img_size = config['IMAGE_SIZE']

    ckpt= torch.load(model_path, map_location='cpu')

    model= JointTransformer(**config['params'], transformer_model=transformer_model)

    model.load_state_dict(ckpt)
    model = model.cuda()

    _, transforms_val = get_transforms(img_size)

    ##########################################################################################################

    dataset_test = ShopeeDatasetJoint(df, 'test', tokenizer=Tokenizer(max_len, transformer_model), transform=transforms_val)
    test_loader = torch.utils.data.DataLoader(dataset_test,  batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

    ##########################################################################################################

    model.eval()

    joint_embeddings = []

    with torch.no_grad():
        for batch in tqdm(test_loader):
            
            images, input_ids, attention_mask = batch

            images = images.cuda()
            input_ids, attention_mask = input_ids.cuda(), attention_mask.cuda()
            
            _embed , _= model(images, input_ids, attention_mask)

            joint_embeddings.append(_embed.detach().cpu())

    joint_embeddings = torch.cat(joint_embeddings, dim=0)
    joint_embeddings= F.normalize(joint_embeddings)

    print(f'joint_embeddings shape : {joint_embeddings.shape}')
    
    del ckpt
    del model
    gc.collect()
    torch.cuda.empty_cache()
    
    return joint_embeddings

# **Smilarity Search**

In [None]:
def get_neighbours_cos_sim(df,embeddings, tuned_threshold, mode='image'):
    '''
    When using cos_sim use normalized features else use normal features
    
    mode: 'image', 'text, 'joint'
    '''
    
    print(f'Setting mode : {mode}')
    
    embeddings = cupy.array(embeddings)
    
    if GET_CV:
        if mode=='image':
            thresholds = list(np.arange(0.3,0.5,0.05))
            
        elif mode=='text':
            thresholds = list(np.arange(0.3,0.5,0.05))
            
        elif mode=='joint':
            thresholds = list(np.arange(0.3,0.6,0.05))

        scores = []
        for threshold in thresholds:
            
################################################# Code for Getting Preds #########################################
            preds = []
            preds2= []
        
            CHUNK = 1024*4

            print('Finding similar titles...for threshold :',threshold)
            CTS = len(embeddings)//CHUNK
            if len(embeddings)%CHUNK!=0: CTS += 1

            for j in range( CTS ):
                a = j*CHUNK
                b = (j+1)*CHUNK
                b = min(b,len(embeddings))

                cts = cupy.matmul(embeddings,embeddings[a:b].T).T

                for k in range(b-a):
                    IDX = cupy.where(cts[k,]>threshold)[0]
                    o = df.iloc[cupy.asnumpy(IDX)].posting_id.values
                    preds.append(o)
                    
                    o2 = ' '.join(o)
                    preds2.append(o2)
######################################################################################################################

            df['pred_matches'] = preds2
            df['f1'] = f1_score(df['matches'], df['pred_matches'])
            score = df['f1'].mean()
            print(f'Our f1 score for threshold {threshold} is {score}')
            scores.append(score)
            
        thresholds_scores = pd.DataFrame({'thresholds': thresholds, 'scores': scores})
        max_score = thresholds_scores[thresholds_scores['scores'] == thresholds_scores['scores'].max()]
        best_threshold = max_score['thresholds'].values[0]
        best_score = max_score['scores'].values[0]
        print(f'Our best score is {best_score} and has a threshold {best_threshold}')
            
    else:
        preds = []
        CHUNK = 1024*4
        threshold = tuned_threshold

        print('Finding similar texts...for threshold :',threshold)
        CTS = len(embeddings)//CHUNK
        if len(embeddings)%CHUNK!=0: CTS += 1

        for j in range( CTS ):
            a = j*CHUNK
            b = (j+1)*CHUNK
            b = min(b,len(embeddings))
            print('chunk',a,'to',b)

            cts = cupy.matmul(embeddings,embeddings[a:b].T).T

            for k in range(b-a):
                IDX = cupy.where(cts[k,]>threshold)[0]
                o = df.iloc[cupy.asnumpy(IDX)].posting_id.values
                preds.append(o)
                    
    return df, preds

In [None]:
def combine_predictions(row):
    x = np.concatenate([row['image_predictions'], row['text_predictions'], row['mean_predictions'] , row['joint_predictions']])
    return ' '.join( np.unique(x) )

In [None]:
class SimilarityFinder:
    def __init__(self,text_embeddings, image_embeddings, joint_embeddings):
        self.text_embeddings = cupy.array(text_embeddings)
        self.image_embeddings = cupy.array(image_embeddings)
        self.joint_embeddings = cupy.array(joint_embeddings)
        
        
    def get_cosine_similarity(self, idx_a, idx_b):
        query1= self.text_embeddings[idx_a : idx_b]
        query2= self.image_embeddings[idx_a : idx_b]
        query3= self.joint_embeddings[idx_a : idx_b]
        
        
        dot_similarity1 = cupy.matmul(query1, self.text_embeddings.T)
        dot_similarity2 = cupy.matmul(query2, self.image_embeddings.T)
        dot_similarity3 = cupy.matmul(query3, self.joint_embeddings.T)
        
        
        dot_similarity= 0.25*dot_similarity1+ 0.25*dot_similarity2 + 0.5*dot_similarity3
        
        return dot_similarity
    
    def generate_preds(self,df, threshold):
        preds= []

        CHUNK = 1024*4
        CTS = len(self.text_embeddings)//CHUNK
        if len(self.text_embeddings)%CHUNK!=0:
            CTS += 1

        for j in range( CTS ):
            a = j*CHUNK
            b = (j+1)*CHUNK
            b = min(b,len(self.text_embeddings))

            sims = self.get_cosine_similarity(a, b)

            for idx in range(len(sims)):
                indices = cupy.where(sims[idx,]>threshold)[0]         
                _preds= df.iloc[cupy.asnumpy(indices)].posting_id.values
                preds.append(_preds)

        return preds
    
    def threshold_tuner(self,df):
        thresholds = list(np.arange(0.3,0.5,0.05))
        scores= []
        
        for threshold in thresholds:
            print('Finding joint sim threshold :',threshold)
            
            preds= self.generate_preds(df,threshold)
            preds2= list(map(lambda x : ' '.join(x) , preds))
            
            df['pred_matches'] = preds2
            df['f1'] = f1_score(df['matches'], df['pred_matches'])
            score = df['f1'].mean()
            scores.append(score)
            
            print(f'Our f1 score for threshold {threshold} is {score}')

        thresholds_scores = pd.DataFrame({'thresholds': thresholds, 'scores': scores})
        max_score = thresholds_scores[thresholds_scores['scores'] == thresholds_scores['scores'].max()]
        best_threshold = max_score['thresholds'].values[0]
        best_score = max_score['scores'].values[0]
        print(f'Our best score is {best_score} and has a threshold {best_threshold}')
        return df,preds
    
    def get_preds(self, df, tuned_threshold):
        if GET_CV:
            df,preds= self.threshold_tuner(df)
            #df['matches'] = df['pred_matches']
        else:
            preds= self.generate_preds(df,tuned_threshold)
            #preds= list(map(lambda x : ' '.join(x) , preds))
            #df['matches'] = preds
            
        return df,preds

# **Driver**

In [None]:
df = read_dataset()

# **Image Predictions**

In [None]:
image_embeddings_swav = generate_image_embeddings(df, image_models[0], VisionEncoderResnet)
image_embeddings_effnet = generate_image_embeddings(df, image_models[1], VisionEncoderEffnet)
image_embeddings_deit = generate_image_embeddings(df, image_models[2], VisionEncoderDeit)

In [None]:
meta_image_embedding = 0.33*image_embeddings_deit+ 0.33*image_embeddings_swav+ 0.34*image_embeddings_effnet

# image_embeddings_deit = F.normalize(image_embeddings_deit)
# image_embeddings_swav = F.normalize(image_embeddings_swav)
# image_embeddings_effnet = F.normalize(image_embeddings_effnet)
meta_image_embedding  = F.normalize(meta_image_embedding)

In [None]:
df,image_predictions = get_neighbours_cos_sim(df,meta_image_embedding, tuned_threshold=0.75, mode='image')

# **Text Predictions**

In [None]:
text_embeddings_xlm = generate_text_embeddings(df, text_models[0])
text_embeddings_bert = generate_text_embeddings(df, text_models[1])

meta_text_embeddings = 0.5*text_embeddings_xlm + 0.5*text_embeddings_bert
meta_text_embeddings= F.normalize(meta_text_embeddings)


In [None]:
df,text_predictions = get_neighbours_cos_sim(df,meta_text_embeddings, tuned_threshold=0.7, mode='text')

# **Joint Predictions**

In [None]:
joint_embeddings = generate_joint_embeddings(df, joint_models[0])


In [None]:
df,joint_predictions = get_neighbours_cos_sim(df,joint_embeddings, tuned_threshold=0.7, mode='joint')

## **Merge Predictions**

In [None]:
sim_finder= SimilarityFinder(meta_text_embeddings, meta_image_embedding, joint_embeddings)

In [None]:
df,mean_preds= sim_finder.get_preds(df,0.5)

In [None]:
df.head()

In [None]:
if GET_CV:
    df['image_predictions'] = image_predictions
    df['text_predictions'] = text_predictions
    df['joint_predictions'] = joint_predictions
    df['mean_predictions'] = mean_preds
    
    df['pred_matches'] = df.apply(combine_predictions, axis = 1)
    
    df['f1'] = f1_score(df['matches'], df['pred_matches'])
    score = df['f1'].mean()
    print(f'Our final f1 cv score is {score}')
    df['matches'] = df['pred_matches']
    df[['posting_id', 'matches']].to_csv('submission.csv', index = False)
else:
    df['image_predictions'] = image_predictions
    df['text_predictions'] = text_predictions
    df['joint_predictions'] = joint_predictions
    df['mean_predictions'] = mean_preds
    
    df['matches'] = df.apply(combine_predictions, axis = 1)
    
    df[['posting_id', 'matches']].to_csv('submission.csv', index = False)

## **DEBUG**