## How to conduct AUTKC optimization?

This example illustrates how to perform AUTKC optimization by the XCurve libraray.

### Import optimizer, loss function, dataset, and dataloader.

First, we get dataloader for training and validation by the `get_data_loader` function, whose essential parameters are explained below:
- `dataset_dir`: This parameter specifies the directory of the dataset. We have implement `cifar-10`, `cifar-100`, `tiny-imagenet-200`, and `place-365`;
- `batch_size`: This parameter specifies the size of each batch for the dataloader;
- `workers`: This parameter specifies the number of workers for the dataloader;
- `train_ratio`: This parameter specifies the ratio of samples for training, and the other samples will be used for validation.

In [1]:
import os
from XCurve.AUTKC.dataloaders import get_data_loader

dataset_root, dataset = 'D:/dataset', 'cifar-100'
dataset_dir = os.path.join(dataset_root, dataset)
train_loader, val_loader, _, num_class = get_data_loader(dataset_dir, batch_size=128, workers=4, train_ratio=0.9)

Files already downloaded and verified
Files already downloaded and verified


Then, we build the model, the loss function, and the optimizer. By default, we use the `resnet18` provided by Pytorch. The `StandardAUTKCLoss` function return the AUTKC loss, whose essential parameters are explained below:
- `surrogate`: This parameter specifies the surrogate loss, whose options include `Sq`, `Exp`, `Logit`, and `Hinge`;
- `K`: This parameter specifies the hyperparameter `K` for the AUTKC loss;
- `epoch_to_paced`: This parameter specifies the number of warm-up epoch for training. By default, we use the CE loss as the warm-up loss. 

In [2]:
import torch.nn as nn
import torchvision.models as models
from XCurve.AUTKC.losses.AUTKCLoss import StandardAUTKCLoss
import torch.optim as optim

model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_class)  
model = model.cuda()

surrogate, K, epoch_to_paced = 'Sq', 5, 3
criterion = StandardAUTKCLoss(surrogate, K, epoch_to_paced).cuda()

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.001)



Finally, we train the model for 10 epochs. To this end, we first specify the `k_list` for evaluation. In each epoch, the codes follow a stanard Pytorch trainging pipeline. And we use the `evaluate` function to get the Top-k accuracy and the AUTKC performance under the given k-list. 

For a more detailed training process, please refer to `example/data/autkc.py` and run `python autkc.py --loss autkc --surrogate Exp --resume checkpoints/*** `.

In [3]:
from XCurve.AUTKC.metrics import evaluate
from XCurve.AUTKC.utils.common_utils import AverageMeter

k_list = [3, 5]
topks = [AverageMeter('Acc@%d' % k, ':6.2f') for k in k_list]
autkcs = [AverageMeter('AUTKC@%d' % k, ':6.2f') for k in k_list]
for epoch in range(10):
    model.train()
    for i, (inputs, targets) in enumerate(train_loader):
        targets = targets.squeeze().cuda(non_blocking =True)
        inputs = inputs.float().cuda(non_blocking =True)
        optimizer.zero_grad()

        outputs = model(inputs).squeeze()
        loss = criterion(outputs, targets, epoch) if hasattr(criterion, 'epoch_to_paced') else criterion(outputs, targets)
        loss.backward()
        optimizer.step()
    
    accs, autkc= evaluate(outputs.data, targets, k_list)
    for _ in range(len(k_list)):
        topks[_].update(accs[_], inputs.size(0))
        autkcs[_].update(autkc[_], inputs.size(0))

    autkc_str = '  '.join(['AUTKC@{} {autkcs.val:.2f} ({autkcs.avg:.2f})'.format(k_list[_], autkcs=autkcs[_]) for _ in range(len(k_list))])
    topks_str = '  '.join(['Acc@{} {topk.val:.2f} ({topk.avg:.2f})'.format(k_list[_], topk=topks[_]) for _ in range(len(k_list))])
    print(epoch, autkc_str, topks_str, sep='\t')


0	AUTKC@3 33.80 (33.80)  AUTKC@5 40.83 (40.83)	Acc@3 40.28 (40.28)  Acc@5 54.17 (54.17)
1	AUTKC@3 43.52 (38.66)  AUTKC@5 50.00 (45.42)	Acc@3 51.39 (45.83)  Acc@5 61.11 (57.64)
2	AUTKC@3 40.28 (39.20)  AUTKC@5 45.83 (45.56)	Acc@3 47.22 (46.30)  Acc@5 55.56 (56.94)
3	AUTKC@3 57.41 (43.75)  AUTKC@5 63.61 (50.07)	Acc@3 68.06 (51.74)  Acc@5 75.00 (61.46)
4	AUTKC@3 57.87 (46.57)  AUTKC@5 62.22 (52.50)	Acc@3 63.89 (54.17)  Acc@5 69.44 (63.06)
5	AUTKC@3 56.02 (48.15)  AUTKC@5 61.94 (54.07)	Acc@3 63.89 (55.79)  Acc@5 72.22 (64.58)
6	AUTKC@3 53.24 (48.88)  AUTKC@5 59.72 (54.88)	Acc@3 63.89 (56.94)  Acc@5 69.44 (65.28)
7	AUTKC@3 62.50 (50.58)  AUTKC@5 67.22 (56.42)	Acc@3 70.83 (58.68)  Acc@5 75.00 (66.49)
8	AUTKC@3 62.96 (51.95)  AUTKC@5 66.39 (57.53)	Acc@3 69.44 (59.88)  Acc@5 72.22 (67.13)
9	AUTKC@3 67.13 (53.47)  AUTKC@5 72.50 (59.03)	Acc@3 75.00 (61.39)  Acc@5 81.94 (68.61)
