In [2]:
import torch.backends.cudnn
torch.backends.cudnn.benchmark = True

import argparse
from operator import itemgetter

from helpers.aligned_printer import AlignedPrinter
from helpers.testset import Testset
from test.multiscale_tester import MultiscaleTester


In [5]:

class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

flags = {
    'log_dir': 'weights/', 
    'log_dates': '0306_0001', 
    'images': '/var/tmp/sauravkadavath/distorted_datasets/imagenet_r/n02051845_pelican/', 
    'match_filenames': None, 
    'max_imgs_per_folder': None, 
    'crop': None, 
    'names': 'L3C', 
    'overwrite_cache': False, 
    'reset_entire_cache': False, 
    'restore_itr': '-1', 
    'recursive': 'auto', 
    'sample': None, 
    'write_to_files': None, 
    'compare_theory': False, 
    'time_report': None, 
    'sort_output': 'testset'
}
flags = Struct(**flags)

testsets = [Testset(images_dir_or_image.rstrip('/'), flags.max_imgs_per_folder,
                    # Append flags.crop to ID so that it creates unique entry in cache
                    append_id=f'_crop{flags.crop}' if flags.crop else None)
            for images_dir_or_image in flags.images.split(',')]
if flags.match_filenames:
    for ts in testsets:
        ts.filter_filenames(flags.match_filenames)

splitter = ',' if ',' in flags.log_dates else '|'  # support tensorboard strings, too
results = []
log_dates = flags.log_dates.split(splitter)
for log_date in log_dates:
    for restore_itr in map(int, flags.restore_itr.split(',')):
        print('Testing {} at {} ---'.format(log_date, restore_itr))
        tester = MultiscaleTester(log_date, flags, restore_itr)
        results += tester.test_all(testsets)

# if --names was passed: will print 'name (log_date)'. otherwise, will just print 'log_date'
if flags.names:
    names = flags.names.split(splitter) if flags.names else log_dates
    names_to_log_date = {log_date: f'{name} ({log_date})'
                         for log_date, name in zip(log_dates, names)}
else:
    # set names to log_dates if --names is not given, i.e., we just output log_date
    names_to_log_date = {log_date: log_date for log_date in log_dates}


Testing 0306_0001 at -1 ---
Restoring weights/0306_0001 cr oi/ckpts/ckpt_0001002750.pt... (strict=True)
odict_keys(['sub_rgb_mean.weight', 'sub_rgb_mean.bias', 'heads.0.head.0.weight', 'heads.0.head.0.bias', 'heads.0.head.1.head.weight', 'heads.0.head.1.head.bias', 'heads.1.head.weight', 'heads.1.head.bias', 'heads.2.head.weight', 'heads.2.head.bias', 'nets.0.enc.levels', 'nets.0.enc.down.weight', 'nets.0.enc.down.bias', 'nets.0.enc.body.0.body.0.weight', 'nets.0.enc.body.0.body.0.bias', 'nets.0.enc.body.0.body.2.weight', 'nets.0.enc.body.0.body.2.bias', 'nets.0.enc.body.1.body.0.weight', 'nets.0.enc.body.1.body.0.bias', 'nets.0.enc.body.1.body.2.weight', 'nets.0.enc.body.1.body.2.bias', 'nets.0.enc.body.2.body.0.weight', 'nets.0.enc.body.2.body.0.bias', 'nets.0.enc.body.2.body.2.weight', 'nets.0.enc.body.2.body.2.bias', 'nets.0.enc.body.3.body.0.weight', 'nets.0.enc.body.3.body.0.bias', 'nets.0.enc.body.3.body.2.weight', 'nets.0.enc.body.3.body.2.bias', 'nets.0.enc.body.4.body.0.weigh

0306_0001: deviantart_13 (        34): mean bpsp=2.946488159043448 crops:freq -> 1:354

KeyboardInterrupt: 