# Using pre-trained feature extractor for optical anomaly detection

In [None]:
# To autoreload external functions
%load_ext autoreload
%autoreload 2

In [None]:
from typing import Optional
import os
from pathlib import Path
from PIL import Image
import math
import numpy as np

import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor
import timm


from bokeh.plotting import figure, show

import rootutils
root = rootutils.setup_root(Path.cwd(), dotenv=True, pythonpath=True, cwd=False)

from src.visualization.utils import save_plot_from_notbook_for_jekyll, bokeh_notebook_setup, save_plot_from_notebook_to_html
from src.visualization.image import plot_img_rgba, add_bboxes_on_img, add_seg_on_img

## Setup

In [None]:
bokeh_notebook_setup()

In [None]:
data_path = Path("../data/raw/transistor")
output_path = Path("./logs")

## Introduction

- task similar to previous post
  - we don't have regular patterns anymore
  - we want to have a method that can be applied to several tasks
- previous research resulted in high effectiveness for features extracted from Deep Learning models pre-trained on ImageNet
  - reference to SPADE, Gaussian AD, PaDim and PatchCore
- we will be following the approach of these papers but change a few components

## Dataset

Like in the previous post, we will use the [MVTec anomaly detection dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) which you can download from the website.
The dataset contains 15 different categories. For the examples in this post we will use the 'Transistor' category.

Here is a normal example without anomaly

In [None]:
img_path = data_path / "train/good/000.png"

img = Image.open(img_path)
img = img.convert("RGBA")

p = plot_img_rgba(img)
show(p)

and in contrast an anomalous example

In [None]:
img_path = data_path / "test/bent_lead/000.png"
seg_path = data_path / "ground_truth/bent_lead/000_mask.png"

img = Image.open(img_path)
img = img.convert("RGBA")

seg = Image.open(seg_path)
seg = np.array(seg)

p = plot_img_rgba(img)
p = add_seg_on_img(p, seg)

show(p)

## Feature Extraction

Like in the PaDim or PatchCore paper we are going to extract features for each image patch of the training set using a neural network architecture for vision tasks pre-trained on the ImageNet dataset. The patch size is determined by our choice for the network layer.
To do the feature extraction we use the PyTorch `feature_extraction` package [based on Torch FX](https://pytorch.org/blog/FX-feature-extraction-torchvision/).
The goal of this post is to demonstrate the principle rather than optimizing our approach to the dataset. Hence, we will simplify many steps compared to the paper.

For the backbone we pick ResNet 34 as it is simple to use and doesn't require much memory.

In [None]:
backbone = timm.create_model("resnet34", pretrained=True)

In the papers features from several layers were combined. To keep it simple, we will use only one layer.
To see the available layer names for feature extraction you can use

In [None]:
train_nodes, eval_nodes = get_graph_node_names(backbone)

Looking at `train_nodes` or `eval_nodes`, you will see that ResNet 34 has 4 main blocks. If you just want to pick the last node of a block, the feature_extraction module allows you to use truncated node names. Here we will be using `'layer3'` to get the last node of all the `layer3.x.ops` nodes. We choose layer 3 as a compromise between having expressive high-level features but still a somewhat high spatial feature map resolution. 

In [None]:
feature_extractor = create_feature_extractor(backbone, return_nodes=["layer3"])

As in the paper, we fix the weights to the pre-trained ImageNet weights. Hence, we can turn off gradient computation to save memory

In [None]:
for param in feature_extractor.parameters():
    param.requires_grad = False

To simplify experimenting with different configurations, we use a Config object

In [None]:
class Config:
    img_shape = (512, 512)  # height, width (multiple of 16)
    batch_size = 4
    num_workers: int = 2  # adjust to the number of processing cores you want to use
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    red_factor = 16  # spatial reduction factor (depends on the chosen ResNet layer)
    n_feats = 256  # number of features (depends on the chosen ResNet layer)

TODO: shorten the explanation. I can't assume ppl have read the paper

To save the features we will follow the memory bank approach from PatchCore. This means we will save the extracted features into a large array without linking them to the original patch location. This is in contrast to the PaDim approach. The advantage is that our approach becomes more robust to rotations, translations and other variations of the objects in the dataset. The disadvantage is that the number of feature vectors we have to compare each patch to becomes quite large. 
TODO brief calculation: N images * Height * Width of feature maps
To use a method like nearest neighbor lookup like in the PatchCore paper, this requires tricks like the coreset reduction. 
We will get around this by choosing a different anomaly detection approach.

First we create a pyTorch Dataset object to hold the image data and specify the necessary transformations

In [None]:
class TrainDataset(Dataset):
    def __init__(self, data_path: os.PathLike, transforms: Optional[A.Compose] = None):
        super(TrainDataset).__init__()

        self.img_paths = list(data_path.iterdir())
        self.transforms = transforms

    def __getitem__(self, index: int):
        img_path = self.img_paths[index]

        img = Image.open(img_path)
        img = img.convert("RGB")
        img = np.array(img)

        if self.transforms:
            img = self.transforms(image=img)["image"]

        return img

    def __len__(self) -> int:
        return len(self.img_paths)

In [None]:
train_path = data_path / "train/good"

default_transforms = A.Compose(
    [
        A.Resize(Config.img_shape[0], Config.img_shape[1]),
        A.Normalize(
            mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0
        ),
        ToTensorV2(),
    ]
)

train_ds = TrainDataset(train_path, transforms=default_transforms)

Afterwards, we create the DataLoader object to feed our data to the feature extractor

In [None]:
train_dl = DataLoader(
    train_ds,
    batch_size=Config.batch_size,
    shuffle=False,
    num_workers=Config.num_workers,
)

We create a function containing the logic to call the feature extractor with a batch of images and collect the resulting features:

In [None]:
def get_features(imgs, extractor, cfg):
    imgs = imgs.to(cfg.device)

    with torch.no_grad():
        feature_dict = extractor(imgs)

    layer = list(feature_dict.keys())[0]

    feats = feature_dict[layer]

    feats = feats.cpu().numpy()
    feats = np.transpose(feats, (0, 2, 3, 1))
    feats = feats.reshape(-1, cfg.n_feats)

    return feats

Finally, we can put everything together to compute the feature memory bank

In [None]:
h, w = Config.img_shape[:2]

# ResNet resolution reduction of layer 2
h_layer = math.ceil(h / Config.red_factor)
w_layer = math.ceil(w / Config.red_factor)

memory_bank_size = len(train_ds) * h_layer * w_layer
memory_bank = np.empty((memory_bank_size, Config.n_feats), dtype=np.float32)

feature_extractor = feature_extractor.to(Config.device)

i_mem = 0

for i, imgs in enumerate(train_dl):
    n_samples = imgs.shape[0]

    feats = get_features(imgs, feature_extractor, Config)
    memory_bank[i_mem : i_mem + feats.shape[0]] = feats
    i_mem += feats.shape[0]

Printing the memory bank shape, we see that it contains more than 2 million feature vectors.

In [None]:
print(memory_bank.shape)

Hence, in the next step in which we want to compute anomaly scores for each patch of a test image by comparing with the memory bank we need an anomaly detection approach which can deal with such a large number of vectors

## Anomaly Detection