In [None]:
import glass
import backbones
import utils
from perlin import perlin_mask

import os
import glob
import random
import PIL
import numpy as np
import matplotlib.pyplot as plt
from enum import Enum

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import pytorch_lightning as pl

import warnings
warnings.filterwarnings('ignore')

In [None]:
from argparse import ArgumentParser

parser = ArgumentParser(description="glass")
parser.add_argument('--image_size', default=288, type=int) # 288
parser.add_argument('--resize', default=288, type=int) # 288
parser.add_argument('--backbone', default='wideresnet50', type=str) 
parser.add_argument('--layers_to_extract_from', nargs='+', default=['layer2', 'layer3'], type=str) # 연산량에 큰 영향을 주는 파라미터 (성능만 괜찮으면 layer를 일부 제거해도 무방)
parser.add_argument('--pretrain_embed_dimension', default=1024, type=int)
parser.add_argument('--target_embed_dimension', default=1024, type=int)
parser.add_argument('--patchsize', default=3, type=int)
parser.add_argument('--coreset_rate', default=0.1, type=float) # 연산량에 큰 영향을 주는 파라미터 (0.01 - 0.25)
parser.add_argument('--anomaly_scorer_num_nn', default=1, type=int)
parser.add_argument('--batch_size', default=8, type=int)
parser.add_argument('--batch_size_inf', default=8, type=int) # 실제 capacitor cap 데이터의 열에 일치하도록 임의 설정
parser.add_argument('--augment', default=True, type=bool)
parser.add_argument('--seed', default=0, type=int) 
parser.add_argument('--device', nargs='+', default=[0], type=int)
parser.add_argument('--num_workers', default=0, type=int)
args = parser.parse_args('')

image_size = args.image_size
resize = args.resize
BATCH_SIZE = args.batch_size
BATCH_SIZE_INF = args.batch_size_inf
SEED = args.seed

# create save path
save_root = "saved"
if not os.path.exists(save_root):
    os.mkdir(save_root)

# set random seeds
def set_seeds(seed=SEED):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    pl.seed_everything(SEED)

set_seeds()

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

class DatasetSplit(Enum):
    TRAIN = "train"
    TEST = "test"

