In [1]:
import sys
sys.path.append('../')
from vgs.models.modules import FCNet
from vgs.models.modules import AE_energy
import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset

In [2]:
encoder = FCNet(
    in_dim = 272,
    out_dim = 128,
    l_hidden = [1024, 1024, 1024],
    activation = 'relu',
    out_activation = 'linear'
)
decoder = FCNet(
    in_dim = 128,
    out_dim = 272,
    l_hidden = [1024, 1024, 1024],
    activation = 'relu',
    out_activation = 'linear',
)

device = torch.device('cuda:0')

energy = AE_energy(  
    encoder = encoder,
    decoder = decoder,
    tau = 0.1, # Entropy regularization
    learn_out_scale=True
).to(device)

In [3]:
val_data = torch.load('../datasets/ebm_exp/val_mvtec.pth', weights_only=False)
X_test = val_data['feature_align']
X_test = torch.tensor(X_test.reshape(len(X_test), 272, -1))
X_test = X_test / X_test.norm(dim=1, keepdim=True)
y_test = torch.tensor(val_data['label'])
clsname_test = np.array(val_data['clsname'])
mask_test = val_data['mask']

batchsize = 128
test_dataset = TensorDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=batchsize, shuffle= False, num_workers=4, pin_memory=True)

In [None]:
from myutils.ebm_utils import *

### TODO: Fill in the path to the energy checkpoints 
energy_ckpts = [] # Example: [f'../results/mvtec/seed_{i}/energy.pth' for i in range(10)]

result_path = '../mvtec_result.txt'
d_cls_results = {}
d_loc_results = {}
for i, ckpt in enumerate(energy_ckpts):
    energy_dict = torch.load(ckpt)
    energy.load_state_dict(energy_dict["state_dict"])
    energy.eval()
    pred_y, pred_y_mask = predict(energy, test_loader, device)
    auc = roc_auc_score(y_test, pred_y)    
    in_pred = pred_y[y_test == 0].numpy().mean()
    out_pred = pred_y[y_test == 1].numpy().mean()
    d_cls_auc = compute_classwise_auc(pred_y, y_test, clsname_test)
    d_loc_auc = compute_classwise_localization_auc(pred_y_mask, mask_test, clsname_test)

    for k, v in d_cls_auc.items():
        if k not in d_cls_results:
            d_cls_results[k] = []
        d_cls_results[k].append(v)
    for k, v in d_loc_auc.items():
        if k not in d_loc_results:
            d_loc_results[k] = []
        d_loc_results[k].append(v)

with open(result_path, 'w') as f:
    f.write('Classification AUC\n')
    f.write('Class, Mean, Std\n')
for k, v in d_cls_results.items():
    with open(result_path, 'a') as f:
        f.write(f'{k}, {100*np.mean(v):.1f}, {100*np.std(v):.2f}\n')

with open(result_path, 'a') as f:
    f.write('\nLocalization AUC\n')
    f.write('Class, Mean, Std\n')
for k, v in d_loc_results.items():
    with open(result_path, 'a') as f:
        f.write(f'{k}, {100*np.mean(v):.1f}, {100*np.std(v):.2f}\n')

  energy_dict = torch.load(ckpt)


bottle: 1.0000
cable: 0.9754
capsule: 0.9087
carpet: 0.9976
grid: 0.9866
hazelnut: 0.9996
leather: 1.0000
metal_nut: 0.9985
pill: 0.9564
screw: 0.8990
tile: 0.9996
toothbrush: 0.9139
transistor: 0.9908
wood: 0.9798
zipper: 0.9659
mean: 0.9715
bottle: 0.9843
cable: 0.9684
capsule: 0.9855
carpet: 0.9875
grid: 0.9704
hazelnut: 0.9833
leather: 0.9849
metal_nut: 0.9565
pill: 0.9584
screw: 0.9875
tile: 0.9496
toothbrush: 0.9881
transistor: 0.9646
wood: 0.9378
zipper: 0.9677
mean: 0.9716


  energy_dict = torch.load(ckpt)


