# Using pre-trained feature extractor for optical anomaly detection

In [None]:
# To autoreload external functions
%load_ext autoreload
%autoreload 2

In [None]:
from typing import Optional
import os
from pathlib import Path
from PIL import Image
import math
import numpy as np
import cv2
from sklearn.metrics import roc_curve, auc
from scipy.ndimage import gaussian_filter
from pyod.models.lunar import LUNAR

import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor
import timm


from bokeh.plotting import figure, show
import bokeh

import rootutils
root = rootutils.setup_root(Path.cwd(), dotenv=True, pythonpath=True, cwd=False)

from src.visualization.utils import save_plot_from_notbook_for_jekyll, bokeh_notebook_setup, save_plot_from_notebook_to_html
from src.visualization.image import plot_img_rgba, add_bboxes_on_img, add_seg_on_img, plot_img_scalar, add_score_map_on_img

## Setup

In [None]:
bokeh_notebook_setup()

In [None]:
data_path = Path("../data/raw/transistor")
output_path = Path("./logs")

## Introduction

- task similar to previous post
  - we don't have regular patterns anymore
  - we want to have a method that can be applied to several tasks
- previous research resulted in high effectiveness for features extracted from Deep Learning models pre-trained on ImageNet
  - reference to SPADE, Gaussian AD, PaDim and PatchCore
- we will be following the approach of these papers but change a few components

## Dataset

Like in the previous post, we will use the [MVTec anomaly detection dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) which you can download from the website.
The dataset contains 15 different categories. For the examples in this post we will use the 'Transistor' category.

Here is a normal example without anomaly

In [None]:
img_path = data_path / "train/good/000.png"

img = Image.open(img_path)
img = img.convert("RGBA")

p = plot_img_rgba(img)
show(p)

and in contrast an anomalous example

In [None]:
img_path = data_path / "test/bent_lead/000.png"
seg_path = data_path / "ground_truth/bent_lead/000_mask.png"

img = Image.open(img_path)
img = img.convert("RGBA")

seg = Image.open(seg_path)
seg = np.array(seg)

p = plot_img_rgba(img)
p = add_seg_on_img(p, seg)

show(p)

## Feature Extraction

