# Analyzing features for anomaly detection from pre-trained models

In [None]:
# To autoreload external functions
%load_ext autoreload
%autoreload 2

In [None]:
import pickle
from typing import Optional
import os
from dotenv import load_dotenv
from pathlib import Path
from PIL import Image
import math
import random
import numpy as np
import cv2
from sklearn.decomposition import PCA
from sklearn.metrics import roc_curve, auc
from scipy.ndimage import gaussian_filter
import umap

import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models import list_models, get_model
from torchvision.models.feature_extraction import (
    get_graph_node_names,
    create_feature_extractor,
)


import matplotlib.pyplot as plt
import seaborn as sns
from bokeh.plotting import figure, show
import bokeh

import rootutils

root = rootutils.setup_root(Path.cwd(), dotenv=True, pythonpath=True, cwd=False)

from src.visualization.utils import (
    save_plot_from_notbook_for_jekyll,
    bokeh_notebook_setup,
)
from src.visualization.image import (
    plot_img_rgba,
    add_seg_on_img,
    add_score_map_on_img,
)
from src.visualization.features import plot_feature_samples, plot_feature_3d_samples

## Setup

In [None]:
%matplotlib ipympl

In [None]:
bokeh_notebook_setup()

# make random number generator repeatable
seed = 1
random.seed(seed)
np.random.seed(seed)

sns.set_style('darkgrid')

In [None]:
data_path = Path("../data/raw/metal_nut")
output_path = Path("./logs")

## Introduction

- take feature extraction approach as in previous blog post
- apply normal PCA to reduce to 2 dimensions
- repeat with modified PCA to keep the feature combinations with smallest variance
- plot and compar normal and anomalous features
- repeat experiment with PCA reduction to 3 dimensions
  - explore 3d plots with bokeh

## Dataset

Use again 'Metal Nut' category from [MVTec anomaly detection dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad)

## Feature Extraction

See last post

In [None]:
class Config:
    model_name = "convnext_base"
    layer_names = ["features.3"]
    img_shape = (224, 224)  # height, width
    batch_size = 4
    num_workers: int = 2  # adjust to the number of processing cores you want to use
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    red_factor = None  # spatial reduction factor (equivalent to patch size)
    n_feats = None  # number of features (depends on the chosen layer)

In [None]:
class PatchCoreModel(nn.Module):
    def __init__(self, feature_extractor):
        super().__init__()

        self.feature_extractor = feature_extractor
        self.patch_layer = torch.nn.AvgPool2d(3, 1, 1)

    def forward(self, x):
        feature_dict = self.feature_extractor(x)
        for k, v in feature_dict.items():
            feature_dict[k] = self.patch_layer(v)

        return feature_dict


class TrainDataset(Dataset):
    def __init__(
        self,
        data_path: os.PathLike,
        transforms: Optional[A.Compose] = None,
        N_train: Optional[int] = None,
    ):
        super(TrainDataset).__init__()

        self.img_paths = list(data_path.iterdir())
        self.transforms = transforms

        if N_train is not None and len(self.img_paths) > N_train:
            self.img_paths = random.sample(self.img_paths, N_train)

    def __getitem__(self, index: int):
        img_path = self.img_paths[index]

        img = Image.open(img_path)
        img = img.convert("RGB")
        img = np.array(img)

        if self.transforms:
            img = self.transforms(image=img)["image"]

        return img

    def __len__(self) -> int:
        return len(self.img_paths)


class ValidationDataset(Dataset):
    def __init__(
        self,
        data_path: os.PathLike,
        gt_path: os.PathLike,
        transforms: Optional[A.Compose] = None,
    ):
        super(ValidationDataset).__init__()

        self.img_paths = list()
        self.gt_paths = list()

        gt_class_paths = list(data_path.iterdir())

        for p in gt_class_paths:
            for img_path in p.iterdir():
                self.img_paths.append(img_path)
                self.gt_paths.append(
                    gt_path / p.name / f"{img_path.stem}_mask{img_path.suffix}"
                )
        self.transforms = transforms

    def __getitem__(self, index: int):
        img_path = self.img_paths[index]
        gt_path = self.gt_paths[index]

        img = Image.open(img_path)
        img = img.convert("RGB")

        if not gt_path.exists():
            # there are no gt annotations for good cases -> all 0
            gt = np.zeros((img.height, img.width))
        else:
            gt = Image.open(gt_path)
            gt = gt.convert("L")
            gt = np.array(gt)
            gt = gt / 255

        if self.transforms:
            img = np.array(img)
            transformed = self.transforms(image=img, mask=gt)
            img = transformed["image"]
            gt = transformed["mask"]

        return img, gt

    def __len__(self) -> int:
        return len(self.img_paths)

In [None]:
def get_features(imgs, extractor, cfg):
    imgs = imgs.to(cfg.device)

    with torch.no_grad():
        feature_dict = extractor(imgs)

    layers = list(feature_dict.keys())

    feats = feature_dict[layers[0]]
    feats = feats.cpu().numpy()
    feats = np.transpose(feats, (0, 2, 3, 1))
    feats = feats.reshape(-1, cfg.n_feats)

    return feats


