In [1]:
import common
import sampler
import glass
import backbones
import utils

from perlin import perlin_mask

import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import time
import PIL
from enum import Enum

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl

from PIL import Image
from torchvision import transforms

import glob
from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from argparse import ArgumentParser

parser = ArgumentParser(description="patchcore")
parser.add_argument('--image_size', default=288, type=int) # 288
parser.add_argument('--resize', default=288, type=int) # 288
parser.add_argument('--backbone', default='wideresnet50', type=str) 
parser.add_argument('--layers_to_extract_from', nargs='+', default=['layer2', 'layer3'], type=str) # 연산량에 큰 영향을 주는 파라미터 (성능만 괜찮으면 layer를 일부 제거해도 무방)
parser.add_argument('--pretrain_embed_dimension', default=1024, type=int)
parser.add_argument('--target_embed_dimension', default=1024, type=int)
parser.add_argument('--patchsize', default=3, type=int)
parser.add_argument('--coreset_rate', default=0.1, type=float) # 연산량에 큰 영향을 주는 파라미터 (0.01 - 0.25)
parser.add_argument('--anomaly_scorer_num_nn', default=1, type=int)
parser.add_argument('--batch_size', default=2, type=int)
parser.add_argument('--batch_size_inf', default=6, type=int) # 실제 capacitor cap 데이터의 열에 일치하도록 임의 설정
parser.add_argument('--augment', default=True, type=bool)
parser.add_argument('--seed', default=0, type=int) 
parser.add_argument('--device', nargs='+', default=[0], type=int)
parser.add_argument('--num_workers', default=0, type=int)
args = parser.parse_args('')

image_size = args.image_size
resize = args.resize
BATCH_SIZE = args.batch_size
BATCH_SIZE_INF = args.batch_size_inf
SEED = args.seed

# create save path
save_root = "saved"
if not os.path.exists(save_root):
    os.mkdir(save_root)

# set random seeds
def set_seeds(seed=SEED):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    pl.seed_everything(SEED)

set_seeds()

Seed set to 0


In [3]:
class AnomalyDataset(Dataset):
    def __init__(self, transform=None, dir="../Data/mvtec/screw/train/good", test=False):
        super().__init__()
        self.test = test        
        self.transform = transform
        self.list_dir = sorted(glob.glob(os.path.join(dir, "*.png")))
        self.list_data = []
        if test: self.list_original_data = []

        for idx, dir in enumerate(self.list_dir):
            x = Image.open(dir).convert("RGB")
            if test: self.list_original_data.append(x)
            if self.transform: x = self.transform(x)
            self.list_data.append(x)
        
        print(f"num_data: {len(self.list_data)}")

    def __len__(self):
        return len(self.list_dir)
    
    def __getitem__(self, idx):
        x = self.list_data[idx]
        return {"image": x}

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

class DatasetSplit(Enum):
    TRAIN = "train"
    TEST = "test"

