In [86]:
import os
import gc
import cv2
import math
import copy
import time
import random

# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

# Utils
import joblib
from tqdm import tqdm
from collections import defaultdict

# Sklearn Imports
from sklearn.preprocessing import LabelEncoder, normalize
from sklearn.model_selection import StratifiedKFold

# For Image Models
import timm

# For Similarity Search
import faiss

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# For colored terminal text
from colorama import Fore, Back, Style
b_ = Fore.BLUE
y_ = Fore.YELLOW
sr_ = Style.RESET_ALL

import warnings
warnings.filterwarnings("ignore")

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"    

# Configurations

In [87]:
CONFIG = {
    "seed": 3698,
    "img_size": 448,
    "model_name": "tf_efficientnet_b0_ns",
    "num_classes": 15587,
    "embedding_size": 512,
    "train_batch_size": 64,
    "valid_batch_size": 64,
    "n_fold": 5,
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    # ArcFace Hyperparameters
    "s": 30.0, 
    "m": 0.30,
    "ls_eps": 0.0,
    "easy_margin": False
}

In [88]:
model_checkpoint_dir = '../runs/effnet_arcface_gem/checkpoint_fold4/best_loss.pth'
test_folder_dir = '../data/happy-whale-and-dolphin/test_images'

train_csv_dir = '../lists/train_modified.csv'
test_csv_dir = '../lists/test_modified.csv'

## Seed for Reducibility

In [89]:
def set_seed(seed=3698):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed(CONFIG['seed'])

## Read data

In [90]:
df = pd.read_csv(train_csv_dir, index_col=0)
df_test = pd.read_csv(test_csv_dir, index_col=0)


In [91]:
df.head()

Unnamed: 0,image,species,individual_id,image_path,split,class,width,height
0,00021adfb725ed.jpg,melon_headed_whale,12348,../data/happy-whale-and-dolphin/train_images/0...,Train,whale,804,671
1,000562241d384d.jpg,humpback_whale,1636,../data/happy-whale-and-dolphin/train_images/0...,Train,whale,3504,2336
2,0007c33415ce37.jpg,false_killer_whale,5842,../data/happy-whale-and-dolphin/train_images/0...,Train,whale,3599,2399
3,0007d9bca26a99.jpg,bottlenose_dolphin,4551,../data/happy-whale-and-dolphin/train_images/0...,Train,dolphin,3504,2336
4,00087baf5cef7a.jpg,humpback_whale,8721,../data/happy-whale-and-dolphin/train_images/0...,Train,whale,3599,2699


In [92]:
encoder = LabelEncoder()

with open("le.pkl", "rb") as fp:
    encoder = joblib.load(fp)

In [93]:
skf = StratifiedKFold(n_splits=CONFIG['n_fold'])

for fold, ( _, val_) in enumerate(skf.split(X=df, y=df.individual_id)):
    df.loc[val_ , "kfold"] = fold

## Dataset model

In [95]:
class HappyWhaleDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.ids = df['image'].values
        self.file_names = df['image_path'].values
        self.labels = df['individual_id'].values
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        idx = self.ids[index]
        img_path = self.file_names[index]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        label = self.labels[index]
        
        if self.transforms:
            img = self.transforms(image=img)["image"]
            
        return {
            'image': img,
            'label': torch.tensor(label, dtype=torch.long),
            'id': idx
        }

## Augmentations

In [96]:
data_transforms = {
    "train" : A.Compose(
            [
                A.Resize(CONFIG['img_size'], CONFIG['img_size']),
                A.ShiftScaleRotate(
                    shift_limit=0.1,
                    scale_limit=0.15,
                    rotate_limit=60,
                    p=0.5
                ),
                A.HueSaturationValue(
                    hue_shift_limit=0.2,
                    sat_shift_limit=0.2,
                    val_shift_limit=0.2,
                    p=0.5
                ),
                A.RandomBrightnessContrast(
                    brightness_limit=(-0.1, 0.1), 
                    contrast_limit=(-0.1, 0.1),
                    p=0.5
                ),
                A.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                    max_pixel_value=255.0,
                    p=1.0
                ),
                ToTensorV2()
            ], p=1.),

    "validate" : A.Compose([
            A.Resize(CONFIG['img_size'], CONFIG['img_size']),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
                max_pixel_value=255.0,
                p=1.0
            ),
            ToTensorV2()
        ], p=1.)
}

