In [1]:
import numpy

In [4]:
!python -u train_net.py  \
        --GPU=0 \
        --DATASET=KSDD \
        --RUN_NAME=RUN_NAME \
        --DATASET_PATH=/path/to/dataset \
        --RESULTS_PATH=/path/to/save/results \
        --SAVE_IMAGES=True \
        --DILATE=7 \
        --EPOCHS=50 \
        --LEARNING_RATE=1.0 \
        --DELTA_CLS_LOSS=0.01 \
        --BATCH_SIZE=1 \
        --WEIGHTED_SEG_LOSS=True \
        --WEIGHTED_SEG_LOSS_P=2 \
        --WEIGHTED_SEG_LOSS_MAX=1 \
        --DYN_BALANCED_LOSS=True \
        --GRADIENT_ADJUSTMENT=True \
        --FREQUENCY_SAMPLING=True \
        --TRAIN_NUM=33 \
        --NUM_SEGMENTED=33 \
        --FOLD=0

^C


In [13]:
import numpy as np
import os
import pickle
from data.dataset import Dataset


def read_split(num_segmented: int, fold: int, kind: str):
    fn = f"DAGM/split_{num_segmented}.pyb"
    with open(f"splits/{fn}", "rb") as f:
        train_samples, test_samples = pickle.load(f)
        if kind == 'TRAIN':
            return train_samples[fold - 1]
        elif kind == 'TEST':
            return test_samples[fold - 1]
        else:
            raise Exception('Unknown')
samples = read_split(0, 1, 'TEST')

In [14]:
samples

