# Deep One Class Classification (DOC)

## Imports

In [1]:
import os
import time
import torch
import logging
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score

import torchvision.transforms as transforms
from torch.utils.data.sampler import BatchSampler


import warnings
warnings.filterwarnings('ignore')

device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')

In [2]:
print(device)

cuda


### For code optimization

In [3]:
# uncomment for code optimization
#torch.backends.cudnn.benchmark = True

NUM_WORKERS = int(os.cpu_count() / 2)
print('number of workers: ', NUM_WORKERS)

number of workers:  6


## 1. Helper Functions

### 1a. Evaluation (Loss, Eval Metrics)

In [4]:
# Loss Function
class HingeLoss(nn.Module):
    def __init__(self):
        super(HingeLoss, self).__init__()
    
    def forward(self, ypred, ytrue, margin = 1.0, smooth = False):
        ypred = ypred.squeeze()
        if smooth:
            loss = torch.nn.Softplus()
            out = torch.mean(loss(margin - (ytrue * ypred)))
        else:
            out = torch.mean(torch.relu(margin - (ytrue * ypred)))
        return out


def evalmetrics(y_true, scores):
    auc_score = roc_auc_score(y_true, scores)
    
    return auc_score


### 1b. Data Set Loaders

In [5]:
class BalancedBatchSampler(BatchSampler): 
    """
    BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples.
    Returns batches of size n_classes * n_samples
    # CODE SOURCE : https://discuss.pytorch.org/t/load-the-same-number-of-data-per-class/65198/4
    # LICENSE: NOT AVAILABLE!
    """
    def __init__(self, dataset, n_classes, n_samples, class_id = None, allSamples = False):
        loader = torch.utils.data.DataLoader(dataset)
        self.labels_list = []
        
        for _, label in loader:
            self.labels_list.append(label)
        
        self.labels = torch.LongTensor(self.labels_list)
        self.labels_set = list(set(self.labels.numpy()))
        print(self.labels_set)
        
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
                                 for label in self.labels_set}
        
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l])
            
        self.used_label_indices_count = {label: 0 for label in self.labels_set}
        self.count = 0
        self.n_classes = n_classes
        self.n_samples = n_samples
        self.dataset = dataset
        self.batch_size = self.n_samples * self.n_classes
        self.class_id = class_id
        self.allSamples = allSamples
 
    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < len(self.dataset):
            
            if self.class_id is not None:
                classes = self.class_id
            else:
                classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
            
            indices = []
            for class_ in classes:
                if not self.allSamples: 
                    indices.extend(self.label_to_indices[class_][
                                   self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                        class_] + self.n_samples])
                else: 
                    indices.extend(self.label_to_indices[class_])
                
                self.used_label_indices_count[class_] += self.n_samples
                
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
                    
            yield indices
            
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return len(self.dataset) // self.batch_size

## 2. Models

