In [1]:
import numpy as np

from DeepSAD import DeepSAD
from datasets.main import load_dataset
from datasets.customize import CustomizeDataset

In [2]:
dataset_name = 'customize'
data_path = '../../data/20200803_label_patch/train'
normal_class = 0
known_outlier_class = 1
n_known_outlier_classes = 1
ratio_known_normal = 0.99
ratio_known_outlier = 0
ratio_pollution = 0.0
seed = 451
net_name = 'vgg'

In [3]:
# Load data
dataset = load_dataset(
    dataset_name,
    data_path,
    normal_class,
    known_outlier_class,
    n_known_outlier_classes,
    ratio_known_normal,
    ratio_known_outlier,
    ratio_pollution,
    random_state=np.random.RandomState(seed),
)

In [4]:
test_set = CustomizeDataset(root=data_path, dataset_name=dataset_name, n_known_outlier_classes=n_known_outlier_classes,
                 ratio_known_normal=ratio_known_normal, ratio_known_outlier=ratio_known_outlier,
                 ratio_pollution=ratio_pollution,
                 random_state=None, split=False
                )

TypeError: zeros_like() received an invalid combination of arguments - got (numpy.ndarray), but expected one of:
 * (Tensor input, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (Tensor input, bool requires_grad)


In [5]:
test_set.

AttributeError: 'CustomizeDataset' object has no attribute 'data'

In [48]:
dataset.test_set

<base.local_dataset.LocalDataset at 0x7fe500f89950>

In [7]:
load_model = '../log/DeepSAD/customize/model.tar'
device = 'cuda'

In [8]:
eta = 1.0
deepSAD = DeepSAD(eta)
deepSAD.set_network(net_name)

# If specified, load Deep SAD model (center c, network weights, and possibly autoencoder weights)
if load_model:
    deepSAD.load_model(model_path=load_model, load_ae=True, map_location=device)

In [9]:
deepSAD.test(test_set)

Min distance: 0.0028568999841809273; Max distance: 0.02352345734834671; Avg: 0.007556464655224133


In [17]:
deepSAD.trainer

<optim.DeepSAD_trainer.DeepSADTrainer at 0x7fe50306e650>

In [10]:
deepSAD.trainer.test_scores

[(0, 0, 0.008454952389001846),
 (1, 0, 0.014362163841724396),
 (2, 0, 0.004431477747857571),
 (3, 0, 0.0052476916462183),
 (4, 0, 0.002913469448685646),
 (5, 0, 0.003655507229268551),
 (6, 0, 0.00724844541400671),
 (7, 0, 0.004695083014667034),
 (8, 0, 0.006936338730156422),
 (9, 0, 0.005717538297176361),
 (10, 0, 0.0039015968795865774),
 (11, 0, 0.004129611421376467),
 (12, 0, 0.005292694550007582),
 (13, 0, 0.004933517891913652),
 (14, 0, 0.005312385503202677),
 (15, 0, 0.006217647343873978),
 (16, 0, 0.009674100205302238),
 (17, 0, 0.007505523040890694),
 (18, 0, 0.012054529041051865),
 (19, 0, 0.01291203685104847),
 (20, 0, 0.012477578595280647),
 (21, 0, 0.004904408473521471),
 (22, 0, 0.008192755281925201),
 (23, 0, 0.007722916081547737),
 (24, 0, 0.009389521554112434),
 (25, 0, 0.00823487900197506),
 (26, 0, 0.003692066762596369),
 (27, 0, 0.005389336030930281),
 (28, 0, 0.009660228155553341),
 (29, 0, 0.005074864253401756),
 (30, 0, 0.0075704678893089294),
 (31, 0, 0.0170681402

In [43]:
dataset.test_set.data

array(['../../data/20200803_label_patch/train/0/mach001_camera01_2020_8_3_14_48_36_512_600.jpg',
       '../../data/20200803_label_patch/train/0/mach001_camera01_2020_8_3_14_30_17_1024_630.jpg',
       '../../data/20200803_label_patch/train/0/mach001_camera02_2020_8_3_14_32_19_0_600.jpg',
       '../../data/20200803_label_patch/train/0/mach001_camera01_2020_8_3_14_45_18_1280_600.jpg',
       '../../data/20200803_label_patch/train/0/mach001_camera01_2020_8_3_14_41_1_384_630.jpg',
       '../../data/20200803_label_patch/train/0/mach001_camera02_2020_8_3_14_33_12_640_600.jpg',
       '../../data/20200803_label_patch/train/0/mach001_camera02_2020_8_3_14_28_27_0_600.jpg',
       '../../data/20200803_label_patch/train/0/mach001_camera02_2020_8_3_14_29_6_1408_600.jpg',
       '../../data/20200803_label_patch/train/0/mach001_camera02_2020_8_3_14_34_24_128_630.jpg',
       '../../data/20200803_label_patch/train/0/mach001_camera01_2020_8_3_14_28_14_640_630.jpg',
       '../../data/20200803_label

In [46]:
[p for p in dataset.train_set.dataset.data if p.split("/")[-2] in ["1", "2"]]

['../../data/20200803_label_patch/train/1/mach001_camera01_2020_8_3_12_46_14_1536_600.jpg',
 '../../data/20200803_label_patch/train/1/mach001_camera01_2020_8_3_12_38_28_512_630.jpg',
 '../../data/20200803_label_patch/train/1/mach001_camera01_2020_8_3_12_43_30_384_600.jpg',
 '../../data/20200803_label_patch/train/1/mach001_camera01_2020_8_3_12_41_39_0_630.jpg',
 '../../data/20200803_label_patch/train/1/mach001_camera01_2020_8_3_12_38_14_1024_600.jpg',
 '../../data/20200803_label_patch/train/1/mach001_camera02_2020_8_3_12_43_30_640_600.jpg',
 '../../data/20200803_label_patch/train/2/mach001_camera03_2020_8_3_13_45_19_375_600.jpg',
 '../../data/20200803_label_patch/train/1/mach001_camera01_2020_8_3_12_44_59_1024_600.jpg',
 '../../data/20200803_label_patch/train/2/mach001_camera03_2020_8_3_13_44_13_375_630.jpg',
 '../../data/20200803_label_patch/train/1/mach001_camera02_2020_8_3_13_32_43_375_630.jpg',
 '../../data/20200803_label_patch/train/1/mach001_camera02_2020_8_3_12_37_44_0_630.jpg',


In [35]:
"../../data/20200803_label_patch/train/0/mach001_camera02_2020_8_3_14_28_14_768_630.jpg".split("/")[-2]

'0'