# Reivew some method
In my experience the Image Retest Task has two main approaches: pair based loss and sofmax loss
# 1. Based on pair-based loss::
* **Contrastive**:
This method only considers 1 negative and 1 positive at the same time.
* **Triplet loss**
Triplet loss consider pair: anchor, positive, negative
Model tries to push (anchor,negative) pair and pull (anchor, positive) pair -> The effectiveness of the method is based on batch size and it takes a long time to train && required strong GPU. Why ? In backpropagation it only considers 1 anchor (same class as positive), 1 postive and 1 negative -> so The method is difficult to converge.
And some methods are suggested to reduce this disadvantage such as: [Lifted structured](https://arxiv.org/pdf/1511.06452.pdf), [N pair loss](https://proceedings.neurips.cc/paper/2016/file/6b180037abbebea991d8b1232f8a8ca9-Paper.pdf), ...

# 2. Based on softmax loss:
Arcface, CosFace, Sphere, [Adaptive margin](https://openaccess.thecvf.com/content_CVPR_2019/html/Liu_AdaptiveFace_Adaptive_Margin_and_Sampling_for_Face_Recognition_CVPR_2019_paper.html)




# SOTA 2021
[Proxy anchor](https://arxiv.org/abs/2003.13911) is the strong method in image retrival task.
That outperform compare to famous method Constrative loss, triplet loss, NCA... and reduce vastly training time.
The author choose the anchor for each class ( not assing anchor to image), then the loss pull the same class and push another sample. Proxy anchor does not select pairs and or tuples but only considers anchor and its sample to pulls and pushes the other sample ( difference class of its), then PA does not need large GPU memory to achieve good results.

In [None]:
from torchvision import datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from torch.utils.data import DataLoader,Dataset
import pandas as pd
import torchvision.models as models
from PIL import Image
import tensorflow as tf
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import torch.nn.init as init
import sklearn.preprocessing
import albumentations as A
from torch.utils.data import DataLoader, Dataset
from albumentations.pytorch import ToTensorV2
import cv2
import os
import random

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False # set True to be faster
    print(f'Setting all seeds to be {seed} to reproduce...')
seed_everything(1024)

In [None]:
class config():
    
    path_model_pretrained = '/u01/sen/whale/pretrain_model/densenet121_ra-50efcf5c.pth'
    path_info = '../input/happy-whale-and-dolphin/train.csv'
    # path save file
    model_name = 'densenet121' # densenet121, resnet50
    loss_type = 'proxy_anchor'
    optimizer = 'adam'
    path_save_log = './proxy_anchor_412022/'
    path_save_model = path_save_log + model_name
    resume_model = ''
    root_data_2022 = '../input/happy-whale-and-dolphin/train_images/'
    phase_idx = 0 ## only on crop
    
# train
    start_epoch = 0
    train_number_epochs = 60
    continue_train = False
    batch_size = 64
    num_embeddings = 512
    input_size = 224
    worker = 2
    num_class = 29392 #11894 # 4019 #23786 #4019 #19767
    scale_margin = 1
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
# optimize
    lr = 1e-4
    lr_decay_step = 30
    lr_decay_gamma =  0.25
    weight_decay = 1e-4
    os.makedirs(path_save_log, exist_ok=True)

In [None]:
class min_edge_crop(A.ImageOnlyTransform):
    
    def __init__(self, p: float = 0.5, always_apply=True):
        super().__init__(always_apply, p)
        self.position = 'center'

    def apply(self, img, **params):
        """
        crop image base on min size
        :param img: image to be cropped
        :param position: where to crop the image
        :return: cropped image
        """
        assert self.position in ['center', 'left', 'right'], "position must either be: left, center or right"

        h, w = img.shape[:2]

        if h == w:
            return img

        min_edge = min(h, w)
        if h > min_edge:
            if self.position == "left":
                img = img[:min_edge]
            elif self.position == "center":
                d = (h - min_edge) // 2
                img = img[d:-d] if d != 0 else img

                if h % 2 != 0:
                    img = img[1:]
            else:
                img = img[-min_edge:]

        if w > min_edge:
            if self.position == "left":
                img = img[:, :min_edge]
            elif self.position == "center":
                d = (w - min_edge) // 2
                img = img[:, d:-d] if d != 0 else img

                if w % 2 != 0:
                    img = img[:, 1:]
            else:
                img = img[:, -min_edge:]
        
#         assert img.shape[0] == img.shape[1], f"height and width must be the same, currently {img.shape[:2]}"
        return img

def get_augmentation(phase,input_size):
    if phase == "train":
        return  A.Compose([
                    min_edge_crop(),
                    A.Resize(height=input_size, width=input_size),
                     A.ToGray(p=0.01),
                     A.OneOf([
                       A.GaussNoise(var_limit=[10, 50]),
                       A.GaussianBlur(),
                       A.MotionBlur(),
                       A.MedianBlur(),
                      ], p=0.2),
                    A.OneOf([
                       A.OpticalDistortion(distort_limit=1.0),
                       A.GridDistortion(num_steps=5, distort_limit=1.),
                       A.ElasticTransform(alpha=3),
                   ], p=0.2),
                     A.OneOf([
                         A.CLAHE(),
                         A.RandomBrightnessContrast(),
                     ], p=0.25),
                     A.HueSaturationValue(p=0.25),
                    A.ShiftScaleRotate(p=0.5, shift_limit=0.0625, scale_limit=0.2, rotate_limit=20),
#                     A.Cutout(max_h_size=int(input_size * 0.1), max_w_size=int(input_size * 0.1), num_holes=5, p=0.5),
                    A.Normalize(),
                    ToTensorV2()
                ])
    elif phase in ['test','valid']:
        return A.Compose([
            min_edge_crop(),
            A.Resize(height=input_size, width=input_size),
            A.Normalize(),
            ToTensorV2()
        ])


In [None]:
class whale_huback():
    
    def __init__(self, df, transform = None):
        
        self.df = df.reset_index()
        self.transform = transform
                
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self,index):
        
        img_path, class_id = self.df.loc[index, 'image'], self.df.loc[index,'individual_id']
        sample = cv2.imread(img_path)
        if sample is None:
            print(img_path)
            
        sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB)
        if self.transform is not None:
            sample = self.transform(image=sample)["image"]

        return sample, torch.tensor(class_id)

In [None]:
info = pd.read_csv(config.path_info)
info['image'] = config.root_data_2022+info['image']
print(len(info['image']))
info = info.sample(frac=1)
info['source'] = '2022'
mapping = {}
index=-1
for class_id in info['individual_id'].unique():
    if class_id not in mapping.keys():
        index+=1
        mapping[class_id] = index
info['individual_id'] = info['individual_id'].apply(lambda class_id: mapping[class_id]) 
total_class = info['individual_id'].unique().shape[0]
info['phase'] = info['individual_id'].apply(lambda class_id: 'train' if class_id <= total_class//2 else 'valid')
info[info['phase']=='train']['individual_id'].unique().shape[0]
train = info[info['phase']=='train']
valid = info[info['phase']!='valid']

print('train.shape',train.shape)
print('valid.shape',train.shape)
train.to_csv("source_train.csv")

In [None]:

dataset = {
        phase: whale_huback(eval(phase),transform = get_augmentation(phase=phase,input_size=config.input_size)) \
        for phase in ['train','valid']
}
dataloader = { phase: DataLoader(dataset=dataset[phase], num_workers=config.worker, batch_size=config.batch_size, \
                                          shuffle=(phase=='train'),pin_memory = (phase=='train'), drop_last=True) \
              for phase in ['train','valid'] }

In [None]:
def binarize(T, nb_classes):
    T = T.cpu().numpy()
    
    T = sklearn.preprocessing.label_binarize(
        T, classes = range(0, nb_classes)
    )
    T = torch.FloatTensor(T).to(config.device)
    return T

def l2_norm(input):
    input_size = input.size()
    buffer = torch.pow(input, 2)
    normp = torch.sum(buffer, 1).add_(1e-12)
    norm = torch.sqrt(normp)
    _output = torch.div(input, norm.view(-1, 1).expand_as(input))
    output = _output.view(input_size)
    return output
class Proxy_Anchor(torch.nn.Module):
    def __init__(self, nb_classes, sz_embed, mrg = 0.1, alpha = 32):
        torch.nn.Module.__init__(self)
        # Proxy Anchor Initialization
        self.proxies = torch.nn.Parameter(torch.randn(nb_classes, sz_embed).cuda())
        nn.init.kaiming_normal_(self.proxies, mode='fan_out')

        self.nb_classes = nb_classes
        self.sz_embed = sz_embed
        self.mrg = mrg
        self.alpha = alpha
        
    def forward(self, X, T):
        P = self.proxies

        cos = F.linear(l2_norm(X), l2_norm(P))  # Calcluate cosine similarity
        P_one_hot = binarize(T = T, nb_classes = self.nb_classes)
        N_one_hot = 1 - P_one_hot
    
        pos_exp = torch.exp(-self.alpha * (cos - self.mrg))
        neg_exp = torch.exp(self.alpha * (cos + self.mrg))

        with_pos_proxies = torch.nonzero(P_one_hot.sum(dim = 0) != 0).squeeze(dim = 1)   # The set of positive proxies of data in the batch
        num_valid_proxies = len(with_pos_proxies)   # The number of positive proxies
        
        P_sim_sum = torch.where(P_one_hot == 1, pos_exp, torch.zeros_like(pos_exp)).sum(dim=0) 
        N_sim_sum = torch.where(N_one_hot == 1, neg_exp, torch.zeros_like(neg_exp)).sum(dim=0)
        
        pos_term = torch.log(1 + P_sim_sum).sum() / num_valid_proxies
        neg_term = torch.log(1 + N_sim_sum).sum() / self.nb_classes
        loss = pos_term + neg_term     
        
        return loss


In [None]:
class Densenet121(nn.Module):
    def __init__(self):
        super(Densenet121, self).__init__()

        self.backbone = models.densenet121(pretrained=True)
        
        in_feature = self.backbone.classifier.in_features
        self.backbone.classifier = nn.Linear(in_feature, config.num_embeddings)
        
    def forward(self, x):
        x = F.normalize(self.backbone(x))
        return x

if config.model_name == 'densenet121':
    model = Densenet121()

model = model.to(config.device)
if config.loss_type == 'proxy_anchor':
    loss_func = Proxy_Anchor(nb_classes=config.num_class, sz_embed = config.num_embeddings, mrg=0.1, alpha=32).to(config.device)

params_group = [{'params': model.parameters(), 'lr':float(config.lr)},]
if config.loss_type == 'proxy_anchor' or config.loss_type == 'proxy_nca':
    params_group.append({'params': loss_func.proxies, 'lr':float(config.lr)*10})

if config.optimizer =='sgd':
    optimizer = torch.optim.SGD(params_group, lr=float(config.lr), weight_decay = config.weight_decay, momentum = 0.9, nesterov=True)
elif config.optimizer == 'adam':
    optimizer = torch.optim.Adam(params_group, lr=float(config.lr), weight_decay = config.weight_decay)
elif config.optimizer == 'rmsprop':
    optimizer = torch.optim.RMSprop(params_group, lr=float(config.lr), alpha=0.9, weight_decay = config.weight_decay, momentum = 0.9)
elif config.optimizer == 'adamw':
    optimizer = torch.optim.AdamW(params_group, lr=float(config.lr), weight_decay = config.weight_decay)

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = config.lr_decay_step, gamma=config.lr_decay_gamma)

In [None]:
def train(model, loss_func, train_loader, optimizer, epoch, scheduler):
    model.train()
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(config.device), labels.to(config.device)
        optimizer.zero_grad()
        embeddings = model(data)
        loss = loss_func(embeddings, labels)
        loss.backward()
        optimizer.step()
        if batch_idx % 20 == 0:
            print("Epoch {} Iteration {}: Loss = {}".format(epoch, batch_idx, loss))
    if scheduler is not None:
        scheduler.step()

In [None]:
def calc_recall_at_k(T, Y, k):
    """
    T : [nb_samples] (target labels)
    Y : [nb_samples x k] (k predicted labels/neighbours)
    """
    s = 0
    for t,y in zip(T,Y):
        if t in torch.Tensor(y).long()[:k]:
            s += 1
    return s / (1. * len(T))
def l2_norm(input):
    input_size = input.size()
    buffer = torch.pow(input, 2)
    normp = torch.sum(buffer, 1).add_(1e-12)
    norm = torch.sqrt(normp)
    _output = torch.div(input, norm.view(-1, 1).expand_as(input))
    output = _output.view(input_size)

    return output
def predict_batchwise(model, dataloader):
    
    model_is_training = model.training
    model.eval()
    
    ds = dataloader.dataset
    A = [[] for i in range(len(ds[0]))]
    with torch.no_grad():
        # extract batches (A becomes list of samples)
        for batch in tqdm(dataloader):
            for i, J in enumerate(batch):
                # i = 0: sz_batch * images
                # i = 1: sz_batch * labels
                # i = 2: sz_batch * indices
                if i == 0:
                    # move images to device of model (approximate device)
                    J = model(J.to(config.device))

                for j in J:
                    A[i].append(j)
    model.train()
    model.train(model_is_training) # revert to previous training state
    
    return [torch.stack(A[i]) for i in range(len(A))]
def evaluate_cos_SOP(model, dataloader):
    nb_classes = config.num_class
    
    # calculate embeddings with model and get targets
    X, T = predict_batchwise(model, dataloader)
    X = l2_norm(X)
    
    # get predictions by assigning nearest 8 neighbors with cosine
    K = 1000
    Y = []
    xs = []
    for x in X:
        if len(xs)<10000:
            xs.append(x)
        else:
            xs.append(x)            
            xs = torch.stack(xs,dim=0)
            cos_sim = F.linear(xs,X)
            y = T[cos_sim.topk(1 + K)[1][:,1:]]
            Y.append(y.float().cpu())
            xs = []
            
    # Last Loop
    xs = torch.stack(xs,dim=0)
    cos_sim = F.linear(xs,X)
    y = T[cos_sim.topk(1 + K)[1][:,1:]]
    Y.append(y.float().cpu())
    Y = torch.cat(Y, dim=0)

    # calculate recall @ 1, 2, 4, 8
    recall = []
    for k in [1, 2, 4, 5, 10, 100, 1000]:
        r_at_k = calc_recall_at_k(T, Y, k)
        recall.append(r_at_k)
        print("R@{} : {:.3f}".format(k, 100 * r_at_k))
    return recall

In [None]:
if config.continue_train==True:
    print('loading status...')
    checkpoint = torch.load(config.resume_model)
    model.load_state_dict(checkpoint['model_state_dict'])
f = open(config.path_save_log+"recall.txt", 'w', buffering=1)
best_recall = 0
model.to(config.device)
writer = SummaryWriter(config.path_save_log)
for epoch in range(config.start_epoch, config.train_number_epochs):
    train(model, loss_func, dataloader['train'], optimizer, epoch, scheduler)
    
    r_at_1 = evaluate_cos_SOP(model, dataloader['valid'])
    print('epoch: {}, recall@1: {}'.format(epoch, r_at_1))
    r_at_1 = r_at_1[0]
    
    torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler': scheduler,
            'epoch': epoch
                }, config.path_save_model+ '_last.pt')
    
    if r_at_1 > best_recall:
        best_recall = r_at_1
        torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler': scheduler,
                'epoch': epoch
                    }, config.path_save_model+ '_best.pt')
    f.write(str(r_at_1)+"\n")
    f.flush()
    writer.add_scalar('precision@1', r_at_1, epoch)