In [23]:
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, MNIST
import torchvision.transforms as transforms

from pytorch_ood.detector import OpenMax, MCD,ODIN, odin_preprocessing, Mahalanobis, \
    EnergyBased, Entropy, MaxLogit, KLMatching, ViM, odin, MaxSoftmax, KNN
from pytorch_ood.model import WideResNet
from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed, extract_features

import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os

class CustomDataset(Dataset):
    def __init__(self, txt_path, img_dir, transform=None, target_transform=None):
        """
        Args:
            txt_path (string): Path to the txt file with image paths and labels.
            img_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

        # Read the txt file
        self.img_labels = []
        with open(txt_path, 'r') as f:
            for line in f:
                parts = line.strip().split(' ')
                self.img_labels.append((parts[0], int(parts[1])))

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path, label = self.img_labels[idx]
        full_img_path = os.path.join(self.img_dir, img_path)
        image = Image.open(full_img_path).convert('RGB')  # Convert to RGB in case of grayscale images

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            label = self.target_transform(label)

        return image, label

In [24]:
fix_random_seed(123)
device = "cuda:0"

In [25]:
trans_mnist = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

trans_cifar10 = WideResNet.transform_for("cifar10-pt")
trans_cifar100 = WideResNet.transform_for("cifar100-pt")

img_dir = 'data/images_classic'

dataset_train = CustomDataset(txt_path='data/benchmark_imglist/cifar100/train_cifar100.txt', 
                              img_dir=img_dir, 
                              transform=trans_cifar100)

dataset_in_test = CustomDataset(txt_path='data/benchmark_imglist/cifar100/test_cifar100.txt', 
                             img_dir=img_dir, 
                             transform=trans_cifar100)

train_loader = DataLoader(dataset_train, batch_size=128, shuffle=True)
test_loader_id = DataLoader(dataset_in_test, batch_size=128)

In [26]:
# Stage 1: Create DNN pre-trained on CIFAR100
model = WideResNet(num_classes=100, pretrained='cifar100-pt')
model = model.to(device)

In [27]:
# import torch.optim as optim

# criterion = torch.nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.01)

In [28]:
# num_epochs = 10
# for epoch in range(num_epochs):
#     model.train()
#     running_loss = 0.0
    
#     for images, labels in train_loader:
#         images, labels = images.to(device), labels.to(device)
        
#         optimizer.zero_grad()
#         outputs = model(images)
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()
        
#         running_loss += loss.item()
    
#     print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

In [29]:
model = model.eval().to(device)

In [30]:
# correct = 0
# total = 0
# with torch.no_grad():
#     for images, labels in test_loader_id:
#         images, labels = images.to(device), labels.to(device)
#         outputs = model(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()

# accuracy = 100 * correct / total
# print(f'Accuracy of the model on the test images: {accuracy}%')

In [31]:
# Stage 2: Create and Fit OpenMax

# detector = MaxSoftmax(model)
detector = KNN(model)
# detector = OpenMax(model, tailsize=25, alpha=5, euclid_weight=0.5)
# detector = MCD(model, mode="mean")
# detector = ODIN(model)
# detector = Mahalanobis(model.features)
# detector = EnergyBased(model)
# detector = Entropy(model)
# detector = MaxLogit(model)
# detector = KLMatching(model)
# detector = ViM(model.features, d=10, w=model.fc.weight, b=model.fc.bias)

detector.fit(train_loader, device=device)

<pytorch_ood.detector.knn.KNN at 0x7fba1bdac890>

In [32]:

directory_list = [
    'data/benchmark_imglist/cifar10/test_cifar10.txt',
    'data/benchmark_imglist/mnist/test_mnist.txt',
    'data/benchmark_imglist/cifar100/test_places365.txt',
    'data/benchmark_imglist/cifar100/test_svhn.txt',
    'data/benchmark_imglist/cifar100/test_texture.txt',
    'data/benchmark_imglist/cifar100/test_tin.txt'
]

metrics_list = []

for data_dir in directory_list:
    dataset_out_test = CustomDataset(txt_path=data_dir,
                                    img_dir=img_dir,
                                    transform=trans_cifar100,
                                    target_transform=ToUnknown())

    test_loader = DataLoader(dataset_in_test+dataset_out_test, batch_size=128)
    
    # Stage 3: Evaluate Detectors
    metrics = OODMetrics()

    for x, y in test_loader:
        metrics.update(detector(x.to(device)), y)

    metrics_list.append(metrics.compute())

In [33]:
for metric in metrics_list:
    print("%.5f" % metric['AUROC'], end=",")

0.71527,0.67423,0.70965,0.83286,0.79731,0.78375,