In [2]:
from pathlib import Path
from typing import Any, Dict

import numpy as np
from IPython.display import display
from PIL import Image
from pytorch_lightning import Trainer
from torchvision.transforms import ToPILImage

from anomalib.config import get_configurable_parameters
from anomalib.data import get_datamodule
from anomalib.models import get_model
from anomalib.pre_processing.transforms import Denormalize
from anomalib.utils.callbacks import LoadModelCallback, get_callbacks


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
MODEL = "multi_category_patchcore"
CONFIG_PATH = f"../anomalib/models/{MODEL}/config.yaml"
with open(file=CONFIG_PATH, mode="r", encoding="utf-8") as file:
    print(file.read())

dataset:
  name: multi_category_mvtec #options: [mvtec, btech, folder]
  format: multi_category_mvtec
  path: ./datasets/MVTec
  task: segmentation
  category:
    - bottle
    - cable
    - capsule
    - carpet
    - grid
    - hazelnut
    - leather
    - metal_nut
    - pill
    - screw
    - tile
    - toothbrush
    - transistor
    - wood
    - zipper
  image_size: 224
  train_batch_size: 32
  test_batch_size: 1
  num_workers: 0
  transform_config:
    train: null
    val: null
  create_validation_set: false
  tiling:
    apply: false
    tile_size: null
    stride: null
    remove_border_count: 0
    use_random_tiling: False
    random_tile_count: 16

model:
  name: multi_category_patchcore
  backbone: wide_resnet50_2
  classifier: mobilenet_v2
  pre_trained: true
  layers:
    - layer2
    - layer3
  coreset_sampling_ratio: 0.1
  num_neighbors: 9
  normalization_method: min_max # options: [null, min_max, cdf]

metrics:
  image:
    - F1Score
    - AUROC
  pixel:
    - F1Score
 

In [4]:
# pass the config file to model, callbacks and datamodule
config = get_configurable_parameters(config_path=CONFIG_PATH)
# or wherever the MVTec dataset is stored.
config["dataset"]["path"] = "../datasets/MVTec"
config.dataset.category

['bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper']

In [5]:
datamodule = get_datamodule(config)
datamodule.setup()
datamodule.prepare_data()