## Arcface layer

In [97]:
class ArcMarginProduct(nn.Module):
    """Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin
            cos(theta + m)
        """
    def __init__(self, in_features, out_features, s=30.0, 
                 m=0.50, easy_margin=False, ls_eps=0.0):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.ls_eps = ls_eps  # label smoothing
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        # _device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        one_hot = torch.zeros(cosine.size(), device="cuda:0")
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
        # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

## GeM pooling layer

In [98]:
class GeM(nn.Module):
    """Credit: https://amaarora.github.io/2020/08/30/gempool.html"""

    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps

    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)

    def __repr__(self):
        return self.__class__.__name__ + \
            '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
                ', ' + 'eps=' + str(self.eps) + ')'

## Base EffNet model

In [99]:
class BaseEffnetModel(nn.Module):
    def __init__(
        self, model_name, embedding_size, num_classes, 
        s, m, easy_margin, ls_eps, pretrained=True, freeze_backbone=False
    ):
        super(BaseEffnetModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Identity()
        self.model.global_pool = nn.Identity()
        self.pooling = GeM()
        self.embedding = nn.Linear(in_features, embedding_size)
        self.fc = ArcMarginProduct(
            in_features=embedding_size,
            out_features=num_classes,
            s=s,
            m=m,
            easy_margin=easy_margin,
            ls_eps=ls_eps
        )

    def forward(self, images, labels):
        features = self.model(images)
        pooled_features = self.pooling(features).flatten(1)
        embedding = self.embedding(pooled_features)
        output = self.fc(embedding, labels)
        return output
    
    def extract(self, images):
        features = self.model(images)
        pooled_features = self.pooling(features).flatten(1)
        embedding = self.embedding(pooled_features)
        return embedding


In [100]:
model = BaseEffnetModel(
    model_name=CONFIG['model_name'],
    embedding_size= CONFIG['embedding_size'],
    num_classes= CONFIG['num_classes'],
    s= CONFIG['s'],
    m= CONFIG['m'],
    easy_margin= CONFIG['easy_margin'],
    ls_eps= CONFIG['ls_eps'],
)


In [101]:
eff_state = torch.load(model_checkpoint_dir)['model_state_dict']
model.load_state_dict(eff_state)
model.to(CONFIG['device'])

BaseEffnetModel(
  (model): EfficientNet(
    (conv_stem): Conv2dSame(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (act1): SiLU(inplace=True)
    (blocks): Sequential(
      (0): Sequential(
        (0): DepthwiseSeparableConv(
          (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (act1): SiLU(inplace=True)
            (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (gate): Sigmoid()
          )
          (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_

In [102]:
df.head(5)

Unnamed: 0,image,species,individual_id,image_path,split,class,width,height,kfold
0,00021adfb725ed.jpg,melon_headed_whale,12348,../data/happy-whale-and-dolphin/train_images/0...,Train,whale,804,671,0.0
1,000562241d384d.jpg,humpback_whale,1636,../data/happy-whale-and-dolphin/train_images/0...,Train,whale,3504,2336,1.0
2,0007c33415ce37.jpg,false_killer_whale,5842,../data/happy-whale-and-dolphin/train_images/0...,Train,whale,3599,2399,0.0
3,0007d9bca26a99.jpg,bottlenose_dolphin,4551,../data/happy-whale-and-dolphin/train_images/0...,Train,dolphin,3504,2336,0.0
4,00087baf5cef7a.jpg,humpback_whale,8721,../data/happy-whale-and-dolphin/train_images/0...,Train,whale,3599,2699,0.0


In [103]:
@torch.inference_mode()
def get_embeddings(model, dataloader, device):
    model.eval()
    
    LABELS = []
    EMBEDS = []
    IDS = []
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, data in bar:        
        images = data['image'].to(device, dtype=torch.float)
        labels = data['label'].to(device, dtype=torch.long)
        ids = data['id']

        outputs = model.extract(images)
        
        LABELS.append(labels.cpu().numpy())
        EMBEDS.append(outputs.cpu().numpy())
        IDS.append(ids)
    
    EMBEDS = np.vstack(EMBEDS)
    LABELS = np.concatenate(LABELS)
    IDS = np.concatenate(IDS)
    
    return EMBEDS, LABELS, IDS

In [104]:
def prepare_loaders(df, fold):
    df_train = df[df.kfold != fold].reset_index(drop=True)
    df_valid = df[df.kfold == fold].reset_index(drop=True)
    
    train_dataset = HappyWhaleDataset(df_train, transforms=data_transforms["train"])
    valid_dataset = HappyWhaleDataset(df_valid, transforms=data_transforms["validate"])

    train_loader = DataLoader(train_dataset, batch_size=CONFIG['train_batch_size'], 
                              num_workers=2, shuffle=False, pin_memory=True)
    valid_loader = DataLoader(valid_dataset, batch_size=CONFIG['valid_batch_size'], 
                              num_workers=2, shuffle=False, pin_memory=True)
    
    return train_loader, valid_loader

In [106]:
train_loader, valid_loader = prepare_loaders(df, fold=0)


In [107]:
train_embeds, train_labels, train_ids = get_embeddings(model, train_loader, CONFIG['device'])
valid_embeds, valid_labels, valid_ids = get_embeddings(model, valid_loader, CONFIG['device'])

100%|██████████| 638/638 [16:45<00:00,  1.58s/it]
100%|██████████| 160/160 [04:12<00:00,  1.58s/it]


In [108]:
train_embeds = normalize(train_embeds, axis=1, norm='l2')
valid_embeds = normalize(valid_embeds, axis=1, norm='l2')

In [109]:
train_labels = encoder.inverse_transform(train_labels)
valid_labels = encoder.inverse_transform(valid_labels)

In [110]:
index = faiss.IndexFlatIP(CONFIG['embedding_size'])
index.add(train_embeds)

In [111]:
D, I = index.search(valid_embeds, k=50)

In [112]:
allowed_targets = np.unique(train_labels)

In [113]:
val_targets_df = pd.DataFrame(np.stack([valid_ids, valid_labels], axis=1), columns=['image','target'])
val_targets_df.loc[~val_targets_df.target.isin(allowed_targets), 'target'] = 'new_individual'
val_targets_df.target.value_counts()

new_individual    1850
37c7aba965a5        80
114207cab555        34
a6e325d8e924        31
19fbb960f07d        31
                  ... 
c511fbe2acd1         1
b90f72a6be9b         1
579a23a02325         1
045ca1b5a580         1
26145086bca6         1
Name: target, Length: 4012, dtype: int64

In [115]:
valid_df = []
for i, val_id in tqdm(enumerate(valid_ids)):
    targets = train_labels[I[i]]
    distances = D[i]
    subset_preds = pd.DataFrame(np.stack([targets,distances],axis=1),columns=['target','distances'])
    subset_preds['image'] = val_id
    valid_df.append(subset_preds)

10207it [00:13, 737.45it/s] 


In [116]:
valid_df = pd.concat(valid_df).reset_index(drop=True)
valid_df = valid_df.groupby(['image','target']).distances.max().reset_index()
valid_df.head()

Unnamed: 0,image,target,distances
0,00021adfb725ed.jpg,01981b6e8ac9,0.995767
1,00021adfb725ed.jpg,0b6c7058e353,0.9952
2,00021adfb725ed.jpg,0c8017782d1f,0.994356
3,00021adfb725ed.jpg,0d4afee74fa8,0.993907
4,00021adfb725ed.jpg,111a78cd63bf,0.993708


In [117]:
valid_df = valid_df.sort_values('distances', ascending=False).reset_index(drop=True)
valid_df.to_csv('val_neighbors.csv')

In [118]:
sample_list = ['938b7e931166', '5bf17305f073', '7593d2aee842', '7362d7a01d00','956562ff2888']


In [119]:
def get_predictions(test_df, threshold=0.2):
    predictions = {}
    for i, row in tqdm(test_df.iterrows()):
        if row.image in predictions:
            if len(predictions[row.image]) == 5:
                continue
            predictions[row.image].append(row.target)
        elif row.distances > threshold:
            predictions[row.image] = [row.target, 'new_individual']
        else:
            predictions[row.image] = ['new_individual', row.target]

    for x in tqdm(predictions):
        if len(predictions[x]) < 5:
            remaining = [y for y in sample_list if y not in predictions]
            predictions[x] = predictions[x] + remaining
            predictions[x] = predictions[x][:5]
        
    return predictions

In [120]:
def map_per_image(label, predictions):
    try:
        return 1 / (predictions[:5].index(label) + 1)
    except ValueError:
        return 0.0

## Compute CV

In [121]:
best_th = 0
best_cv = 0
for th in [0.1*x for x in range(11)]:
    all_preds = get_predictions(valid_df, threshold=th)
    cv = 0
    for i,row in val_targets_df.iterrows():
        target = row.target
        preds = all_preds[row.image]
        val_targets_df.loc[i,th] = map_per_image(target, preds)
    cv = val_targets_df[th].mean()
    print(f"CV at threshold {th}: {cv}")
    if cv > best_cv:
        best_th = th
        best_cv = cv

398126it [00:15, 26189.60it/s]
100%|██████████| 10207/10207 [00:00<00:00, 3202518.02it/s]


CV at threshold 0.0: 0.20625224519120874


398126it [00:15, 26063.33it/s]
100%|██████████| 10207/10207 [00:00<00:00, 3079946.83it/s]


CV at threshold 0.1: 0.20625224519120874


398126it [00:15, 25140.47it/s]
100%|██████████| 10207/10207 [00:00<00:00, 2635836.78it/s]


CV at threshold 0.2: 0.20625224519120874


398126it [00:15, 26021.30it/s]
100%|██████████| 10207/10207 [00:00<00:00, 3105190.46it/s]


CV at threshold 0.30000000000000004: 0.20625224519120874


398126it [00:15, 25998.17it/s]
100%|██████████| 10207/10207 [00:00<00:00, 3179211.42it/s]


CV at threshold 0.4: 0.20625224519120874


398126it [00:15, 26266.34it/s]
100%|██████████| 10207/10207 [00:00<00:00, 3157871.28it/s]


CV at threshold 0.5: 0.20625224519120874


398126it [00:15, 26337.76it/s]
100%|██████████| 10207/10207 [00:00<00:00, 3161369.14it/s]


CV at threshold 0.6000000000000001: 0.20625224519120874


398126it [00:15, 25784.04it/s]
100%|██████████| 10207/10207 [00:00<00:00, 3285087.55it/s]


CV at threshold 0.7000000000000001: 0.20625224519120874


398126it [00:15, 25764.13it/s]
100%|██████████| 10207/10207 [00:00<00:00, 2423919.20it/s]


CV at threshold 0.8: 0.20625224519120874


398126it [00:15, 26247.41it/s]
100%|██████████| 10207/10207 [00:00<00:00, 3268284.67it/s]


CV at threshold 0.9: 0.20625224519120874


398126it [00:14, 26621.84it/s]
100%|██████████| 10207/10207 [00:00<00:00, 2946607.54it/s]


CV at threshold 1.0: 0.2506335521374223


In [122]:
print("Best threshold", best_th)
print("Best cv", best_cv)
val_targets_df.describe()

Best threshold 1.0
Best cv 0.2506335521374223


Unnamed: 0,0.0,0.1,0.2,0.30000000000000004,0.4,0.5,0.6000000000000001,0.7000000000000001,0.8,0.9,1.0
count,10207.0,10207.0,10207.0,10207.0,10207.0,10207.0,10207.0,10207.0,10207.0,10207.0,10207.0
mean,0.206252,0.206252,0.206252,0.206252,0.206252,0.206252,0.206252,0.206252,0.206252,0.206252,0.250634
std,0.319175,0.319175,0.319175,0.319175,0.319175,0.319175,0.319175,0.319175,0.319175,0.319175,0.384931
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5
max,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [123]:
## Adjustment: Since Public lb has nearly 10% 'new_individual' (Be Careful for private LB)
val_targets_df['is_new_individual'] = val_targets_df.target=='new_individual'
print(val_targets_df.is_new_individual.value_counts().to_dict())
val_scores = val_targets_df.groupby('is_new_individual').mean().T
val_scores['adjusted_cv'] = val_scores[True]*0.1+val_scores[False]*0.9
best_threshold_adjusted = val_scores['adjusted_cv'].idxmax()
print("best_threshold",best_threshold_adjusted)
val_scores

{False: 8357, True: 1850}
best_threshold 0.0


is_new_individual,False,True,adjusted_cv
0.0,0.141225,0.5,0.177102
0.1,0.141225,0.5,0.177102
0.2,0.141225,0.5,0.177102
0.3,0.141225,0.5,0.177102
0.4,0.141225,0.5,0.177102
0.5,0.141225,0.5,0.177102
0.6000000000000001,0.141225,0.5,0.177102
0.7000000000000001,0.141225,0.5,0.177102
0.8,0.141225,0.5,0.177102
0.9,0.141225,0.5,0.177102


## Inference

In [124]:
train_embeds = np.concatenate([train_embeds, valid_embeds])
train_labels = np.concatenate([train_labels, valid_labels])
print(train_embeds.shape,train_labels.shape)

(51033, 512) (51033,)


In [125]:
index = faiss.IndexFlatIP(CONFIG['embedding_size'])
index.add(train_embeds)

In [129]:
df_test['individual_id'] = -1

In [131]:
test_dataset = HappyWhaleDataset(df_test, transforms=data_transforms["validate"])
test_loader = DataLoader(test_dataset, batch_size=CONFIG['valid_batch_size'], 
                         num_workers=2, shuffle=False, pin_memory=True)

In [132]:
test_embeds, _, test_ids = get_embeddings(model, test_loader, CONFIG['device'])
test_embeds = normalize(test_embeds, axis=1, norm='l2')

100%|██████████| 437/437 [11:21<00:00,  1.56s/it]


In [133]:
D, I = index.search(test_embeds, k=50)


In [134]:
test_df = []
for i, test_id in tqdm(enumerate(test_ids)):
    targets = train_labels[I[i]]
    distances = D[i]
    subset_preds = pd.DataFrame(np.stack([targets, distances], axis=1), columns=['target','distances'])
    subset_preds['image'] = test_id
    test_df.append(subset_preds)
    
test_df = pd.concat(test_df).reset_index(drop=True)
test_df = test_df.groupby(['image','target']).distances.max().reset_index()
test_df = test_df.sort_values('distances', ascending=False).reset_index(drop=True)
test_df.to_csv('test_neighbors.csv')

27956it [00:10, 2621.56it/s]


In [135]:
predictions = get_predictions(test_df, best_threshold_adjusted)

predictions = pd.Series(predictions).reset_index()
predictions.columns = ['image','predictions']
predictions['predictions'] = predictions['predictions'].apply(lambda x: ' '.join(x))
predictions.to_csv('../result/submission.csv',index=False)
predictions.head()

1114903it [00:44, 25124.89it/s]
100%|██████████| 27956/27956 [00:00<00:00, 3047350.76it/s]


Unnamed: 0,image,predictions
0,6827583d7f8de8.jpg,be315fad6ef4 new_individual 99591df6b695 aeae6...
1,5ee9b577ad8a7c.jpg,12f2a4406d7e new_individual cd720f8127f5 1c496...
2,3674a3bf79e4e8.jpg,f195c38bcf17 new_individual cc0e0b020a90 0b297...
3,db0c734e80a4e5.jpg,2e268c8dbd31 new_individual 7d07bf29721f cd720...
4,3cbaaacdf49bed.jpg,84a261c0e5cf new_individual cd720f8127f5 78453...
