<a href="https://colab.research.google.com/github/real-rookie/novelty-detection-algorithms-evaluation/blob/main/generic_inter_set.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# env
!pip install lightning
!pip install anomalib
!pip install OpenVINO
!pip install wandb

In [None]:
# unzip code and datasets
!unzip -o /content/drive/MyDrive/novelty-detection-algorithms-evaluation.zip -d /home/
%cd /home

In [None]:
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision.utils import save_image
import os
import random
import numpy as np

In [None]:
# set parameters

DATASET_INFO = {
    # idx 0: paths, idx 1: number of classes
    "MNIST": ["datasets/MNIST/images", 10],
    "FashionMNIST": ["datasets/FashionMNIST/images", 10],
    "CIFAR10": ["datasets/CIFAR10/images", 10],
}

dataset = "CIFAR10"
dataset_path = DATASET_INFO[dataset][0]
num_total_classes = DATASET_INFO[dataset][1]

In [None]:
# make datasets
%cd /home
if dataset in ["MNIST", "FashionMNIST", "CIFAR10"]:
    os.system(f"rm -rf {dataset_path}")
    os.system(f"mkdir -p {dataset_path}/train")
    os.system(f"mkdir -p {dataset_path}/test/normal")
    os.system(f"mkdir -p {dataset_path}/test/novel")

In [None]:
train_data = None
test_data = None
if dataset == "MNIST":
    train_data = datasets.MNIST(root="datasets", train=True, download=True, transform=ToTensor())
    test_data = datasets.MNIST(root="datasets", train=False, download=True, transform=ToTensor())
elif dataset == "FashionMNIST":
    train_data = datasets.FashionMNIST(root="datasets", train=True, download=True, transform=ToTensor())
    test_data = datasets.FashionMNIST(root="datasets", train=False, download=True, transform=ToTensor())
elif dataset == "CIFAR10":
    train_data = datasets.CIFAR10(root="datasets/CIFAR10", train=True, download=True, transform=ToTensor())
    test_data = datasets.CIFAR10(root="datasets/CIFAR10", train=False, download=True, transform=ToTensor())
else:
    print("Wrong dataset specified")
    os.abort()

In [None]:
train_counter = np.zeros(num_total_classes, dtype=int)
test_counter = np.zeros(num_total_classes, dtype=int)
if dataset in ["MNIST", "FashionMNIST", "CIFAR10"]:
    for img, label in train_data:
        if label < np.floor(DATASET_INFO[dataset][1] / 2).astype(int):
            save_image(img, f"{dataset_path}/train/{label}_{train_counter[label]}.png")
            train_counter[label] += 1
    for img, label in test_data:
        if label < np.floor(DATASET_INFO[dataset][1] / 2).astype(int):
            save_image(img, f"{dataset_path}/test/normal/{label}_{test_counter[label]}.png")
        else:
            save_image(img, f"{dataset_path}/test/novel/{label}_{test_counter[label]}.png")
        test_counter[label] += 1
print(f"train: {train_counter}")
print(f"test: {test_counter}")

In [None]:
# train and testing
%cd /home/novelty-detection-algorithms-evaluation
!python generic_inter_set.py --mode train --data CIFAR10 --model RD4AD

In [None]:
!python generic_inter_set.py --mode test --data CIFAR10 --model RD4AD