class AnomalyDataset(Dataset):
    def __init__(self, 
                 source="D:/Dev/Data/mvtec/speefox",
                 anomaly_source_path="D:/Dev/Data/dtd/images",
                 resize=288,
                 imagesize=288,
                 split=DatasetSplit.TRAIN,
                 rotate_degrees=0,
                 translate=0,
                 brightness_factor=0,
                 contrast_factor=0,
                 saturation_factor=0,
                 gray_p=0,
                 h_flip_p=0,
                 v_flip_p=0,
                 distribution=0,
                 mean=0.5,
                 std=0.1,
                 fg=0,
                 rand_aug=1,
                 downsampling=8,
                 scale=0,
                 anomaly_type=None,):
        super().__init__()
        self.source = source
        self.split = split
        self.distribution = distribution
        self.mean = mean
        self.std = std
        self.fg = fg
        self.rand_aug = rand_aug
        self.downsampling = downsampling
        self.resize = resize if self.distribution != 1 else [resize, resize]
        self.imgsize = imagesize
        self.imagesize = (3, self.imgsize, self.imgsize)
        self.anomaly_type = anomaly_type

        self.imgpaths, self.data_to_iterate = self.get_image_data()
        self.anomaly_source_paths = sorted(1 * glob.glob(anomaly_source_path + "/*/*.jpg") +
                                           0 * list(self.imgpaths.values())[0])

        self.transform_img = [
            transforms.Resize(self.resize),
            transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor),
            transforms.RandomHorizontalFlip(h_flip_p),
            transforms.RandomVerticalFlip(v_flip_p),
            transforms.RandomGrayscale(gray_p),
            transforms.RandomAffine(rotate_degrees,
                                    translate=(translate, translate),
                                    scale=(1.0 - scale, 1.0 + scale),
                                    interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(self.imgsize),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ]
        self.transform_img = transforms.Compose(self.transform_img)

        self.transform_mask = [
            transforms.Resize(self.resize),
            transforms.CenterCrop(self.imgsize),
            transforms.ToTensor(),
        ]
        self.transform_mask = transforms.Compose(self.transform_mask)

        # image processing for speefox data
        self.data_to_iterate_image_path = []
        self.data_to_iterate_image = []
        self.data_to_iterate_anomaly = []
        for idx in range(len(self.data_to_iterate)):
            anomaly, image_path = self.data_to_iterate[idx] # [anomaly유형, img경로]
            image = PIL.Image.open(image_path).convert("RGB")
            
            rows, cols = 3, 6
            width, height = image.size

            tile_width = width // cols
            tile_height = height // rows

            for col in range(cols):
                left = col * tile_width
                upper = tile_height
                right = left + tile_width
                lower = tile_height * 2
                tile = image.crop((left, upper, right, lower))
                tile = self.transform_img(tile)
                
                self.data_to_iterate_image_path.append(image_path)
                self.data_to_iterate_image.append(tile)
                self.data_to_iterate_anomaly.append(anomaly)

        print(f"len data per row: {len(self.data_to_iterate_image)} ({len(self.data_to_iterate)}*6)")

    def rand_augmenter(self):
        list_aug = [
            transforms.ColorJitter(contrast=(0.8, 1.2)),
            transforms.ColorJitter(brightness=(0.8, 1.2)),
            transforms.ColorJitter(saturation=(0.8, 1.2), hue=(-0.2, 0.2)),
            transforms.RandomHorizontalFlip(p=1),
            transforms.RandomVerticalFlip(p=1),
            transforms.RandomGrayscale(p=1),
            transforms.RandomAutocontrast(p=1),
            transforms.RandomEqualize(p=1),
            transforms.RandomAffine(degrees=(-45, 45)),
        ]
        aug_idx = np.random.choice(np.arange(len(list_aug)), 3, replace=False)

        transform_aug = [
            transforms.Resize(self.resize),
            list_aug[aug_idx[0]],
            list_aug[aug_idx[1]],
            list_aug[aug_idx[2]],
            transforms.CenterCrop(self.imgsize),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ]

        transform_aug = transforms.Compose(transform_aug)
        return transform_aug
    
    def __getitem__(self, idx):
        image_path = self.data_to_iterate_image_path[idx]
        image = self.data_to_iterate_image[idx]
        anomaly = self.data_to_iterate_anomaly[idx]
        
        mask_fg = mask_s = aug_image = torch.tensor([1])
        if self.split == DatasetSplit.TRAIN:
            aug = PIL.Image.open(np.random.choice(self.anomaly_source_paths)).convert("RGB")
            if self.rand_aug:
                transform_aug = self.rand_augmenter()
                aug = transform_aug(aug)
            else:
                aug = self.transform_img(aug)

            mask_all = perlin_mask(image.shape, self.imgsize // self.downsampling, 0, 6, mask_fg, 1)
            mask_s = torch.from_numpy(mask_all[0])
            mask_l = torch.from_numpy(mask_all[1])

            beta = np.random.normal(loc=self.mean, scale=self.std)
            beta = np.clip(beta, .2, .8)
            aug_image = image * (1 - mask_l) + (1 - beta) * aug * mask_l + beta * image * mask_l
        return {
            "image": image,
            "aug": aug_image,
            "mask_s": mask_s,
            "is_anomaly": int(anomaly != "good"),
            "image_path": image_path,
        }
    
    def __len__(self):
        return len(self.data_to_iterate * 5)
    
    def get_image_data(self):
        imgpaths = {}

        classpath = os.path.join(self.source, self.split.value) # e.g. D:/Dev/Data/mvtec/speefox/test
        if not self.anomaly_type:
            anomaly_types = os.listdir(classpath) # ["good", "A", "AI", "B", "E", "P", "T", "U", "V", "W"]
            for anomaly in anomaly_types:
                anomaly_path = os.path.join(classpath, anomaly) # e.g. D:/Dev/Data/mvtec/speefox/test/A
                anomaly_files = sorted(os.listdir(anomaly_path))
                imgpaths[anomaly] = [os.path.join(anomaly_path, x) for x in anomaly_files]
        else: # anomaly_type이 명시된 경우
            anomaly_path = os.path.join(classpath, self.anomaly_type)
            anomaly_files = sorted(os.listdir(anomaly_path))
            imgpaths[self.anomaly_type] = [os.path.join(anomaly_path, x) for x in anomaly_files]


        data_to_iterate = []
        for anomaly in sorted(imgpaths.keys()):
            for i, image_path in enumerate(imgpaths[anomaly]):
                data_tuple = [anomaly, image_path]
                data_to_iterate.append(data_tuple)
        return imgpaths, data_to_iterate

In [None]:
device = utils.set_torch_device(gpu_ids=args.device)
model = glass.GLASS(device)

backbone = backbones.load(args.backbone)
backbone.name, backbone.seed = args.backbone, None

model.load(
    backbone                 = backbone,
    layers_to_extract_from   = args.layers_to_extract_from,
    device                   = device,
    input_shape              = (3, image_size, image_size), ##
    pretrain_embed_dimension = args.pretrain_embed_dimension,
    target_embed_dimension   = args.target_embed_dimension,
    patchsize                = args.patchsize,
    meta_epochs              = 640,
    eval_epochs              = 1,
    dsc_layers               = 4,
    dsc_hidden               = 1024,
    dsc_margin               = 0.5,
)

# 저장 경로 설정
save_path = "saved"
if not os.path.exists(save_path):
    os.mkdir(save_path)

models_dir = os.path.join(save_path, "models")
model.set_model_dir(os.path.join(models_dir, f"backbone"))

In [None]:
"""
anomaly_type: ["good", "A", "AI", "B", "E", "P", "T", "U", "V", "W"]
"""

# normal data
test_dataset = AnomalyDataset(
    source="D:/Dev/Data/speefox",
    anomaly_source_path="D:/Dev/Data/dtd/images",
    resize=image_size,
    imagesize=image_size,
    split=DatasetSplit.TEST,
    anomaly_type="good"
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=6,
    shuffle=False,
    num_workers=0,
    # prefetch_factor=2,
    pin_memory=True,
)

# images_per_cap_normal, scores_per_cap_normal, masks_per_cap_normal = model.tester_speefox(test_dataloader)
scores_per_cap_normal = model.tester_speefox(test_dataloader)

In [None]:
"""
anomaly_type: ["good", "A", "AI", "B", "E", "P", "T", "U", "V", "W"]
"""

# abnormal data
test_dataset = AnomalyDataset(
    source="D:/Dev/Data/speefox",
    anomaly_source_path="D:/Dev/Data/dtd/images",
    resize=image_size,
    imagesize=image_size,
    split=DatasetSplit.TEST,
    anomaly_type="W"
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=6,
    shuffle=False,
    num_workers=0,
    # prefetch_factor=2,
    pin_memory=True,
)

scores_per_cap_abnormal = model.tester_speefox(test_dataloader)

In [None]:
plt.figure(figsize=(30, 8))
plt.scatter(x=np.arange(len(scores_per_cap_normal)), y=scores_per_cap_normal, c="blue")
plt.scatter(x=np.arange(len(scores_per_cap_normal), len(scores_per_cap_normal)+len(scores_per_cap_abnormal)), y=scores_per_cap_abnormal, c="red")