[('0001.PNG', True),
 ('0002.PNG', True),
 ('0003.PNG', True),
 ('0004.PNG', True),
 ('0005.PNG', True),
 ('0006.PNG', True),
 ('0007.PNG', True),
 ('0008.PNG', True),
 ('0009.PNG', True),
 ('0010.PNG', True),
 ('0011.PNG', True),
 ('0012.PNG', True),
 ('0013.PNG', True),
 ('0014.PNG', True),
 ('0015.PNG', True),
 ('0016.PNG', True),
 ('0017.PNG', True),
 ('0018.PNG', True),
 ('0019.PNG', True),
 ('0020.PNG', True),
 ('0021.PNG', True),
 ('0022.PNG', True),
 ('0023.PNG', True),
 ('0024.PNG', True),
 ('0025.PNG', True),
 ('0026.PNG', True),
 ('0027.PNG', True),
 ('0028.PNG', True),
 ('0029.PNG', True),
 ('0030.PNG', True),
 ('0031.PNG', True),
 ('0032.PNG', True),
 ('0033.PNG', True),
 ('0034.PNG', True),
 ('0035.PNG', True),
 ('0036.PNG', True),
 ('0037.PNG', True),
 ('0038.PNG', True),
 ('0039.PNG', True),
 ('0040.PNG', True),
 ('0041.PNG', True),
 ('0042.PNG', True),
 ('0043.PNG', True),
 ('0044.PNG', True),
 ('0045.PNG', True),
 ('0046.PNG', True),
 ('0047.PNG', True),
 ('0048.PNG',

In [4]:
_CLASSNAMES = [
    "Class1",
    "Class2",
    "Class3",
    "Class4",
    "Class5",
    "Class6",
    "Class7",
    "Class8",
    "Class9",
    "Class10",
    "Class11",
    "Class12",
    "Class13",
    "Class14",
    "Class15",
]
class DatasetSplit(Enum):
    TRAIN = "train"
    VAL = "val"
    TEST = "test"
    
class MVTecDataset(Dataset):
    def __init__(self, kind: str, cfg):
        super(MVTecDataset, self).__init__(os.path.join(
            cfg.DATASET_PATH, f"{cfg.FOLD}"), cfg, kind)
        if kind == "TRAIN":
            self.split = DatasetSplit.TRAIN
        if kind == "TEST":
            self.split = DatasetSplit.TEST
        if kind == "VAL":
            self.split = DatasetSplit.VAL
        
        self.source = cfg.DATASET_PATH
        self.train_val_split = 1.0
        
        self.classnames_to_use = [classname] if classname is not None else _CLASSNAMES
        self.read_contents()
        
    def get_image_data(self):
        imgpaths_per_class = {}
        maskpaths_per_class = {}

        for classname in self.classnames_to_use:
            classpath = os.path.join(self.source, classname, self.split.value)
            maskpath = os.path.join(self.source, classname, "ground_truth")
            anomaly_types = os.listdir(classpath)

            imgpaths_per_class[classname] = {}
            maskpaths_per_class[classname] = {}

            for anomaly in anomaly_types:
                anomaly_path = os.path.join(classpath, anomaly)
                anomaly_files = sorted(os.listdir(anomaly_path))
                imgpaths_per_class[classname][anomaly] = [
                    os.path.join(anomaly_path, x) for x in anomaly_files
                ]

                if self.train_val_split < 1.0:
                    n_images = len(imgpaths_per_class[classname][anomaly])
                    train_val_split_idx = int(n_images * self.train_val_split)
                    if self.split == DatasetSplit.TRAIN:
                        imgpaths_per_class[classname][anomaly] = imgpaths_per_class[
                            classname
                        ][anomaly][:train_val_split_idx]
                    elif self.split == DatasetSplit.VAL:
                        imgpaths_per_class[classname][anomaly] = imgpaths_per_class[
                            classname
                        ][anomaly][train_val_split_idx:]

                if self.split == DatasetSplit.TEST and anomaly != "good":
                    anomaly_mask_path = os.path.join(maskpath, anomaly)
                    anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))
                    maskpaths_per_class[classname][anomaly] = [
                        os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files
                    ]
                else:
                    maskpaths_per_class[classname]["good"] = None

        # Unrolls the data dictionary to an easy-to-iterate list.
        data_to_iterate = []
        for classname in sorted(imgpaths_per_class.keys()):
            for anomaly in sorted(imgpaths_per_class[classname].keys()):
                for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]):
                    data_tuple = [classname, anomaly, image_path]
                    if self.split == DatasetSplit.TEST and anomaly != "good":
                        data_tuple.append(maskpaths_per_class[classname][anomaly][i])
                    else:
                        data_tuple.append(None)
                    data_to_iterate.append(data_tuple)

        return imgpaths_per_class, data_to_iterate

    def read_contents(self):
        pos_samples, neg_samples = [], []

        samples = read_split(self.cfg.NUM_SEGMENTED, self.cfg.FOLD, self.kind)

        sub_dir = self.kind.lower()

        for image_name, is_segmented in samples:
            image_path = os.path.join(self.path, sub_dir, image_name)
            image = self.read_img_resize(
                image_path, self.grayscale, self.image_size)
            img_name_short = image_name[:-4]
            seg_mask_path = os.path.join(
                self.path, sub_dir, "Label",  f"{img_name_short}_label.PNG")

            if os.path.exists(seg_mask_path):
                seg_mask, _ = self.read_label_resize(
                    seg_mask_path, self.image_size, dilate=self.cfg.DILATE)
                image = self.to_tensor(image)
                seg_loss_mask = self.distance_transform(
                    seg_mask, self.cfg.WEIGHTED_SEG_LOSS_MAX, self.cfg.WEIGHTED_SEG_LOSS_P)
                seg_mask = self.to_tensor(self.downsize(seg_mask))
                seg_loss_mask = self.to_tensor(self.downsize(seg_loss_mask))
                pos_samples.append(
                    (image, seg_mask, seg_loss_mask, is_segmented, image_path, None, img_name_short))

            else:
                seg_mask = np.zeros_like(image)
                image = self.to_tensor(image)
                seg_loss_mask = self.to_tensor(
                    self.downsize(np.ones_like(seg_mask)))
                seg_mask = self.to_tensor(self.downsize(seg_mask))
                neg_samples.append(
                    (image, seg_mask, seg_loss_mask, True, image_path, seg_mask_path, img_name_short))

        self.pos_samples = pos_samples
        self.neg_samples = neg_samples

        self.num_pos = len(pos_samples)
        self.num_neg = len(neg_samples)
        self.len = 2 * \
            len(pos_samples) if self.kind in ['TRAIN'] else len(
                pos_samples) + len(neg_samples)

        self.init_extra()


In [6]:
mvtec = MVTecDataset(kind="TRAIN")

TypeError: __init__() missing 1 required positional argument: 'cfg'

In [3]:
which python3

SyntaxError: invalid syntax (<ipython-input-3-3778c32ba68a>, line 1)

In [4]:
import gc, torch

gc.collect()

torch.cuda.empty_cache()