# Inference Notebook

In [None]:
import sys
if 'kaggle_web_client' in sys.modules:
    sys.path.append('../input/imports/pytorch-image-models-master/pytorch-image-models-master')
    sys.path.append('../input/imports/transformers-master/transformers-master')

In [None]:
import os
import re
import cv2
import math
import random
import numpy as np
import pandas as pd
import gc
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import timm
import albumentations
from albumentations.pytorch.transforms import ToTensorV2

from transformers import AutoTokenizer, AutoModel

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.neighbors import NearestNeighbors

## Config

In [None]:
class CFG:
    compute_cv = True  # set False to fast save
    todo_predictions = ['predictions']
    
    ### CNN and BERT
    use_amp = True
    scale = 30  # ArcFace
    margin = 0.5  # ArcFace
    seed = 2021
    classes = 11014
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(device)
    
    ### CNN 1
    cnn_model_name = 'swin_base_patch4_window12_384'
    img_size = 384
    cnn_batch_size = 32
    cnn_model_path = '../input/shopee-arcface-models/swin_base_patch4_window12_384-size384-epoch11-bs12x2-cv0.8099-sub.pt'
    num_tta = 3
    cnn_fc_dim = 768
    cnn_use_fc = True
    
    ### BERT 1
    if 'kaggle_web_client' in sys.modules:
        bert_model_name = '../input/bertmodel/paraphrase-xlm-r-multilingual-v1'  # for kaggle notebook
    else:
        bert_model_name = 'sentence-transformers/paraphrase-xlm-r-multilingual-v1'
    bert_model_path = '../input/shopee-arcface-models/paraphrase-xlm-r-multilingual-v1_len128_epoch7-bs16x1-cv0.7997-sub.pt'
    max_length = 128
    bert_batch_size = 32
    bert_fc_dim = 768
    bert_use_fc = True
    
    ### BERT 2
    if 'kaggle_web_client' in sys.modules:
        bert_model_name2 = '../input/bertmodel/distilbert-base-indonesian'  # for kaggle notebook
    else:
        bert_model_name2 = 'cahya/distilbert-base-indonesian'
    bert_model_path2 = '../input/shopee-arcface-models/distilbert-base-indonesian_len128_epoch8-bs16x1-cv0.7911-sub.pt'
    max_length2 = 128
    bert_batch_size2 = 32
    bert_fc_dim2 = 768
    bert_use_fc2 = True
    
    ### Prediction
    cnn_threshold = 0.84
    bert_threshold = 0.84
    chunk = 32
    max_preds = 42
    nearlest_one = True # True is better
        
    ### Data
    
    train_csv_path = '../input/shopee-product-matching/train.csv'
    test_csv_path = '../input/shopee-product-matching/test.csv'
    
    if compute_cv == True:
        images_dir = '../input/shopee-product-matching/train_images/'
    else:
        images_dir = '../input/shopee-product-matching/test_images/'

    if 'kaggle_web_client' in sys.modules:
        num_workers = 4
    else:
        num_workers = 0  # for Windows 10

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True # set True to be faster

seed_everything(CFG.seed)

## Utils

In [None]:
def read_dataset():
    
    df = pd.read_csv(CFG.test_csv_path)
    
    if len(df) > 3:
        CFG.compute_cv = False
        CFG.images_dir = '../input/shopee-product-matching/test_images/'
    
    if CFG.compute_cv == True:
        df = pd.read_csv(CFG.train_csv_path)
        print('Using train as test to compute CV. Shape is', df.shape)
    else:
        print('Test shape is', df.shape )
    
    image_paths = CFG.images_dir + df['image']

    return df, image_paths

In [None]:
def f1_score(y_true, y_pred):
    y_true = y_true.apply(lambda x: set(x.split()))
    y_pred = y_pred.apply(lambda x: set(x.split()))
    intersection = np.array([len(x[0] & x[1]) for x in zip(y_true, y_pred)])
    len_y_pred = y_pred.apply(lambda x: len(x)).values
    len_y_true = y_true.apply(lambda x: len(x)).values
    f1 = 2 * intersection / (len_y_pred + len_y_true)
    return f1

