#### Downlaod MvTec (toothbrush) dataset

#### Clone repo

In [None]:
!git clone https://github.com/taikiinoue45/RIAD.git
import sys
sys.path.append("RIAD/riad")

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
from glob import glob
from tqdm.auto import tqdm
import math

import torch
import torch.nn as nn
import torchvision.transforms as transforms


#### EDA

In [None]:
normal_paths = glob("toothbrush/train/good/*.png")

img = cv2.imread(normal_paths[0])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

plt.imshow(img)
plt.show()

len(normal_paths)

In [None]:
defect_paths = sorted(glob("toothbrush/test/defective/*.png"))
mask_paths = sorted(glob("toothbrush/ground_truth/defective/*_mask.png"))

for i, (p1, p2) in enumerate(zip(defect_paths, mask_paths)):
    img = cv2.imread(p1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    mask = cv2.imread(p2)
    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.subplot(1, 2, 2)
    plt.imshow(mask)
    plt.show()
    if i > 4:
        break

#### Model

In [None]:
from models import UNet

#### Dataset

In [None]:
class MVTecDataset(torch.utils.data.Dataset):
    def __init__(self, img_paths, transform, img_size=256):
        self.img_paths = img_paths
        self.transform = transform
        self.img_size = 256

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (self.img_size, self.img_size))
        img = self.transform(img)
        return img

BS = 8
# Preprocess Transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5], 
        std=[0.5, 0.5, 0.5])
])

normal_paths = glob("toothbrush/train/good/*.png")
train_ds = MVTecDataset(normal_paths, transform)
train_loader = torch.utils.data.DataLoader(train_ds, BS, shuffle=True)

val_normal_paths = glob("toothbrush/test/good/*.png")
val_defect_paths = glob("toothbrush/test/defective/*.png")

len(train_ds), len(val_normal_paths), len(val_defect_paths)

In [None]:
img = train_ds[0]
img = (img*0.5 + 0.5).permute(1, 2, 0)

plt.imshow(img)
plt.show()

#### Training

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using: {device}")
cutout_sizes = [2, 4, 8, 16]
num_disjoint_masks = 3

model = UNet().to(device)

optimizer = torch.optim.Adam(model.parameters())

In [None]:
def reconstruct(model, mb_img, cutout_size, num_disjoint_masks):
    _, _, h, w = mb_img.shape
    num_disjoint_masks = num_disjoint_masks
    disjoint_masks = create_disjoint_masks((h, w), cutout_size, num_disjoint_masks)

    mb_reconst = 0
    for mask in disjoint_masks:
        mb_cutout = mb_img * mask
        mb_inpaint = model(mb_cutout)
        mb_reconst += mb_inpaint * (1 - mask)

    return mb_reconst

def create_disjoint_masks(
    img_size,
    cutout_size = 8,
    num_disjoint_masks = 3,
):
    img_h, img_w = img_size
    grid_h = math.ceil(img_h / cutout_size)
    grid_w = math.ceil(img_w / cutout_size)
    num_grids = grid_h * grid_w
    disjoint_masks = []
    for grid_ids in np.array_split(np.random.permutation(num_grids), num_disjoint_masks):
        flatten_mask = np.ones(num_grids)
        flatten_mask[grid_ids] = 0
        mask = flatten_mask.reshape((grid_h, grid_w))
        mask = mask.repeat(cutout_size, axis=0).repeat(cutout_size, axis=1)
        mask = torch.tensor(mask, requires_grad=False, dtype=torch.float)
        mask = mask.to(device)
        disjoint_masks.append(mask)

    return disjoint_masks

In [None]:
from criterions import MSGMSLoss, SSIMLoss

In [None]:
def train_epoch(dataloader, model, optimizer):
    size = len(dataloader.dataset) # number of samples
    num_batches = len(dataloader) # batches per epoch

    model.train() # to training mode.
    epoch_loss = 0
    for batch_i, img in enumerate(tqdm(dataloader)):
        optimizer.zero_grad()

        img = img.to(device)
        cutout_size = np.random.choice(cutout_sizes)
        img_reconstruct = reconstruct(model, img, cutout_size, num_disjoint_masks)

        loss_mse = nn.MSELoss()(img, img_reconstruct)
        loss_msgms = MSGMSLoss()(img, img_reconstruct)
        loss_ssim = SSIMLoss()(img, img_reconstruct)
        loss_total = loss_mse + loss_msgms + loss_ssim

        loss_total.backward()
        optimizer.step()

        epoch_loss += loss_total.item()
    return epoch_loss/num_batches

In [None]:
model = model.to(device)
EPOCHS = 300
logs = {
    'train_loss': []
}
best_loss = np.inf

for epoch in tqdm(range(EPOCHS)):
    train_loss = train_epoch(train_loader, model, optimizer)
    
    print(f'EPOCH: {epoch:04d} train_loss: {train_loss:.4f}')

    logs['train_loss'].append(train_loss)
    
    # On epoch end
    torch.save(model.state_dict(), "last.pth")
    # check improvement
    if train_loss < best_loss:
        best_loss = train_loss
        torch.save(model.state_dict(), "best.pth")

#### Download pre-trained model

In [None]:
!gdown --fuzzy 164Kcrmlk3-wo5UX95YMcCT8nDv_AcHjk

In [None]:
# model.load_state_dict(torch.load("best.pth"))
model.load_state_dict(torch.load("mvtec-tooth-riad.pth"))
_ = model.eval()

