In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights, efficientnet_v2_s, EfficientNet_V2_S_Weights
from models.autoencoder import Decoder

In [2]:
def create_noise(true_feats, mix_noise=1, noise_std=0.1, device='cuda'):
    B, C, H, W = true_feats.shape
    # 1. 노이즈 인덱스 생성 (B 크기)
    noise_idxs = torch.randint(0, mix_noise, size=(B,))
    # 2. 원-핫 인코딩 (B, K)
    noise_one_hot = F.one_hot(noise_idxs, num_classes=mix_noise)

    # 3. 가우시안 노이즈 생성 (B, K, C, H, W)
    noise = torch.stack([
        torch.normal(0, noise_std * 1.1**k, size=(B, C, H, W))
        for k in range(mix_noise)
    ], dim=1)

    noise = (noise * noise_one_hot.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)).sum(dim=1)

    return noise.to(device)

In [3]:
class SimpleSAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.feature_extractor = efficientnet_v2_s(EfficientNet_V2_S_Weights.DEFAULT).features
        self.feature_adaptor = nn.Linear(1280, 1280)
        self.decoder = Decoder([1280, 640, 256, 128, 64, 4])
    
    def forward(self, x):
        
        with torch.no_grad():
            x = self.feature_extractor(x)
        x = x.permute(0, 2, 3, 1)
        x = self.feature_adaptor(x)
        x = x.permute(0, 3, 1, 2)
        x = self.decoder(x)
        mask = x[:,3]
        x = x[:,0:3]
        return x, mask
    
    def train_model(self, x):
        with torch.no_grad():
            x = self.feature_extractor(x)
        x = x.permute(0, 2, 3, 1)   # B, H, W, C
        x = self.feature_adaptor(x) 
        x = x.permute(0, 3, 1, 2)   # B, C, H, W
        noise = create_noise(x, device=self.device)
        x = x + noise
        x = self.decoder(x)
        mask = x[:,3]
        x = x[:,0:3]
        return x, mask

In [19]:
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Subset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from datasets.dataset import MvtecADDataset
from losses.ssim_loss import SSIM_Loss
from losses.gms_loss import MSGMS_Loss, MSGMS_Score
from utils.early_stopping import EarlyStopping
from utils.save import save_anomaly_map, plot_fig, save_model
from scipy.ndimage import gaussian_filter
from torchvision.utils import save_image
from eval.evaluate_experiment import *
import numpy as np
import random
import os
import json