class AnomalyDataset(Dataset):
    def __init__(self, 
                 source="D:/Dev/Data/mvtec/screw",
                 anomaly_source_path="D:/Dev/Data/dtd/images",
                 # dataset_name='mvtec',
                 # classname='screw',
                 resize=288,
                 imagesize=288,
                 split=DatasetSplit.TRAIN,
                 rotate_degrees=0,
                 translate=0,
                 brightness_factor=0,
                 contrast_factor=0,
                 saturation_factor=0,
                 gray_p=0,
                 h_flip_p=0,
                 v_flip_p=0,
                 distribution=0,
                 mean=0.5,
                 std=0.1,
                 fg=0,
                 rand_aug=1,
                 downsampling=8,
                 scale=0,
                 batch_size=8):
        super().__init__()
        self.source = source
        self.split = split
        self.batch_size = batch_size
        self.distribution = distribution
        self.mean = mean
        self.std = std
        self.fg = fg
        self.rand_aug = rand_aug
        self.downsampling = downsampling
        self.resize = resize if self.distribution != 1 else [resize, resize]
        self.imgsize = imagesize
        self.imagesize = (3, self.imgsize, self.imgsize)
        # self.classname = classname
        # self.dataset_name = dataset_name

        self.imgpaths, self.data_to_iterate = self.get_image_data()
        self.anomaly_source_paths = sorted(1 * glob.glob(anomaly_source_path + "/*/*.jpg") +
                                           0 * list(self.imgpaths.values())[0])

        self.transform_img = [
            transforms.Resize(self.resize),
            transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor),
            transforms.RandomHorizontalFlip(h_flip_p),
            transforms.RandomVerticalFlip(v_flip_p),
            transforms.RandomGrayscale(gray_p),
            transforms.RandomAffine(rotate_degrees,
                                    translate=(translate, translate),
                                    scale=(1.0 - scale, 1.0 + scale),
                                    interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(self.imgsize),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ]
        self.transform_img = transforms.Compose(self.transform_img)

        self.transform_mask = [
            transforms.Resize(self.resize),
            transforms.CenterCrop(self.imgsize),
            transforms.ToTensor(),
        ]
        self.transform_mask = transforms.Compose(self.transform_mask)

    def rand_augmenter(self):
        list_aug = [
            transforms.ColorJitter(contrast=(0.8, 1.2)),
            transforms.ColorJitter(brightness=(0.8, 1.2)),
            transforms.ColorJitter(saturation=(0.8, 1.2), hue=(-0.2, 0.2)),
            transforms.RandomHorizontalFlip(p=1),
            transforms.RandomVerticalFlip(p=1),
            transforms.RandomGrayscale(p=1),
            transforms.RandomAutocontrast(p=1),
            transforms.RandomEqualize(p=1),
            transforms.RandomAffine(degrees=(-45, 45)),
        ]
        aug_idx = np.random.choice(np.arange(len(list_aug)), 3, replace=False)

        transform_aug = [
            transforms.Resize(self.resize),
            list_aug[aug_idx[0]],
            list_aug[aug_idx[1]],
            list_aug[aug_idx[2]],
            transforms.CenterCrop(self.imgsize),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ]

        transform_aug = transforms.Compose(transform_aug)
        return transform_aug
    
    def __getitem__(self, idx):
        # TEST, anomaly != "good": [anomaly, img_path, maskpath]
        # else: [anomaly, img_path, None]
        anomaly, image_path, mask_path = self.data_to_iterate[idx] # [anomaly유형, img경로]
        image = PIL.Image.open(image_path).convert("RGB")
        image = self.transform_img(image)
        
        mask_fg = mask_s = aug_image = torch.tensor([1])
        if self.split == DatasetSplit.TRAIN:
            aug = PIL.Image.open(np.random.choice(self.anomaly_source_paths)).convert("RGB")
            if self.rand_aug:
                transform_aug = self.rand_augmenter()
                aug = transform_aug(aug)
            else:
                aug = self.transform_img(aug)

            mask_all = perlin_mask(image.shape, self.imgsize // self.downsampling, 0, 6, mask_fg, 1)
            mask_s = torch.from_numpy(mask_all[0])
            mask_l = torch.from_numpy(mask_all[1])

            beta = np.random.normal(loc=self.mean, scale=self.std)
            beta = np.clip(beta, .2, .8)
            aug_image = image * (1 - mask_l) + (1 - beta) * aug * mask_l + beta * image * mask_l

        if self.split == DatasetSplit.TEST and mask_path is not None:
            mask_gt = PIL.Image.open(mask_path).convert('L')
            mask_gt = self.transform_mask(mask_gt)
        else:
            mask_gt = torch.zeros([1, *image.size()[1:]])

        # print(f"image: {image.shape} mask_gt: {mask_gt.shape}") # [3, 288, 288] [1, 288, 288]

        return {
            "image": image,
            "aug": aug_image,
            "mask_s": mask_s,
            "mask_gt": mask_gt,
            "is_anomaly": int(anomaly != "good"),
            "image_path": image_path,
        }
    
    def __len__(self):
        return len(self.data_to_iterate)
    
    def get_image_data(self):
        imgpaths = {}
        maskpaths = {}

        classpath = os.path.join(self.source, self.split.value)
        # maskpath = os.path.join(self.source, "ground_truth")
        anomaly_types = os.listdir(classpath)

        for anomaly in anomaly_types:
            anomaly_path = os.path.join(classpath, anomaly)
            anomaly_files = sorted(os.listdir(anomaly_path))
            imgpaths[anomaly] = [os.path.join(anomaly_path, x) for x in anomaly_files]

            if self.split == DatasetSplit.TEST and anomaly != "good":
                # anomaly_mask_path = os.path.join(maskpath, anomaly)
                # anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))
                # maskpaths[anomaly] = [os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files]
            else:
                maskpaths["good"] = None

        data_to_iterate = []
        # for classname in sorted(imgpaths_per_class.keys()):
        for anomaly in sorted(imgpaths.keys()):
            for i, image_path in enumerate(imgpaths[anomaly]):
                # data_tuple = [classname, anomaly, image_path]
                data_tuple = [anomaly, image_path]
                if self.split == DatasetSplit.TEST and anomaly != "good":
                    data_tuple.append(maskpaths[anomaly][i])
                else:
                    data_tuple.append(None)
                data_to_iterate.append(data_tuple)

        return imgpaths, data_to_iterate

In [5]:
train_dataset = AnomalyDataset(
    source="D:/Dev/Data/mvtec/speefox",
    anomaly_source_path="D:/Dev/Data/dtd/images",
    resize=960,
    imagesize=960,
    split=DatasetSplit.TRAIN,
    rotate_degrees=0,
    translate=0,
    brightness_factor=0,
    contrast_factor=0,
    saturation_factor=0,
    gray_p=0,
    h_flip_p=0,
    v_flip_p=0,
    distribution=0,
    mean=0.5,
    std=0.1,
    fg=0,
    rand_aug=1,
    downsampling=8,
    scale=0,
    batch_size=8
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0,
    # prefetch_factor=2,
    pin_memory=True,
)

In [6]:
test_dataset = AnomalyDataset(
    source="D:/Dev/Data/mvtec/speefox",
    anomaly_source_path="D:/Dev/Data/dtd/images",
    resize=960,
    imagesize=960,
    split=DatasetSplit.TEST,
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=0,
    # prefetch_factor=2,
    pin_memory=True,
)

FileNotFoundError: [WinError 3] 지정된 경로를 찾을 수 없습니다: 'D:/Dev/Data/mvtec/speefox\\ground_truth\\test'

In [None]:
device = utils.set_torch_device(gpu_ids=args.device)
model = glass.GLASS(device)

backbone = backbones.load(args.backbone)
backbone.name, backbone.seed = args.backbone, None

model.load(
    backbone                 = backbone,
    layers_to_extract_from   = args.layers_to_extract_from,
    device                   = device,
    input_shape              = (3, image_size, image_size), ##
    pretrain_embed_dimension = args.pretrain_embed_dimension,
    target_embed_dimension   = args.target_embed_dimension,
    patchsize                = args.patchsize,
)

# 저장 경로 설정
save_path = "saved"
if not os.path.exists(save_path):
    os.mkdir(save_path)

models_dir = os.path.join(save_path, "models")
model.set_model_dir(os.path.join(models_dir, f"backbone"))

In [None]:
# train
flag = model.trainer(train_dataloader, test_dataloader)

# save model
i_auroc, i_ap, p_auroc, p_ap, p_pro, epoch = model.tester(test_dataloader, dataset_name)
result_collect.append(
    {
        "dataset_name": dataset_name,
        "image_auroc": i_auroc,
        "image_ap": i_ap,
        "pixel_auroc": p_auroc,
        "pixel_ap": p_ap,
        "pixel_pro": p_pro,
        "best_epoch": epoch,
    }
)

In [None]:
# # load model
# model.load_from_path(load_path=save_root, 
#                           device=device, 
#                           nn_method=common.FaissNN(on_gpu=True, num_workers=8))

# # validate -> threshold 설정을 위해 validation 수행
# scores, _ = model.predict(valid_dataloader)

# # set threshold
# threshold = np.max(scores)
# print(f"threshold: {threshold}")

In [None]:
# """
# fault classes
# screw: manipulated_front | scratch_head | scratch_neck | thread_side | thread_top
# bottle: broken_large | broken_small | contamination
# metal_nut: bent | color | flip | scratch
# """
# test_dataset = AnomalyDataset(dir="../Data/mvtec/bottle/test/broken_large", transform=test_transform, test=True)
# test_dataloader = DataLoader(test_dataset, 
#                              batch_size=BATCH_SIZE_INF, 
#                              shuffle=False, 
#                              num_workers=args.num_workers,
#                              pin_memory=True)

In [None]:
# scores_list = []
# seg_list = []

# # inference
# start = time.time()
# scores, seg = model.predict(test_dataloader)
# print(f"Avg Prediction Time: {(time.time() - start) / (len(test_dataloader) * BATCH_SIZE_INF) :.6f}")

# scores_list.append(scores)
# seg_list.append(seg)

# scores = np.max(scores_list, axis=0)
# prediction = np.where(scores < threshold, 0, 1)

# print(f"n_anomaly: {np.sum(prediction)} / {scores.shape[-1]}")

In [None]:
# # scores = np.array(scores_list)
# # min_scores = scores.min(axis=-1).reshape(-1, 1)
# # max_scores = scores.max(axis=-1).reshape(-1, 1)
# # scores = (scores - min_scores) / (max_scores - min_scores)
# # scores = np.mean(scores, axis=0)

# segmentations = np.array(seg_list)
# min_scores = (segmentations.reshape(len(segmentations), -1).min(axis=-1).reshape(-1, 1, 1, 1))
# max_scores = (segmentations.reshape(len(segmentations), -1).max(axis=-1).reshape(-1, 1, 1, 1))
# segmentations = (segmentations - min_scores) / (max_scores - min_scores)
# segmentations = np.mean(segmentations, axis=0)

# # imgs = [np.transpose(x, (1, 2, 0)) for x in test_dataset.list_original_data]
# imgs = [x for x in test_dataset.list_original_data]

In [None]:
# for idx in range(len(imgs)):
#     f, axes = plt.subplots(1, 2)
#     axes[0].imshow(imgs[idx])
#     axes[1].imshow(segmentations[idx])
#     f.tight_layout()