#### Test

In [None]:
# Dataset for inference
class MVTecDefectDataset(torch.utils.data.Dataset):
    def __init__(self, img_paths, mask_paths, transform, img_size=256):
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.transform = transform
        self.img_size = 256

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (self.img_size, self.img_size))
        img = self.transform(img)
        
        if self.mask_paths:
            mask_path = self.mask_paths[idx]
            mask = cv2.imread(mask_path)
            mask = cv2.resize(mask, (self.img_size, self.img_size))
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
            mask = mask / 255.
            mask = torch.tensor(mask, dtype=torch.float)
        else:
            c, h, w = img.shape
            mask = torch.zeros((h, w), dtype=torch.float) 
        return img, mask

# defect image with mask    
defect_paths = sorted(glob("toothbrush/test/defective/*.png"))
mask_paths = sorted(glob("toothbrush/ground_truth/defective/*_mask.png"))
val_defect_ds = MVTecDefectDataset(defect_paths, mask_paths, transform)

# normal image with black mask
val_normal_paths = glob("toothbrush/test/good/*.png")
val_normal_ds = MVTecDefectDataset(val_normal_paths, None, transform)

# combine defect and normal dataset
val_ds = torch.utils.data.ConcatDataset([val_defect_ds, val_normal_ds])
val_loader = torch.utils.data.DataLoader(val_ds, BS, shuffle=True)

In [None]:
img, mask = val_ds[0]
img.shape, mask.shape

img = (img*0.5 + 0.5).permute(1, 2, 0)
plt.imshow(img)
plt.show()
plt.imshow(mask)
plt.show()


In [None]:
import torch.nn.functional as F
import matplotlib.pyplot as plt
from numpy import ndarray as NDArray
from sklearn.metrics import roc_auc_score, roc_curve

def mean_smoothing(amaps, kernel_size=21):
    mean_kernel = torch.ones(1, 1, kernel_size, kernel_size) / kernel_size ** 2
    mean_kernel = mean_kernel.to(amaps.device)
    return F.conv2d(amaps, mean_kernel, padding=kernel_size // 2, groups=1)

def compute_auroc(epoch, ep_reconst, ep_gt):
    num_data = len(ep_reconst)
    y_score = ep_reconst.reshape(num_data, -1).max(axis=1)  # y_score.shape -> (num_data,)
    y_true = ep_gt.reshape(num_data, -1).max(axis=1)  # y_true.shape -> (num_data,)
    score = roc_auc_score(y_true, y_score)
    fpr, tpr, thresholds = roc_curve(y_true, y_score)
    plt.plot(fpr, tpr, marker="o", label=f"AUROC Score: {round(score, 3)}")
    plt.xlabel("FPR: FP / (TN + FP)", fontsize=14)
    plt.ylabel("TPR: TP / (TP + FN)", fontsize=14)
    plt.legend(fontsize=14)
    plt.tight_layout()
    plt.show()

    return score


artifacts = {
    "img": [],
    "reconst": [],
    "gt": [],
    "amap": [],
}
for mb_img, mb_gt in tqdm(val_loader):
    mb_amap = 0
    with torch.no_grad():
        for cutout_size in cutout_sizes:
            mb_img = mb_img.to(device)
            mb_reconst = reconstruct(model, mb_img, cutout_size, num_disjoint_masks)
            mb_amap += MSGMSLoss()(mb_img, mb_reconst, as_loss=False)

    mb_amap = mean_smoothing(mb_amap)
    artifacts["amap"].extend(mb_amap.squeeze(1).detach().cpu().numpy())
    mb_img = mb_img*0.5 + 0.5
    artifacts["img"].extend(mb_img.permute(0, 2, 3, 1).detach().cpu().numpy())
    mb_reconst = mb_reconst*0.5 + 0.5
    artifacts["reconst"].extend(mb_reconst.permute(0, 2, 3, 1).detach().cpu().numpy())
    artifacts["gt"].extend(mb_gt.detach().cpu().numpy())

ep_amap = np.array(artifacts["amap"])
ep_amap = (ep_amap - ep_amap.min()) / (ep_amap.max() - ep_amap.min())
artifacts["amap"] = list(ep_amap)

auroc = compute_auroc(epoch, np.array(artifacts["amap"]), np.array(artifacts["gt"]))

#### Visualization

In [None]:
for i, (amap, img, recons, gt) in enumerate(zip(artifacts["amap"], artifacts["img"], artifacts["reconst"], artifacts["gt"])):
    plt.figure(figsize=(20, 5))
    plt.subplot(1, 4, 1)
    plt.title("AMAP")
    plt.imshow(amap, cmap="jet", vmin=0, vmax=1)
    plt.subplot(1, 4, 2)
    plt.title("Img")
    plt.imshow(img)
    plt.subplot(1, 4, 3)
    plt.title("Reconstruction")
    plt.imshow(recons)
    plt.subplot(1, 4, 4)
    plt.title("GT")
    plt.imshow(gt)
    plt.show()
    
    plt.figure(figsize=(20, 4))
    plt.hist(amap.ravel())
    plt.title(f"AMP value min: {amap.min():.4f}, max: {amap.max():.4f}, mean: {amap.mean():.4f}")
    plt.xlim(xmin=0., xmax = 1.)
    plt.axvline(amap.max(), color='red', linestyle='dashed', linewidth=2)
    plt.show()