## Towards Accurate Open-Set Recognition via Background-Class Regularization
----
**Wonwoo Cho and Jaegul Choo** In European Conference on Computer Vision (ECCV), 2022

This notebook provides sample training and inference processes of our proposed method in a sample text classification experiment.

In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.autograd import Variable as V
import torchtext

from torchtext.legacy import data
from torchtext.legacy import datasets

## Load datasets

In [3]:
TEXT = data.Field(pad_first=True, lower=True, fix_length=100)
LABEL = data.Field(sequential=False)

train = data.TabularDataset(path='./data/20newsgroups/20ng-train.txt',
                                 format='csv',
                                 fields=[('label', LABEL), ('text', TEXT)])

test = data.TabularDataset(path='./data/20newsgroups/20ng-test.txt',
                                 format='csv',
                                 fields=[('label', LABEL), ('text', TEXT)])

TEXT.build_vocab(train, max_size=10000)
LABEL.build_vocab(train, max_size=10000)

train_iter = data.BucketIterator(train, batch_size=64, repeat=False)
test_iter = data.BucketIterator(test, batch_size=64, repeat=False)

In [4]:
TEXT_custom = data.Field(pad_first=True, lower=True)

custom_data = data.TabularDataset(path='./data/wikitext_reformatted/wikitext103_sentences',
                                  format='csv',
                                  fields=[('text', TEXT_custom)])

TEXT_custom.build_vocab(train.text, max_size=10000)
print('vocab length (including special tokens):', len(TEXT_custom.vocab))

train_iter_oe = data.BucketIterator(custom_data, batch_size=64, repeat=False)

vocab length (including special tokens): 10002


## Initializing model and class-wise anchors