## ArcFace

In [None]:
class ArcMarginProduct(nn.Module):
    def __init__(self, in_features, out_features, scale=30.0, margin=0.50, easy_margin=False, ls_eps=0.0):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.scale = scale
        self.margin = margin
        self.ls_eps = ls_eps  # label smoothing
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(margin)
        self.sin_m = math.sin(margin)
        self.th = math.cos(math.pi - margin)
        self.mm = math.sin(math.pi - margin) * margin
        
        self.criterion = nn.CrossEntropyLoss()
        
    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------------
        one_hot = torch.zeros(cosine.size(), device=CFG.device)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.scale
        return output, self.criterion(output,label)

## Bert Model

In [None]:
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask


class ShopeeBertModel(nn.Module):

    def __init__(
        self,
        n_classes = CFG.classes,
        model_name = None,
        fc_dim = 768,
        margin = CFG.margin,
        scale = CFG.scale,
        use_fc = True        
    ):

        super(ShopeeBertModel,self).__init__()
        print('Building Model Backbone for {} model'.format(model_name))

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.backbone = AutoModel.from_pretrained(model_name).to(CFG.device)

        in_features = 768
        self.use_fc = use_fc
        
        self.dropout = nn.Dropout(p=0.1)
        self.classifier = nn.Linear(in_features, fc_dim)
        self.bn = nn.BatchNorm1d(fc_dim)
        self._init_params()
        in_features = fc_dim
            
        self.final = ArcMarginProduct(
            in_features,
            n_classes,
            scale = scale,
            margin = margin,
            easy_margin = False,
            ls_eps = 0.0
        )

    def _init_params(self):
        nn.init.xavier_normal_(self.classifier.weight)
        nn.init.constant_(self.classifier.bias, 0)
        nn.init.constant_(self.bn.weight, 1)
        nn.init.constant_(self.bn.bias, 0)

    def forward(self, texts, labels=torch.tensor([0])):
        features = self.extract_features(texts)
        if self.training:
            logits = self.final(features, labels.to(CFG.device))
            return logits
        else:
            return features
        
    def extract_features(self, texts):
        encoding = self.tokenizer(texts, padding=True, truncation=True,
                             max_length=CFG.max_length, return_tensors='pt').to(CFG.device)
        input_ids = encoding['input_ids']
        attention_mask = encoding['attention_mask']
        embedding = self.backbone(input_ids, attention_mask=attention_mask)
        x = mean_pooling(embedding, attention_mask)
        
        if self.use_fc:
            x = self.dropout(x)
            x = self.classifier(x)
            x = self.bn(x)
        
        return x

In [None]:
def get_bert_embeddings(column, model_name, model_path, fc_dim=768, use_fc=True, chunk=32):
    
    print('Getting BERT ArcFace embeddings...')
    
    model = ShopeeBertModel(model_name=model_name, fc_dim=fc_dim, use_fc=use_fc)
    model.to(CFG.device)
    model.load_state_dict(torch.load(model_path, map_location=CFG.device))
    model.eval()
    
    bert_embeddings = torch.zeros((df.shape[0], 768)).to(CFG.device)
    for i in tqdm(list(range(0, df.shape[0], chunk)) + [df.shape[0]-chunk], ncols=100):
        titles = []
        for title in df[column][i : i + chunk].values:
            try:
                title = ' ' + title.encode('utf-8').decode("unicode_escape").encode('ascii', 'ignore').decode("unicode_escape") + ' '
            except:
                pass
            title = title.lower()
            
            titles.append(title)
            
        with torch.no_grad():
            if CFG.use_amp:
                with torch.cuda.amp.autocast():
                    model_output = model(titles)
            else:
                model_output = model(titles)
            
        bert_embeddings[i : i + chunk] = model_output
    
    del model, titles, model_output
    gc.collect()
    torch.cuda.empty_cache()
    
    return bert_embeddings

