## AUPRC Loss for Image Retrieval

This code base supports datasets, metrics and AUPRC loss for image retrieval.

### Datasets

The following class is a wrapper of image retrieval datasets:
> CLASS XCurve.AUPRC.RetrievalDataset(data_dir, list_dir, subset, input_size, batchsize, num_sample_per_id, normal_mean=[0.485, 0.456, 0.406], normal_std=[0.229, 0.224, 0.225], split='train')  [\[SOURCE\]](https://github.com/statusrank/XCurve/blob/master/XCurve/AUPRC/datasets/dataset.py)

#### Parameters:
- data_dir (str): Path to the data.
- list_dir (str): Path to the subset lists.
- subset (str): Subset used for training, validation or testing.
- input_size (int): Input image size.
- batchsize (int): Number of samples per iteration. Must be consistent with the batch size of dataloader.
- num_sample_per_id (int): Number of samples for each id.
- normal_mean (list[int], optinal): Mean for model input normalization.
- normal_std (list[int], optinal): Standard deviation for model input normalization.
- split (str, options: ['train', 'val', 'test']): Data split.

Three benchmark datasets are provided: iNaturalist, Stanford Online Products (SOP), and PKU VehicleID (VehID). The default configures are listed as dictionaries in `XCurve.AUPRC.DefaultInatDatasetCfg`, `XCurve.AUPRC.DefaultSOPDatasetCfg`, and `XCurve.AUPRC.DefaultVehIDDatasetCfg`, respectively. 

#### Functions:
- RetrievalDataset.\_\_getitem\_\_(idx) -> tuple(torch.Tensor, torch.Tensor): Load an image and the corresponding labels.
- RetrievalDataset.get_cnt_per_id() -> list\[int\]: Load a list describing the number of samples for each id.
- RetrievalDataset.reset() -> None: Shuffle the data. It should be called after an epoch.

#### Example:

In [3]:
from torch.utils.data import DataLoader
from XCurve.AUPRC import RetrievalDataset, DefaultInatDatasetCfg

args = DefaultInatDatasetCfg
print(args)
dataset = RetrievalDataset(**args, split='train')
dataloader = DataLoader(dataset, batch_size=args.batchsize, num_workers=4)
print(dataloader.dataset.get_cnt_per_id()[:10])
for i, (img, lbl) in enumerate(dataloader):
    print(img.shape, lbl.shape)
    ## do something here
    break
dataloader.dataset.reset()

{'subset': None, 'dataset_train': 'train', 'dataset_val': 'val', 'dataset_test': 'test', 'data_dir': './data/iNaturalist/images', 'list_dir': './data/iNaturalist/split_list', 'num_sample_per_id': 4, 'inst_blc': True, 'input_size': 256, 'batchsize': 56, 'normal_mean': [0.485, 0.456, 0.406], 'normal_std': [0.229, 0.224, 0.225]}

shuffling data...
train set has 278656 samples per epoch
[26, 30, 24, 19, 22, 13, 16, 15, 23, 19]
torch.Size([56, 3, 256, 256]) torch.Size([56, 1])

shuffling data...


### AUPRC Loss

The following functions build the class to compute the List-stable AUPRC loss:
> XCurve.AUPRC.ListStableAUPRC(tau1=0.1, tau2=0.001, beta=0.001, prior_mul=0.1,
num_sample_per_id=4, var_reg_weight_pos=5, var_reg_weight_neg=1) -> SOPRC [\[SOURCE\]](https://github.com/statusrank/XCurve/blob/master/XCurve/AUPRC/losses/__init__.py)

#### Parameters:
- tau1 (float): Control the surrogate loss of pos-neg pairs. See \tau_1 in Eq.(7).
- tau2 (float): Control the surrogate loss of pos-pos pairs. See \tau_2 in Eq.(7).
- beta (float): Control the exponential moving average. See \beta in Eq.(10).
- prior_mul (float): Imbalance ratio of the id with most positive examples.
- num_sample_per_id (int): Number of examples for each id.
- var_reg_weight_pos (float): Weight of the variance regular term w.r.t. positive examples.
- var_reg_weight_neg (float): Weight of the variance regular term w.r.t. negative examples.

The default configures are listed as dictionaries in `XCurve.AUPRC.DefaultLossCfg`.

The following class supports the computation of List-stable AUPRC loss, where the parameters are the same as above:
> CLASS SOPRC (num_sample_per_id, temp, beta, prior_mul=0.1, **kwargs) [\[SOURCE\]](https://github.com/statusrank/XCurve/blob/master/XCurve/AUPRC/losses/soprc.py)

#### Functions:
- SOPRC.forward(feats, targets) -> torch.Tensor: Compute loss.
- SOPRC.update_cnt_per_id(cnt_per_id) -> None: Update a list describing the number of samples for each id. It must be called before running.

#### Example:

In [8]:
import torch.nn.functional as F
from XCurve.AUPRC import (ListStableAUPRC, DefaultLossCfg, \
    RetrievalDataset, DefaultInatDatasetCfg)

dataset = RetrievalDataset(**DefaultInatDatasetCfg, split='train')
criterion = ListStableAUPRC(**DefaultLossCfg)
criterion.update_cnt_per_id(dataset.get_cnt_per_id())
feats = F.normalize(torch.randn((16, 128)).cuda(), dim=1, p=2)
targets = torch.tensor([5,5,5,5,3,3,3,3,1,1,1,1,9,9,9,9]).cuda()
loss = criterion(feats, targets)
print(loss.item())


shuffling data...
train set has 280056 samples per epoch
1.0337904691696167
