## How to conduct OpenAUC optimization?

This example illustrates how to perform OpenAUC optimization by the XCurve libraray.

First, we get the open-set dataset. To this end, we follow the following steps: (1) Use the `get_class_splits` function to get the split of the original dataset, where the path of the dataset is stored in `XCurve.OpenAUC.utils.config.py` such as `svhn_root = 'D:/dataset/svhn'`. Some datasets split different classes as the unkown classes, whose details can be found in `XCurve.OpenAUC.dataloaders.open_set_splits`. (2) Use the `get_datasets` function to get the open-set dataset. (3) Construct dataloaders for the train set, test set, and the open-set set.

In [4]:
from XCurve.OpenAUC.dataloaders.open_set_datasets import get_class_splits, get_datasets
from torch.utils.data import DataLoader
import argparse

parser = argparse.ArgumentParser("Training")
parser.add_argument('--dataset', type=str, default='svhn')
parser.add_argument('--split_idx', default=0, type=int, help='0-4 OSR splits for each dataset')
parser.add_argument('--model', type=str, default='classifier32')
parser.add_argument('--image_size', type=int, default=32)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--transform', type=str, default='rand-augment')
parser.add_argument('--num_workers', type=int, default=0) # zero for windows
parser.add_argument('--label_smoothing', type=float, default=None, help="Smoothing constant for label smoothing.")
parser.add_argument('--temp', type=float, default=1.0, help="temp for label_smoothing")
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--alpha', type=float, default=2, help="parameter for openauc loss")
parser.add_argument('--lamda', type=float, default=0.05, help="parameter for openauc loss")
parser.add_argument('--optim', type=str, default=None, help="Which optimizer to use {adam, sgd}")
parser.add_argument('--lr', type=float, default=0.1, help="learning rate for model")
parser.add_argument('--weight_decay', type=float, default=1e-4, help="LR regularisation on weights")
parser.add_argument('--scheduler', type=str, default='cosine_warm_restarts')
parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--num_restarts', type=int, default=2, help='How many restarts for cosine_warm_restarts schedule')
args, _ = parser.parse_known_args() # for VS code


args.train_classes, args.open_set_classes = get_class_splits(args.dataset, args.split_idx)

datasets = get_datasets(args.dataset, transform=args.transform, train_classes=args.train_classes,
                        open_set_classes=args.open_set_classes, image_size=args.image_size, seed=args.seed)

dataloaders = {}
for k, v, in datasets.items():
    shuffle = True if k == 'train' else False
    dataloaders[k] = DataLoader(v, batch_size=args.batch_size, shuffle=shuffle, sampler=None, num_workers=args.num_workers)

trainloader = dataloaders['train']
testloader = dataloaders['val']
outloader = dataloaders['test_unknown']

Loading datasets...
Using downloaded and verified file: D:/dataset/svhn\train_32x32.mat
Using downloaded and verified file: D:/dataset/svhn\test_32x32.mat


Then, we specify the model, the loss function, and the optimizer. To generate open-set samples by the manifold mixup technique, a wrapper model, which will return the manifold, is defined in `XCurve.OpenAUC.models.wrapper_classes`. Besides, we define the loss `StandardOpenAUCLoss`, whose essential parameters are explained below:
- `loss_close`: This parameter specifies the loss function for close-set samples;
- `alpha`: This parameter specifies the beta distribution $B(\alpha, \alpha)$ that generate the ratio for manifold mixup.
- `lambd`: This parameter spcifies the hyper-parameter to balance the close-set loss and the open-set loss.

In [5]:
from XCurve.OpenAUC.utils.model_utils import get_model
from XCurve.OpenAUC.optimizers import get_optimizer, get_scheduler
from XCurve.OpenAUC.models.wrapper_classes import Classifier32Wrapper
from XCurve.OpenAUC.losses.OpenAUCLoss import StandardOpenAUCLoss
from XCurve.OpenAUC.losses.Softmax import Softmax

net = get_model(args, wrapper_class=Classifier32Wrapper)
criterion = StandardOpenAUCLoss(loss_close=Softmax(**{'temp': args.temp, 'label_smoothing': args.label_smoothing}), alpha=args.alpha, lambd=args.lamda)

params_list = [{'params': net.parameters()}, {'params': criterion.parameters()}]
optimizer = get_optimizer(args=args, params_list=params_list, **{'dataset': args.dataset})
scheduler = get_scheduler(optimizer, args)

Finally, we train the model for 200 epochs. The training phase follows the standard Pytorch traning process except that the AUTKC loss requies the $f_\text{post}$ and the manifolds. In the testing phase, two classes `EnsembleModel` and `OpenSetEvaluator` are used to facilitate the evaluation process, where the close-set model are transformed to a open-set classifier.

For a more detailed training process, please refer to `example/data/openauc.py`.

In [6]:
import torch
from tqdm import tqdm
from XCurve.OpenAUC.utils.common_utils import AverageMeter
from XCurve.OpenAUC.metrics import OpenSetEvaluator, EnsembleModel

def train(net, optimizer, trainloader):
    net.train()
    losses = AverageMeter()
    torch.cuda.empty_cache()

    loss_all = 0
    for data, labels, _ in tqdm(trainloader):
        data, labels = data.cuda(), labels.cuda()
        optimizer.zero_grad()
        embedding, logits = net(data, True)
        _, loss = criterion(logits, labels, net.net.fc, embedding)
        loss.backward()
        optimizer.step()
        
        losses.update(loss.item(), data.size(0))
        loss_all += losses.avg

    return loss_all