class SimpleAD():
    def __init__(self, args):
        self.args = args
        if args.seed is None:
            args.seed = random.randint(1, 10000)
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(args.seed)
        train_data = MvtecADDataset(root_dir=f"mvtec_anomaly_detection_{args.img_size}", split="train", img_size=args.img_size)
        img_nums = len(train_data)
        valid_num = int(img_nums * 0.2)
        train_num = img_nums - valid_num
        train_dataset, val_dataset = torch.utils.data.random_split(train_data, [train_num, valid_num])

        self.train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        self.valid_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True)

        # 모델 학습 설정
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = SimpleSAE().to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=args.lr)
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='min', factor=0.5, patience=3, verbose=True)
        if args.prefix is None:
            self.save_name = f'{self.model.__class__.__name__}_{args.img_size}'
        else:
            self.save_name = f'{self.model.__class__.__name__}_{args.img_size}_{args.prefix}'
        
        os.makedirs(f'metrics/{self.save_name}', exist_ok=True)
        with open(os.path.join(f'metrics/{self.save_name}', 'model_training_log.txt'), 'w') as f:
            state = {k: v for k, v in args._get_kwargs()}
            f.write(str(state))
        
        # fetch fixed data for debugging
        x_normal_fixed, _, _, _, _ = next(iter(self.valid_loader))
        self.x_normal_fixed = x_normal_fixed.to(self.device)

        test_dataset = MvtecADDataset(root_dir=f"mvtec_anomaly_detection_{args.img_size}", split="test", img_size=args.img_size)
        random_indices = random.sample(range(len(test_dataset)), args.batch_size)
        random_subset = Subset(test_dataset, random_indices)
        random_loader = DataLoader(random_subset, batch_size=args.batch_size, shuffle=False)
        x_test_fixed, _, _, _, _ = next(iter(random_loader))
        self.x_test_fixed = x_test_fixed.to(self.device)   
        
    def train(self):
        # 학습 루프
        best_loss = 100000
        early_stopping = EarlyStopping(patience=10)

        for epoch in tqdm(range(self.args.epochs)):
            
            train_loss, train_l1_loss, train_l2_loss, train_gms_loss, train_ssim_loss = self._train()
            valid_loss = self._eval()
            
            if epoch % 10 == 9:
                save_sample = os.path.join(f'metrics/{self.save_name}', f'{epoch+1}-images.jpg')
                save_sample2 = os.path.join(f'metrics/{self.save_name}', f'{epoch+1}test-images.jpg')
                self.save_snapshot(self.x_normal_fixed, self.x_test_fixed, save_sample, save_sample2)
                
            self.scheduler.step(valid_loss / len(self.valid_loader))
            early_stopping(val_loss=valid_loss / len(self.valid_loader))
            best_loss = save_model(self.model, train_loss / len(self.train_loader), valid_loss / len(self.valid_loader), best_loss, epoch+1, self.save_name)
            
            
            print(f"Epoch [{epoch+1}/{self.args.epochs}], Train Loss: {train_loss / len(self.train_loader):.4f}, Valid Loss {valid_loss / len(self.valid_loader):.4f}")
            print(f'Train L1_Loss: {train_l1_loss / len(self.train_loader) * self.args.delta:.6f} L2_Loss: {train_l2_loss / len(self.train_loader)* self.args.gamma:.6f} GMS_Loss: {train_gms_loss / len(self.train_loader)* self.args.alpha:.6f} SSIM_Loss: {train_ssim_loss / len(self.train_loader)* self.args.beta:.6f}')
            if early_stopping.early_stop:
                print("Early stopping triggered")
                break
            
    def _train(self):
        self.model.train()
        ssim = SSIM_Loss()
        mse = nn.MSELoss()
        msgms = MSGMS_Loss()
        l1 = nn.L1Loss()
        train_loss = 0
        train_l1_loss = 0
        train_l2_loss = 0
        train_gms_loss = 0
        train_ssim_loss = 0
        for images, _, _, _, _ in tqdm(self.train_loader):
            if torch.isnan(images).any():
                print("NaN detected in input images")
                continue  # NaN이 포함된 이미지는 건너뛰기
            
            if torch.isinf(images).any():
                print("Inf detected in input images")
                continue  # Inf가 포함된 이미지는 건너뛰기
            
            images = images.to(self.device)
            self.optimizer.zero_grad()
            outputs, mask = self.model.train_model(images)
            l1_loss = l1(images, outputs)
            l2_loss = mse(images, outputs)
            gms_loss = msgms(images, outputs)
            ssim_loss = ssim(images, outputs)
            loss = l1_loss * self.args.delta + self.args.gamma * l2_loss + self.args.alpha * gms_loss + self.args.beta * ssim_loss
            
            train_loss += loss.item()
            train_l1_loss += l1_loss.item()
            train_l2_loss += l2_loss.item()
            train_gms_loss += gms_loss.item()
            train_ssim_loss += ssim_loss.item()

            loss.backward()
            self.optimizer.step()
            
        return train_loss, train_l1_loss, train_l2_loss, train_gms_loss, train_ssim_loss
    
    def _eval(self):
        self.model.eval()
        ssim = SSIM_Loss()
        mse = nn.MSELoss()
        msgms = MSGMS_Loss()
        l1 = nn.L1Loss()
        valid_l1_loss = 0
        valid_l2_loss = 0
        valid_gms_loss = 0
        valid_ssim_loss = 0
        valid_loss = 0
        with torch.no_grad():
            for images, _, _, _, _ in tqdm(self.valid_loader):
                images = images.to(self.device)
                outputs, mask = self.model.train_model(images)

                l1_loss = l1(images, outputs)
                l2_loss = mse(images, outputs)
                gms_loss = msgms(images, outputs)
                ssim_loss = ssim(images, outputs)
                loss = self.args.delta * l1_loss + self.args.gamma * l2_loss + self.args.alpha * gms_loss + self.args.beta * ssim_loss

                valid_l1_loss += l1_loss.item()
                valid_l2_loss += l2_loss.item()
                valid_gms_loss += gms_loss.item()
                valid_ssim_loss += ssim_loss.item()
                valid_loss += loss.item()
        return valid_loss
          
    def _test(self, test_loader, root_anomaly_map_dir):
        msgms_score = MSGMS_Score()
        scores = []
        test_imgs = []
        gt_list = []
        gt_mask_list = []
        recon_imgs = []
        self.model.eval()
        with torch.no_grad():
            for images, masks, labels, _, image_paths in tqdm(test_loader):
                score = 0
                images = images.to(self.device)
                test_imgs.extend(images.cpu().numpy())
                gt_list.extend(labels.cpu().numpy())
                gt_mask_list.extend(masks.cpu().numpy())
                outputs, mask = self.model(images)
                score = msgms_score(images, outputs)
                # score = F.mse_loss(images, outputs, reduction='none').mean(dim=1)
                score = score.squeeze().cpu().numpy()
                
                for i in range(score.shape[0]):
                    score[i] = gaussian_filter(score[i], sigma=7)

                scores.extend(score)
                recon_imgs.extend(outputs.cpu().numpy())
                
                # 배치의 각 이미지에 대해 anomaly map 저장 
                for i in range(images.size(0)):
                    image_path = image_paths[i]
                    anomaly_map = score[i]
                    save_anomaly_map(anomaly_map, image_path, root_anomaly_map_dir, img_size=self.args.img_size)
                    
        return scores, test_imgs, recon_imgs, gt_list, gt_mask_list
    
    def test(self, evaluated_objects, pro_integration_limit=0.3):
        
        assert 0.0 < pro_integration_limit <= 1.0
        root_anomaly_map_dir=f'anomaly_maps/{self.save_name}'
        output_dir=f'metrics/{self.save_name}'
        evaluation_dict = dict()
        # Keep track of the mean performance measures.
        au_pros = []
        au_rocs = []
        
        p_acs = []
        p_prs = []
        p_res = []
        p_f1s = []
        i_acs = []
        i_prs = []
        i_res = []
        i_f1s = []
        

        # Evaluate each dataset object separately.
        for obj in evaluated_objects:
            print(f"=== Evaluate {obj} ===")
            evaluation_dict[obj] = dict()
            
            test_dataset = MvtecADDataset(root_dir=f"mvtec_anomaly_detection_{self.args.img_size}", split="test", img_size=self.args.img_size, object_names=[obj])
            test_loader = DataLoader(test_dataset, batch_size=self.args.batch_size, shuffle=False)
            scores, test_imgs, recon_imgs, gt_list, gt_mask_list = self._test(test_loader=test_loader, root_anomaly_map_dir=root_anomaly_map_dir)
            scores = np.asarray(scores)

            # Calculate the PRO and ROC curves.
            au_pro, au_roc, pro_curve, roc_curve, pixel_level_metrics, image_level_metrics = \
                calculate_metrics(
                    np.asanyarray(gt_mask_list).squeeze(axis=1),
                    scores,
                    pro_integration_limit)
                
            threshold = pixel_level_metrics['threshold']
            save_dir = f'metrics/{self.save_name}/pictures_{obj}'
            os.makedirs(save_dir, exist_ok=True)
            plot_fig(test_img=test_imgs, recon_imgs=recon_imgs, scores=scores, gts=gt_mask_list, threshold=threshold, save_dir=save_dir)
            
            evaluation_dict[obj]['au_pro'] = au_pro
            evaluation_dict[obj]['classification_au_roc'] = au_roc
            evaluation_dict[obj]['pixel_level_accuracy'] = pixel_level_metrics['accuracy']
            evaluation_dict[obj]['pixel_level_precision'] = pixel_level_metrics['precision']
            evaluation_dict[obj]['pixel_level_recall'] = pixel_level_metrics['recall']
            evaluation_dict[obj]['pixel_level_f1_score'] = pixel_level_metrics['f1']
            evaluation_dict[obj]['image_level_accuracy'] = image_level_metrics['accuracy']
            evaluation_dict[obj]['image_level_precision'] = image_level_metrics['precision']
            evaluation_dict[obj]['image_level_recall'] = image_level_metrics['recall']
            evaluation_dict[obj]['image_level_f1_score'] = image_level_metrics['f1']
            

            evaluation_dict[obj]['classification_roc_curve_fpr'] = roc_curve[0]
            evaluation_dict[obj]['classification_roc_curve_tpr'] = roc_curve[1]

            # Keep track of the mean performance measures.
            au_pros.append(au_pro)
            au_rocs.append(au_roc)
            p_acs.append(pixel_level_metrics['accuracy'])
            p_prs.append(pixel_level_metrics['precision'])
            p_res.append(pixel_level_metrics['recall'])
            p_f1s.append(pixel_level_metrics['f1'])
            i_acs.append(image_level_metrics['accuracy'])
            i_prs.append(image_level_metrics['precision'])
            i_res.append(image_level_metrics['recall'])
            i_f1s.append(image_level_metrics['f1'])

            print('\n')

        # Compute the mean of the performance measures.
        evaluation_dict['mean_au_pro'] = np.mean(au_pros).item()
        evaluation_dict['mean_classification_au_roc'] = np.mean(au_rocs).item()
        
        evaluation_dict['mean_pixel_level_accuracy'] = np.mean(p_acs).item()
        evaluation_dict['mean_pixel_level_precision'] = np.mean(p_prs).item()
        evaluation_dict['mean_pixel_level_recall'] = np.mean(p_res).item()
        evaluation_dict['mean_pixel_level_f1_score'] = np.mean(p_f1s).item()
        evaluation_dict['mean_image_level_accuracy'] = np.mean(i_acs).item()
        evaluation_dict['mean_image_level_precision'] = np.mean(i_prs).item()
        evaluation_dict['mean_image_level_recall'] = np.mean(i_res).item()
        evaluation_dict['mean_image_level_f1_score'] = np.mean(i_f1s).item()

        # If required, write evaluation metrics to drive.
        if output_dir is not None:
            makedirs(output_dir, exist_ok=True)

            with open(path.join(output_dir, 'metrics.json'), 'w') as file:
                json.dump(evaluation_dict, file, indent=4)

            print(f"Wrote metrics to {path.join(output_dir, 'metrics.json')}")
    
    def load_model(self, state_dict_path):
        self.model.load_state_dict(torch.load(state_dict_path, weights_only=True))
    
    def save_snapshot(self, x, x2, save_dir, save_dir2):
        self.model.eval()
        with torch.no_grad():
            x_fake_list = x
            recon, _ = self.model(x)
            x_concat = torch.cat((x_fake_list, recon), dim=3)
            save_image((x_concat.data.cpu()), save_dir, nrow=1, padding=0)
            print(('Saved real and fake images into {}...'.format(save_dir)))

            x_fake_list = x2
            recon, _ = self.model(x2)
            x_concat = torch.cat((x_fake_list, recon), dim=3)
            save_image((x_concat.data.cpu()), save_dir2, nrow=1, padding=0)
            print(('Saved real and fake images into {}...'.format(save_dir2)))