TODO: Update using native torchvision models with the new API (e.g. automatically load transforms): See [MODELS AND PRE-TRAINED WEIGHTS](https://pytorch.org/vision/stable/models.html)

Like in the PaDim or PatchCore paper we are going to extract features for each image patch of the training set using a neural network architecture for vision tasks pre-trained on the ImageNet dataset. The patch size is determined by our choice for the network layer.
To do the feature extraction we use the PyTorch `feature_extraction` package [based on Torch FX](https://pytorch.org/blog/FX-feature-extraction-torchvision/).
The goal of this post is to demonstrate the principle rather than optimizing our approach to the dataset. Hence, we will simplify many steps compared to the paper.

For the backbone we pick ResNet 34 as it is simple to use and doesn't require much memory.

In [None]:
backbone = timm.create_model("resnet34", pretrained=True)

In the papers features from several layers were combined. To keep it simple, we will use only one layer.
To see the available layer names for feature extraction you can use

In [None]:
train_nodes, eval_nodes = get_graph_node_names(backbone)

Looking at `train_nodes` or `eval_nodes`, you will see that ResNet 34 has 4 main blocks. If you just want to pick the last node of a block, the feature_extraction module allows you to use truncated node names. We will use `'layer2'` to get the last node of all the `layer2.x.ops` nodes. We choose layer 2 as a compromise between having expressive high-level features but still a somewhat high spatial feature map resolution.

In [None]:
feature_extractor = create_feature_extractor(backbone, return_nodes=["layer2"])

As in the paper, we fix the weights to the pre-trained ImageNet weights. Hence, we can turn off gradient computation to save memory

In [None]:
for param in feature_extractor.parameters():
    param.requires_grad = False

To simplify experimenting with different configurations, we use a Config object

In [None]:
class Config:
    img_shape = (256, 256)  # height, width (multiple of 16)
    batch_size = 4
    num_workers: int = 2  # adjust to the number of processing cores you want to use
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    red_factor = 8  # spatial reduction factor (depends on the chosen ResNet layer)
    n_feats = 128  # number of features (depends on the chosen ResNet layer)

TODO: shorten the explanation. I can't assume ppl have read the paper

To save the features we will follow the memory bank approach from PatchCore. This means we will save the extracted features into a large array without linking them to the original patch location. This is in contrast to the PaDim approach. The advantage is that our approach becomes more robust to rotations, translations and other variations of the objects in the dataset. The disadvantage is that the number of feature vectors we have to compare each patch to becomes quite large. 
TODO brief calculation: N images * Height * Width of feature maps
To use a method like nearest neighbor lookup like in the PatchCore paper, this requires tricks like the coreset reduction. 
We will get around this by choosing a different anomaly detection approach.

First we create a pyTorch Dataset object to hold the image data and specify the necessary transformations

In [None]:
class TrainDataset(Dataset):
    def __init__(self, data_path: os.PathLike, transforms: Optional[A.Compose] = None):
        super(TrainDataset).__init__()

        self.img_paths = list(data_path.iterdir())
        self.transforms = transforms

    def __getitem__(self, index: int):
        img_path = self.img_paths[index]

        img = Image.open(img_path)
        img = img.convert("RGB")
        img = np.array(img)

        if self.transforms:
            img = self.transforms(image=img)["image"]

        return img

    def __len__(self) -> int:
        return len(self.img_paths)

As we are using a backbone network pre-trained on ImageNet, we need to apply the same normalization transformations

In [None]:
train_path = data_path / "train/good"

default_transforms = A.Compose(
    [
        A.Resize(Config.img_shape[0], Config.img_shape[1]),
        A.Normalize(
            mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0
        ),
        ToTensorV2(),
    ]
)

train_ds = TrainDataset(train_path, transforms=default_transforms)

Afterwards, we create the DataLoader object to feed our data to the feature extractor

In [None]:
train_dl = DataLoader(
    train_ds,
    batch_size=Config.batch_size,
    shuffle=False,
    num_workers=Config.num_workers,
)

We create a function containing the logic to call the feature extractor with a batch of images and collect the resulting features:

In [None]:
# Debug (TODO: remove)
# imgs = next(iter(train_dl))
# print(feature_extractor(imgs)['layer2'].shape)

In [None]:
def get_features(imgs, extractor, cfg):
    imgs = imgs.to(cfg.device)

    with torch.no_grad():
        feature_dict = extractor(imgs)

    layer = list(feature_dict.keys())[0]

    feats = feature_dict[layer]

    feats = feats.cpu().numpy()
    feats = np.transpose(feats, (0, 2, 3, 1))
    feats = feats.reshape(-1, cfg.n_feats)

    return feats

Finally, we can put everything together to compute the feature memory bank

In [None]:
h, w = Config.img_shape[:2]

# ResNet resolution reduction of target layer
h_layer = math.ceil(h / Config.red_factor)
w_layer = math.ceil(w / Config.red_factor)

memory_bank_size = len(train_ds) * h_layer * w_layer
memory_bank = np.empty((memory_bank_size, Config.n_feats), dtype=np.float32)

feature_extractor = feature_extractor.to(Config.device)

i_mem = 0

for i, imgs in enumerate(train_dl):
    n_samples = imgs.shape[0]

    feats = get_features(imgs, feature_extractor, Config)
    memory_bank[i_mem : i_mem + feats.shape[0]] = feats
    i_mem += feats.shape[0]

Printing the memory bank shape, we see that it contains over 200k feature vectors.

In [None]:
print(memory_bank.shape)

Hence, in the next step in which we want to compute anomaly scores for each patch of a test image by comparing with the memory bank we need an anomaly detection approach which can deal with such a large number of vectors

## Anomaly Detection

For the anomaly detection part, we will extract the features of a target image with the same model as before. Afterwards, we will apply an off-the-shelf anomaly detection algorithm from the [Python Outlier Detection (PyOD) library](https://github.com/yzhao062/pyod).

Side remark: I will use the terms anomaly detection and outlier detection interchangeably.

TODO: why do we pick LUNAR

Fitting the anomaly detection model on this large set of feature vectors may take a few minutes

In [None]:
clf = LUNAR()
clf.fit(memory_bank)

Afterwards we pick a defect image from the training data and extract its features

In [None]:
img_path = data_path / "test/bent_lead/000.png"
seg_path = data_path / "ground_truth/bent_lead/000_mask.png"

img = Image.open(img_path)

img_np = np.array(img)
img_t = default_transforms(image=img_np)["image"]
img_t = torch.unsqueeze(img_t, 0)

test_feats = get_features(img_t, feature_extractor, Config)

To get an anomaly score map, we have to reshape the features to first match the image patch locations and eventually resize it to the original image size

In [None]:
ano_scores = clf.decision_function(test_feats)
score_patches = np.expand_dims(ano_scores, 0)
score_patches = score_patches.reshape(h_layer, w_layer)

anomaly_map = cv2.resize(score_patches, (img.width, img.height))

# apply Gaussian blur to smooth out possible resizing artifacts
anomaly_map = gaussian_filter(anomaly_map, sigma=4)

# make anomaly scores start at 0
anomaly_map = anomaly_map - anomaly_map.min()

This allows us to overlay the anomaly score map with the original defect image and to compare with the ground truth annotation

In [None]:
seg = Image.open(seg_path)
seg = np.array(seg)

p_img = plot_img_rgba(img, title="Image with ground truth annotation")
p_img = add_seg_on_img(p_img, seg)
p_ano = plot_img_rgba(img, title="Image with prediction")
p_ano = add_score_map_on_img(p_ano, anomaly_map, alpha=0.6)
p = bokeh.layouts.row(p_img, p_ano)
show(p)

And indeed, we can see how the area with the highest anomaly scores correspond to the marked ground-truth defect annotation.

### Putting everything together

In [None]:
class AnomalyDetector:
    def __init__(self, transforms, feature_extractor, clf, cfg) -> None:
        self.transforms = transforms
        self.feature_extractor = feature_extractor.to(cfg.device)
        self.clf = clf
        self.cfg = cfg

        self.h_layer = math.ceil(cfg.img_shape[0] / cfg.red_factor)
        self.w_layer = math.ceil(cfg.img_shape[1] / cfg.red_factor)

    def __call__(self, img: Image.Image) -> np.ndarray:
        img_np = np.array(img)
        img_t = self.transforms(image=img_np)["image"]
        img_t = torch.unsqueeze(img_t, 0)

        feats = get_features(img_t, self.feature_extractor, self.cfg)

        ano_scores = self.clf.decision_function(feats)
        score_patches = np.expand_dims(ano_scores, 0)
        score_patches = score_patches.reshape(self.h_layer, self.w_layer)

        anomaly_map = cv2.resize(score_patches, (img.width, img.height))

        # apply Gaussian blur to smooth out possible resizing artifacts
        anomaly_map = gaussian_filter(anomaly_map, sigma=4)
        anomaly_map = anomaly_map - anomaly_map.min()

        return anomaly_map

Let's test our new anomaly detector on different defect images

In [None]:
detector = AnomalyDetector(default_transforms, feature_extractor, clf, Config)

img_it = (data_path / "test/damaged_case").iterdir()
seg_it = (data_path / "ground_truth/damaged_case").iterdir()

In [None]:
img_path = next(img_it)
seg_path = next(seg_it)

img = Image.open(img_path)
seg = Image.open(seg_path)
seg = np.array(seg)

anomaly_map = detector(img)

p_img = plot_img_rgba(img)
p_img = add_seg_on_img(p_img, seg)
p_ano = plot_img_rgba(img)
p_ano = add_score_map_on_img(p_ano, anomaly_map, alpha=0.6)
p = bokeh.layouts.row(p_img, p_ano)

show(p)

## Validation

To quantify how well this approach works over all test data, we will make anomaly score predictions over all test images and compute the area under receiver operating characteristic curve (AUROC) metric. See [Receiver operating characteristic](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) for more details. This allows also to compare the approach with recent literature.

Like in the 'training' phase, we first create a pyTorch Dataset. As the predictor class can handle already native python image objects, we don't necessarily need a DataLoader. The DataLoader would allow us to speed up the validation process by using batches but for this blog post we will keep it simple.

In [None]:
class ValidationDataset(Dataset):
    def __init__(
        self,
        data_path: os.PathLike,
        gt_path: os.PathLike,
    ):
        super(ValidationDataset).__init__()

        self.img_paths = list()
        self.gt_paths = list()

        gt_class_paths = list(data_path.iterdir())

        for p in gt_class_paths:
            for img_path in p.iterdir():
                self.img_paths.append(img_path)
                self.gt_paths.append(gt_path / p.name / img_path.name)

    def __getitem__(self, index: int):
        img_path = self.img_paths[index]
        gt_path = self.img_paths[index]

        img = Image.open(img_path)
        img = img.convert("RGB")

        if not gt_path.exists():
            # there are no gt annotations for good cases -> all 0
            gt = np.zeros_like(img)
        else:
            gt = Image.open(gt_path)
            gt = gt.convert("L")
            gt = np.array(gt)
            gt = gt / 255

        return img, gt

    def __len__(self) -> int:
        return len(self.img_paths)

In [None]:
val_path = data_path / "test"
gt_path = data_path / "ground_truth"

val_ds = ValidationDataset(val_path, gt_path)

With that in place, we can loop through the validation dataset and store ground truth and anomaly score predictions

In [None]:
img, gt = val_ds[0]

pred_size = len(val_ds) * img.height * img.width
preds_pix = np.empty(pred_size, dtype=np.float32)
gts_pix = np.empty(pred_size, dtype=np.int32)
preds_img = np.empty(len(val_ds), dtype=np.float32)
gts_img = np.empty(len(val_ds), dtype=np.int32)

feature_extractor = feature_extractor.to(Config.device)

i_pix = 0

for i in range(len(val_ds)):
    img, gt = val_ds[i]
    gt = gt.astype(np.int32)

    anomaly_map = detector(img)
    n_pix = anomaly_map.shape[0] * anomaly_map.shape[1]

    preds_pix[i_pix : i_pix + n_pix] = anomaly_map.reshape((-1,))
    gts_pix[i_pix : i_pix + n_pix] = gt.reshape((-1,))

    # use max score of the map as image-level anomaly score
    preds_img[i] = anomaly_map.max()
    # for good images gt will be all zero, for defect images max will be 1
    gts_img[i] = gt.max()

    i_pix += n_pix

The AUROC score is computed using the ground truth values and prediction scores

In [None]:
fpr_img, tpr_img, thresholds_img = roc_curve(gts_img, preds_img)
auroc_img = auc(fpr_img, tpr_img)

print(f"image-wise AUROC: {auroc_img:.5f}")

In [None]:
p = figure(
    title=f"ROC curve for image-wise prediction (area = {auroc_img:.5f})",
    x_axis_label="False Positive Rate",
    y_axis_label="True Positive Rate",
)
p.line(fpr_img, tpr_img, line_width=2)
show(p)

Pixel-wise AUROC (TODO keep?)

In [None]:
# fpr_pix, tpr_pix, thresholds_pix = roc_curve(gts_pix, preds_pix)
# auroc_pix = auc(fpr_pix, tpr_pix)

# print(f"pixel-wise AUROC: {auroc_pix:.5f}")