In [None]:
import sys
sys.path.append('../')

In [None]:
import torch
import torchvision.transforms
import itertools
import pandas as pd
import numpy as np
from tqdm.auto import tqdm, trange

from models.train_full import train, test
from models.models import CNNClassifier
from models.utils import set_seed

In [None]:
def pretty(ld, indent=0):
    with open('result.txt', 'w', encoding='utf-8') as file:
        for d in tqdm(ld):
            file.write('{' + '\n')
            for key, value in d.items():
                file.write('\t' * (indent+1) + str(key) + ':' + str(value) + '\n')
                # file.write('\t' * (indent+1) + str(key) + '\n')
                # file.write('\t' * (indent+2) + str(value) + '\n')
            file.write('},\n')

In [None]:
do_train = True

seed = 4444

metric_filter_1 = 'val_mcc'
metric_filter_2 = 'val_mse'

In [None]:
dict_model = dict(
    # dictionary with model information
    in_channels=[3],
    out_channels=[2],
    dim_layers=[[32, 64, 128]],
    block_conv_layers=[3],
    residual=[True],
    max_pooling=[True, False],
    # training param
    transforms=[torchvision.transforms.RandomHorizontalFlip()]
)

list_model = [dict(zip(dict_model.keys(), k)) for k in itertools.product(*dict_model.values())]

In [None]:
if do_train:
    for d in tqdm(list_model):
        set_seed(seed)
        
        d = d.copy()
        transforms = d.pop('transforms')

        train(
            model = CNNClassifier(**d),
            dict_model=d,
            log_dir = "./logs_full",
            data_path = "./data/UTKFace",
            save_path = "./models/saved_full",
            lr = 1e-2,
            optimizer_name = "adamw",
            n_epochs = 65,
            batch_size = 64,
            num_workers = 2,
            scheduler_mode = 'min_mse',
            debug_mode = False,
            device = None,
            steps_save = 1,
            use_cpu = False,
            transforms = transforms,
            loss_age_weight = 1e-2,
        )

#### Results

In [None]:
res_test = test(
    data_path = "./data/UTKFace",
    save_path = './models/saved_full',
    n_runs = 1,
    batch_size = 64,
    num_workers = 0,
    debug_mode = False,
    use_cpu = False,
    save = True,
    verbose = False,
)

In [None]:
all = res_test
# ascending order
sort_idx = np.argsort([k['dict'][metric_filter_1] for k in all])[::-1]
all[sort_idx[0]]['dict']

In [None]:
all = res_test[2]
# ascending order
sort_idx = np.argsort([k['dict'][metric_filter_2] for k in all])[::-1]
all[sort_idx[0]]['dict']

In [None]:
# pretty([all[k]['dict'] for k in sort_idx])