In [20]:
import argparse
import sys
sys.argv = ['script_name', '--epochs', '100',]
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--prefix', type=str, default=None)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--img_size', type=int, default=224)
parser.add_argument('--alpha', type=float, default=1.0)
parser.add_argument('--beta', type=float, default=1.0)
parser.add_argument('--gamma', type=float, default=1.0)
parser.add_argument('--delta', type=float, default=0.0)
parser.add_argument('--seed', type=int, default=None, help='manual seed')
args = parser.parse_args()

exp = SimpleAD(args=args)



In [21]:
OBJECT_NAMES = ['bottle', 'cable', 'capsule', 'carpet', 'grid',
                'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
                'tile', 'toothbrush', 'transistor', 'wood', 'zipper']
exp.load_model('save/SimpleSAE_224')
exp.test(evaluated_objects=OBJECT_NAMES, pro_integration_limit=0.3)

=== Evaluate bottle ===


  0%|          | 0/6 [00:00<?, ?it/s]

Compute PRO curve...
Sort 4164608 anomaly scores...
AU-PRO (FPR limit: 0.3): 0.5983889377326753
Threshold: 0.1865
Pixel-level Accuracy: 0.9074
Pixel-level Precision: 0.2911
Pixel-level Recall: 0.4204
Pixel-level F1 Score: 0.3440
Image-level classification AU-ROC: 0.9285714285714286
Threshold: 0.4626
Image-level Accuracy: 0.5783
Image-level Precision: 1.0000
Image-level Recall: 0.4444
Image-level F1 Score: 0.6154


