In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# cd ''

In [None]:
from AnomalyDetection.dataset import MVTecDataset
from AnomalyDetection.utils import compute_distance_matrix, denormalize, concatenate_embeddings

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.models import wide_resnet50_2
from torchvision import transforms as T

import os
from tqdm import tqdm
import pickle

import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
from PIL import Image

from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
from sklearn.metrics import precision_recall_curve
from scipy.ndimage import gaussian_filter

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
model = wide_resnet50_2(pretrained=True, progress=True)
model.to(device)
model.eval()
print()

Downloading: "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth" to /root/.cache/torch/hub/checkpoints/wide_resnet50_2-95faca4d.pth
100%|██████████| 132M/132M [00:01<00:00, 134MB/s]





In [None]:
CLASS_NAMES = ['bottle', 'cable', 'capsule', 'carpet', 'grid',
               'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
               'tile', 'toothbrush', 'transistor', 'wood', 'zipper']

## SPADE

In [None]:
save_path = 'result_spade_20'
top_k = 20

In [None]:
os.makedirs(os.path.join(save_path, 'temp'), exist_ok=True)

In [None]:
from AnomalyDetection.base import get_train_features, get_test_features 
from AnomalyDetection.spade import spade_localization

In [None]:
for class_name in CLASS_NAMES:
    train_dataset = MVTecDataset(class_name=class_name, train=True)
    train_dataloader = DataLoader(train_dataset, batch_size=32, pin_memory=True)
    test_dataset = MVTecDataset(class_name=class_name, train=False)
    test_dataloader = DataLoader(test_dataset, batch_size=32, pin_memory=True)

    train_outputs = get_train_features(model, train_dataloader, class_name, device)
    test_outputs, gt_list, gt_mask_list, test_imgs = get_test_features(model, test_dataloader, class_name, device)


    dist_matrix = compute_distance_matrix(torch.flatten(test_outputs['avgpool'], 1),
                                    torch.flatten(train_outputs['avgpool'], 1), device=device)


    topk_values, topk_indexes = torch.topk(dist_matrix, k=top_k, dim=1, largest=False)
    scores = torch.mean(topk_values, 1).cpu().detach().numpy()
 
    fpr, tpr, _ = roc_curve(gt_list, scores)
    roc_auc = roc_auc_score(gt_list, scores)
    print(f'{class_name} ROCAUC: {roc_auc}')    
    
    # anomaly localization 
    score_map_list = spade_localization(train_outputs, test_outputs, topk_indexes, class_name)

    flatten_gt_mask_list = np.concatenate(gt_mask_list).ravel()
    flatten_score_map_list = np.concatenate(score_map_list).ravel()

    fpr, tpr, _ = roc_curve(flatten_gt_mask_list, flatten_score_map_list)
    per_pixel_rocauc = roc_auc_score(flatten_gt_mask_list, flatten_score_map_list)
    print(f'{class_name} pixel ROCAUC: {per_pixel_rocauc}')

    precision, recall, thresholds = precision_recall_curve(flatten_gt_mask_list, flatten_score_map_list)
    num = 2 * precision * recall
    denom = precision + recall
    f1 = np.divide(num, denom, out=np.zeros_like(num), where=denom != 0)
    threshold = thresholds[np.argmax(f1)]


| feature extraction | train | bottle |: 100%|██████████| 7/7 [00:22<00:00,  3.20s/it]
| feature extraction | test | bottle |: 100%|██████████| 3/3 [01:19<00:00, 26.61s/it]


bottle ROCAUC: 0.9642857142857142


| localization | test | bottle |: 100%|██████████| 83/83 [02:46<00:00,  2.01s/it]


bottle pixel ROCAUC: 0.9752557312700088


| feature extraction | train | cable |:  86%|████████▌ | 6/7 [00:18<00:03,  3.15s/it]


KeyboardInterrupt: ignored

## SPADE Transformer

In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.29.0-py3-none-any.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m65.3 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0 (from transformers)
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m19.1 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m107.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.14.1 tokenizers-0.13.3 transformers-4.29.0