def get_ground_truths(gts, cfg):
    gts = gts.numpy()
    feat_shape = [s // cfg.red_factor for s in gts[0].shape[:2]]
    patch_gts = np.zeros((gts.shape[0], *feat_shape))

    for i, gt in enumerate(gts):
        gt = cv2.resize(gt, dsize=feat_shape, interpolation=cv2.INTER_LINEAR)
        # gt = np.rint(gt)
        gt = np.floor(gt)
        patch_gts[i] = gt

    patch_gts = patch_gts.reshape(-1)

    return patch_gts

In [None]:
backbone = get_model(Config.model_name, weights="DEFAULT")
feature_extractor = create_feature_extractor(backbone, return_nodes=Config.layer_names)
for param in feature_extractor.parameters():
    param.requires_grad = False

feature_extractor = PatchCoreModel(feature_extractor)

In [None]:
train_path = data_path / "train/good"
val_path = data_path / "test"
gt_path = data_path / "ground_truth"

default_transforms = A.Compose(
    [
        A.Resize(Config.img_shape[0], Config.img_shape[1]),
        A.Normalize(
            mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0
        ),
        ToTensorV2(),
    ]
)

train_ds = TrainDataset(train_path, transforms=default_transforms)

train_dl = DataLoader(
    train_ds,
    batch_size=Config.batch_size,
    shuffle=False,
    num_workers=Config.num_workers,
)

In [None]:
imgs = next(iter(train_dl))
feats_shapes = []

for layer_name in Config.layer_names:
    feats_shapes.append(feature_extractor(imgs)[layer_name].shape)

Config.n_feats = sum([fs[1] for fs in feats_shapes])
Config.red_factor = Config.img_shape[0] // feats_shapes[0][2]

print("n feats:", Config.n_feats)
print("red factor:", Config.red_factor)

In [None]:
h, w = Config.img_shape[:2]
h_layer = math.ceil(h / Config.red_factor)
w_layer = math.ceil(w / Config.red_factor)

memory_bank_size = len(train_ds) * h_layer * w_layer
memory_bank = np.empty((memory_bank_size, Config.n_feats), dtype=np.float32)

feature_extractor = feature_extractor.to(Config.device)

i_mem = 0

for i, imgs in enumerate(train_dl):
    n_samples = imgs.shape[0]

    feats = get_features(imgs, feature_extractor, Config)
    memory_bank[i_mem : i_mem + feats.shape[0]] = feats
    i_mem += feats.shape[0]

In [None]:
print("Train memory bank shape:", memory_bank.shape)

## Validation Data Features

In [None]:
val_ds = ValidationDataset(val_path, gt_path, transforms=default_transforms)

val_dl = DataLoader(
    val_ds,
    batch_size=Config.batch_size,
    shuffle=False,
    num_workers=Config.num_workers,
)

In [None]:
val_feat_bank_size = len(val_ds) * h_layer * w_layer
val_feat_bank = np.empty((val_feat_bank_size, Config.n_feats), dtype=np.float32)
val_gt_bank = np.zeros((val_feat_bank_size), dtype=np.float32)

feature_extractor = feature_extractor.to(Config.device)

i_mem = 0

for i, (imgs, gts) in enumerate(val_dl):
    n_samples = imgs.shape[0]

    feats = get_features(imgs, feature_extractor, Config)
    patch_gts = get_ground_truths(gts, Config)
    val_feat_bank[i_mem : i_mem + feats.shape[0]] = feats
    val_gt_bank[i_mem : i_mem + patch_gts.shape[0]] = patch_gts

    i_mem += feats.shape[0]

In [None]:
print("Validation feature bank shape:", val_feat_bank.shape)
print("Validation ground truth bank shape:", val_gt_bank.shape)

## PCA

Look at [Gaussian-AD code](https://github.com/ORippler/gaussian-ad-mvtec/blob/bc10bd736d85b750410e6b0e7ac843061e09511e/src/gaussian/model.py#L207) for PCA keeping features with least variance

In [None]:
X_train = memory_bank
pca = PCA(n_components=None).fit(X_train)

In [None]:
X_val = val_feat_bank[::50]
y = val_gt_bank[::50]

In [None]:
variance_threshold = 0.9
variances = pca.explained_variance_ratio_.cumsum()
i_comp_thresh = (variances > variance_threshold).argmax()

# Normal PCA
pca_comps = pca.components_[: i_comp_thresh + 1]
X_pca = np.matmul(X_val, pca_comps.T)

# Negative PCA
npca_comps = pca.components_[i_comp_thresh - 1 :]
X_npca = np.matmul(X_val, npca_comps.T)

print(X_pca.shape)
print(X_npca.shape)

In [None]:
n_dim = 2

umap_for_pca = umap.UMAP(n_components=n_dim)
X_pca_embed = umap_for_pca.fit_transform(X_pca)

umap_for_npca = umap.UMAP(n_components=n_dim)
X_npca_embed = umap_for_npca.fit_transform(X_npca)

p_pca = plot_feature_samples(
    X_pca_embed, y, title="Feature embedding after standard PCA", width=600, height=600
)
p_npca = plot_feature_samples(
    X_npca_embed, y, title="Feature embedding after negative PCA", width=600, height=600
)
p = bokeh.layouts.row(p_pca, p_npca)
show(p)

In [None]:
n_dim = 3

umap_for_pca = umap.UMAP(n_components=n_dim)
X_pca_embed = umap_for_pca.fit_transform(X_pca)

umap_for_npca = umap.UMAP(n_components=n_dim)
X_npca_embed = umap_for_npca.fit_transform(X_npca)

In [None]:
ax = plot_feature_3d_samples(X_pca_embed, y, title="Feature embedding after standard PCA")
plt.show()

In [None]:
ax = plot_feature_3d_samples(X_npca_embed, y, title="Feature embedding after negative PCA")
plt.show()

In [None]:
# plot_path = output_path / "ROC_curve.html"
# save_plot_from_notbook_for_jekyll(p, plot_path)

## Conclusion