=== Evaluate cable ===


  0%|          | 0/10 [00:00<?, ?it/s]

Compute PRO curve...
Sort 7526400 anomaly scores...
AU-PRO (FPR limit: 0.3): 0.4097633831878087
Threshold: 0.2929
Pixel-level Accuracy: 0.9428
Pixel-level Precision: 0.1910
Pixel-level Recall: 0.3062
Pixel-level F1 Score: 0.2353
Image-level classification AU-ROC: 0.7358508245877061
Threshold: 0.3888
Image-level Accuracy: 0.6533
Image-level Precision: 0.7500
Image-level Recall: 0.6522
Image-level F1 Score: 0.6977


=== Evaluate capsule ===


  0%|          | 0/9 [00:00<?, ?it/s]

Compute PRO curve...
Sort 6623232 anomaly scores...
AU-PRO (FPR limit: 0.3): 0.6687455951885157
Threshold: 0.2073
Pixel-level Accuracy: 0.9832
Pixel-level Precision: 0.1992
Pixel-level Recall: 0.2776
Pixel-level F1 Score: 0.2319
Image-level classification AU-ROC: 0.566414040686079
Threshold: 0.2310
Image-level Accuracy: 0.5682
Image-level Precision: 0.8611
Image-level Recall: 0.5688
Image-level F1 Score: 0.6851