## Image Model

In [None]:
class ShopeeCnnModel(nn.Module):

    def __init__(
        self,
        model_name,
        fc_dim,
        n_classes = CFG.classes,
        margin = CFG.margin,
        scale = CFG.scale,
        use_fc = True,
        pretrained = True):

        super(ShopeeCnnModel,self).__init__()
        print('Building Model Backbone for {} model'.format(model_name))

        self.backbone = timm.create_model(model_name, pretrained=pretrained)
        
        if model_name in ['tf_efficientnet_b0_ns', 'tf_efficientnet_b1_ns', 'tf_efficientnet_b2_ns', 'tf_efficientnet_b3_ns',
                         'tf_efficientnet_b4_ns', 'tf_efficientnet_b5_ns', 'tf_efficientnet_b6_ns', 'tf_efficientnet_b7_ns',
                         'efficientnet_v2s', 'efficientnet_v2m', 'efficientnet_v2l']:
            in_features = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Identity()
            self.backbone.global_pool = nn.Identity()
            
        elif model_name in ['dm_nfnet_f0', 'dm_nfnet_f1', 'dm_nfnet_f2', 'dm_nfnet_f3', 'dm_nfnet_f4',
                           'eca_nfnet_l0', 'eca_nfnet_l1']:
            in_features = self.backbone.head.fc.in_features
            self.backbone.head.fc = nn.Identity()
            self.backbone.head.global_pool = nn.Identity()
            
        elif model_name in ['swin_small_patch4_window7_224', 'swin_base_patch4_window7_224', 'swin_base_patch4_window12_384']:
            in_features = self.backbone.head.in_features
            self.backbone.head = nn.Identity()
        
        self.model_name = model_name
        self.pooling =  nn.AdaptiveAvgPool2d(1)
        self.use_fc = use_fc

        self.dropout = nn.Dropout(p=0.0)
        self.classifier = nn.Linear(in_features, fc_dim)
        self.bn = nn.BatchNorm1d(fc_dim)
        self._init_params()
        in_features = fc_dim

        self.final = ArcMarginProduct(
            in_features,
            n_classes,
            scale = scale,
            margin = margin,
            easy_margin = False,
            ls_eps = 0.0
        )

    def _init_params(self):
        nn.init.xavier_normal_(self.classifier.weight)
        nn.init.constant_(self.classifier.bias, 0)
        nn.init.constant_(self.bn.weight, 1)
        nn.init.constant_(self.bn.bias, 0)

    def forward(self, image, label):
        features = self.extract_features(image)
        if self.training:
            logits = self.final(features, label)
            return logits
        else:
            return features

    def extract_features(self, x):
        batch_size = x.shape[0]
        x = self.backbone(x)
        if CFG.cnn_model_name not in ['swin_small_patch4_window7_224', 'swin_base_patch4_window7_224', 'swin_base_patch4_window12_384']:
            x = self.pooling(x).view(batch_size, -1)

        if self.use_fc:
            x = self.dropout(x)
            x = self.classifier(x)
            x = self.bn(x)
            
        return x

