In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import warnings

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pyroml import PyroModel, Stage, Trainer

# Required to make imgaug work
np.sctypes = dict(
    float=[np.float16, np.float32, np.float64],
    int=[np.int8, np.int16, np.int32, np.int64],
    uint=[np.uint8, np.uint16, np.uint32, np.uint64],
)


from anomalib.models.components import DynamicBufferMixin, KCenterGreedy  # noqa: E402
from anomalib.models.image.patchcore.anomaly_map import (  # noqa: E402
    AnomalyMapGenerator,
)

In [3]:
class PatchcoreModel(DynamicBufferMixin, nn.Module):
    def __init__(
        self,
        backbone: nn.Module,
        num_neighbors: int = 9,
    ) -> None:
        super().__init__()

        self.num_neighbors = num_neighbors
        self.feature_pooler = nn.AvgPool2d(3, 1, 1)
        self.feature_extractor = backbone
        self.anomaly_map_generator = AnomalyMapGenerator()
        self.register_buffer("memory_bank", torch.Tensor())
        self.memory_bank: torch.Tensor

    def postprocess_features(
        self, features: dict[str, torch.Tensor]
    ) -> dict[str, torch.Tensor]:
        pp_features = {}

        for layer, feature in features.items():
            feature = self.feature_pooler(feature)

            pp_features[layer] = feature

        return pp_features

    def get_embedding(self, input_tensor: torch.Tensor) -> torch.Tensor:
        self.feature_extractor.eval()
        with torch.no_grad():
            features = self.feature_extractor(input_tensor)
        features = self.postprocess_features(features)
        embedding = self.generate_embedding(features)
        return embedding

    def embed(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor, int, int, int]:
        embedding = self.get_embedding(input_tensor)
        B, _, W, H = embedding.shape
        embedding = self.reshape_embedding(embedding)
        return embedding, B, W, H

    def embedding_to_score(
        self, embedding: torch.Tensor, B: int, W: int, H: int
    ) -> dict[str, torch.Tensor]:
        # apply nearest neighbor search
        patch_scores, locations = self.nearest_neighbors(
            embedding=embedding, n_neighbors=1
        )
        # reshape to batch dimension
        patch_scores = patch_scores.reshape((B, -1))
        locations = locations.reshape((B, -1))
        # compute anomaly score
        pred_score = self.compute_anomaly_score(patch_scores, locations, embedding)
        # reshape to w, h
        patch_scores = patch_scores.reshape((B, 1, W, H))
        # get anomaly map
        patch_scores = patch_scores.float()  # in case use_bfloat16 is True
        anomaly_map = patch_scores
        output = {"anomaly_map": anomaly_map, "pred_score": pred_score}
        return output

    def forward(self, input_tensor: torch.Tensor):
        """Return Embedding during training, or a tuple of anomaly map and anomaly score during testing.

        Steps performed:
        1. Get features from a CNN.
        2. Generate embedding based on the features.
        3. Compute anomaly map in test mode.

        Args:
            input_tensor (torch.Tensor): Input tensor

        Returns:
            Tensor | dict[str, torch.Tensor]: Embedding for training, anomaly map and anomaly score for testing.
        """

        embedding, B, W, H = self.embed(input_tensor)

        if self.training:
            output = embedding

        elif len(self.memory_bank) == 0:
            warnings.warn("empty memory bank during eval / test")
            output = {
                "anomaly_map": torch.zeros_like(input_tensor),
                "pred_score": torch.zeros((B)),
            }
        else:
            output = self.embedding_to_score(embedding, B, W, H)

        return output

    def generate_embedding(self, features) -> torch.Tensor:
        """Generate embedding from hierarchical feature map.

        Args:
            features: Hierarchical feature map from a CNN (ResNet18 or WideResnet)
            features: dict[str:Tensor]:

        Returns:
            Embedding vector
        """

        layers = list(features.keys())
        embeddings = features[layers[0]]

        for layer in layers[1:]:
            layer_embedding = features[layer]
            layer_embedding = F.interpolate(
                layer_embedding, size=embeddings.shape[-2:], mode="bilinear"
            )
            embeddings = torch.cat((embeddings, layer_embedding), 1)

        return embeddings

    @staticmethod
    def reshape_embedding(embedding: torch.Tensor) -> torch.Tensor:
        """Reshape Embedding.

        Reshapes Embedding to the following format:
            - [Batch, Embedding, Patch, Patch] to [Batch*Patch*Patch, Embedding]

        Args:
            embedding (torch.Tensor): Embedding tensor extracted from CNN features.

        Returns:
            Tensor: Reshaped embedding tensor.
        """
        embedding_size = embedding.size(1)
        return embedding.permute(0, 2, 3, 1).reshape(-1, embedding_size)

    def subsample_embedding(
        self, embedding: torch.Tensor, sampling_ratio: float
    ) -> None:
        """Subsample embedding based on coreset sampling and store to memory.

        Args:
            embedding (np.ndarray): Embedding tensor from the CNN
            sampling_ratio (float): Coreset sampling ratio
        """

        # Coreset Subsampling
        sampler = KCenterGreedy(embedding=embedding, sampling_ratio=sampling_ratio)
        coreset = sampler.sample_coreset()
        self.memory_bank = coreset

    @staticmethod
    def euclidean_dist(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Calculate pair-wise distance between row vectors in x and those in y.

        Replaces torch cdist with p=2, as cdist is not properly exported to onnx and openvino format.
        Resulting matrix is indexed by x vectors in rows and y vectors in columns.

        Args:
            x: input tensor 1
            y: input tensor 2

        Returns:
            Matrix of distances between row vectors in x and y.

        """
        x_norm = x.pow(2).sum(dim=-1, keepdim=True)  # |x|
        y_norm = y.pow(2).sum(dim=-1, keepdim=True)  # |y|

        # row distance can be rewritten as sqrt(|x| - 2 * x @ y.T + |y|.T)
        res = (
            x_norm - 2 * torch.matmul(x, y.transpose(-2, -1)) + y_norm.transpose(-2, -1)
        )
        return res.clamp_min_(0).sqrt_()

    def nearest_neighbors(
        self, embedding: torch.Tensor, n_neighbors: int
    ):  # -> tuple[torch.Tensor, torch.Tensor]:
        """Nearest Neighbours using brute force method and euclidean norm.

        Args:
            embedding (torch.Tensor): Features to compare the distance with the memory bank.
            n_neighbors (int): Number of neighbors to look at

        Returns:
            Tensor: Patch scores.
            Tensor: Locations of the nearest neighbor(s).
        """
        distances = self.euclidean_dist(embedding, self.memory_bank)

        if n_neighbors == 1:
            # when n_neighbors is 1, speed up computation by using min instead of topk
            patch_scores, locations = distances.min(1)
        else:
            patch_scores, locations = distances.topk(
                k=n_neighbors, largest=False, dim=1
            )

        return patch_scores, locations

    def compute_anomaly_score(
        self,
        patch_scores: torch.Tensor,
        locations: torch.Tensor,
        embedding: torch.Tensor,
    ) -> torch.Tensor:
        """Compute Image-Level Anomaly Score.

        Args:
            patch_scores (torch.Tensor): Patch-level anomaly scores
            locations: Memory bank locations of the nearest neighbor for each patch location
            embedding: The feature embeddings that generated the patch scores

        Returns:
            Tensor: Image-level anomaly scores
        """
        # Don't need to compute weights if num_neighbors is 1
        if self.num_neighbors == 1:
            return patch_scores.amax(1)

        batch_size, num_patches = patch_scores.shape

        # 1. Find the patch with the largest distance to it's nearest neighbor in each image
        # indices of m^test,* in the paper
        max_patches = torch.argmax(patch_scores, dim=1)

        # m^test,* in the paper
        max_patches_features = embedding.reshape(batch_size, num_patches, -1)[
            torch.arange(batch_size), max_patches
        ]

        # 2. Find the distance of the patch to it's nearest neighbor, and the location of the nn in the membank
        score = patch_scores[torch.arange(batch_size), max_patches]  # s^* in the paper

        if self.num_neighbors > 0:
            # indices of m^* in the paper
            nn_index = locations[torch.arange(batch_size), max_patches]

            # 3. Find the support samples of the nearest neighbor in the membank
            nn_sample = self.memory_bank[nn_index, :]  # m^* in the paper

            # indices of N_b(m^*) in the paper
            # edge case when memory bank is too small
            memory_bank_effective_size = self.memory_bank.shape[0]

            _, support_samples = self.nearest_neighbors(
                nn_sample,
                n_neighbors=min(self.num_neighbors, memory_bank_effective_size),
            )

            # 4. Find the distance of the patch features to each of the support samples
            distances = self.euclidean_dist(
                max_patches_features.unsqueeze(1), self.memory_bank[support_samples]
            )

            # 5. Apply softmax to find the weights
            weights = (1 - F.softmax(distances.squeeze(1), 1))[..., 0]

        else:
            weights = 1

        # 6. Apply the weight factor to the score
        return weights * score  # s in the paper


class PyroPatchcore(PyroModel):
    def __init__(
        self,
        backbone: nn.Module,
        coreset_sampling_ratio: float = 0.1,
        num_neighbors: int = 9,
    ) -> None:
        super().__init__()
        self.model = PatchcoreModel(
            backbone=backbone,
            num_neighbors=num_neighbors,
        )
        self.coreset_sampling_ratio = coreset_sampling_ratio
        self.embeddings: list[torch.Tensor] = []

    def training_step(self, batch) -> None:
        embedding = self.model(batch["img"])
        self.embeddings.append(embedding.cpu())

    def validation_step(self, batch):
        # Get anomaly maps and predicted scores from the model.
        output = self.model(batch["img"])
        # Add anomaly maps and predicted scores to the batch.
        del batch["img"]
        batch["anomaly_maps"] = output["anomaly_map"].cpu()
        batch["pred_scores"] = output["pred_score"].cpu()
        return batch

    def step(self, batch: dict[str, torch.Tensor], stage: Stage):
        if stage == Stage.TRAIN:
            return self.training_step(batch)
        return self.validation_step(batch)

    def configure_optimizers(self, trainer: Trainer):
        pass

    def _fit(self, loss: torch.Tensor):
        pass

    def on_train_end(self, args):
        self.construct()

    def anomaly_map(self, patch_scores, image_size=256):
        return self.model.anomaly_map_generator(patch_scores, [image_size, image_size])

    def construct(self) -> None:
        print("Aggregating the embedding extracted from the training set.")
        self.cpu()
        embeddings = torch.vstack(self.embeddings).to(self.trainer.device)
        print("Applying core-set subsampling to get the embedding.")
        self.model.subsample_embedding(embeddings, self.coreset_sampling_ratio)

In [4]:
from pyroml.models.backbone import Backbone


# backbone = Backbone.load(
#     "vit_base_patch16_224.mae", pre_trained=True, image_size=(3, 224, 224)
# )
backbone = Backbone.load(
    "resnet18", pre_trained=True, layers=(2, 3), image_size=(3, 224, 224)
)
backbone.feature_dims

Backbone has 2,782,784 params


{'layer2': torch.Size([1, 128, 28, 28]),
 'layer3': torch.Size([1, 256, 14, 14])}

In [5]:
from timm.data import resolve_model_data_config, create_transform

data_config = resolve_model_data_config(backbone)
transforms = create_transform(**data_config, is_training=False)
data_config, transforms

({'input_size': (3, 224, 224),
  'interpolation': 'bicubic',
  'mean': (0.485, 0.456, 0.406),
  'std': (0.229, 0.224, 0.225),
  'crop_pct': 0.875,
  'crop_mode': 'center'},
 Compose(
     Resize(size=256, interpolation=bicubic, max_size=None, antialias=True)
     CenterCrop(size=(224, 224))
     MaybeToTensor()
     Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
 ))

In [None]:
import fiftyone as fo
import fiftyone.utils.huggingface as fouh

fo.config.requirement_error_level = 1
# Load the dataset
# Note: other available arguments include 'max_samples', etc
dataset = fouh.load_from_hub("Voxel51/mvtec-ad")

The requested operation requires that 'huggingface_hub>=0.20.0' is installed on your machine, but found 'huggingface_hub==0.19.4'.


  from .autonotebook import tqdm as notebook_tqdm


Downloading config file fiftyone.yml from Voxel51/mvtec-ad


fiftyone.yml: 100%|██████████| 127/127 [00:00<?, ?B/s] 

Loading dataset





KeyError: 'tags'

In [13]:
from pyroml.template.mvtec.dataset import MVTecDataset
from timm.data.transforms_factory import transforms_imagenet_eval

transform = transforms_imagenet_eval()
tr_ds = MVTecDataset(split="train", transform=transform)
tr_ds, te_ds = torch.utils.data.random_split(tr_ds, [2500, len(tr_ds) - 2500])

len(tr_ds), len(te_ds), tr_ds[0].keys()

Resolving data files:   0%|          | 0/6617 [00:00<?, ?it/s]

Saving the dataset (0/10 shards):   0%|          | 0/6613 [00:00<?, ? examples/s]

dict_keys(['image'])


KeyError: 'img'

In [None]:
backbone.model.forward_features(torch.rand(1, 3, 224, 224)).shape

In [None]:
model = PyroPatchcore(backbone, coreset_sampling_ratio=0.01, num_neighbors=9)
model

In [None]:
from pyroml.callbacks.progress.tqdm_progress import TQDMProgress


trainer = Trainer(
    lr=0,
    max_epochs=1,
    evaluate_on=False,
    device="cuda",
    pin_memory=False,
    dtype=torch.float16,
    wandb=False,
    callbacks=[TQDMProgress(stack_bars=False)],
)

trainer.fit(model, tr_dataset=tr_ds)