=== Evaluate carpet ===


  0%|          | 0/8 [00:00<?, ?it/s]

Compute PRO curve...
Sort 5870592 anomaly scores...
AU-PRO (FPR limit: 0.3): 0.7098573915978729
Threshold: 0.1732
Pixel-level Accuracy: 0.9541
Pixel-level Precision: 0.2039
Pixel-level Recall: 0.6408
Pixel-level F1 Score: 0.3094
Image-level classification AU-ROC: 0.8611556982343499
Threshold: 0.2349
Image-level Accuracy: 0.6496
Image-level Precision: 1.0000
Image-level Recall: 0.5393
Image-level F1 Score: 0.7007


=== Evaluate grid ===


  0%|          | 0/5 [00:00<?, ?it/s]

Compute PRO curve...
Sort 3913728 anomaly scores...
AU-PRO (FPR limit: 0.3): 0.2909166868862982
Threshold: 0.2366
Pixel-level Accuracy: 0.9829
Pixel-level Precision: 0.0233
Pixel-level Recall: 0.0361
Pixel-level F1 Score: 0.0283
Image-level classification AU-ROC: 0.7167919799498746
Threshold: 0.2824
Image-level Accuracy: 0.3590
Image-level Precision: 0.8889
Image-level Recall: 0.1404
Image-level F1 Score: 0.2424


=== Evaluate hazelnut ===


  0%|          | 0/7 [00:00<?, ?it/s]

Compute PRO curve...
Sort 5519360 anomaly scores...
AU-PRO (FPR limit: 0.3): 0.7818229292800275
Threshold: 0.2937
Pixel-level Accuracy: 0.9812
Pixel-level Precision: 0.5661
Pixel-level Recall: 0.5107
Pixel-level F1 Score: 0.5370
Image-level classification AU-ROC: 0.827857142857143
Threshold: 0.2955
Image-level Accuracy: 0.7273
Image-level Precision: 0.8125
Image-level Recall: 0.7429
Image-level F1 Score: 0.7761


=== Evaluate leather ===


  0%|          | 0/8 [00:00<?, ?it/s]

Compute PRO curve...
Sort 6221824 anomaly scores...
AU-PRO (FPR limit: 0.3): 0.9007491215650754
Threshold: 0.3000
Pixel-level Accuracy: 0.9905
Pixel-level Precision: 0.3207
Pixel-level Recall: 0.4172
Pixel-level F1 Score: 0.3626
Image-level classification AU-ROC: 0.9402173913043478
Threshold: 0.3613
Image-level Accuracy: 0.5323
Image-level Precision: 1.0000
Image-level Recall: 0.3696
Image-level F1 Score: 0.5397


=== Evaluate metal_nut ===


  0%|          | 0/8 [00:00<?, ?it/s]

Compute PRO curve...
Sort 5770240 anomaly scores...
AU-PRO (FPR limit: 0.3): 0.515807654773548
Threshold: 0.1885
Pixel-level Accuracy: 0.8641
Pixel-level Precision: 0.4444
Pixel-level Recall: 0.6377
Pixel-level F1 Score: 0.5238
Image-level classification AU-ROC: 0.7385141739980449
Threshold: 0.4124
Image-level Accuracy: 0.5217
Image-level Precision: 0.9318
Image-level Recall: 0.4409
Image-level F1 Score: 0.5985


=== Evaluate pill ===


  0%|          | 0/11 [00:00<?, ?it/s]

