# PatchCore

In [1]:
import common
import sampler
import patchcore
import backbones
import utils

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os

import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import time

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl

from PIL import Image
from torchvision.transforms import v2

import glob
from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

from argparse import ArgumentParser

parser = ArgumentParser(description="patchcore")
parser.add_argument('--image_size', default=128, type=int) # 224
parser.add_argument('--resize', default=128, type=int) # 224
parser.add_argument('--backbone', default='wideresnet101', type=str) 
parser.add_argument('--layers_to_extract_from', nargs='+', default=['layer2', 'layer3'], type=str)
parser.add_argument('--pretrain_embed_dimension', default=1024, type=int)
parser.add_argument('--target_embed_dimension', default=1024, type=int)
parser.add_argument('--patchsize', default=3, type=int)
parser.add_argument('--coreset_rate', default=0.1, type=float)
parser.add_argument('--anomaly_scorer_num_nn', default=5, type=int)
parser.add_argument('--batch_size', default=1, type=int)
parser.add_argument('--batch_size_inf', default=1, type=int)
parser.add_argument('--cv', default=5, type=int)
parser.add_argument('--seed', default=826, type=int)
parser.add_argument('--device', nargs='+', default=[0], type=int)
parser.add_argument('--num_workers', default=0, type=int)
parser.add_argument('--num_data', default=50, type=int)
args = parser.parse_args('')

image_size = args.image_size
resize = args.resize
BATCH_SIZE = args.batch_size
BATCH_SIZE_INF = args.batch_size_inf
NUM_DATA= args.num_data
CV = args.cv
SEED = args.seed

# create save path
save_root = "saved"
if not os.path.exists(save_root):
    os.mkdir(save_root)

def set_seeds(seed=SEED):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    pl.seed_everything(SEED)

set_seeds()

Seed set to 826


## data loader

In [None]:
class AnomalyDataset(Dataset):
    def __init__(self, transform=None, dir="../Data/mvtec/bottle/train/good", test=None):
        super().__init__()
        self.test = test        
        self.transform = transform
        self.list_dir = sorted(glob.glob(os.path.join(dir, "*.png")))[:NUM_DATA]
        print(f"num_data: {len(self.list_dir)}")
        
        # test시 img 당 시간 측정 위해
        if test:
            self.list_data = []
            for idx, dir in enumerate(self.list_dir):
                x = Image.open(dir).convert("RGB")
                if self.transform: x = self.transform(x)
                
                self.list_data.append(x)

    def __len__(self):
        return len(self.list_dir)
    
    def __getitem__(self, idx):
        if self.test: x = self.list_data[idx]
        else:
            for idx, dir in enumerate(self.list_dir):
                x = Image.open(dir).convert("RGB")
                if self.transform: x = self.transform(x)
        return {"image": x}

In [4]:
train_transform = v2.Compose([
    v2.ToImage(),
    v2.Resize(size=(resize, resize)),
    v2.ToDtype(torch.float32, scale=True),
    # v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_transform = v2.Compose([
    v2.ToImage(),
    v2.Resize(size=(resize, resize)),
    v2.ToDtype(torch.float32, scale=True),
    # v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
train_dataset = AnomalyDataset(dir="../Data/mvtec/bottle/train/good", transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=args.num_workers)

valid_dataset = AnomalyDataset(dir="../Data/mvtec/bottle/test/good", transform=test_transform)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=args.num_workers)

num_data: 50
num_data: 41


## train

In [6]:
device = utils.set_torch_device(gpu_ids=args.device)
patch_core = patchcore.PatchCore(device)

patch_core.load(
    backbone                 = args.backbone,
    layers_to_extract_from   = args.layers_to_extract_from,
    device                   = device,
    input_shape              = (3, image_size, image_size),
    pretrain_embed_dimension = args.pretrain_embed_dimension,
    target_embed_dimension   = args.target_embed_dimension,
    patchsize                = args.patchsize,
    anomaly_scorer_num_nn    = args.anomaly_scorer_num_nn,
    featuresampler           = sampler.GreedyCoresetSampler(percentage=args.coreset_rate, device=device),
    nn_method                = common.FaissNN(on_gpu=False, num_workers=args.num_workers)
)

In [7]:
# train
patch_core.fit(train_dataloader)

# save model
patch_core.save_to_path(save_path=save_root)



In [None]:
# load model
patch_core.load_from_path(load_path=save_root, 
                          device=device, 
                          nn_method=common.FaissNN(on_gpu=False, num_workers=args.num_workers))

# validate for threshold
scores, _ = patch_core.predict(
    valid_dataloader
)

threshold = np.max(scores)
print(f"threshold: {threshold}")

Inferring...:  98%|█████████▊| 40/41 [00:25<00:00,  1.66it/s]

## inference

In [None]:
# screw: manipulated_front | scratch_head | scratch_neck | thread_side | thread_top
# bottle: broken_large | broken_small | contamination
test_dataset = AnomalyDataset(dir="../Data/mvtec/bottle/test/broken_large", transform=test_transform, test=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE_INF, shuffle=False, num_workers=args.num_workers)

threshold_list = [] 
scores_list = []
seg_list = []

# inference
start = time.time()
scores, seg = patch_core.predict(
    test_dataloader
)
print(f"Avg Prediction Time: {(time.time() - start) / (len(test_dataloader) * BATCH_SIZE_INF) :.6f}")

threshold_list.append(threshold)
scores_list.append(scores)
seg_list.append(seg)

In [None]:
threshold = np.mean(threshold_list)
print(f"threshold: {threshold}")

scores = np.max(scores_list, axis=0)
prediction = np.where(scores<threshold, 0, 1)

print(f"n_anomaly: {np.sum(prediction)}")

In [None]:
# scores = np.array(scores_list)
# min_scores = scores.min(axis=-1).reshape(-1, 1)
# max_scores = scores.max(axis=-1).reshape(-1, 1)
# scores = (scores - min_scores) / (max_scores - min_scores)
# scores = np.mean(scores, axis=0)

segmentations = np.array(seg_list)
# min_scores = (segmentations.reshape(len(segmentations), -1).min(axis=-1).reshape(-1, 1, 1, 1))
# max_scores = (segmentations.reshape(len(segmentations), -1).max(axis=-1).reshape(-1, 1, 1, 1))
# segmentations = (segmentations - min_scores) / (max_scores - min_scores)
segmentations = np.mean(segmentations, axis=0)

imgs = [np.transpose(x, (1, 2, 0)) for x in test_dataset.list_data]

In [None]:
for idx in range(len(imgs)):
    f, axes = plt.subplots(1, 2)
    axes[0].imshow(imgs[idx])
    axes[1].imshow(segmentations[idx], vmax=7)
    f.tight_layout()