In [6]:
class CIFAR10_LeNet(nn.Module):
    def __init__(self):
        super(CIFAR10_LeNet, self).__init__()

        self.rep_dim = 128
        self.pool = nn.MaxPool2d(2, 2)

        self.conv1 = nn.Conv2d(3, 32, 5, bias=False, padding=2)
        self.bn2d1 = nn.BatchNorm2d(32, eps=1e-04, affine=False)
        self.conv2 = nn.Conv2d(32, 64, 5, bias=False, padding=2)
        self.bn2d2 = nn.BatchNorm2d(64, eps=1e-04, affine=False)
        self.conv3 = nn.Conv2d(64, 128, 5, bias=False, padding=2)
        self.bn2d3 = nn.BatchNorm2d(128, eps=1e-04, affine=False)
        self.fc1 = nn.Linear(128 * 4 * 4, self.rep_dim, bias=False)
        self.fc2 = nn.Linear(self.rep_dim, int(self.rep_dim/2), bias=False)
        self.fc3 = nn.Linear(int(self.rep_dim/2), 1, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(F.leaky_relu(self.bn2d1(x)))
        x = self.conv2(x)
        x = self.pool(F.leaky_relu(self.bn2d2(x)))
        x = self.conv3(x)
        x = self.pool(F.leaky_relu(self.bn2d3(x)))
        x = x.view(x.size(0), -1)
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        x = self.fc3(x)
        return x

## 3. Training Schedule

### All 10 Classes

In [7]:
# number of repetitions per class
# Results reported in paper are the summary statistics over 10 repetitions (i.e., n_reps = 10)
n_reps = 1

logging.basicConfig(filename='CIFAR10_DOC.log', level=logging.INFO)

# For every class in CIFAR10 dataset
for classNum in range(10):
    normalclass = classNum
    BATCH_SIZE = 256

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


    ###############################################################################################
    #                                     Train Data
    ###############################################################################################
    cifar_train =  torchvision.datasets.CIFAR10(root="./cifar/cifar_train", train=True, download=True, 
                                              transform=transform)

    balanced_batch_sampler = BalancedBatchSampler(cifar_train, 10, n_samples = BATCH_SIZE, class_id = [normalclass])
    train_loader = torch.utils.data.DataLoader(cifar_train, batch_sampler=balanced_batch_sampler, pin_memory=True, num_workers=NUM_WORKERS)


    ###############################################################################################
    #                                     Evaluations on :-
    ###############################################################################################
    # Test Data
    test_dat = torchvision.datasets.CIFAR10('./cifar', train=False, download=True, transform=transform)
    test_loader = torch.utils.data.DataLoader(dataset=test_dat, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=NUM_WORKERS)

    # Train (used to Evaluate i.e. TP_rate etc.)
    balanced_batch_sampler_eval = BalancedBatchSampler(cifar_train, 10, n_samples = BATCH_SIZE, class_id = [normalclass],allSamples=True)
    train_loader_eval = torch.utils.data.DataLoader(cifar_train, batch_sampler=balanced_batch_sampler_eval, pin_memory=True, num_workers=NUM_WORKERS)
    
    # See Paper's Appendix for the Optimal Parameters
    # K = 0, lam = 0.5, EPOCH = 300, SGD, lr = 0.005 
    # K = 1, lam = 1.0, EPOCH = 300, SGD, lr = 0.001
    # K = 2, lam = 0.5, Epoch = 300, SGD, lr = 0.005
    # K = 3, lam = 1.0, Epoch = 300, SGD,lr = 0.001
    # K = 4, lam = 1.0, Epoch = 300, SGD,lr = 0.001
    # K = 5, lam = 0.5, Epoch = 400, SGD,lr = 0.001
    # K = 6, lam = 0.5, Epoch = 300, SGD,lr = 0.001
    # K = 7, lam = 1.0, Epoch = 300, SGD,lr = 0.001
    # K = 8, lam = 0.5, Epoch = 300, SGD,lr = 0.001
    # K = 9, lam = 0.001, Epoch = 50, SGD,lr = 0.001

    if classNum in [1,3,4,7]:
        lam = 1.0
    elif classNum == 9:
        lam = 0.001
    else:
        lam = 0.5
        
    if classNum == 5:
        EPOCH = 400
    elif classNum == 9:
        EPOCH = 50
    else:
        EPOCH = 300
        
    if classNum == 0 or classNum == 2:
        lr = 0.005
    else:
        lr = 0.001
    
    print('----------------------------------------------')
    print('Class: '+str(classNum)+' lam: '+str(lam)+' # epochs: '+str(EPOCH)+' lr: '+str(lr))
    logging.info('----------------------------------------------')
    logging.info('Class: '+str(classNum)+' lam: '+str(lam)+' # epochs: '+str(EPOCH)+' lr: '+str(lr))
    
    for repetition in range(n_reps):
        print('**********************************')
        print('repetition number: '+str(repetition))
        logging.info('**********************************')
        logging.info('repetition number: '+str(repetition))

        net = CIFAR10_LeNet()
        net = net.to(device)

        optimizer = torch.optim.SGD(net.parameters(), lr = lr)  
        train_loss = HingeLoss()

        for epoch in range(EPOCH):

            loss_epoch = 0.0
            n_batches = 0
            epoch_start_time = time.time()

            for (i,data) in enumerate(train_loader):
                inputs, labels = data
                inputs = inputs.to(device)
                ytr = np.ones((len(inputs),1))
                ytr = torch.from_numpy(ytr).to(device).float()

                # Zero the network parameter gradients
                optimizer.zero_grad()

                # Update network parameters via backpropagation: forward + backward + optimize
                outputs = net(inputs) 
                lam = torch.tensor(lam).to(device)
                l2_reg = torch.tensor(0.).to(device)

                for name, param in net.named_parameters():
                    l2_reg += torch.norm(param)**2


                loss = train_loss(outputs,ytr)
                loss += lam * l2_reg

                loss.backward()
                optimizer.step()

                loss_epoch += loss.item()
                n_batches += 1

            if epoch % 50 == 0: # Every 50 Epochs!
                idx_label_score = []
                net.eval()
                with torch.no_grad():

                    for data in test_loader:
                        inputs, labels = data

                        labels = labels.cpu().data.numpy()
                        labels = (labels==normalclass).astype(int)

                        inputs = inputs.to(device)
                        outputs = net(inputs)
                        scores = outputs.data.cpu().numpy()
                        idx_label_score += list(zip(labels.tolist(),
                                                    scores.tolist()))


                tstlabels, scores = zip(*idx_label_score)
                tstlabels = np.array(tstlabels)
                tstlabels[np.where(tstlabels==0)]=-1.0
                scores = np.array(scores)
                Tst_auc_score = evalmetrics(tstlabels,scores.flatten())   

                # log epoch statistics
                epoch_train_time = time.time() - epoch_start_time
                print('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f} \t Test AUC :{:.4f} '
                            .format(epoch + 1, EPOCH, epoch_train_time, loss_epoch / n_batches, Tst_auc_score))
                logging.info('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f} \t Test AUC :{:.4f} '
                            .format(epoch + 1, EPOCH, epoch_train_time, loss_epoch / n_batches, Tst_auc_score))

        # Final Solution
        print(' Final (TOT. EPOCHS {})::  Loss: {:.8f} \t   AUC (Test) :{:.4f} '
                .format(EPOCH, loss_epoch / n_batches, Tst_auc_score))
        logging.info(' Final (TOT. EPOCHS {})::  Loss: {:.8f} \t   AUC (Test) :{:.4f} '
                .format(EPOCH, loss_epoch / n_batches, Tst_auc_score))

Files already downloaded and verified
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Files already downloaded and verified
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
----------------------------------------------
Class: 0 lam: 0.5 # epochs: 300 lr: 0.005
**********************************
repetition number: 0
  Epoch 1/300	 Time: 3.785	 Loss: 64.07535754 	 Test AUC :0.4961 
  Epoch 51/300	 Time: 1.004	 Loss: 0.17551526 	 Test AUC :0.7663 
  Epoch 101/300	 Time: 0.918	 Loss: 0.17170373 	 Test AUC :0.7697 
  Epoch 151/300	 Time: 0.911	 Loss: 0.17140419 	 Test AUC :0.7783 
  Epoch 201/300	 Time: 1.047	 Loss: 0.17258948 	 Test AUC :0.7727 
  Epoch 251/300	 Time: 0.955	 Loss: 0.17273657 	 Test AUC :0.7778 
 Final (TOT. EPOCHS 300)::  Loss: 0.17065016 	   AUC (Test) :0.7778 
Files already downloaded and verified
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Files already downloaded and verified
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
----------------------------------------------
Class: 1 lam: 1.0 # epochs: 300 lr: 0.001
*********************