In [None]:
save_path = 'result_10_transformer'
top_k = 10

In [None]:
os.makedirs(os.path.join(save_path, 'temp'), exist_ok=True)

In [None]:
from AnomalyDetection.transformer_base import get_train_features, get_test_features 
from AnomalyDetection.spade import spade_localization

In [1]:
for class_name in CLASS_NAMES:
    train_dataset = MVTecDataset(class_name=class_name, train=True)
    train_dataloader = DataLoader(train_dataset, batch_size=32, pin_memory=True)
    test_dataset = MVTecDataset(class_name=class_name, train=False)
    test_dataloader = DataLoader(test_dataset, batch_size=32, pin_memory=True)

    train_outputs = get_train_features(train_dataloader, class_name, device)
    test_outputs, gt_list, gt_mask_list, test_imgs = get_test_features(test_dataloader, class_name, device)


    dist_matrix = compute_distance_matrix(test_outputs['avgpool'], train_outputs['avgpool'], device=device)
    topk_values, topk_indexes = torch.topk(dist_matrix, k=top_k, dim=1, largest=False)
    scores = torch.mean(topk_values, 1).cpu().detach().numpy()
    
    fpr, tpr, _ = roc_curve(gt_list, scores)
    roc_auc = roc_auc_score(gt_list, scores)
    print(f'{class_name} ROCAUC: {roc_auc}')

## PaDiM

In [None]:
import torch
import torch.nn.functional as F
import random
from scipy.spatial.distance import mahalanobis

In [None]:
from AnomalyDetection.base import get_train_features, get_test_features 
from AnomalyDetection.padim import gaussian_train, gaussian_test

In [None]:
save_path = 'result_padim_exp'

In [None]:
os.makedirs(os.path.join(save_path, 'temp'), exist_ok=True)

In [None]:
t_d = 1792
d = 550
idx = torch.tensor(random.sample(range(0, t_d), d))

In [None]:
for class_name in CLASS_NAMES:

    train_dataset = MVTecDataset(class_name=class_name, train=True)
    train_dataloader = DataLoader(train_dataset, batch_size=32, pin_memory=True)
    test_dataset = MVTecDataset(class_name=class_name, train=False)
    test_dataloader = DataLoader(test_dataset, batch_size=32, pin_memory=True)
    
    for x, y, mask in test_dataloader:
        image_dims = x.shape
        break
        
    train_outputs = get_train_features(model, train_dataloader, class_name, device)
    test_outputs, gt_list, gt_mask_list, test_imgs = get_test_features(model, test_dataloader, class_name, device)
   
    train_outputs = gaussian_train(train_outputs, idx, device)
    scores = gaussian_test(train_outputs, test_outputs, idx, image_dims, device)

    img_scores = scores.reshape(scores.shape[0], -1).max(axis=1)
    gt_list = np.asarray(gt_list)
    fpr, tpr, _ = roc_curve(gt_list, img_scores)
    img_roc_auc = roc_auc_score(gt_list, img_scores)
    print(f'Anomaly Detection ROCAUC: {img_roc_auc}')

    gt_mask = np.asarray(gt_mask_list)
    precision, recall, thresholds = precision_recall_curve(gt_mask.flatten(), scores.flatten())
    num = 2 * precision * recall
    denom = precision + recall
    f1 = np.divide(num, denom, out=np.zeros_like(num), where=denom != 0)
    threshold = thresholds[np.argmax(f1)]

    fpr, tpr, _ = roc_curve(gt_mask.flatten(), scores.flatten())
    per_pixel_rocauc = roc_auc_score(gt_mask.flatten(), scores.flatten())
    print('Localization ROCAUC: %.3f' % (per_pixel_rocauc))

| feature extraction | train | bottle |: 100%|██████████| 7/7 [00:10<00:00,  1.45s/it]
| feature extraction | test | bottle |: 100%|██████████| 3/3 [00:03<00:00,  1.23s/it]