bottle: 1.0000
cable: 0.9786
capsule: 0.9130
carpet: 0.9968
grid: 0.9850
hazelnut: 1.0000
leather: 1.0000
metal_nut: 0.9995
pill: 0.9498
screw: 0.9057
tile: 0.9996
toothbrush: 0.9167
transistor: 0.9946
wood: 0.9851
zipper: 0.9627
mean: 0.9725
bottle: 0.9853
cable: 0.9679
capsule: 0.9853
carpet: 0.9881
grid: 0.9701
hazelnut: 0.9843
leather: 0.9855
metal_nut: 0.9572
pill: 0.9574
screw: 0.9871
tile: 0.9538
toothbrush: 0.9879
transistor: 0.9635
wood: 0.9368
zipper: 0.9677
mean: 0.9719


  energy_dict = torch.load(ckpt)


bottle: 1.0000
cable: 0.9726
capsule: 0.9194
carpet: 0.9980
grid: 0.9841
hazelnut: 0.9996
leather: 1.0000
metal_nut: 0.9980
pill: 0.9465
screw: 0.9059
tile: 1.0000
toothbrush: 0.9056
transistor: 0.9921
wood: 0.9842
zipper: 0.9551
mean: 0.9708
bottle: 0.9850
cable: 0.9676
capsule: 0.9854
carpet: 0.9889
grid: 0.9709
hazelnut: 0.9835
leather: 0.9867
metal_nut: 0.9571
pill: 0.9578
screw: 0.9871
tile: 0.9523
toothbrush: 0.9881
transistor: 0.9637
wood: 0.9390
zipper: 0.9667
mean: 0.9720


  energy_dict = torch.load(ckpt)


bottle: 1.0000
cable: 0.9711
capsule: 0.9122
carpet: 0.9976
grid: 0.9825
hazelnut: 1.0000
leather: 1.0000
metal_nut: 0.9990
pill: 0.9624
screw: 0.8908
tile: 1.0000
toothbrush: 0.9139
transistor: 0.9958
wood: 0.9825
zipper: 0.9653
mean: 0.9715
bottle: 0.9845
cable: 0.9680
capsule: 0.9851
carpet: 0.9878
grid: 0.9706
hazelnut: 0.9841
leather: 0.9869
metal_nut: 0.9566
pill: 0.9579
screw: 0.9870
tile: 0.9531
toothbrush: 0.9876
transistor: 0.9634
wood: 0.9381
zipper: 0.9674
mean: 0.9719


  energy_dict = torch.load(ckpt)


bottle: 1.0000
cable: 0.9730
capsule: 0.9063
carpet: 0.9992
grid: 0.9891
hazelnut: 0.9993
leather: 1.0000
metal_nut: 0.9990
pill: 0.9572
screw: 0.9045
tile: 1.0000
toothbrush: 0.9139
transistor: 0.9925
wood: 0.9921
zipper: 0.9688
mean: 0.9730
bottle: 0.9846
cable: 0.9680
capsule: 0.9854
carpet: 0.9876
grid: 0.9707
hazelnut: 0.9837
leather: 0.9859
metal_nut: 0.9570
pill: 0.9573
screw: 0.9865
tile: 0.9528
toothbrush: 0.9880
transistor: 0.9642
wood: 0.9389
zipper: 0.9692
mean: 0.9720


  energy_dict = torch.load(ckpt)


bottle: 1.0000
cable: 0.9734
capsule: 0.9178
carpet: 0.9988
grid: 0.9866
hazelnut: 1.0000
leather: 1.0000
metal_nut: 0.9995
pill: 0.9607
screw: 0.9067
tile: 0.9996
toothbrush: 0.8750
transistor: 0.9921
wood: 0.9825
zipper: 0.9653
mean: 0.9705
bottle: 0.9849
cable: 0.9674
capsule: 0.9852
carpet: 0.9883
grid: 0.9699
hazelnut: 0.9831
leather: 0.9868
metal_nut: 0.9566
pill: 0.9577
screw: 0.9874
tile: 0.9532
toothbrush: 0.9884
transistor: 0.9640
wood: 0.9381
zipper: 0.9695
mean: 0.9720


  energy_dict = torch.load(ckpt)


