In [None]:
import argparse
import os
import sys
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from torchvision import transforms
import timm
from sklearn.metrics import roc_auc_score
from scipy.ndimage import gaussian_filter
import re
import shutil
from pathlib import Path
from scipy.ndimage import gaussian_filter, map_coordinates
import cv2
import noise
from rich.console import Console
from rich.table import Table
from tqdm import tqdm
from PIL import ImageFilter
import io

In [None]:
import warnings
warnings.filterwarnings("ignore")

# Anomaly Generator

In [None]:
bounds = {
    "mvtec" : {
        "toothbrush":(14, 39), "cable":(14, 65), "screw":(9, 35), "transistor":(14, 86), "capsule":(14, 39) , "bottle":(14, 65), "hazelnut":(14, 65), "metal_nut":(14, 65), "pill":(14, 39), "zipper":(14, 65),    
        "wood":(14, 86), "carpet":(14, 86), "grid":(14, 86), "leather":(14, 86), "tile":(14, 86)
    },
    'visa' : {
        'candle':(14, 86), 'capsules':(6, 86), 'cashew':(10, 86), 'chewinggum':(14, 86), 'fryum':(6, 15), 'macaroni1':(14, 86),
        'macaroni2':(14, 86), 'pcb1':(14, 86), 'pcb2':(14, 86), 'pcb3':(14, 86), 'pcb4':(14, 86), 'pipe_fryum':(14, 86)
    },
    'mpdd' : {
        'bracket_black':(6, 86), 'bracket_brown':(6, 86), 'bracket_white':(6, 86), 'connector':(6, 86), 'metal_plate':(14, 86), 'tubes':(6, 86) 
    },
    'btad' : {
        '01':(9, 86), '02':(14, 86), '03':(9, 86)
    },
    'dtd' : {
        'Blotchy_099':(14, 86) , 'Fibrous_183':(14, 86) , 'Marbled_078':(14, 86) , 'Matted_069':(14, 86) , 'Mesh_114':(14, 86) , 'Perforated_037':(14, 86) , 'Stratified_154':(14, 86) , 'Woven_001':(14, 86) , 'Woven_068':(14, 86) , 'Woven_104':(14, 86) , 'Woven_125':(14, 86) , 'Woven_127':(14, 86)
    },
    'brats2021' : {
        "":(14, 86)
    },
    'headct' : {
        "":(9, 86)
    },
    'wfdd' : {
        "grey_cloth":(9, 86), "grid_cloth":(9, 86), "pink_flower":(9, 86), "yellow_cloth":(9, 86)
    }
}

