In [None]:
! [ -d "GroceryStoreDataset-master" ] && echo "skipping" || (wget -nc --no-check-certificate https://github.com/marcusklasson/GroceryStoreDataset/archive/refs/heads/master.zip && unzip master.zip -d .)

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
import pathlib, json

from sap_computer_vision.datasets import image_folder as imgf

dataset_folder = pathlib.Path('GroceryStoreDataset-master/').resolve()

images_train, class_names = imgf.register('grocery_train',
                                          base_dir=dataset_folder / 'dataset' / 'train')
images_val, class_names = imgf.register('grocery_val',
                                          base_dir=dataset_folder / 'dataset' / 'val',
                                          class_names=class_names)
images_test, class_names = imgf.register('grocery_test',
                                          base_dir=dataset_folder / 'dataset' / 'test',
                                          class_names=class_names)

In [None]:
import numpy as np
from sap_computer_vision import setup_loggers, get_cfg, get_config_file

out_dir = 'distance_learning_grocery_higher_lr'

setup_loggers(out_dir)

cfg = get_cfg()
cfg.merge_from_file(get_config_file('Base-EarlyStopping'))
cfg.merge_from_file(get_config_file('Base-Evaluation'))
cfg.merge_from_file(get_config_file('TripletDistanceLearner/FPN-Resnet50'))

cfg.OUTPUT_DIR = out_dir
cfg.DATASETS.TRAIN = ('grocery_train', )
cfg.DATASETS.TEST = ('grocery_val', )

cfg.DATALOADER.PK_SAMPLER.P_CLASSES_PER_BATCH = 30
cfg.DATALOADER.PK_SAMPLER.K_EXAMPLES_PER_CLASS = 4
cfg.DATALOADER.NUM_WORKERS = 10
cfg.SOLVER.MAX_ITER = 5000
cfg.SOLVER.BASE_LR = 0.01
cfg.SOLVER.GAMMA = float(np.sqrt(0.1))
cfg.SOLVER.EARLY_STOPPING.ENABLED = False

cfg.SOLVER.WARMUP_ITERS = max(int(0.01 * cfg.SOLVER.MAX_ITER), 0)
cfg.SOLVER.STEPS = [cfg.SOLVER.MAX_ITER * p for p in (0.25, 0.375, 0.5, 0.75, 0.9)]
for aug in ['RANDOM_LIGHTING', 'RANDOM_BRIGHTNESS', 'RANDOM_SATURATION', 'RANDOM_CONTRAST', 'RANDOM_ROTATION', 'CROP', 'CUT_OUT']:
    if cfg.INPUT.get(aug, None) is not None:
        cfg.INPUT[aug].ENABLED = True
cfg.MODEL.TRIPLET_DISTANCE_LEARNER.MARGIN_LOSS.MARGIN = 0.5
cfg.MODEL.TRIPLET_DISTANCE_LEARNER.LOSS = 'MARGIN_LOSS'
cfg.MODEL.FEATURE_EXTRACTION.PROJECTION_SIZE = 512
cfg.MODEL.FEATURE_EXTRACTION.INTERMEDIATE_SIZE = None
cfg.DATALOADER.SAMPLER_TRAIN = 'PKSampler'
cfg.MODEL.TRIPLET_DISTANCE_LEARNER.TRIPLET_STRATEGY = ('*', '*')
delay = (cfg.SOLVER.MAX_ITER  * 0.5)
strategies_pos = np.linspace(0.5, 0.8, 21)
strategies_neg = 1. - strategies_pos
strategies = [(float(p), float(n)) for p, n in zip(strategies_pos, strategies_neg)]
switch_steps = np.linspace(delay, cfg.SOLVER.MAX_ITER, len(strategies)+1)[:-1]
cfg.DATALOADER.PK_SAMPLER.STRATEGY_SWITCHES = [(int(step), strat) for (step, strat) in zip(switch_steps, strategies)]

cfg.TEST.EVAL_PERIOD = 250

In [None]:
out_dir = pathlib.Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
with (out_dir / 'used_config.yaml').open('w') as stream:
    stream.write(cfg.dump())

In [None]:
from sap_computer_vision.engine import TripletDistanceTrainer

In [None]:
trainer = TripletDistanceTrainer(cfg)

In [None]:
trainer.resume_or_load(resume=False)
trainer.train()

In [None]:
cfg.DATASETS.TEST = ('grocery_test', )
metrics = trainer.test(cfg, trainer.model)
print(json.dumps(metrics, ident=2))