bottle: 1.0000
cable: 0.9670
capsule: 0.9158
carpet: 0.9984
grid: 0.9866
hazelnut: 1.0000
leather: 1.0000
metal_nut: 0.9985
pill: 0.9547
screw: 0.8975
tile: 1.0000
toothbrush: 0.9139
transistor: 0.9942
wood: 0.9851
zipper: 0.9645
mean: 0.9718
bottle: 0.9850
cable: 0.9686
capsule: 0.9856
carpet: 0.9876
grid: 0.9706
hazelnut: 0.9835
leather: 0.9876
metal_nut: 0.9569
pill: 0.9583
screw: 0.9870
tile: 0.9517
toothbrush: 0.9878
transistor: 0.9634
wood: 0.9383
zipper: 0.9691
mean: 0.9721


  energy_dict = torch.load(ckpt)


bottle: 0.9992
cable: 0.9771
capsule: 0.9099
carpet: 0.9992
grid: 0.9850
hazelnut: 1.0000
leather: 1.0000
metal_nut: 0.9995
pill: 0.9501
screw: 0.9031
tile: 0.9996
toothbrush: 0.9056
transistor: 0.9925
wood: 0.9825
zipper: 0.9666
mean: 0.9713
bottle: 0.9847
cable: 0.9674
capsule: 0.9853
carpet: 0.9884
grid: 0.9703
hazelnut: 0.9826
leather: 0.9858
metal_nut: 0.9571
pill: 0.9576
screw: 0.9872
tile: 0.9536
toothbrush: 0.9882
transistor: 0.9634
wood: 0.9376
zipper: 0.9686
mean: 0.9719


  energy_dict = torch.load(ckpt)


bottle: 1.0000
cable: 0.9708
capsule: 0.9079
carpet: 0.9968
grid: 0.9916
hazelnut: 1.0000
leather: 1.0000
metal_nut: 0.9985
pill: 0.9602
screw: 0.9110
tile: 1.0000
toothbrush: 0.9139
transistor: 0.9938
wood: 0.9886
zipper: 0.9651
mean: 0.9732
bottle: 0.9848
cable: 0.9679
capsule: 0.9853
carpet: 0.9885
grid: 0.9712
hazelnut: 0.9833
leather: 0.9859
metal_nut: 0.9573
pill: 0.9576
screw: 0.9874
tile: 0.9534
toothbrush: 0.9879
transistor: 0.9638
wood: 0.9383
zipper: 0.9680
mean: 0.9720


  energy_dict = torch.load(ckpt)


bottle: 1.0000
cable: 0.9698
capsule: 0.9067
carpet: 0.9996
grid: 0.9891
hazelnut: 1.0000
leather: 1.0000
metal_nut: 0.9990
pill: 0.9648
screw: 0.9037
tile: 1.0000
toothbrush: 0.9111
transistor: 0.9917
wood: 0.9895
zipper: 0.9635
mean: 0.9726
bottle: 0.9853
cable: 0.9681
capsule: 0.9855
carpet: 0.9875
grid: 0.9711
hazelnut: 0.9837
leather: 0.9848
metal_nut: 0.9578
pill: 0.9575
screw: 0.9867
tile: 0.9530
toothbrush: 0.9877
transistor: 0.9643
wood: 0.9378
zipper: 0.9683
mean: 0.9719