Compute PRO curve...
Sort 8379392 anomaly scores...
AU-PRO (FPR limit: 0.3): 0.7721693929672467
Threshold: 0.1481
Pixel-level Accuracy: 0.9610
Pixel-level Precision: 0.4206
Pixel-level Recall: 0.4253
Pixel-level F1 Score: 0.4229
Image-level classification AU-ROC: 0.640480087288598
Threshold: 0.2459
Image-level Accuracy: 0.5808
Image-level Precision: 0.8989
Image-level Recall: 0.5674
Image-level F1 Score: 0.6957


=== Evaluate screw ===


  0%|          | 0/10 [00:00<?, ?it/s]

Compute PRO curve...
Sort 8028160 anomaly scores...
AU-PRO (FPR limit: 0.3): 0.8262349472277485
Threshold: 0.2026
Pixel-level Accuracy: 0.9809
Pixel-level Precision: 0.0393
Pixel-level Recall: 0.2844
Pixel-level F1 Score: 0.0691
Image-level classification AU-ROC: 0.6394753023160483
Threshold: 0.3184
Image-level Accuracy: 0.3812
Image-level Precision: 0.8125
Image-level Recall: 0.2185
Image-level F1 Score: 0.3444


=== Evaluate tile ===


  0%|          | 0/8 [00:00<?, ?it/s]

Compute PRO curve...
Sort 5870592 anomaly scores...
AU-PRO (FPR limit: 0.3): 0.515360239189716
Threshold: 0.1771
Pixel-level Accuracy: 0.8528
Pixel-level Precision: 0.2080
Pixel-level Recall: 0.3889
Pixel-level F1 Score: 0.2711
Image-level classification AU-ROC: 0.9502164502164502
Threshold: 0.3368
Image-level Accuracy: 0.6496
Image-level Precision: 1.0000
Image-level Recall: 0.5119
Image-level F1 Score: 0.6772


=== Evaluate toothbrush ===


  0%|          | 0/3 [00:00<?, ?it/s]

Compute PRO curve...
Sort 2107392 anomaly scores...
AU-PRO (FPR limit: 0.3): 0.8118730503070536
Threshold: 0.1467
Pixel-level Accuracy: 0.9672
Pixel-level Precision: 0.2801
Pixel-level Recall: 0.7233
Pixel-level F1 Score: 0.4039
Image-level classification AU-ROC: 0.9944444444444445
Threshold: 0.3268
Image-level Accuracy: 0.5952
Image-level Precision: 1.0000
Image-level Recall: 0.4333
Image-level F1 Score: 0.6047


=== Evaluate transistor ===


  0%|          | 0/7 [00:00<?, ?it/s]

Compute PRO curve...
Sort 5017600 anomaly scores...
AU-PRO (FPR limit: 0.3): 0.495110008199993
Threshold: 0.2666
Pixel-level Accuracy: 0.9041
Pixel-level Precision: 0.2426
Pixel-level Recall: 0.4717
Pixel-level F1 Score: 0.3204
Image-level classification AU-ROC: 0.8366666666666667
Threshold: 0.3432
Image-level Accuracy: 0.7200
Image-level Precision: 0.6111
Image-level Recall: 0.8250
Image-level F1 Score: 0.7021


=== Evaluate wood ===


  0%|          | 0/5 [00:00<?, ?it/s]

Compute PRO curve...
Sort 3963904 anomaly scores...
AU-PRO (FPR limit: 0.3): 0.45907605542078955
Threshold: 0.1915
Pixel-level Accuracy: 0.9360
Pixel-level Precision: 0.2174
Pixel-level Recall: 0.2520
Pixel-level F1 Score: 0.2334
Image-level classification AU-ROC: 0.9298245614035088
Threshold: 0.2729
Image-level Accuracy: 0.5570
Image-level Precision: 1.0000
Image-level Recall: 0.4167
Image-level F1 Score: 0.5882


=== Evaluate zipper ===


  0%|          | 0/10 [00:00<?, ?it/s]

Compute PRO curve...
Sort 7576576 anomaly scores...
AU-PRO (FPR limit: 0.3): 0.34261895809522647
Threshold: 0.0711
Pixel-level Accuracy: 0.7266
Pixel-level Precision: 0.0506
Pixel-level Recall: 0.6877
Pixel-level F1 Score: 0.0943
Image-level classification AU-ROC: 0.6932773109243697
Threshold: 0.3019
Image-level Accuracy: 0.6291
Image-level Precision: 0.9091
Image-level Recall: 0.5882
Image-level F1 Score: 0.7143


Wrote metrics to metrics/SimpleSAE_224/metrics.json