In [None]:
def get_valid_transforms(img_size=512):

    return albumentations.Compose([
        albumentations.Resize(img_size, img_size, p=1.),
        albumentations.Normalize(
            mean = [0.485, 0.456, 0.406],
            std = [0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0
        ),
        ToTensorV2(p=1.0)
    ])

def get_test_transforms(img_size=512):

    return albumentations.Compose([
        albumentations.RandomResizedCrop(img_size, img_size, scale=(0.6, 1.0), ratio=(1.0, 1.0)),
        albumentations.Normalize(
            mean = [0.485, 0.456, 0.406],
            std = [0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0
        ),
        ToTensorV2(p=1.0)
    ])

In [None]:
class ShopeeTestImageDataset(Dataset):

    def __init__(self, image_paths, transforms=None):
        self.image_paths = image_paths
        self.transform = transforms

    def __len__(self):
        return self.image_paths.shape[0]

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        
        return image

In [None]:
def get_cnn_embeddings(model, dataloader):
    model.eval()

    embeds = []
    for _, image in tqdm(enumerate(dataloader), total=len(dataloader), desc="get_cnn_embeddings", ncols=80): 
        img = image.to(CFG.device)

        with torch.no_grad():
            if CFG.use_amp:
                with torch.cuda.amp.autocast():
                    features = model(img, torch.tensor([1]))
            else:
                features = model(img, torch.tensor([1]))

        embeddings = features.detach().cpu().numpy().astype('float32')
        embeds.append(embeddings)

    del model
    embeddings = np.concatenate(embeds)
    del embeds
    gc.collect()
    return embeddings

## Prediction Function

In [None]:
def get_predictions(df, cnn_embeddings_mean_half, bert_embeddings_half, cnn_threshold=1.0, bert_threshold=1.0, chunk=32, nearlest_one=True, max_preds=50):

    print('Finding similar ones...')
    CTS = len(df) // chunk
    if (len(df) % chunk) != 0:
        CTS += 1
        
    preds = []
    for j in tqdm(range(CTS)):
        a = j * chunk
        b = min((j+1) * chunk, len(df))
        cnn_cts = torch.matmul(cnn_embeddings_mean_half, cnn_embeddings_mean_half[a:b].T).T
        bert_cts = torch.matmul(bert_embeddings_half, bert_embeddings_half[a:b].T).T
        
        for k in range(b-a):
            sim = (cnn_cts[k,] / cnn_threshold) ** 6 + (bert_cts[k,] / bert_threshold) ** 6
            sim_desc = torch.sort(sim, descending=True)
            
            IDX = sim_desc[1][sim_desc[0] > 1][:max_preds].cpu().detach().numpy()
            o = df.iloc[IDX].posting_id.values
            
            if (len(IDX) == 1) and nearlest_one:
                IDX = sim_desc[1][:2].cpu().detach().numpy()
                o = df.iloc[IDX].posting_id.values
            
            preds.append(o)

    del cnn_cts, bert_cts
    gc.collect()
    torch.cuda.empty_cache()
    
    return preds

# Calculating Predictions

In [None]:
df, image_paths = read_dataset()

### BERT 1 Embeddings

In [None]:
bert_embeddings = get_bert_embeddings(column='title', model_name=CFG.bert_model_name, model_path=CFG.bert_model_path,
                                      fc_dim=CFG.bert_fc_dim, use_fc=CFG.bert_use_fc, chunk=CFG.bert_batch_size)
print('bert_embeddings.shape:', bert_embeddings.shape)

### BERT 2 Embeddings

In [None]:
bert_embeddings2 = get_bert_embeddings(column='title', model_name=CFG.bert_model_name2, model_path=CFG.bert_model_path2,
                                       fc_dim=CFG.bert_fc_dim2, use_fc=CFG.bert_use_fc2, chunk=CFG.bert_batch_size2)
print('bert_embeddings2.shape:', bert_embeddings2.shape)

### Image 1 Embeddings

In [None]:
model = ShopeeCnnModel(model_name = CFG.cnn_model_name, fc_dim=CFG.cnn_fc_dim, use_fc=CFG.cnn_use_fc, pretrained=False)
model.to(CFG.device)
model.load_state_dict(torch.load(CFG.cnn_model_path, map_location=CFG.device))

cnn_embeddings_all = []

for tta in range(CFG.num_tta):
    if tta == 0:
        test_dataset = ShopeeTestImageDataset(image_paths=image_paths, transforms=get_valid_transforms(img_size=CFG.img_size))
    else:
        test_dataset = ShopeeTestImageDataset(image_paths=image_paths, transforms=get_test_transforms(img_size=CFG.img_size))

    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=CFG.cnn_batch_size, num_workers=CFG.num_workers,
                                                  pin_memory=True, shuffle=False, drop_last=False)
    cnn_embeddings = get_cnn_embeddings(model, test_dataloader)

    cnn_embeddings_all.append(cnn_embeddings)

del cnn_embeddings

cnn_embeddings_mean = np.mean(cnn_embeddings_all, axis=0)
print('cnn_embeddings_mean.shape:', cnn_embeddings_mean.shape)

del cnn_embeddings_all
gc.collect()
torch.cuda.empty_cache()

## Prediction

In [None]:
cnn_embeddings_mean_half = torch.tensor(cnn_embeddings_mean, dtype=torch.float16).to(CFG.device)

bert_embeddings_half = (bert_embeddings.to(torch.float16) + bert_embeddings2.to(torch.float16)) / 2

predictions = get_predictions(df,
                              F.normalize(cnn_embeddings_mean_half),
                              F.normalize(bert_embeddings_half),
                              cnn_threshold=CFG.cnn_threshold,
                              bert_threshold=CFG.bert_threshold,
                              chunk=CFG.chunk,
                              max_preds=CFG.max_preds,
                              nearlest_one=CFG.nearlest_one)

df['predictions'] = predictions

# Submission

In [None]:
def combine_predictions(row):
    x = np.concatenate([row[col] for col in CFG.todo_predictions])
    return ' '.join( np.unique(x) )

In [None]:
df['matches'] = df.apply(combine_predictions, axis=1)
df[['posting_id', 'matches']].to_csv('submission.csv', index=False)
submission_df = pd.read_csv('submission.csv')

In [None]:
submission_df

## Compute CV

In [None]:
def combine_for_cv(row):
    x = np.concatenate([row[col] for col in CFG.todo_predictions])
    return np.unique(x)

def getMetric(col):
    def f1score(row):
        n = len(np.intersect1d(row.target, row[col]))
        return 2 * n / (len(row.target) + len(row[col]))
    return f1score

In [None]:
def histplot(preds, num_correct):
    plt.figure(figsize=(20, 4))
    plt.xlim(0, 60, 1)
    plt.ylim(1, 4e4, 1)
    plt.hist(preds, label='Predict (all)', bins=60, range=(0, 60), alpha=0.6, log=True, align='left', color='red', rwidth=0.3,)
    plt.hist(num_correct, label='Predict (collect)', bins=60, range=(0, 60), alpha=0.3, log=True, align='left', color='blue', rwidth=0.6)
    plt.hist(num_target, label='Target', bins=60, range=(0, 60), alpha=0.3, histtype='stepfilled', log=True, align='left', color='gray')
    plt.legend()
    plt.show()

In [None]:
if CFG.compute_cv:
    tmp = df.groupby('label_group').posting_id.agg('unique').to_dict()
    df['target'] = df.label_group.map(tmp)
    df['oof'] = df.apply(combine_for_cv, axis=1)
    df['f1'] = df.apply(getMetric('oof'), axis=1)
    print('CV Score =', df.f1.mean())
    
    for todo in CFG.todo_predictions:
        print(f"{todo} :", round(df.apply(getMetric(todo), axis=1).mean(), 4))

In [None]:
if CFG.compute_cv:
    df['correct_oof'] = df.apply(lambda row: np.intersect1d(row['oof'], row['target']), axis=1)
    num_correct_oof = df['correct_oof'].apply(lambda x: len(x))
    num_oof_preds = df['oof'].apply(lambda x: len(x))
    num_target = df['target'].apply(lambda x: len(x))
    histplot(num_oof_preds, num_correct_oof)

In [None]:
if CFG.compute_cv:
    num_oof_preds.value_counts()

End