In [None]:
/tmp/ipykernel_1048143/471784250.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  energy_dict = torch.load(ckpt)
=============
bottle: 1.0000
cable: 0.9683
capsule: 0.9055
carpet: 0.9968
grid: 0.9858
hazelnut: 1.0000
leather: 1.0000
metal_nut: 0.9985
pill: 0.9599
screw: 0.9000
tile: 1.0000
toothbrush: 0.9222
transistor: 0.9917
wood: 0.9895
zipper: 0.9706
mean: 0.9726
=============
=============
bottle: 0.9840
cable: 0.9674
capsule: 0.9851
carpet: 0.9892
grid: 0.9698
hazelnut: 0.9832
leather: 0.9889
metal_nut: 0.9562
pill: 0.9574
screw: 0.9876
tile: 0.9526
toothbrush: 0.9874
transistor: 0.9630
wood: 0.9400
zipper: 0.9659
mean: 0.9718
=============
/tmp/ipykernel_1048143/471784250.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  energy_dict = torch.load(ckpt)
=============
bottle: 1.0000
cable: 0.9771
capsule: 0.9087
carpet: 0.9980
grid: 0.9883
hazelnut: 1.0000
leather: 1.0000
metal_nut: 1.0000
pill: 0.9599
screw: 0.9020
tile: 0.9996
toothbrush: 0.9111
transistor: 0.9958
wood: 0.9904
zipper: 0.9706
mean: 0.9734
=============
=============
bottle: 0.9836
cable: 0.9680
capsule: 0.9851
carpet: 0.9872
grid: 0.9713
hazelnut: 0.9838
leather: 0.9875
metal_nut: 0.9567
pill: 0.9574
screw: 0.9864
tile: 0.9543
toothbrush: 0.9872
transistor: 0.9640
wood: 0.9376
zipper: 0.9670
mean: 0.9718
=============
/tmp/ipykernel_1048143/471784250.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  energy_dict = torch.load(ckpt)
=============
bottle: 0.9992
cable: 0.9706
capsule: 0.8979
carpet: 0.9972
grid: 0.9866
hazelnut: 0.9993
leather: 1.0000
metal_nut: 0.9995
pill: 0.9602
screw: 0.9024
tile: 1.0000
toothbrush: 0.9222
transistor: 0.9917
wood: 0.9895
zipper: 0.9719
mean: 0.9725
=============
=============
bottle: 0.9840
cable: 0.9672
capsule: 0.9853
carpet: 0.9894
grid: 0.9708
hazelnut: 0.9828
leather: 0.9889
metal_nut: 0.9569
pill: 0.9577
screw: 0.9862
tile: 0.9537
toothbrush: 0.9873
transistor: 0.9642
wood: 0.9375
zipper: 0.9677
mean: 0.9720
=============
/tmp/ipykernel_1048143/471784250.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  energy_dict = torch.load(ckpt)
=============
bottle: 1.0000
cable: 0.9753
capsule: 0.8935
carpet: 0.9948
grid: 0.9883
hazelnut: 1.0000
leather: 1.0000
metal_nut: 1.0000
pill: 0.9577
screw: 0.8967
tile: 0.9996
toothbrush: 0.9111
transistor: 0.9942
wood: 0.9895
zipper: 0.9593
mean: 0.9707
=============
=============
bottle: 0.9831
cable: 0.9673
capsule: 0.9851
carpet: 0.9881
grid: 0.9703
hazelnut: 0.9837
leather: 0.9874
metal_nut: 0.9560
pill: 0.9578
screw: 0.9867
tile: 0.9482
toothbrush: 0.9870
transistor: 0.9636
wood: 0.9379
zipper: 0.9670
mean: 0.9713
=============
/tmp/ipykernel_1048143/471784250.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  energy_dict = torch.load(ckpt)
=============
bottle: 1.0000
cable: 0.9762
capsule: 0.9134
carpet: 0.9984
grid: 0.9891
hazelnut: 1.0000
leather: 1.0000
metal_nut: 0.9990
pill: 0.9591
screw: 0.8893
tile: 1.0000
toothbrush: 0.9139
transistor: 0.9888
wood: 0.9965
zipper: 0.9711
mean: 0.9730
=============
=============
bottle: 0.9830
cable: 0.9671
capsule: 0.9854
carpet: 0.9886
grid: 0.9702
hazelnut: 0.9837
leather: 0.9862
metal_nut: 0.9559
pill: 0.9575
screw: 0.9863
tile: 0.9514
toothbrush: 0.9873
transistor: 0.9631
wood: 0.9392
zipper: 0.9685
mean: 0.9715
=============
/tmp/ipykernel_1048143/471784250.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  energy_dict = torch.load(ckpt)
=============
bottle: 1.0000
cable: 0.9730
capsule: 0.8995
carpet: 0.9980
grid: 0.9900
hazelnut: 1.0000
leather: 1.0000
metal_nut: 0.9995
pill: 0.9613
screw: 0.9072
tile: 0.9996
toothbrush: 0.9083
transistor: 0.9954
wood: 0.9877
zipper: 0.9661
mean: 0.9724
=============
=============
bottle: 0.9836
cable: 0.9672
capsule: 0.9850
carpet: 0.9881
grid: 0.9694
hazelnut: 0.9840
leather: 0.9871
metal_nut: 0.9565
pill: 0.9577
screw: 0.9872
tile: 0.9549
toothbrush: 0.9875
transistor: 0.9632
wood: 0.9387
zipper: 0.9681
mean: 0.9719
=============
/tmp/ipykernel_1048143/471784250.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  energy_dict = torch.load(ckpt)
=============
bottle: 1.0000
cable: 0.9657
capsule: 0.9027
carpet: 0.9988
grid: 0.9925
hazelnut: 1.0000
leather: 1.0000
metal_nut: 0.9980
pill: 0.9591
screw: 0.9074
tile: 0.9993
toothbrush: 0.9083
transistor: 0.9946
wood: 0.9860
zipper: 0.9688
mean: 0.9721
=============
=============
bottle: 0.9843
cable: 0.9675
capsule: 0.9850
carpet: 0.9890
grid: 0.9710
hazelnut: 0.9830
leather: 0.9870
metal_nut: 0.9563
pill: 0.9572
screw: 0.9866
tile: 0.9529
toothbrush: 0.9875
transistor: 0.9624
wood: 0.9369
zipper: 0.9681
mean: 0.9717
=============
/tmp/ipykernel_1048143/471784250.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  energy_dict = torch.load(ckpt)
=============
bottle: 1.0000
cable: 0.9687
capsule: 0.9122
carpet: 0.9992
grid: 0.9875
hazelnut: 1.0000
leather: 1.0000
metal_nut: 0.9990
pill: 0.9643
screw: 0.9045
tile: 1.0000
toothbrush: 0.9139
transistor: 0.9921
wood: 0.9877
zipper: 0.9656
mean: 0.9730
=============
=============
bottle: 0.9836
cable: 0.9678
capsule: 0.9851
carpet: 0.9878
grid: 0.9704
hazelnut: 0.9844
leather: 0.9888
metal_nut: 0.9565
pill: 0.9577
screw: 0.9871
tile: 0.9526
toothbrush: 0.9870
transistor: 0.9633
wood: 0.9398
zipper: 0.9683
mean: 0.9720
=============
/tmp/ipykernel_1048143/471784250.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  energy_dict = torch.load(ckpt)
=============
bottle: 1.0000
cable: 0.9597
capsule: 0.9154
carpet: 0.9980
grid: 0.9850
hazelnut: 1.0000
leather: 1.0000
metal_nut: 0.9995
pill: 0.9493
screw: 0.9113
tile: 1.0000
toothbrush: 0.8972
transistor: 0.9933
wood: 0.9842
zipper: 0.9672
mean: 0.9707
=============
=============
bottle: 0.9835
cable: 0.9668
capsule: 0.9854
carpet: 0.9881
grid: 0.9700
hazelnut: 0.9821
leather: 0.9864
metal_nut: 0.9562
pill: 0.9579
screw: 0.9872
tile: 0.9530
toothbrush: 0.9868
transistor: 0.9635
wood: 0.9376
zipper: 0.9695
mean: 0.9716
=============
/tmp/ipykernel_1048143/471784250.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  energy_dict = torch.load(ckpt)
=============
bottle: 1.0000
cable: 0.9695
capsule: 0.9071
carpet: 0.9980
grid: 0.9883
hazelnut: 0.9996
leather: 1.0000
metal_nut: 0.9995
pill: 0.9534
screw: 0.9018
tile: 0.9982
toothbrush: 0.8944
transistor: 0.9771
wood: 0.9868
zipper: 0.9601
mean: 0.9689
=============
=============
bottle: 0.9834
cable: 0.9672
capsule: 0.9854
carpet: 0.9878
grid: 0.9705
hazelnut: 0.9832
leather: 0.9866
metal_nut: 0.9566
pill: 0.9570
screw: 0.9862
tile: 0.9499
toothbrush: 0.9873
transistor: 0.9634
wood: 0.9369
zipper: 0.9677
mean: 0.9713
=============