In [68]:
class ClfGRU(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(len(TEXT.vocab), 50, padding_idx=1)
        self.gru = nn.GRU(input_size=50, hidden_size=128, num_layers=2,
            bias=True, batch_first=True,bidirectional=False)
        self.linear = nn.Linear(128, num_classes)
        
    def initialize(self, means):
        self.register_buffer("means", nn.Parameter(means, requires_grad=True))

    def forward(self, x):
        embeds = self.embedding(x)
        hidden = self.gru(embeds)[1][1]
        logits = self.linear(hidden)
        return hidden, logits

means = torch.randn((20, 128), requires_grad=True).cuda()
model = ClfGRU(20).cuda()
model.initialize(means)

In [69]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer2 = torch.optim.Adam([model.means], lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer2, T_max=15)

## Training function

In [70]:
def train_f(model):
    model.train()
    data_loss_ema = 0
    oe_loss_ema = 0

    for batch_idx, (batch, batch_oe) in enumerate(zip(iter(train_iter), iter(train_iter_oe))):
        inputs = batch.text.t().cuda()
        labels = (batch.label - 1).cuda().long()
        feature, logits = model(inputs)

        inputs_oe = batch_oe.text.t().cuda()
        feature_oe, logits_oe = model(inputs_oe)
        labels_oe = (torch.ones(inputs_oe.shape[0]) * 20).cuda().long()

        feat = torch.cat((feature, feature_oe))
        indlen = feature.shape[0]


        distance = torch.norm(feat.unsqueeze(1) - model.means, p=2, dim=2)**2
        scores = torch.igammac(torch.tensor(feat.shape[-1]/2).cuda(), distance/2)

        _, index = torch.max(scores, dim=1)
        max_scores = scores * F.one_hot(index, num_classes=20)
        mod_scores = torch.cat((max_scores, 1-torch.sum(max_scores, dim=1).unsqueeze(1)),1)

        mod_scores = torch.clamp(mod_scores, 1e-31, 1 - 1e-31)
        oe_loss = F.nll_loss(torch.log(mod_scores), torch.cat((labels, labels_oe), dim=0))

        data_loss = F.cross_entropy(-distance[:indlen]/2, labels)

        loss = data_loss + 5 * oe_loss

        optimizer.zero_grad()
        optimizer2.zero_grad()
        loss.backward()
        
        optimizer.step()
        optimizer2.step()

        data_loss_ema = data_loss_ema * 0.9 + data_loss.data.cpu().numpy() * 0.1
        oe_loss_ema = oe_loss_ema * 0.9 + oe_loss.data.cpu().numpy() * 0.1

        if (batch_idx % 200 == 0 or batch_idx < 10):
            print('iter: {} \t| data_loss_ema: {} \t| oe_loss_ema: {}'.format(
                batch_idx, data_loss_ema, oe_loss_ema))

    scheduler.step()
    scheduler2.step()

## Evaluation function

In [71]:
def evaluate(model):
    model.eval()
    running_loss = 0
    num_examples = 0
    correct = 0

    for batch_idx, batch in enumerate(iter(test_iter)):
        inputs = batch.text.t().cuda()
        labels = (batch.label - 1).cuda()

        feature, logits = model(inputs)
        
        distance = torch.norm(feature.unsqueeze(1) - model.means, p=2, dim=2)**2
        dist_preds = torch.argmax(-distance/2, dim=1)
            
        acc_ind = (dist_preds == labels)
        correct += acc_ind.sum().data.cpu().numpy()
        num_examples += inputs.shape[0]

    acc = correct / num_examples
    loss = running_loss / num_examples

    return acc, loss

## Train and Validation

In [73]:
print('test acc: {} \t| test loss: {}\n'.format(acc, loss))
for epoch in range(15):
    print('Epoch', epoch)
    train_f(model)
    acc, loss = evaluate(model)
    print('test acc: {} \t| test loss: {}\n'.format(acc, loss))

test acc: 0.05167375132837407 	| test loss: 0.0

Epoch 0
iter: 0 	| data_loss_ema: 1.3363681793212892 	| oe_loss_ema: 3.5393924713134766
iter: 1 	| data_loss_ema: 2.379087295532227 	| oe_loss_ema: 6.45477466583252
iter: 2 	| data_loss_ema: 3.059560829162598 	| oe_loss_ema: 9.216882724761964
iter: 3 	| data_loss_ema: 3.8100886825561524 	| oe_loss_ema: 11.755770700454715
iter: 4 	| data_loss_ema: 4.389196296081543 	| oe_loss_ema: 13.877079693031314
iter: 5 	| data_loss_ema: 4.707281022857666 	| oe_loss_ema: 15.895047993259434
iter: 6 	| data_loss_ema: 5.024770608178711 	| oe_loss_ema: 17.659764919397357
iter: 7 	| data_loss_ema: 5.378938082425538 	| oe_loss_ema: 19.35660581027012
iter: 8 	| data_loss_ema: 5.5537258537728285 	| oe_loss_ema: 20.93909021825678
iter: 9 	| data_loss_ema: 5.730370165597083 	| oe_loss_ema: 22.201275556782665
test acc: 0.6065356004250797 	| test loss: 0.0

Epoch 1
iter: 0 	| data_loss_ema: 0.1047799825668335 	| oe_loss_ema: 1.0684310913085937
iter: 1 	| data_los

## Load Unknown-Class Datasets

In [44]:
TEXT_20ng = data.Field(pad_first=True, lower=True, fix_length=100)
LABEL_20ng = data.Field(sequential=False)

train_20ng = data.TabularDataset(path='./data/20newsgroups/20ng-train.txt',
                                 format='csv',
                                 fields=[('label', LABEL_20ng), ('text', TEXT_20ng)])

test_20ng = data.TabularDataset(path='./data/20newsgroups/20ng-test.txt',
                                 format='csv',
                                 fields=[('label', LABEL_20ng), ('text', TEXT_20ng)])

TEXT_20ng.build_vocab(train_20ng, max_size=10000)
LABEL_20ng.build_vocab(train_20ng, max_size=10000)
print('vocab length (including special tokens):', len(TEXT_20ng.vocab))

train_iter_20ng = data.BucketIterator(train_20ng, batch_size=64, repeat=False)
test_iter_20ng = data.BucketIterator(test_20ng, batch_size=64, repeat=False)

vocab length (including special tokens): 10002


In [27]:
TEXT_m30k = data.Field(pad_first=True, lower=True)

m30k_data = data.TabularDataset(path='./data/multi30k/train.txt',
                                  format='csv',
                                  fields=[('text', TEXT_m30k)])

TEXT_m30k.build_vocab(train_20ng.text, max_size=10000)
print('vocab length (including special tokens):', len(TEXT_m30k.vocab))

train_iter_m30k = data.BucketIterator(m30k_data, batch_size=64, repeat=False)

vocab length (including special tokens): 10002


In [12]:
TEXT_wmt16 = data.Field(pad_first=True, lower=True)

wmt16_data = data.TabularDataset(path='./data/wmt16/wmt16_sentences',
                                  format='csv',
                                  fields=[('text', TEXT_wmt16)])

TEXT_wmt16.build_vocab(train_20ng.text, max_size=10000)
print('vocab length (including special tokens):', len(TEXT_wmt16.vocab))

train_iter_wmt16 = data.BucketIterator(wmt16_data, batch_size=64, repeat=False)

vocab length (including special tokens): 10002


In [13]:
# set up fields
TEXT_imdb = data.Field(pad_first=True, lower=True)
LABEL_imdb = data.Field(sequential=False)

# make splits for data
train_imdb, test_imdb = datasets.IMDB.splits(TEXT_imdb, LABEL_imdb)

# build vocab
TEXT_imdb.build_vocab(train_20ng.text, max_size=10000)
LABEL_imdb.build_vocab(train_imdb, max_size=10000)
print('vocab length (including special tokens):', len(TEXT_imdb.vocab))

vocab length (including special tokens): 10002


In [14]:
# make iterators
train_iter_imdb, test_iter_imdb = data.BucketIterator.splits(
    (train_imdb, test_imdb), batch_size=64, repeat=False)

## FPR95 Measure

In [45]:
def stable_cumsum(arr, rtol=1e-05, atol=1e-08):
    """Use high precision for cumsum and check that final value matches sum
    Parameters
    ----------
    arr : array-like
        To be cumulatively summed as flat
    rtol : float
        Relative tolerance, see ``np.allclose``
    atol : float
        Absolute tolerance, see ``np.allclose``
    """
    out = np.cumsum(arr, dtype=np.float64)
    expected = np.sum(arr, dtype=np.float64)
    if not np.allclose(out[-1], expected, rtol=rtol, atol=atol):
        raise RuntimeError('cumsum was found to be unstable: '
                           'its last element does not correspond to sum')
    return out

def fpr_and_fdr_at_recall(y_true, y_score, recall_level=0.95, pos_label=None):
    classes = np.unique(y_true)
    if (pos_label is None and
            not (np.array_equal(classes, [0, 1]) or
                     np.array_equal(classes, [-1, 1]) or
                     np.array_equal(classes, [0]) or
                     np.array_equal(classes, [-1]) or
                     np.array_equal(classes, [1]))):
        raise ValueError("Data is not binary and pos_label is not specified")
    elif pos_label is None:
        pos_label = 1.

    # make y_true a boolean vector
    y_true = (y_true == pos_label)

    # sort scores and corresponding truth values
    desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1]
    y_score = y_score[desc_score_indices]
    y_true = y_true[desc_score_indices]

    # y_score typically has many tied values. Here we extract
    # the indices associated with the distinct values. We also
    # concatenate a value for the end of the curve.
    distinct_value_indices = np.where(np.diff(y_score))[0]
    threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]

    # accumulate the true positives with decreasing threshold
    tps = stable_cumsum(y_true)[threshold_idxs]
    fps = 1 + threshold_idxs - tps      # add one because of zero-based indexing

    thresholds = y_score[threshold_idxs]

    recall = tps / tps[-1]

    last_ind = tps.searchsorted(tps[-1])
    sl = slice(last_ind, None, -1)      # [last_ind::-1]
    recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl]

    cutoff = np.argmin(np.abs(recall - recall_level))

    return fps[cutoff] / (np.sum(np.logical_not(y_true)))