class RandomAugmentations:
    def __init__(self, seed=None):
        self.seed = seed

        self.param_ranges = {
            'brightness': {'light': (0.1, 0.1), 'medium': (0.4, 0.4), 'heavy': (0.8, 1)},
            'contrast':   {'light': (0.1, 0.1), 'medium': (0.4, 0.4),  'heavy': (0.8, 1)},
            'saturation': {'light': (0.1, 0.1), 'medium': (0.4, 0.4), 'heavy': (0.8, 1)},
            'hue':        {'light': (0.1, 0.1), 'medium': (0.3, 0.3), 'heavy': (0.5, 0.5)},

            'elastic_alpha': {'light': (10, 20), 'medium': (20, 40), 'heavy': (40, 100)},

            'torn_lines': {'light': (1, 3), 'medium': (5, 10), 'heavy': (10, 20)},
            
            'perlin_scale': {'light': (20, 50), 'medium': (10, 20), 'heavy': (5, 10)},
            'perlin_threshold': {'light': (200, 255), 'medium': (150, 200), 'heavy': (128, 150)},

            'swirl_strength': {'light': (0.5, 1.0), 'medium': (1.0, 1.5), 'heavy': (1.5, 2)},

            'erase_ratio': {'light': (0.01, 0.05), 'medium': (0.05, 0.1), 'heavy': (0.1, 0.2)},
            'erase_rects': {'light': (1, 2), 'medium': (2, 3), 'heavy': (3, 5)},

            'blur_radius': {'light': (0.2, 0.5), 'medium': (0.8, 1.5), 'heavy': (2.0, 3)},

            'jpeg_quality': {'light': (30, 50), 'medium': (20, 30), 'heavy': (1, 20)},
        }

        self.num_augmentations = {'light': 1, 'medium': 2, 'heavy': 4}

        self.augmentations = {
            "light": [
                self.gaussian_blur, self.elastic_transform, self.swirl_distortion,
            ],
            "medium": [
                self.gaussian_blur, self.swirl_distortion, self.jpeg_artifacts, 
                self.elastic_transform
            ],
            "heavy": [
                self.elastic_transform, self.torn_paper_effect, self.jpeg_artifacts,
                self.perlin_noise_mask, self.swirl_distortion, self.random_erasing, 
                self.gaussian_blur,
            ]
        }


    def apply(self, image, level='medium'):
        image_np = np.array(image)
        n_augmentations = random.randint(0, self.num_augmentations[level])
        selected_augmentations = random.sample(self.augmentations[level], n_augmentations)
        
        selected_augmentations.insert(random.randint(0, len(selected_augmentations)), self.color_transformation)

        for augmentation in selected_augmentations:
            image_np = augmentation(image_np, level)

        return Image.fromarray(image_np)

    def elastic_transform(self, image, level):
        alpha = random.uniform(*self.param_ranges['elastic_alpha'][level])
        sigma = 3.0
        random_state = np.random.RandomState(self.seed)
        shape = image.shape

        dx = gaussian_filter((random_state.rand(*shape[:2]) * 2 - 1), sigma, mode="reflect") * alpha
        dy = gaussian_filter((random_state.rand(*shape[:2]) * 2 - 1), sigma, mode="reflect") * alpha
        x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
        indices = (y + dy).flatten(), (x + dx).flatten()

        distorted_image = np.zeros_like(image)
        for i in range(shape[2]):
            distorted_image[..., i] = map_coordinates(image[..., i], indices, order=1, mode='reflect').reshape(shape[:2])
        return distorted_image

    def torn_paper_effect(self, image, level):
        image_np = image.copy()
        height, width = image_np.shape[:2]
        num_lines = random.randint(*self.param_ranges['torn_lines'][level])
        for _ in range(num_lines):
            start_x = np.random.randint(0, width)
            start_y = np.random.randint(0, height)
            end_x = np.random.randint(0, width)
            end_y = np.random.randint(0, height)
            cv2.line(image_np, (start_x, start_y), (end_x, end_y), [random.choice([0, 255]) for _ in range(3)], thickness=1)
        return image_np

    def perlin_noise_mask(self, image, level):
        scale = random.uniform(*self.param_ranges['perlin_scale'][level])
        threshold = random.randint(*self.param_ranges['perlin_threshold'][level])
        height, width = image.shape[:2]
        mask = np.zeros((height, width), dtype=np.float32)
        for i in range(height):
            for j in range(width):
                mask[i, j] = noise.pnoise2(i / scale, j / scale, octaves=6)
        mask = (mask - mask.min()) / (mask.max() - mask.min()) * 255
        image[mask > threshold] = np.random.randint(0, 255, 3)
        return image

    def color_transformation(self, image, level):
        b = random.uniform(*self.param_ranges['brightness'][level])
        c = random.uniform(*self.param_ranges['contrast'][level])
        s = random.uniform(*self.param_ranges['saturation'][level])
        h = random.uniform(*self.param_ranges['hue'][level])
        transform = transforms.ColorJitter(brightness=b, contrast=c, saturation=s, hue=h)
        return np.array(transform(Image.fromarray(image)))

    def swirl_distortion(self, image, level):
        strength = random.uniform(*self.param_ranges['swirl_strength'][level])
        patch_np = np.array(image)
        height, width = patch_np.shape[:2]
        center_x, center_y = width // 2, height // 2
        y, x = np.indices((height, width))
        x = x - center_x
        y = y - center_y
        distance = np.sqrt(x**2 + y**2)
        angle = strength * np.exp(-distance**2 / (2 * (min(height, width) // 3)**2))
        new_x = center_x + x * np.cos(angle) - y * np.sin(angle)
        new_y = center_y + x * np.sin(angle) + y * np.cos(angle)
        map_x = np.clip(new_x, 0, width - 1).astype(np.float32)
        map_y = np.clip(new_y, 0, height - 1).astype(np.float32)
        return cv2.remap(patch_np, map_x, map_y, interpolation=cv2.INTER_LINEAR)

    def random_erasing(self, image, level):
        image_np = image.copy()
        h, w = image_np.shape[:2]
        erase_area_ratio = random.uniform(*self.param_ranges['erase_ratio'][level])
        num_rectangles = random.randint(*self.param_ranges['erase_rects'][level])
        for _ in range(num_rectangles):
            erase_area = int(erase_area_ratio * h * w)
            erase_aspect_ratio = random.uniform(0.3, 3.3)
            erase_height = int(np.sqrt(erase_area / erase_aspect_ratio))
            erase_width = int(erase_aspect_ratio * erase_height)
            x = random.randint(0, max(0, w - erase_width))
            y = random.randint(0, max(0, h - erase_height))
            image_np[y:y+erase_height, x:x+erase_width] = np.random.randint(0, 255, 3)
        return image_np

    def gaussian_blur(self, image, level):
        radius = random.uniform(*self.param_ranges['blur_radius'][level])
        return np.array(Image.fromarray(image).filter(ImageFilter.GaussianBlur(radius=radius)))

    def jpeg_artifacts(self, image, level):
        quality = random.randint(*self.param_ranges['jpeg_quality'][level])
        pil_image = Image.fromarray(image)
        buffer = io.BytesIO()
        pil_image.save(buffer, format='JPEG', quality=quality)
        return np.array(Image.open(buffer))
    

class AnomalyGenerator(object):
    def __init__(self, dataset, class_name, seed):
        self.lower_bound, self.upper_bound = bounds[dataset][class_name]

        self.random_augmentor = RandomAugmentations(seed=seed)

    def rotate(self, patch, width, height, min_angle=-90, max_angle=90):
        random_rotate = random.uniform(min_angle, max_angle)
        patch = patch.convert("RGBA").rotate(random_rotate, expand=True)
        patch = patch.resize((width, height), resample=Image.BICUBIC)
        mask = patch.split()[-1]
        
        return patch.convert("RGB"), mask

    def intersect_masks(self, mask1, mask2):
        mask1_np = np.array(mask1)
        mask2_np = np.array(mask2)
    
        intersection = np.logical_and(mask1_np, mask2_np).astype(np.uint8) * 255
        intersection_mask = Image.fromarray(intersection)
    
        return intersection_mask
    
    def get_max_shape(self, x, y, foreground_mask):
        max_width = 0
        for i in range(x, foreground_mask.shape[1]):
            if foreground_mask[y, i] == 1:
                max_width += 1
            else:
                break

        max_height = 0
        for j in range(y, foreground_mask.shape[0]):
            if foreground_mask[j, x] == 1:
                max_height += 1
            else:
                break
                
        return max_width, max_height
    
    def sample_patch_size(self, foreground_mask, x1, y1, x2, y2, max_width, max_height):
        num_attempts = 0
        while num_attempts <= 10:
            patch_width = random.randint(0, max_width)
            patch_height = random.randint(0, max_height)

            patch_region_src = foreground_mask[y1:y1+patch_height, x1:x1+patch_width]
            patch_region_dst = foreground_mask[y2:y2+patch_height, x2:x2+patch_width]

            if np.all(patch_region_src == 1) and np.all(patch_region_dst == 1):
                break
                
            num_attempts += 1

        return patch_width, patch_height
    
    def expand_mask(self, mask, kernel_size=(3, 3)):
        kernel = np.ones(kernel_size, np.uint8)
        expanded_mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=10)
        
        return expanded_mask
    
    def sample_coordinate_shape(self, foreground_mask):
        foreground_mask = self.expand_mask(foreground_mask)
        
        h, w = foreground_mask.shape
        coords = np.column_stack(np.where(foreground_mask == 1))
        
        num_attempts = 0
        while num_attempts <= 250:
            y1, x1 = coords[random.randint(0, len(coords) - 1)]
            max_width1, max_height1 = self.get_max_shape(x1, y1, foreground_mask)       

            y2, x2 = coords[random.randint(0, len(coords) - 1)]
            max_width2, max_height2 = self.get_max_shape(x2, y2, foreground_mask) 

            max_width, max_height = min(max_width1, max_width2), min(max_height1, max_height2)

            patch_width, patch_height = self.sample_patch_size(foreground_mask, x1, y1, x2, y2, max_width, max_height)

            if patch_width < self.lower_bound or patch_height < self.lower_bound:
                num_attempts += 1

            else:
                patch_region_src = foreground_mask[y1:y1+patch_height, x1:x1+patch_width]
                patch_region_dst = foreground_mask[y2:y2+patch_height, x2:x2+patch_width]

                if np.all(patch_region_src == 1) and np.all(patch_region_dst == 1) and (x2 + patch_width <= w and y2 + patch_height <= h):
                    break

                num_attempts += 1
        
        if num_attempts > 250:
            y1, x1 = coords[random.randint(0, len(coords) - 1)]
            y2, x2 = coords[random.randint(0, len(coords) - 1)]
            patch_width, patch_height = self.lower_bound, self.lower_bound

        return x1, y1, x2, y2, patch_width, patch_height
            
    def __call__(self, imgs, foreground_masks):
        batch_size, _, h, w = imgs.shape
        transformed_imgs = []
        transformed_masks = []

        for i in range(batch_size):
            img = imgs[i].cpu()
            img_pil = transforms.ToPILImage()(img)
            foreground_mask = foreground_masks[i].cpu().squeeze(0).numpy()

            x1, y1, x2, y2, patch_width, patch_height = self.sample_coordinate_shape(foreground_mask)

            patch = img_pil.crop((x1, y1, x1 + int(patch_width), y1 + int(patch_height)))

            # transformations 
            patch = self.random_augmentor.apply(patch, np.random.choice(['light', 'medium', 'heavy'], p=[0.2,0.2,0.6]))

            patch, rotation_mask = self.rotate(patch, patch_width, patch_height)
            
            mask = np.ones((int(patch_height), int(patch_width)), dtype=np.uint8)
            mask = cv2.resize(mask, (int(patch_width), int(patch_height)), interpolation=cv2.INTER_CUBIC)
            mask = self.intersect_masks(mask, rotation_mask)
                                    
            augmented = img_pil.copy()

            augmented.paste(patch, (x2, y2), mask=mask)

            org_mask = Image.fromarray(np.zeros((h, w), dtype='uint8')).convert('L')
            org_mask.paste(mask, (x2, y2))
            
            augmented = transforms.ToTensor()(augmented)
            org_mask = transforms.ToTensor()(org_mask)

            transformed_imgs.append(augmented)
            transformed_masks.append(org_mask)

        transformed_imgs = torch.stack(transformed_imgs)
        transformed_masks = torch.stack(transformed_masks)

        return transformed_imgs, transformed_masks

# PatchGuard

In [None]:
def my_forward_wrapper(attn_obj):
    def my_forward(x):
        B, N, C = x.shape
        qkv = attn_obj.qkv(x).reshape(B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   

        attn = (q @ k.transpose(-2, -1)) * attn_obj.scale
        attn = attn.softmax(dim=-1)
        attn = attn_obj.attn_drop(attn)
        attn_obj.attn_map = attn
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = attn_obj.proj(x)
        x = attn_obj.proj_drop(x)
        return x

    return my_forward

class FeatureExtractor(nn.Module):
    def __init__(self, hf_path, feature_layer_indices, reg_layer_indices, image_size, device):
        super(FeatureExtractor, self).__init__()

        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        self.mu = torch.tensor(mean).view(1, 3, 1, 1).to(device)
        self.std = torch.tensor(std).view(1, 3, 1, 1).to(device)
        self.norm = lambda x: (x - self.mu) / self.std

        self.feature_layer_indices = feature_layer_indices
        self.reg_layer_indices = reg_layer_indices

        self.pretrained_model = timm.create_model(hf_path, pretrained=False, num_classes=0, img_size=image_size).to(device)

        self.embed_dim = len(feature_layer_indices) * self.pretrained_model.embed_dim
        self.patch_size = self.pretrained_model.patch_embed.patch_size[0]
        self.num_patches = (image_size // self.patch_size) ** 2

        pattern = r'reg(\d+)'
        match = re.search(pattern, hf_path)
        self.start_index = int(match.group(1)) + 1 if match else 1

        indices = set(feature_layer_indices + reg_layer_indices)
        for i in indices:
            self.pretrained_model.blocks[i-1].attn.forward = my_forward_wrapper(self.pretrained_model.blocks[i-1].attn)
            
    def forward(self, x, use_reg=True):
        x = self.norm(x)
        x = self.pretrained_model.patch_embed(x)
        x = self.pretrained_model._pos_embed(x)
        x = self.pretrained_model.patch_drop(x)
        x = self.pretrained_model.norm_pre(x)

        out = []
        attention_weights = []

        # iterating through the layers up to last layer to extract from => 12 layers
        for idx, layer in enumerate(self.pretrained_model.blocks, start=1):
            x = layer(x)

            if idx in self.feature_layer_indices:
                features_layer = self.pretrained_model.norm(x[:, self.start_index:, :])
                out.append(features_layer)

            if use_reg and idx in self.reg_layer_indices:
                attention_map = layer.attn.attn_map
                attention_weights.append(attention_map[:, :, 1:, 1:])  # Remove CLS token

            if idx == max(self.feature_layer_indices):  
                break

        features = torch.cat(out, dim=-1)

        return (features, attention_weights) if use_reg else (features, None)
                
class AttentionBlock(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim)
        )
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        attn_output, _ = self.attn(x, x, x)
        x = self.layer_norm_1(x + self.dropout1(attn_output))
        x = x + self.dropout2(self.linear(x))

        return x

class Discriminator(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_patches, num_layers=1, num_heads=12, dropout_rate=0):
        super(Discriminator, self).__init__()
        self.transformer_encoder = nn.Sequential(*[AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout_rate) for _ in range(num_layers)])
        self.output_layer = nn.Sequential(
            nn.Linear(embed_dim, 1)
        )
        self.positional_encodings = nn.Parameter(torch.randn(num_patches, embed_dim))

    def forward(self, x):
        x = x + self.positional_encodings.unsqueeze(0)
        x = self.transformer_encoder(x)
        x = self.output_layer(x).squeeze(-1)

        return x

class PatchGuard(nn.Module):
    def __init__(self, args, device):
        super().__init__()

        self.feature_extractor = FeatureExtractor(args["hf_path"], args["feature_layers"], args["reg_layers"], args["image_size"], device)

        embed_dim = self.feature_extractor.embed_dim
        self.num_patches = self.feature_extractor.num_patches
        self.patch_size = self.feature_extractor.patch_size
        self.patches_per_side = int(np.sqrt(self.num_patches))

        self.discriminator = Discriminator(embed_dim, args["hidden_dim"], self.num_patches, args["dsc_layers"], args["dsc_heads"], 0.2).to(device)

    def forward(self, x):
        embeddings, _ = self.feature_extractor(x, False)
        scores = self.discriminator(embeddings)
        return scores

# Attack

In [None]:
def pgd_attack(model, images, masks, epsilon, num_iter):
    alpha =  (2.5 * epsilon) / num_iter

    X = images.clone().detach()
    original_X = X.data

    for i in range(num_iter) :    
        X.requires_grad = True

        scores = model(X)

        zeros_count = (masks == 0).sum(dim=1)
        non_zeros_count = (masks != 0).sum(dim=1)

        anomalous_loss = (masks * scores).sum(dim=1) / (non_zeros_count + 1e-8)
        normal_loss = ((1 - masks) * scores).sum(dim=1) / (zeros_count + 1e-8)
        loss = normal_loss.sum() - anomalous_loss.sum()
        loss.backward()

        adv_X = X + alpha * X.grad.sign()
        delta = torch.clamp(adv_X - original_X, min=-epsilon, max=epsilon)
        X = torch.clamp(original_X + delta, min=0, max=1).detach_()

    return X


# Loss

In [None]:
class Loss(nn.Module):
    def __init__(self, reg_type, device, num_patches):
        super(Loss, self).__init__()
        self.reg_type = reg_type
        self.device = device
        self.num_patches = num_patches
        self.reg_weights = {"KL_divergence":0.01, "L2_norm":0.1, "soft_R":0.1, "R":0.1}
        self.reg_hypers = {"R":{"tau":0.01}}

    def forward(self, scores, masks, attn_weights=None):
        loss_per_patch = F.binary_cross_entropy_with_logits(scores, masks, reduction='none')
        loss_per_image = loss_per_patch.sum(dim=1)

        # zeros_count = (masks == 0).sum(dim=1)
        # non_zeros_count = (masks != 0).sum(dim=1)

        # anomalous_loss = (masks * loss_per_patch).sum(dim=1) / (non_zeros_count + 1e-8)
        # normal_loss = ((1 - masks) * loss_per_patch).sum(dim=1) / (zeros_count + 1e-8)

        # loss_per_image = normal_loss.sum() + anomalous_loss.sum()

        reg_term = 0
        if attn_weights is not None:
            reg_term = self.reg(attn_weights)  

        total_loss = loss_per_image.mean() + 1 * reg_term
        return total_loss

    def reg(self, attn_weights):
        reg_term = 0
        coefs = [0.005, 0.01, 0.05]
        if self.reg_type == "KL_divergence":
            for i, attn_weight in enumerate(attn_weights):
                reg_term += coefs[i] * F.kl_div(torch.log(attn_weight + 1e-8),  torch.full_like(attn_weight, 1 / self.num_patches).to(self.device), reduction='mean')
        
        # not recommended        
        elif self.reg_type == "L2_norm":
            for attn_weight in attn_weights:
                reg_term += attn_weight.norm(p='fro')
        elif self.reg_type == "soft_R":
            for attn_weight in attn_weights:
                soft_mask = torch.sigmoid((self.delta - attn_weight) / self.reg_hypers["R"]["tau"])  # Smooth threshold
                reg_term += 1 / (torch.sum(attn_weight * soft_mask) + 1e-8)
                #reg_term -= torch.sum(attn_weight * soft_mask)
        elif self.reg_type == "R":
            for attn_weight in attn_weights:
                mask = (attn_weight <= 1 / self.num_patches).float() # Hard threshold
                reg_term += 1 / (torch.sum(attn_weight * mask) + 1e-8) # Summing selected attention values
                #reg_term -= torch.sum(attn_weight * mask) # Summing selected attention values

        
        return reg_term

# Dataset

In [None]:
class MVTec(data.Dataset):
    def __init__(self, path, class_name, transform=None, mask_transform=None, seed=0, split='train', size=224):
        self.transform = transform
        self.mask_transform = mask_transform
        self.data = []
        self.size = size
        self.has_fg_mask = class_name in ['bottle', 'cable', 'capsule', 'hazelnut', 'metal_nut', 'pill', 'screw', 'toothbrush', 'zipper']

        path = os.path.join(path, class_name)
        mv_str = '_mask.'

        normal_dir = os.path.join(path, split, "good")
        
        if split == 'train' and self.has_fg_mask:
            self.foreground_mask_path = os.path.join(path, split, "foreground_mask")
            
        for img_file in os.listdir(normal_dir):
            image_dir = os.path.join(normal_dir, img_file)
            foreground_mask_dir = None
            if split == 'train' and self.has_fg_mask:
                foreground_mask_dir = os.path.join(self.foreground_mask_path, img_file)
            self.data.append((image_dir, None, foreground_mask_dir))
            
        if split == 'test':
            test_dir = os.path.join(path, "test")
            test_anomaly_dirs = []
            for entry in os.listdir(test_dir):
                full_path = os.path.join(test_dir, entry)

                if os.path.isdir(full_path) and full_path != normal_dir:
                    test_anomaly_dirs.append(full_path)

            for dir in test_anomaly_dirs:
                for img_file in os.listdir(dir):
                    image_dir = os.path.join(dir, img_file)
                    mask_dir = image_dir.replace("test", "ground_truth")
                    parts = mask_dir.rsplit('.', 1)
                    mask_dir = parts[0] + mv_str + parts[1]
                    self.data.append((image_dir, mask_dir, None))

            random.shuffle(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, mask_path, fore_path = self.data[idx]
        image = Image.open(img_path).convert('RGB')  

        image = self.transform(image)          

        if mask_path:
            mask = Image.open(mask_path).convert('RGB')
            mask = self.mask_transform(mask)
            mask = 1.0 - torch.all(mask == 0, dim=0).float()
            label = 1
        else:
            C, W, H = image.shape
            mask = torch.zeros((H, W))
            label = 0
            
        C, W, H = image.shape
        foreground_mask = torch.ones((H, W))
        if fore_path:
            foreground_mask = Image.open(fore_path).convert('L')
            foreground_mask = foreground_mask.resize((self.size, self.size), Image.NEAREST)
            foreground_mask = transforms.ToTensor()(foreground_mask)
            
        return image, label, mask, foreground_mask
    

class MPDD(data.Dataset):
    def __init__(self, path, class_name, transform=None, mask_transform=None, seed=0, split='train', size=224):
        self.transform = transform
        self.mask_transform = mask_transform
        self.data = []
        self.size = size
        self.has_fg_mask = class_name in ['bracket_black', 'bracket_brown', 'bracket_white', 'connector', 'metal_plate', 'tubes']

        path = os.path.join(path, class_name)
        mv_str = '_mask.'

        normal_dir = os.path.join(path, split, "good")
        
        if split == 'train' and self.has_fg_mask:
            self.foreground_mask_path = os.path.join(path, split, "foreground_mask")
            os.makedirs(self.foreground_mask_path, exist_ok=True)
            
        for img_file in os.listdir(normal_dir):
            image_dir = os.path.join(normal_dir, img_file)
            foreground_mask_dir = None
            if split == 'train' and self.has_fg_mask:
                foreground_mask_dir = os.path.join(self.foreground_mask_path, img_file)
            self.data.append((image_dir, None, foreground_mask_dir))
            
        if split == 'test':
            test_dir = os.path.join(path, "test")
            test_anomaly_dirs = []
            for entry in os.listdir(test_dir):
                full_path = os.path.join(test_dir, entry)

                if os.path.isdir(full_path) and full_path != normal_dir:
                    test_anomaly_dirs.append(full_path)

            for dir in test_anomaly_dirs:
                for img_file in os.listdir(dir):
                    image_dir = os.path.join(dir, img_file)
                    mask_dir = image_dir.replace("test", "ground_truth")
                    parts = mask_dir.rsplit('.', 1)
                    mask_dir = parts[0] + mv_str + parts[1]
                    self.data.append((image_dir, mask_dir, None))

            random.shuffle(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, mask_path, fore_path = self.data[idx]
        image = Image.open(img_path).convert('RGB')  

        image = self.transform(image)          

        if mask_path:
            mask = Image.open(mask_path).convert('RGB')
            mask = self.mask_transform(mask)
            mask = 1.0 - torch.all(mask == 0, dim=0).float()
            label = 1
        else:
            C, W, H = image.shape
            mask = torch.zeros((H, W))
            label = 0
            
        C, W, H = image.shape
        foreground_mask = torch.ones((H, W))
        if fore_path:
            foreground_mask = Image.open(fore_path).convert('L')
            foreground_mask = foreground_mask.resize((self.size, self.size), Image.NEAREST)
            foreground_mask = transforms.ToTensor()(foreground_mask)
            
        return image, label, mask, foreground_mask
    
class BTAD(data.Dataset):
    def __init__(self, path, class_name, transform=None, mask_transform=None, seed=0, split='train', size=224):
        self.transform = transform
        self.mask_transform = mask_transform
        self.data = []
        self.size = size
        self.has_fg_mask = class_name in ['01', '03']

        path = os.path.join(path, class_name)
        
        normal_dir = os.path.join(path, split, "ok")
        
        if split == 'train' and self.has_fg_mask:
            self.foreground_mask_path = os.path.join(path, split, "foreground_mask")
            os.makedirs(self.foreground_mask_path, exist_ok=True)
            
        for img_file in os.listdir(normal_dir):
            image_dir = os.path.join(normal_dir, img_file)
            foreground_mask_dir = None
            if split == 'train' and self.has_fg_mask:
                foreground_mask_dir = os.path.join(self.foreground_mask_path, img_file)
            self.data.append((image_dir, None, foreground_mask_dir))
            
        if split == 'test':
            test_dir = os.path.join(path, "test")
            test_anomaly_dirs = []
            for entry in os.listdir(test_dir):
                full_path = os.path.join(test_dir, entry)

                if os.path.isdir(full_path) and full_path != normal_dir:
                    test_anomaly_dirs.append(full_path)

            for dir in test_anomaly_dirs:
                for img_file in os.listdir(dir):
                    image_dir = os.path.join(dir, img_file)
                    mask_dir = image_dir.replace("test", "ground_truth")
                    if class_name in ["01", "02"]:
                        mask_dir = mask_dir.replace('.bmp', '.png')
                    self.data.append((image_dir, mask_dir, None))

            random.shuffle(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, mask_path, fore_path = self.data[idx]
        image = Image.open(img_path).convert('RGB')  

        image = self.transform(image)          

        if mask_path:
            mask = Image.open(mask_path).convert('RGB')
            mask = self.mask_transform(mask)
            mask = 1.0 - torch.all(mask == 0, dim=0).float()
            label = 1
        else:
            C, W, H = image.shape
            mask = torch.zeros((H, W))
            label = 0
            
        C, W, H = image.shape
        foreground_mask = torch.ones((H, W))
        if fore_path:
            foreground_mask = Image.open(fore_path).convert('L')
            foreground_mask = foreground_mask.resize((self.size, self.size), Image.NEAREST)
            foreground_mask = transforms.ToTensor()(foreground_mask)
            
        return image, label, mask, foreground_mask

class VisA(data.Dataset):
    def __init__(self, path, class_name, transform=None, mask_transform=None, seed=0, split='train', size=224):
        self.path_normal = os.path.join(path, class_name, "Data", "Images", "Normal")
        self.path_anomaly = os.path.join(path, class_name, "Data", "Images", "Anomaly")
        self.foreground_mask_path = os.path.join(path, class_name, "Data", "Images", "foreground_mask")
        self.normal_test = []

        self.class_name = class_name
        self.transform = transform
        self.mask_transform = mask_transform
        self.data = []
        self.size = size
        img_count = 0

        for filename in os.listdir(self.path_normal):
            if filename.lower().endswith(('.jpg', '.jpeg')):
                img_count += 1
                                
        for img_path in os.listdir(self.path_normal):
            image_dir = os.path.join(self.path_normal, img_path)
            foreground_mask_dir = os.path.join(self.foreground_mask_path, img_path)
            if not os.path.exists(foreground_mask_dir):
                self.normal_test.append(image_dir)
                img_count -= 1
                                
        if split == 'train':
            for img_path in os.listdir(self.path_normal):
                image_dir = (os.path.join(self.path_normal, img_path))
                if image_dir not in self.normal_test:
                    foreground_mask_dir = os.path.join(self.foreground_mask_path, img_path)
                    self.data.append((image_dir, None, foreground_mask_dir))
        elif split == 'test':
            for img_path in self.normal_test:
                self.data.append((os.path.join(self.path_normal, img_path), None, None)) 

            for img_path in os.listdir(self.path_anomaly):
                image_dir = os.path.join(self.path_anomaly, img_path)
                mask_dir = image_dir.replace("Images", "Masks")[:-3] + "png"
                self.data.append((image_dir, mask_dir, None)) 

            random.shuffle(self.data)            

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, mask_path, fore_path = self.data[idx]
        image = Image.open(img_path).convert('RGB')  

        image = self.transform(image)          

        if mask_path:
            mask = Image.open(mask_path).convert('RGB')
            mask = self.mask_transform(mask)
            mask = 1.0 - torch.all(mask == 0, dim=0).float()
            label = 1
        else:
            C, W, H = image.shape
            mask = torch.zeros((H, W))
            label = 0
            
        C, W, H = image.shape
        foreground_mask = torch.ones((H, W))
        if fore_path:
            foreground_mask = Image.open(fore_path).convert('L')
            foreground_mask = foreground_mask.resize((self.size, self.size), Image.NEAREST)
            foreground_mask = transforms.ToTensor()(foreground_mask)
            
        return image, label, mask, foreground_mask
    
class DTD(data.Dataset):
    def __init__(self, path, class_name, transform=None, mask_transform=None, seed=0, split='train', size=224):
        self.transform = transform
        self.mask_transform = mask_transform
        self.data = []
        self.size = size

        path = os.path.join(path, class_name)
        mv_str = '_mask.'

        normal_dir = os.path.join(path, split, "good")
            
        for img_file in os.listdir(normal_dir):
            image_dir = os.path.join(normal_dir, img_file)
            self.data.append((image_dir, None, None))
            
        if split == 'test':
            test_dir = os.path.join(path, "test")
            test_anomaly_dirs = []
            for entry in os.listdir(test_dir):
                full_path = os.path.join(test_dir, entry)

                if os.path.isdir(full_path) and full_path != normal_dir:
                    test_anomaly_dirs.append(full_path)

            for dir in test_anomaly_dirs:
                for img_file in os.listdir(dir):
                    image_dir = os.path.join(dir, img_file)
                    mask_dir = image_dir.replace("test", "ground_truth")
                    parts = mask_dir.rsplit('.', 1)
                    mask_dir = parts[0] + mv_str + parts[1]
                    self.data.append((image_dir, mask_dir, None))

            random.shuffle(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, mask_path, fore_path = self.data[idx]
        image = Image.open(img_path).convert('RGB')  

        image = self.transform(image)          

        if mask_path:
            mask = Image.open(mask_path).convert('RGB')
            mask = self.mask_transform(mask)
            mask = 1.0 - torch.all(mask == 0, dim=0).float()
            label = 1
        else:
            C, W, H = image.shape
            mask = torch.zeros((H, W))
            label = 0
            
        C, W, H = image.shape
        foreground_mask = torch.ones((H, W))
        if fore_path:
            foreground_mask = Image.open(fore_path).convert('L')
            foreground_mask = foreground_mask.resize((self.size, self.size), Image.NEAREST)
            foreground_mask = transforms.ToTensor()(foreground_mask)
            
        return image, label, mask, foreground_mask
    
class BraTS2021(data.Dataset):
    def __init__(self, path, transform=None, mask_transform=None, seed=0, split='train', size=224):
        self.transform = transform
        self.mask_transform = mask_transform
        self.data = []
        self.size = size

        normal_dir = os.path.join(path, split, "normal")
        
        if split == 'train':
            self.foreground_mask_path = os.path.join(path, split, "foreground_mask")
            os.makedirs(self.foreground_mask_path, exist_ok=True)
            
        for img_file in os.listdir(normal_dir):
            image_dir = os.path.join(normal_dir, img_file)
            foreground_mask_dir = None
            if split == 'train':
                foreground_mask_dir = os.path.join(self.foreground_mask_path, img_file)
            self.data.append((image_dir, None, foreground_mask_dir))
            
        if split == 'test':
            test_dir = os.path.join(path, "test")
            test_anomaly_dirs = []
            for entry in os.listdir(test_dir):
                full_path = os.path.join(test_dir, entry)

                if os.path.isdir(full_path) and full_path != normal_dir:
                    test_anomaly_dirs.append(full_path)

            for dir in test_anomaly_dirs:
                for img_file in os.listdir(dir):
                    image_dir = os.path.join(dir, img_file)
                    mask_dir = os.path.join(path, "ground_truth", img_file)
                    mask_dir = mask_dir.replace("flair", "seg")
                    self.data.append((image_dir, mask_dir, None))

            random.shuffle(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, mask_path, fore_path = self.data[idx]
        image = Image.open(img_path).convert('RGB')  

        image = self.transform(image)          

        if mask_path:
            mask = Image.open(mask_path).convert('RGB')
            mask = self.mask_transform(mask)
            mask = 1.0 - torch.all(mask == 0, dim=0).float()
            label = 1
        else:
            C, W, H = image.shape
            mask = torch.zeros((H, W))
            label = 0
            
        C, W, H = image.shape
        foreground_mask = torch.ones((H, W))
        if fore_path:
            foreground_mask = Image.open(fore_path).convert('L')
            foreground_mask = foreground_mask.resize((self.size, self.size), Image.NEAREST)
            foreground_mask = transforms.ToTensor()(foreground_mask)
            
        return image, label, mask, foreground_mask
    
class HeadCT(data.Dataset):
    def __init__(self, path, transform=None, mask_transform=None, seed=0, split='train', size=224):
        self.transform = transform
        self.mask_transform = mask_transform
        self.data = []
        self.size = size

        mv_str = '_mask.'

        normal_dir = os.path.join(path, split, "good")
        
        if split == 'train':
            self.foreground_mask_path = os.path.join(path, split, "foreground_mask")
            os.makedirs(self.foreground_mask_path, exist_ok=True)
            
        for img_file in os.listdir(normal_dir):
            image_dir = os.path.join(normal_dir, img_file)
            foreground_mask_dir = None
            if split == 'train':
                foreground_mask_dir = os.path.join(self.foreground_mask_path, img_file)
            self.data.append((image_dir, None, foreground_mask_dir))
            
        if split == 'test':
            test_dir = os.path.join(path, "test")
            test_anomaly_dirs = []
            for entry in os.listdir(test_dir):
                full_path = os.path.join(test_dir, entry)

                if os.path.isdir(full_path) and full_path != normal_dir:
                    test_anomaly_dirs.append(full_path)

            for dir in test_anomaly_dirs:
                for img_file in os.listdir(dir):
                    image_dir = os.path.join(dir, img_file)
                    mask_dir = image_dir.replace("test", "ground_truth")
                    parts = mask_dir.rsplit('.', 1)
                    mask_dir = parts[0] + mv_str + parts[1]
                    self.data.append((image_dir, mask_dir, None))

            random.shuffle(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, mask_path, fore_path = self.data[idx]
        image = Image.open(img_path).convert('RGB')  

        image = self.transform(image)          

        if mask_path:
            mask = Image.open(mask_path).convert('RGB')
            mask = self.mask_transform(mask)
            mask = 1.0 - torch.all(mask == 0, dim=0).float()
            label = 1
        else:
            C, W, H = image.shape
            mask = torch.zeros((H, W))
            label = 0
            
        C, W, H = image.shape
        foreground_mask = torch.ones((H, W))
        if fore_path:
            foreground_mask = Image.open(fore_path).convert('L')
            foreground_mask = foreground_mask.resize((self.size, self.size), Image.NEAREST)
            foreground_mask = transforms.ToTensor()(foreground_mask)
            
        return image, label, mask, foreground_mask
    
class WFDD(data.Dataset):
    def __init__(self, path, class_name, transform=None, mask_transform=None, seed=0, split='train', size=224):
        self.transform = transform
        self.mask_transform = mask_transform
        self.data = []
        self.size = size

        path = os.path.join(path, class_name)
        mv_str = '_mask.'

        normal_dir = os.path.join(path, split, "good")
            
        for img_file in os.listdir(normal_dir):
            image_dir = os.path.join(normal_dir, img_file)
            self.data.append((image_dir, None, None))
            
        if split == 'test':
            test_dir = os.path.join(path, "test")
            test_anomaly_dirs = []
            for entry in os.listdir(test_dir):
                full_path = os.path.join(test_dir, entry)

                if os.path.isdir(full_path) and full_path != normal_dir:
                    test_anomaly_dirs.append(full_path)

            for dir in test_anomaly_dirs:
                for img_file in os.listdir(dir):
                    image_dir = os.path.join(dir, img_file)
                    mask_dir = image_dir.replace("test", "ground_truth")
                    parts = mask_dir.rsplit('.', 1)
                    mask_dir = parts[0] + mv_str + parts[1]
                    self.data.append((image_dir, mask_dir, None))

            random.shuffle(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, mask_path, fore_path = self.data[idx]
        image = Image.open(img_path).convert('RGB')  

        image = self.transform(image)          

        if mask_path:
            mask = Image.open(mask_path).convert('RGB')
            mask = self.mask_transform(mask)
            mask = 1.0 - torch.all(mask == 0, dim=0).float()
            label = 1
        else:
            C, W, H = image.shape
            mask = torch.zeros((H, W))
            label = 0
            
        C, W, H = image.shape
        foreground_mask = torch.ones((H, W))
        if fore_path:
            foreground_mask = Image.open(fore_path).convert('L')
            foreground_mask = foreground_mask.resize((self.size, self.size), Image.NEAREST)
            foreground_mask = transforms.ToTensor()(foreground_mask)
            
        return image, label, mask, foreground_mask

# Utils

In [None]:
def patchify(x, patch_size):
    if len(x.shape) == 3:  # If single-channel image, add batch dimension
        x = x.unsqueeze(1)

    bs, c, h, w = x.shape

    unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size)
    x = unfold(x)  # Shape: (B, C * patch_size * patch_size, num_patches)

    num_patches = (h // patch_size) * (w // patch_size)
    x = x.view(bs, c, patch_size, patch_size, num_patches).permute(0, 4, 1, 2, 3)
    
    return x

def label_patch(x):
    labels = torch.any(x > 0, dim=(2, 3, 4)).float()
    return labels

def get_dataloader(image_size, path, dataset_name, class_name, batch_size, test_batch_size, num_workers, seed):
    transform = transforms.Compose([
                                    transforms.Resize((image_size, image_size), Image.LANCZOS), 
                                    transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.01), 
                                    transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.5))], p=0.2),
                                    transforms.ToTensor()
                                    ])
    mask_transform = transforms.Compose([transforms.Resize((image_size, image_size), Image.LANCZOS),transforms.ToTensor()])
    
    if dataset_name == 'mvtec':
        train_set = MVTec(path, class_name, transform=transform, mask_transform=mask_transform, seed=seed, split='train', size=image_size)
        test_set = MVTec(path, class_name, transform=transform, mask_transform=mask_transform, seed=seed, split='test', size=image_size)
    elif dataset_name == 'visa':
        train_set = VisA(path, class_name, transform=transform, mask_transform=mask_transform, seed=seed, split='train', size=image_size)
        test_set = VisA(path, class_name, transform=transform, mask_transform=mask_transform, seed=seed, split='test', size=image_size)
    elif dataset_name == 'mpdd':
        train_set = MPDD(path, class_name, transform=transform, mask_transform=mask_transform, seed=seed, split='train', size=image_size)
        test_set = MPDD(path, class_name, transform=transform, mask_transform=mask_transform, seed=seed, split='test', size=image_size)
    elif dataset_name == 'btad':
        train_set = BTAD(path, class_name, transform=transform, mask_transform=mask_transform, seed=seed, split='train', size=image_size)
        test_set = BTAD(path, class_name, transform=transform, mask_transform=mask_transform, seed=seed, split='test', size=image_size)
    elif dataset_name == "dtd":
        train_set = DTD(path, class_name, transform=transform, mask_transform=mask_transform, seed=seed, split='train', size=image_size)
        test_set = DTD(path, class_name, transform=transform, mask_transform=mask_transform, seed=seed, split='test', size=image_size)
    elif dataset_name == "brats2021":
        train_set = BraTS2021(path, transform=transform, mask_transform=mask_transform, seed=seed, split='train', size=image_size)
        test_set = BraTS2021(path, transform=transform, mask_transform=mask_transform, seed=seed, split='test', size=image_size)
    elif dataset_name == "headct":
        train_set = HeadCT(path, transform=transform, mask_transform=mask_transform, seed=seed, split='train', size=image_size)
        test_set = HeadCT(path, transform=transform, mask_transform=mask_transform, seed=seed, split='test', size=image_size)
    elif dataset_name == 'wfdd':
        train_set = WFDD(path, class_name, transform=transform, mask_transform=mask_transform, seed=seed, split='train', size=image_size)
        test_set = WFDD(path, class_name, transform=transform, mask_transform=mask_transform, seed=seed, split='test', size=image_size)

    train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers)
    test_loader = data.DataLoader(test_set, batch_size=test_batch_size, shuffle=False, num_workers=num_workers)

    print(f"Dataloaders for dataset {dataset_name} and class {class_name} have been prepared.")

    return train_loader, test_loader
    

def soft_topk_pooling(scores, k, temperature=1.0):
    topk_values, topk_indices = torch.topk(scores, k, dim=1)
    weights = torch.softmax(topk_values / temperature, dim=1)
    return (topk_values * weights).sum(dim=1)
    

def get_auc(test_scores, test_labels, test_masks, patches_per_side, sigma, radius, k):
    scores = torch.cat(test_scores, dim=0)

    topk_values, _ = torch.topk(scores, k, dim=1)
    pred_labels = torch.mean(topk_values, dim=1)
    image_labels = torch.cat(test_labels, dim=0)

    image_auroc = roc_auc_score(image_labels.view(-1).cpu().numpy(), pred_labels.view(-1).cpu().numpy())

    masks = torch.cat(test_masks, dim=0)
    patch_scores = scores.reshape(-1, patches_per_side, patches_per_side)
    pixel_scores = F.interpolate(patch_scores.unsqueeze(1), size=(masks.shape[-1], masks.shape[-1]), mode='bilinear', align_corners=False)
    localization = gaussian_filter(pixel_scores.squeeze(1).cpu().detach().numpy(), sigma=sigma, radius=radius, axes=(1,2))

    pixel_auroc = roc_auc_score(masks.view(-1).cpu().numpy(), localization.reshape(-1))

    return image_auroc, pixel_auroc


def save_model(model, filepath="./model.pth"):
    torch.save(model.state_dict(), filepath)
    print(f"Model saved to {filepath}")

def load_model(model, filepath="./model.pth"):
    try:
        model.load_state_dict(torch.load(filepath))
        print(f"Model loaded from {filepath}")
    except FileNotFoundError:
        print(f"Error: File '{filepath}' not found.")
        sys.exit(1)

def log_loss(epoch, loss, filepath="./loss_log.txt"):
    with open(filepath, "a") as f:
        f.write(f"Epoch {epoch} : {loss}\n")

def display_results(metrics_dict, description):
    console = Console()
    table = Table(title=f"{description}")

    table.add_column("Metric", style="cyan", justify="center")
    table.add_column("Value", style="magenta", justify="center")

    for metric, value in metrics_dict.items():
        table.add_row(f"{metric}", f"{value:.4f}")

    console.print(table)

# Vis

In [None]:
def image_transform(image):
     return np.clip(image* 255, 0, 255).astype(np.uint8)
    
def cvt2heatmap(gray):
    heatmap = cv2.applyColorMap(np.uint8(gray), cv2.COLORMAP_JET)
    return heatmap

def show_cam_on_image(img, anomaly_map):
    cam = np.float32(anomaly_map)/255 + np.float32(img)/255
    cam = cam / np.max(cam)
    return np.uint8(255 * cam) 

def min_max_norm(image):
    a_min, a_max = image.min(), image.max()
    return (image-a_min)/(a_max - a_min)
    
def get_heatmap(raw_image, segmentation):
    ano_map = gaussian_filter(segmentation, sigma=4)
    ano_map = min_max_norm(ano_map)
    ano_map = cvt2heatmap(ano_map * 255.0)
    raw_image = image_transform(raw_image.detach().cpu().numpy())
    image_cv2 = np.uint8(np.transpose(raw_image,(1,2,0)))
    ano_map = show_cam_on_image(image_cv2[..., ::-1], ano_map)
    ano_map = ano_map[..., ::-1]
    return ano_map

def transparent_cmap(cmap, N=255):
    mycmap = cmap
    mycmap._init()
    mycmap._lut[:,-1] = np.linspace(0, 0.8, N+4)
    return mycmap

def visualize_heatmap(args):
    device = torch.device("cuda" if args["device"] != "cpu" and torch.cuda.is_available() else "cpu")
    _, test_loader = get_dataloader(args["image_size"], args["dataset_dir"], args["dataset"], args["class_name"], args["train_batch_size"], args["test_batch_size"], args["num_workers"], args["seed"])

    model = PatchGuard(args, device)
    load_model(model, args["checkpoint_dir"]+f"patchguard_{args['dataset']}_{args['class_name']}.pth")    

    Path(f"./plots").mkdir(exist_ok=True, parents=True)
    plot_path = Path(f"./plots")

    Path(f"{plot_path}/clean_image").mkdir(exist_ok=True, parents=True)
    clean_image_path = Path(f"{plot_path}/clean_image")

    Path(f"{plot_path}/adv_image").mkdir(exist_ok=True, parents=True)
    adv_image_path = Path(f"{plot_path}/adv_image")
    
    Path(f"{plot_path}/clean_heatmap").mkdir(exist_ok=True, parents=True)
    clean_heatmap_path = Path(f"{plot_path}/clean_heatmap")
    
    Path(f"{plot_path}/adv_heatmap").mkdir(exist_ok=True, parents=True)
    adv_heatmap_path = Path(f"{plot_path}/adv_heatmap")

    Path(f"{plot_path}/mask").mkdir(exist_ok=True, parents=True)
    mask_path = Path(f"{plot_path}/mask")
    
    cmap = transparent_cmap(plt.cm.jet)
    
    with torch.no_grad():
        i = 0
        for images, _, masks, _ in test_loader:
            images, masks = images.to(device), masks.to(device)
            
            for mode in ["clean", "adv"]:
                image_path = clean_image_path if mode == "clean" else adv_image_path
                heatmap_path = clean_heatmap_path if mode == "clean" else adv_heatmap_path

                if mode == "adv":
                    with torch.set_grad_enabled(True):
                        images = pgd_attack(model, images, label_patch(patchify(masks, model.patch_size)), args["epsilon_visualization"], args["step_visualization"])
                        
                scores = model(images)
    
                batch_size, num_patches = scores.shape
                image_size = images.shape[-1]
                patches_per_side = int(np.sqrt(num_patches))

                j = i
                for b in range(batch_size):
                    patch_scores = scores[b].reshape((patches_per_side, patches_per_side))
                    scores_interpolated = F.interpolate(patch_scores.unsqueeze(0).unsqueeze(0),
                                                        size=image_size,
                                                        mode='bilinear',
                                                        align_corners=False
                                                        ).squeeze(0).squeeze(0)
    
                    segmentation = gaussian_filter(scores_interpolated.cpu().detach().numpy(), sigma=args["smoothing_sigma"], radius=args["smoothing_radius"])  
                    segmentation = gaussian_filter(segmentation, sigma=4)
                    segmentation = min_max_norm(segmentation)

                    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
                    ax.imshow(images[b].cpu().detach().permute(1, 2, 0).numpy())
                    ax.axis('off')
                    plt.savefig(os.path.join(image_path, f'img{j}.png'),  bbox_inches='tight', pad_inches=0, format='png')
                    plt.close(fig)
                    
                    # Save the mask
                    if mode == "clean":
                        fig, ax = plt.subplots(1, 1, figsize=(5, 5))
                        modified_image = images[b].clone()
                        modified_image[0, masks[b] > 0] = 1.0
                        modified_image[1, masks[b] > 0] = 0.0
                        modified_image[2, masks[b] > 0] = 0.0
    
                        ax.imshow(modified_image.cpu().detach().permute(1, 2, 0).numpy())
                        ax.axis('off')
                        plt.savefig(os.path.join(mask_path, f'img{j}.png'), bbox_inches='tight', pad_inches=0, format='png')
                        plt.close(fig)
                    
                    # Save the heatmap
                    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
                    ax.imshow(images[b].cpu().detach().permute(1, 2, 0).numpy())
                    ax.imshow(segmentation, cmap=cmap, interpolation='bilinear')
                    ax.axis('off')
                    plt.savefig(os.path.join(heatmap_path, f'img{j}.png'), bbox_inches='tight', pad_inches=0, format='png')
                    plt.close(fig)
    
                    j = j + 1

            i += args.test_batch_size
                
        print("Visualization complete.")
        shutil.make_archive(f'visualization', 'zip', plot_path)
        return

# Test

In [None]:
def test(model, test_loader, device, args, adv_test, epsilon=8/255, steps=10):
    model.eval()

    test_scores = []
    test_labels = []
    test_masks = []

    with torch.no_grad():
        for batch in test_loader:
            images, labels = batch[0].to(device), batch[1].to(device)
            if adv_test:
                masks = label_patch(patchify(batch[2], model.patch_size)).to(device)
                with torch.set_grad_enabled(True):
                    if args["attack_type"] == "PGD":
                        images = pgd_attack(model, images, masks, epsilon, steps)

            masks = batch[2].to(device)
            scores = model(images)

            test_scores.append(scores.cpu())
            test_labels.append(labels.cpu())
            test_masks.append(masks.cpu())

    image_auc, pixel_auc = get_auc(test_scores, test_labels, test_masks, model.patches_per_side, args["smoothing_sigma"], args["smoothing_radius"], args["top_k"])

    return image_auc, pixel_auc

def run_test(args):
    device = torch.device("cuda" if args["device"] != "cpu" and torch.cuda.is_available() else "cpu")
    model = PatchGuard(args, device).to(device)
    load_model(model, args["checkpoint_dir"]+f"patchguard_{args['dataset']}_{args['class_name']}.pth")
    _, test_loader = get_dataloader(args["image_size"], args["dataset_dir"], args["dataset"], args["class_name"], args["train_batch_size"], args["test_batch_size"], args["num_workers"], args["seed"])

    image_auc, pixel_auc = test(model, test_loader, device, args, False)
    display_results({"Image-level AUC":image_auc, "Pixel-level AUC":pixel_auc}, "Clean Performance")

    if args["adv_test"]:
        epsilons = args["epsilon_test"]
        step = args["step_test"]

        for epsilon in epsilons:
            image_auc, pixel_auc = test(model, test_loader, device, args, True, epsilon, step)
            display_results({"Image-level AUC":image_auc, "Pixel-level AUC":pixel_auc}, f"{args['attack_type']} attack (eps={epsilon}, step={step})")


# Train

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def train_step(model, anomaly_generator, train_loader, optimizer, criterion, use_reg, device, args):
    total_sample = 0
    total_loss = 0

    batch_iterator = tqdm(train_loader, disable=not args["use_tqdm"], desc="Training Batches")
    for batch in batch_iterator:
        loss = 0

        normal_data = [batch[0].to(device)]

        if args["adv_train"]:
            images = batch[0].to(device).clone()
            adv_normal_images = pgd_attack(model, images, torch.zeros(images.shape[0], model.num_patches).to(device), args["epsilon_train"], args["step_train"])
            normal_data.append(adv_normal_images)

        for imgs in normal_data:
            features, attn_weights = model.feature_extractor(imgs, use_reg)

            scores_true = model.discriminator(features)

            masks_true = torch.zeros(features.shape[0], features.shape[1]).to(device)
            loss += criterion(scores_true, masks_true, attn_weights)

        images = batch[0].clone()
        foreground_masks = batch[3]

        augmented_images, augmented_masks = anomaly_generator(images, foreground_masks)
        augmented_masks = label_patch(patchify(augmented_masks, model.patch_size))
        augmented_images, augmented_masks = augmented_images.to(device), augmented_masks.to(device)

        anomaly_data = [augmented_images]
        if args["adv_train"]:
            adv_distorted_images = pgd_attack(model, augmented_images.clone(), augmented_masks, args["epsilon_train"], args["step_train"])
            anomaly_data.append(adv_distorted_images)

        for imgs in anomaly_data:
            features, attn_weights = model.feature_extractor(imgs, use_reg)
            scores_aug = model.discriminator(features)

            loss += criterion(scores_aug, augmented_masks, attn_weights)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_sample += (len(anomaly_data) + len(normal_data)) * images.shape[0]
        total_loss += loss.item() * images.shape[0]
    
    return total_loss / total_sample


def train(model, anomaly_generator, train_loader, optimizer, lr_scheduler, criterion, use_reg, epochs, device, args):
    model.train()
    
    epoch_iterator = tqdm(range(epochs), disable=not args["use_tqdm"], desc="Epochs")
    for epoch in epoch_iterator:
        total_loss = train_step(model, anomaly_generator, train_loader, optimizer, criterion, use_reg, device, args)
        
        lr_scheduler.step()

        epoch_iterator.set_postfix(loss=total_loss)
        print(f"Current Epoch {epoch}, Current Loss {total_loss}")
        log_loss(epoch, total_loss)

def run_train(args):
    set_seed(args["seed"])

    device = torch.device("cuda" if args["device"] != "cpu" and torch.cuda.is_available() else "cpu")
    model = PatchGuard(args, device).to(device)

    train_loader, _ = get_dataloader(args["image_size"], args["dataset_dir"], args["dataset"], args["class_name"], args["train_batch_size"], args["test_batch_size"], args["num_workers"], args["seed"])
    
    optimizer = optim.AdamW(model.parameters(), lr=args["lr"])
    lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args["epochs"], eta_min=args["lr"] * args["lr_decay_factor"])
    
    criterion = Loss(args["reg_type"], device, model.num_patches)

    anomaly_generator = AnomalyGenerator(args["dataset"], args["class_name"], args["seed"])  

    train(model, anomaly_generator, train_loader, optimizer, lr_scheduler, criterion, args["use_reg"], args["epochs"], device, args)

    save_model(model, args["checkpoint_dir"]+f"patchguard_{args['dataset']}_{args['class_name']}.pth")

# Main

In [None]:
def main(args):
    if args["mode"] == "train":
        run_train(args)
    elif args["mode"] == "test":
        run_test(args)
    else:
        visualize_heatmap(args)

## Train

In [None]:
args = {
    "mode": "train",
    "seed": 0,
    "class_name": "",
    "dataset": "",
    "dataset_dir": "",
    "checkpoint_dir": "",
    "device": "cuda",

    "epochs": ,
    "train_batch_size": 16,
    "test_batch_size": 16,
    "lr": 0.0008,
    "lr_decay_factor": 0.0125,
    "lr_adaptor": 0.0001,
    "wd": 0.00001,
    "image_size": 224,
    "num_workers": 1,
    "use_tqdm": False,

    # Feature extractor config
    "hf_path": "vit_small_patch14_dinov2.lvd142m",
    "feature_layers": [12],
    "reg_layers": [6, 9, 12],

    # Discriminator config
    "hidden_dim": 2048,
    "dsc_layers": 1,
    "dsc_heads": 4,
    "top_k": 3,
    "smoothing_sigma": 6,
    "smoothing_radius": 7,

    # Adversarial attack config
    "attack_type": "PGD",
    "adv_train": True,
    "adv_test": True,
    "epsilon_train": 8/255,  # Normalize
    "epsilon_test": [4/255, 8/255],  # Normalize
    "epsilon_visualization": 8/255,  # Normalize
    "step_train": 10,
    "step_test": 10,
    "step_visualization": 10,

    # Regularizer config
    "use_reg": True,
    "reg_type": "KL_divergence",
}

main(args)


## Adversarial Attack

In [None]:
# attack and test
args["mode"] = "test"
main(args)