In [None]:
import sys
sys.path.append('../')
from vgs.models.modules import FCNet
from vgs.models.modules import AE_energy
import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset

In [3]:
encoder = FCNet(
    in_dim = 272,
    out_dim = 128,
    l_hidden = [1024, 1024, 1024],
    activation = 'relu',
    out_activation = 'linear'
)
decoder = FCNet(
    in_dim = 128,
    out_dim = 272,
    l_hidden = [1024, 1024, 1024],
    activation = 'relu',
    out_activation = 'linear',
)

device = torch.device('cuda:0')

energy = AE_energy(  
    encoder = encoder,
    decoder = decoder,
    tau = 0.1, # Entropy regularization
    learn_out_scale=True
).to(device)

In [6]:
val_data = torch.load('../datasets/ebm_exp/val_mvtec.pth', weights_only=False)
X_test = val_data['feature_align']
X_test = torch.tensor(X_test.reshape(len(X_test), 272, -1))
X_test = X_test / X_test.norm(dim=1, keepdim=True)
y_test = torch.tensor(val_data['label'])
clsname_test = np.array(val_data['clsname'])
mask_test = val_data['mask']

batchsize = 128
test_dataset = TensorDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=batchsize, shuffle= False, num_workers=4, pin_memory=True)

In [None]:
from myutils.ebm_utils import *

### TODO: Fill in the path to the energy checkpoints 
energy_ckpts = [] # Example: [f'../results/mvtec/test_{i}/energy.pth' for i in range 10]

result_path = '../mvtec_result.txt'
d_cls_results = {}
d_loc_results = {}
for i, ckpt in enumerate(energy_ckpts):
    energy_dict = torch.load(ckpt)
    energy.load_state_dict(energy_dict["state_dict"])
    energy.eval()
    pred_y, pred_y_mask = predict(energy, test_loader, device)
    auc = roc_auc_score(y_test, pred_y)    
    in_pred = pred_y[y_test == 0].numpy().mean()
    out_pred = pred_y[y_test == 1].numpy().mean()
    d_cls_auc = compute_classwise_auc(pred_y, y_test, clsname_test)
    d_loc_auc = compute_classwise_localization_auc(pred_y_mask, mask_test, clsname_test)

    for k, v in d_cls_auc.items():
        if k not in d_cls_results:
            d_cls_results[k] = []
        d_cls_results[k].append(v)
    for k, v in d_loc_auc.items():
        if k not in d_loc_results:
            d_loc_results[k] = []
        d_loc_results[k].append(v)

with open(result_path, 'w') as f:
    f.write('Classification AUC\n')
    f.write('Class, Mean, Std\n')
for k, v in d_cls_results.items():
    with open(result_path, 'a') as f:
        f.write(f'{k}, {100*np.mean(v):.1f}, {100*np.std(v):.2f}\n')

with open(result_path, 'a') as f:
    f.write('\nLocalization AUC\n')
    f.write('Class, Mean, Std\n')
for k, v in d_loc_results.items():
    with open(result_path, 'a') as f:
        f.write(f'{k}, {100*np.mean(v):.1f}, {100*np.std(v):.2f}\n')