## OSR Evaluation

In [74]:
score = []
true = []
y_true = []
y_pred = []

model.eval()
with torch.no_grad():
    for batch_idx, batch in enumerate(iter(test_iter)):
        inputs = batch.text.t().cuda()
        labels = (batch.label - 1).cuda()

        feature, logits = model(inputs)

        distance = torch.norm(feature.unsqueeze(1) - model.means, p=2, dim=2)**2       
        min_dist, _ = torch.min(distance, dim=1)
            
        logits_all = -distance/2
        preds_all = torch.argmax(logits_all, dim=1)
            
        y_true = np.append(y_true, labels.detach().cpu().numpy())
        y_pred = np.append(y_pred, preds_all.detach().cpu().numpy())

        score = np.append(score, min_dist.detach().cpu().numpy())
        true = np.append(true, np.zeros_like(min_dist.detach().cpu().numpy()))
        
        if batch_idx >= 116:
            break
            
    for batch_idx, batch in enumerate(iter(train_iter_imdb)):
        inputs = batch.text.t().cuda()
        feature, logits = model(inputs)

        distance = torch.norm(feature.unsqueeze(1) - model.means, p=2, dim=2)**2       
        min_dist, _ = torch.min(distance, dim=1)
            
        logits_all = -distance/2
        preds_all = torch.argmax(logits_all, dim=1)
        
        y_true = np.append(y_true, labels.detach().cpu().numpy())
        y_pred = np.append(y_pred, preds_all.detach().cpu().numpy())
        
        score = np.append(score, min_dist.detach().cpu().numpy())
        true = np.append(true, np.ones_like(min_dist.detach().cpu().numpy()))
        
        if batch_idx >= 116:
            break

In [80]:
from sklearn.metrics import roc_auc_score, average_precision_score
print("AUROC:", roc_auc_score(true, score))
print("AUPR:", average_precision_score(true, score))
print("FPR95:",fpr_and_fdr_at_recall(true, score))

AUROC: 0.9990908282834529
AUPR: 0.9990259424232494
FPR95: 0.00334941050375134


In [77]:
from sklearn import metrics
threshold = np.percentile(score[64*117:], 10)
print(threshold)

(score[64*117:] < threshold).sum() / len(score[64*117:])

182.2787353515625


0.10008038585209003

In [78]:
import copy
index = np.where(score > threshold)
y_pred2 = copy.deepcopy(y_pred)
y_pred2[index] = 20

print((y_true[:64*117] == y_pred2[:64*117]).sum()/(64*117))

0.750801282051282