def test(net, testloader, outloader):
    model = EnsembleModel(net).cuda()
    model.eval()

    evaluate = OpenSetEvaluator(model=model, known_data_loader=testloader, unknown_data_loader=outloader)
    preds = evaluate.predict(save=False)
    results = evaluate.evaluate(evaluate, load=False, preds=preds)

    return results


for epoch in range(args.epochs):
    print("==> Epoch {}/{}".format(epoch+1, args.epochs), end='\t')

    train(net, optimizer, trainloader)

    print("==> Test", end='\t')
    results = test(net, testloader, outloader)
    print("Acc:{:.3f}\tAUROC:{:.3f}\tOpenAUC:{:.3f}".format(results['Acc'], results['AUROC'], results['OpenAUC']))
                                                                            
    scheduler.step(epoch=epoch)

==> Epoch 1/200	

100%|██████████| 265/265 [00:17<00:00, 14.72it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:02<00:00, 32.22it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 31.22it/s]


Acc:0.250	AUROC:0.508	OpenAUC:0.127
==> Epoch 2/200	

100%|██████████| 265/265 [00:18<00:00, 14.64it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:02<00:00, 32.34it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 32.37it/s]


Acc:0.250	AUROC:0.479	OpenAUC:0.124
==> Epoch 3/200	

100%|██████████| 265/265 [00:18<00:00, 14.42it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:02<00:00, 32.30it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 31.71it/s]


Acc:0.250	AUROC:0.416	OpenAUC:0.106
==> Epoch 4/200	

100%|██████████| 265/265 [00:17<00:00, 15.41it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:01<00:00, 34.39it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 32.02it/s]


Acc:0.250	AUROC:0.420	OpenAUC:0.109
==> Epoch 5/200	

100%|██████████| 265/265 [00:17<00:00, 14.91it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:01<00:00, 33.61it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 33.74it/s]


Acc:0.250	AUROC:0.500	OpenAUC:0.127
==> Epoch 6/200	

100%|██████████| 265/265 [00:17<00:00, 15.09it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:01<00:00, 33.62it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 33.66it/s]


Acc:0.250	AUROC:0.455	OpenAUC:0.120
==> Epoch 7/200	

100%|██████████| 265/265 [00:17<00:00, 14.94it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:02<00:00, 32.41it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 33.45it/s]


Acc:0.250	AUROC:0.442	OpenAUC:0.119
==> Epoch 8/200	

100%|██████████| 265/265 [00:17<00:00, 15.06it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:01<00:00, 33.97it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 33.53it/s]


Acc:0.355	AUROC:0.602	OpenAUC:0.242
==> Epoch 9/200	

100%|██████████| 265/265 [00:17<00:00, 14.95it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:01<00:00, 33.96it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 34.09it/s]


Acc:0.537	AUROC:0.628	OpenAUC:0.359
==> Epoch 10/200	

100%|██████████| 265/265 [00:17<00:00, 15.40it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:01<00:00, 34.34it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 34.32it/s]


Acc:0.769	AUROC:0.766	OpenAUC:0.638
==> Epoch 11/200	

100%|██████████| 265/265 [00:17<00:00, 14.92it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:01<00:00, 33.73it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 33.66it/s]


Acc:0.799	AUROC:0.728	OpenAUC:0.632
==> Epoch 12/200	

100%|██████████| 265/265 [00:17<00:00, 15.05it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:02<00:00, 30.98it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 31.96it/s]


Acc:0.866	AUROC:0.739	OpenAUC:0.682
==> Epoch 13/200	

100%|██████████| 265/265 [00:17<00:00, 15.05it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:01<00:00, 33.66it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 33.62it/s]


Acc:0.868	AUROC:0.707	OpenAUC:0.653
==> Epoch 14/200	

100%|██████████| 265/265 [00:17<00:00, 14.75it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:02<00:00, 33.41it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 33.64it/s]


Acc:0.888	AUROC:0.666	OpenAUC:0.628
==> Epoch 15/200	

100%|██████████| 265/265 [00:17<00:00, 15.05it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:01<00:00, 33.75it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 33.69it/s]


Acc:0.883	AUROC:0.813	OpenAUC:0.751
==> Epoch 16/200	

100%|██████████| 265/265 [00:17<00:00, 14.97it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:01<00:00, 33.78it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 33.81it/s]


Acc:0.903	AUROC:0.792	OpenAUC:0.751
==> Epoch 17/200	

100%|██████████| 265/265 [00:17<00:00, 15.29it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:02<00:00, 32.54it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 33.88it/s]


Acc:0.914	AUROC:0.760	OpenAUC:0.727
==> Epoch 18/200	

100%|██████████| 265/265 [00:17<00:00, 15.04it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:01<00:00, 33.76it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 33.66it/s]


Acc:0.914	AUROC:0.805	OpenAUC:0.768
==> Epoch 19/200	

100%|██████████| 265/265 [00:17<00:00, 15.11it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:01<00:00, 33.57it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 33.69it/s]


Acc:0.912	AUROC:0.752	OpenAUC:0.721
==> Epoch 20/200	

100%|██████████| 265/265 [00:17<00:00, 15.00it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:02<00:00, 33.37it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 33.49it/s]


Acc:0.917	AUROC:0.734	OpenAUC:0.698
==> Epoch 21/200	

100%|██████████| 265/265 [00:17<00:00, 15.09it/s]


==> Test	Forward pass through Closed Set test set...


100%|██████████| 67/67 [00:02<00:00, 33.41it/s]


Forward pass through Open Set test set...


100%|██████████| 86/86 [00:02<00:00, 33.53it/s]


Acc:0.927	AUROC:0.816	OpenAUC:0.786
==> Epoch 22/200	

 17%|█▋        | 46/265 [00:03<00:15, 14.33it/s]


KeyboardInterrupt: 