In [1]:
import time
import argparse
import logging
from tqdm import tqdm
import pandas as pd
from collections import defaultdict
from scipy.stats import gmean

import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from tensorboard_logger import Logger

from resnet import resnet50
from loss import *
from datasets import AgeDB
from utils import *

import os
os.environ["KMP_WARNINGS"] = "FALSE"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# imbalanced related
# LDS
parser.add_argument('--lds', action='store_true', default=True, help='whether to enable LDS')
parser.add_argument('--lds_kernel', type=str, default='gaussian',
                    choices=['gaussian', 'triang', 'laplace'], help='LDS kernel type')
parser.add_argument('--lds_ks', type=int, default=9, help='LDS kernel size: should be odd number')
parser.add_argument('--lds_sigma', type=float, default=1, help='LDS gaussian/laplace kernel sigma')
# FDS
parser.add_argument('--fds', action='store_true', default=False, help='whether to enable FDS')
parser.add_argument('--fds_kernel', type=str, default='gaussian',
                    choices=['gaussian', 'triang', 'laplace'], help='FDS kernel type')
parser.add_argument('--fds_ks', type=int, default=9, help='FDS kernel size: should be odd number')
parser.add_argument('--fds_sigma', type=float, default=1, help='FDS gaussian/laplace kernel sigma')
parser.add_argument('--start_update', type=int, default=0, help='which epoch to start FDS updating')
parser.add_argument('--start_smooth', type=int, default=1, help='which epoch to start using FDS to smooth features')
parser.add_argument('--bucket_num', type=int, default=100, help='maximum bucket considered for FDS')
parser.add_argument('--bucket_start', type=int, default=3, choices=[0, 3],
                    help='minimum(starting) bucket for FDS, 0 for IMDBWIKI, 3 for AgeDB')
parser.add_argument('--fds_mmt', type=float, default=0.9, help='FDS momentum')

# re-weighting: SQRT_INV / INV
parser.add_argument('--reweight', type=str, default='inverse', choices=['none', 'sqrt_inv', 'inverse'], help='cost-sensitive reweighting scheme')
# two-stage training: RRT
parser.add_argument('--retrain_fc', action='store_true', default=False, help='whether to retrain last regression layer (regressor)')

# training/optimization related
parser.add_argument('--dataset', type=str, default='agedb', choices=['imdb_wiki', 'agedb'], help='dataset name')
parser.add_argument('--data_dir', type=str, default='./Faces/UTKFace/', help='data directory')
parser.add_argument('--model', type=str, default='resnet50', help='model name')
parser.add_argument('--store_root', type=str, default='checkpoint', help='root path for storing checkpoints, logs')
parser.add_argument('--store_name', type=str, default='', help='experiment store name')
parser.add_argument('--gpu', type=int, default=None)
parser.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'], help='optimizer type')
parser.add_argument('--loss', type=str, default='l1', choices=['mse', 'l1', 'focal_l1', 'focal_mse', 'huber'], help='training loss type')
parser.add_argument('--lr', type=float, default=1e-3, help='initial learning rate')
parser.add_argument('--epoch', type=int, default=30, help='number of epochs to train')
parser.add_argument('--momentum', type=float, default=0.9, help='optimizer momentum')
parser.add_argument('--weight_decay', type=float, default=1e-4, help='optimizer weight decay')
parser.add_argument('--schedule', type=int, nargs='*', default=[60, 80], help='lr schedule (when to drop lr by 10x)')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--print_freq', type=int, default=10, help='logging frequency')
parser.add_argument('--img_size', type=int, default=224, help='image size used in training')
parser.add_argument('--workers', type=int, default=16, help='number of workers used in data loading')
# checkpoints
parser.add_argument('--resume', type=str, default='', help='checkpoint file path to resume training')
parser.add_argument('--pretrained', type=str, default='', help='checkpoint file path to load backbone weights')
parser.add_argument('--evaluate', action='store_true', help='evaluate only flag')

parser.set_defaults(augment=True)
args, unknown = parser.parse_known_args()

args.start_epoch, args.best_loss = 0, 1e5
PATH = "./Faces/UTKFace/"

if len(args.store_name):
    args.store_name = f'_{args.store_name}'
if not args.lds and args.reweight != 'none':
    args.store_name += f'_{args.reweight}'
if args.lds:
    args.store_name += f'_lds_{args.lds_kernel[:3]}_{args.lds_ks}'
    if args.lds_kernel in ['gaussian', 'laplace']:
        args.store_name += f'_{args.lds_sigma}'
if args.fds:
    args.store_name += f'_fds_{args.fds_kernel[:3]}_{args.fds_ks}'
    if args.fds_kernel in ['gaussian', 'laplace']:
        args.store_name += f'_{args.fds_sigma}'
    args.store_name += f'_{args.start_update}_{args.start_smooth}_{args.fds_mmt}'
if args.retrain_fc:
    args.store_name += f'_retrain_fc'
args.store_name = f"{args.dataset}_{args.model}{args.store_name}_{args.optimizer}_{args.loss}_{args.lr}_{args.batch_size}"

prepare_folders(args)

logging.root.handlers = []
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(message)s",
    handlers=[
        logging.FileHandler(os.path.join(args.store_root, args.store_name, 'training.log')),
        logging.StreamHandler()
    ])
print = logging.info
print(f"Args: {args}")
print(f"Store name: {args.store_name}")

tb_logger = Logger(logdir=os.path.join(args.store_root, args.store_name), flush_secs=2)

def reload_data():
    age_list = []
    gender_list = []
    race_list = []
    datetime_list = []
    filename_list = []

    for filename in os.listdir("./Faces/UTKFace/"):

        args = filename.split("_")

        if len(args)<4:
            age = int(args[0])
            gender = int(args[1])
            race = 4
            datetime = args[2].split(".")[0]
        else:
            age = int(args[0])
            gender = int(args[1])
            race = int(args[2])
            datetime = args[3].split(".")[0]

        age_list.append(age)
        gender_list.append(gender)
        race_list.append(race)
        datetime_list.append(datetime)
        filename_list.append(filename)


    d = {'age': age_list, 'gender': gender_list, 'race': race_list, 'datetime': datetime_list, 'filename': filename_list}
    return pd.DataFrame(data=d)

overwrite previous folder: checkpoint\agedb_resnet50_lds_gau_9_1_adam_l1_0.001_64 ? [Y/n] :Y


2022-11-30 18:18:54,072 | Args: Namespace(lds=True, lds_kernel='gaussian', lds_ks=9, lds_sigma=1, fds=False, fds_kernel='gaussian', fds_ks=9, fds_sigma=1, start_update=0, start_smooth=1, bucket_num=100, bucket_start=3, fds_mmt=0.9, reweight='inverse', retrain_fc=False, dataset='agedb', data_dir='./Faces/UTKFace/', model='resnet50', store_root='checkpoint', store_name='agedb_resnet50_lds_gau_9_1_adam_l1_0.001_64', gpu=None, optimizer='adam', loss='l1', lr=0.001, epoch=30, momentum=0.9, weight_decay=0.0001, schedule=[60, 80], batch_size=64, print_freq=10, img_size=224, workers=16, resume='', pretrained='', evaluate=False, augment=True, start_epoch=0, best_loss=100000.0)
2022-11-30 18:18:54,073 | Store name: agedb_resnet50_lds_gau_9_1_adam_l1_0.001_64


checkpoint\agedb_resnet50_lds_gau_9_1_adam_l1_0.001_64 removed.
===> Creating folder: checkpoint\agedb_resnet50_lds_gau_9_1_adam_l1_0.001_64


In [3]:
def train(train_loader, model, optimizer, epoch):
    batch_time = AverageMeter('Time', ':6.2f')
    data_time = AverageMeter('Data', ':6.4f')
    losses = AverageMeter(f'Loss ({args.loss.upper()})', ':.3f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses],
        prefix="Epoch: [{}]".format(epoch)
    )

    model.train()
    end = time.time()
    for idx, (inputs, targets, weights) in enumerate(train_loader):
        data_time.update(time.time() - end)
        inputs, targets, weights = \
            inputs.cuda(non_blocking=True), targets.cuda(non_blocking=True), weights.cuda(non_blocking=True)
        if args.fds:
            outputs, _ = model(inputs, targets, epoch)
        else:
            outputs = model(inputs, targets, epoch)

        loss = globals()[f"weighted_{args.loss}_loss"](outputs, targets, weights)
        assert not (np.isnan(loss.item()) or loss.item() > 1e6), f"Loss explosion: {loss.item()}"

        losses.update(loss.item(), inputs.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()
        if idx % args.print_freq == 0:
            progress.display(idx)

    if args.fds and epoch >= args.start_update:
        print(f"Create Epoch [{epoch}] features of all training data...")
        encodings, labels = [], []
        with torch.no_grad():
            for (inputs, targets, _) in tqdm(train_loader):
                inputs = inputs.cuda(non_blocking=True)
                outputs, feature = model(inputs, targets, epoch)
                encodings.extend(feature.data.squeeze().cpu().numpy())
                labels.extend(targets.data.squeeze().cpu().numpy())

        encodings, labels = torch.from_numpy(np.vstack(encodings)).cuda(), torch.from_numpy(np.hstack(labels)).cuda()
        model.module.FDS.update_last_epoch_stats(epoch)
        model.module.FDS.update_running_stats(encodings, labels, epoch)

    return losses.avg


def validate(val_loader, model, train_labels=None, prefix='Val',mode= None):
    batch_time = AverageMeter('Time', ':6.3f')
    losses_mse = AverageMeter('Loss (MSE)', ':.3f')
    losses_l1 = AverageMeter('Loss (L1)', ':.3f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses_mse, losses_l1],
        prefix=f'{prefix}: '
    )

    criterion_mse = nn.MSELoss()
    criterion_l1 = nn.L1Loss()
    criterion_gmean = nn.L1Loss(reduction='none')

    model.eval()
    losses_all = []
    preds, labels = [], []
    with torch.no_grad():
        end = time.time()
        for idx, (inputs, targets, _) in enumerate(val_loader):
            inputs, targets = inputs.cuda(non_blocking=True), targets.cuda(non_blocking=True)
            outputs = model(inputs)

            preds.extend(outputs.data.cpu().numpy())
            labels.extend(targets.data.cpu().numpy())

            loss_mse = criterion_mse(outputs, targets)
            loss_l1 = criterion_l1(outputs, targets)
            loss_all = criterion_gmean(outputs, targets)
            losses_all.extend(loss_all.cpu().numpy())

            losses_mse.update(loss_mse.item(), inputs.size(0))
            losses_l1.update(loss_l1.item(), inputs.size(0))

            batch_time.update(time.time() - end)
            end = time.time()
            if idx % args.print_freq == 0:
                progress.display(idx)

        shot_dict = shot_metrics(np.hstack(preds), np.hstack(labels), train_labels,mode = mode)
        loss_gmean = gmean(np.hstack(losses_all), axis=None).astype(float)
        print(f" * Overall: MSE {losses_mse.avg:.3f}\tL1 {losses_l1.avg:.3f}\tG-Mean {loss_gmean:.3f}")
        print(f" * Many: MSE {shot_dict['many']['mse']:.3f}\t"
              f"L1 {shot_dict['many']['l1']:.3f}\tG-Mean {shot_dict['many']['gmean']:.3f}")
        print(f" * Median: MSE {shot_dict['median']['mse']:.3f}\t"
              f"L1 {shot_dict['median']['l1']:.3f}\tG-Mean {shot_dict['median']['gmean']:.3f}")
        print(f" * Low: MSE {shot_dict['low']['mse']:.3f}\t"
              f"L1 {shot_dict['low']['l1']:.3f}\tG-Mean {shot_dict['low']['gmean']:.3f}")

    return losses_mse.avg, losses_l1.avg, loss_gmean


def shot_metrics(preds, labels, train_labels, many_shot_thr=100, low_shot_thr=20,mode = None):
    train_labels = np.array(train_labels).astype(int)

    if isinstance(preds, torch.Tensor):
        preds = preds.detach().cpu().numpy()
        labels = labels.detach().cpu().numpy()
    elif isinstance(preds, np.ndarray):
        pass
    else:
        raise TypeError(f'Type ({type(preds)}) of predictions not supported')

    train_class_count, test_class_count = [], []
    mse_per_class, l1_per_class, l1_all_per_class = [], [], []
    for l in np.unique(labels):
        train_class_count.append(len(train_labels[train_labels == l]))
        test_class_count.append(len(labels[labels == l]))
        mse_per_class.append(np.sum((preds[labels == l] - labels[labels == l]) ** 2))
        l1_per_class.append(np.sum(np.abs(preds[labels == l] - labels[labels == l])))
        l1_all_per_class.append(np.abs(preds[labels == l] - labels[labels == l]))

    many_shot_mse, median_shot_mse, low_shot_mse = [], [], []
    many_shot_l1, median_shot_l1, low_shot_l1 = [], [], []
    many_shot_gmean, median_shot_gmean, low_shot_gmean = [], [], []
    many_shot_cnt, median_shot_cnt, low_shot_cnt = [], [], []

    for i in range(len(train_class_count)):
        if train_class_count[i] > many_shot_thr:
            many_shot_mse.append(mse_per_class[i])
            many_shot_l1.append(l1_per_class[i])
            many_shot_gmean += list(l1_all_per_class[i])
            many_shot_cnt.append(test_class_count[i])
        elif train_class_count[i] < low_shot_thr:
            low_shot_mse.append(mse_per_class[i])
            low_shot_l1.append(l1_per_class[i])
            low_shot_gmean += list(l1_all_per_class[i])
            low_shot_cnt.append(test_class_count[i])
        else:
            median_shot_mse.append(mse_per_class[i])
            median_shot_l1.append(l1_per_class[i])
            median_shot_gmean += list(l1_all_per_class[i])
            median_shot_cnt.append(test_class_count[i])

    shot_dict = defaultdict(dict)
    shot_dict['many']['mse'] = np.sum(many_shot_mse) / np.sum(many_shot_cnt)
    shot_dict['many']['l1'] = np.sum(many_shot_l1) / np.sum(many_shot_cnt)
    shot_dict['many']['gmean'] = 0 if mode == 'high' else gmean(np.hstack(many_shot_gmean), axis=None).astype(float)
    shot_dict['median']['mse'] = np.sum(median_shot_mse) / np.sum(median_shot_cnt)
    shot_dict['median']['l1'] = np.sum(median_shot_l1) / np.sum(median_shot_cnt)
    shot_dict['median']['gmean'] = gmean(np.hstack(median_shot_gmean), axis=None).astype(float)
    shot_dict['low']['mse'] = np.sum(low_shot_mse) / np.sum(low_shot_cnt)
    shot_dict['low']['l1'] = np.sum(low_shot_l1) / np.sum(low_shot_cnt)

    shot_dict['low']['gmean'] = 0 if mode == 'low' else gmean(np.hstack(low_shot_gmean), axis=None).astype(float)

    return shot_dict





In [4]:
def main():
    if args.gpu is not None:
        print(f"Use GPU: {args.gpu} for training")

    # Data
    print('=====> Preparing data...')
    print("UTKFaces")
    df = reload_data()
    
    split_size1=0.8
    split_size2=0.5
    shuffle_dataset = True
    random_seed = 42

    df_train = df.sample(frac=split_size1, random_state=random_seed)
    test_val_df = df.drop(df_train.index)
    df_val = test_val_df.sample(frac=split_size2, random_state=random_seed)
    df_test = test_val_df.drop(df_val.index)
    df_test_18, df_test_80 = df_test[df_test['age'] <18],df_test[df_test['age'] >=80]
    train_labels = df_train['age']

    train_dataset = AgeDB(data_dir=PATH, df=df_train, img_size=args.img_size, split='train',
                          reweight=args.reweight, lds=args.lds, lds_kernel=args.lds_kernel, lds_ks=args.lds_ks, lds_sigma=args.lds_sigma)
    val_dataset = AgeDB(data_dir=PATH, df=df_val, img_size=args.img_size, split='val')
    test_dataset = AgeDB(data_dir=PATH, df=df_test, img_size=args.img_size, split='test')
    test_dataset_18 = AgeDB(data_dir=PATH, df=df_test_18, img_size=args.img_size, split='test')
    test_dataset_80 = AgeDB(data_dir=PATH, df=df_test_80, img_size=args.img_size, split='test')
    
    print('worker = ' +str(args.workers) )
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.workers, pin_memory=True, drop_last=False)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
                            num_workers=args.workers, pin_memory=True, drop_last=False)
    test_loader = DataLoader(test_dataset_18, batch_size=args.batch_size, shuffle=False,
                             num_workers=args.workers, pin_memory=True, drop_last=False)
    
    #test_loader_18 = DataLoader(test_dataset_18, batch_size=args.batch_size, shuffle=False,
    #                         num_workers=args.workers, pin_memory=True, drop_last=False)
    #test_loader_80 = DataLoader(test_dataset_80, batch_size=args.batch_size, shuffle=False,
    #                         num_workers=args.workers, pin_memory=True, drop_last=False)
    
    print(f"Training data size: {len(train_dataset)}")
    print(f"Validation data size: {len(val_dataset)}")
    #print(f"Test data size: {len(test_dataset)}")
    print(f"Test data size age < 18: {len(test_dataset_18)}")
    #print(f"Test data size age >= 80: {len(test_dataset_80)}")

    # Model
    print('=====> Building model...')
    
    model = resnet50(fds=args.fds, bucket_num=args.bucket_num, bucket_start=args.bucket_start,
                     start_update=args.start_update, start_smooth=args.start_smooth,
                     kernel=args.fds_kernel, ks=args.fds_ks, sigma=args.fds_sigma, momentum=args.fds_mmt)
    """
    model=resnet50(fds=False, bucket_num=100, bucket_start=3,
                     start_update=0, start_smooth=1,
                     kernel='gaussian', ks=9, sigma=1, momentum=0.9)
    """
    model = torch.nn.DataParallel(model).cuda()

    # evaluate only
    if args.evaluate:
        assert args.resume, 'Specify a trained model using [args.resume]'
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        print(f"===> Checkpoint '{args.resume}' loaded (epoch [{checkpoint['epoch']}]), testing...")
        validate(test_loader, model, train_labels=train_labels, prefix='Test')
        return

    if args.retrain_fc:
        assert args.reweight != 'none' and args.pretrained
        print('===> Retrain last regression layer only!')
        for name, param in model.named_parameters():
            if 'fc' not in name and 'linear' not in name:
                param.requires_grad = False

    # Loss and optimizer
    if not args.retrain_fc:
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) if args.optimizer == 'adam' else \
            torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    else:
        # optimize only the last linear layer
        parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
        names = list(filter(lambda k: k is not None, [k if v.requires_grad else None for k, v in model.module.named_parameters()]))
        assert 1 <= len(parameters) <= 2  # fc.weight, fc.bias
        print(f'===> Only optimize parameters: {names}')
        optimizer = torch.optim.Adam(parameters, lr=args.lr) if args.optimizer == 'adam' else \
            torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    if args.pretrained:
        checkpoint = torch.load(args.pretrained, map_location="cpu")
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in checkpoint['state_dict'].items():
            if 'linear' not in k and 'fc' not in k:
                new_state_dict[k] = v
        model.load_state_dict(new_state_dict, strict=False)
        print(f'===> Pretrained weights found in total: [{len(list(new_state_dict.keys()))}]')
        print(f'===> Pre-trained model loaded: {args.pretrained}')

    if args.resume:
        if os.path.isfile(args.resume):
            print(f"===> Loading checkpoint '{args.resume}'")
            checkpoint = torch.load(args.resume) if args.gpu is None else \
                torch.load(args.resume, map_location=torch.device(f'cuda:{str(args.gpu)}'))
            args.start_epoch = checkpoint['epoch']
            args.best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(f"===> Loaded checkpoint '{args.resume}' (Epoch [{checkpoint['epoch']}])")
        else:
            print(f"===> No checkpoint found at '{args.resume}'")

    cudnn.benchmark = True

    for epoch in range(args.start_epoch, args.epoch):
        adjust_learning_rate(optimizer, epoch, args)
        train_loss = train(train_loader, model, optimizer, epoch)
        val_loss_mse, val_loss_l1, val_loss_gmean = validate(val_loader, model, train_labels=train_labels)

        loss_metric = val_loss_mse if args.loss == 'mse' else val_loss_l1
        is_best = loss_metric < args.best_loss
        args.best_loss = min(loss_metric, args.best_loss)
        print(f"Best {'L1' if 'l1' in args.loss else 'MSE'} Loss: {args.best_loss:.3f}")
        save_checkpoint(args, {
            'epoch': epoch + 1,
            'model': args.model,
            'best_loss': args.best_loss,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, is_best)
        print(f"Epoch #{epoch}: Train loss [{train_loss:.4f}]; "
              f"Val loss: MSE [{val_loss_mse:.4f}], L1 [{val_loss_l1:.4f}], G-Mean [{val_loss_gmean:.4f}]")

        tb_logger.log_value('train_loss', train_loss, epoch)
        tb_logger.log_value('val_loss_mse', val_loss_mse, epoch)
        tb_logger.log_value('val_loss_l1', val_loss_l1, epoch)
        tb_logger.log_value('val_loss_gmean', val_loss_gmean, epoch)


In [5]:
#train and test
main()

2022-11-30 18:18:54,155 | =====> Preparing data...
2022-11-30 18:18:54,155 | UTKFaces
2022-11-30 18:18:54,241 | Using re-weighting: [INVERSE]
2022-11-30 18:18:54,242 | Using LDS: [GAUSSIAN] (9/1)
2022-11-30 18:18:54,261 | worker = 16
2022-11-30 18:18:54,261 | Training data size: 18966
2022-11-30 18:18:54,262 | Validation data size: 2371
2022-11-30 18:18:54,262 | Test data size age < 18: 432
2022-11-30 18:18:54,262 | =====> Building model...
2022-11-30 18:19:29,052 | Epoch: [0][  0/297]	Time  34.52 ( 34.52)	Data 29.5915 (29.5915)	Loss (L1) 37.809 (37.809)
2022-11-30 18:19:31,169 | Epoch: [0][ 10/297]	Time   0.22 (  3.33)	Data 0.0000 (2.6903)	Loss (L1) 18.116 (41.372)
2022-11-30 18:19:33,394 | Epoch: [0][ 20/297]	Time   0.22 (  1.85)	Data 0.0000 (1.4092)	Loss (L1) 13.541 (34.737)
2022-11-30 18:19:35,628 | Epoch: [0][ 30/297]	Time   0.22 (  1.33)	Data 0.0010 (0.9547)	Loss (L1) 19.258 (33.083)
2022-11-30 18:19:37,851 | Epoch: [0][ 40/297]	Time   0.22 (  1.06)	Data 0.0000 (0.7219)	Loss (L1)

2022-11-30 18:23:18,667 | Val: [ 0/38]	Time 28.496 (28.496)	Loss (MSE) 2106.736 (2106.736)	Loss (L1) 38.371 (38.371)
2022-11-30 18:23:19,370 | Val: [10/38]	Time  0.063 ( 2.654)	Loss (MSE) 1606.832 (1674.452)	Loss (L1) 33.155 (34.066)
2022-11-30 18:23:20,022 | Val: [20/38]	Time  0.064 ( 1.421)	Loss (MSE) 1366.217 (1646.996)	Loss (L1) 30.064 (33.503)
2022-11-30 18:23:20,681 | Val: [30/38]	Time  0.067 ( 0.984)	Loss (MSE) 1273.193 (1630.146)	Loss (L1) 30.147 (33.326)
2022-11-30 18:23:21,743 |  * Overall: MSE 1635.779	L1 33.432	G-Mean 24.105
2022-11-30 18:23:21,744 |  * Many: MSE 1689.201	L1 34.279	G-Mean 25.332
2022-11-30 18:23:21,745 |  * Median: MSE 1147.687	L1 25.406	G-Mean 14.825
2022-11-30 18:23:21,745 |  * Low: MSE 704.745	L1 22.152	G-Mean 14.947
2022-11-30 18:23:21,746 | Best L1 Loss: 16.921
2022-11-30 18:23:21,923 | Epoch #1: Train loss [20.4424]; Val loss: MSE [1635.7790], L1 [33.4321], G-Mean [24.1047]
2022-11-30 18:23:51,109 | Epoch: [2][  0/297]	Time  29.19 ( 29.19)	Data 28.940

2022-11-30 18:26:56,968 | Epoch: [3][260/297]	Time   0.23 (  0.34)	Data 0.0000 (0.1116)	Loss (L1) 14.839 (16.777)
2022-11-30 18:26:59,222 | Epoch: [3][270/297]	Time   0.24 (  0.33)	Data 0.0000 (0.1075)	Loss (L1) 12.681 (16.775)
2022-11-30 18:27:01,447 | Epoch: [3][280/297]	Time   0.22 (  0.33)	Data 0.0000 (0.1037)	Loss (L1) 12.403 (16.732)
2022-11-30 18:27:03,685 | Epoch: [3][290/297]	Time   0.23 (  0.32)	Data 0.0000 (0.1001)	Loss (L1) 11.499 (16.614)
2022-11-30 18:27:33,017 | Val: [ 0/38]	Time 27.368 (27.368)	Loss (MSE) 335.162 (335.162)	Loss (L1) 15.277 (15.277)
2022-11-30 18:27:33,724 | Val: [10/38]	Time  0.064 ( 2.552)	Loss (MSE) 546.741 (477.940)	Loss (L1) 18.171 (17.229)
2022-11-30 18:27:34,370 | Val: [20/38]	Time  0.065 ( 1.368)	Loss (MSE) 414.046 (466.063)	Loss (L1) 15.986 (17.301)
2022-11-30 18:27:35,018 | Val: [30/38]	Time  0.064 ( 0.947)	Loss (MSE) 604.209 (470.482)	Loss (L1) 19.301 (17.402)
2022-11-30 18:27:36,072 |  * Overall: MSE 470.923	L1 17.464	G-Mean 11.933
2022-11-30

2022-11-30 18:31:04,936 | Epoch: [5][220/297]	Time   0.22 (  0.36)	Data 0.0000 (0.1378)	Loss (L1) 15.837 (15.153)
2022-11-30 18:31:07,168 | Epoch: [5][230/297]	Time   0.22 (  0.36)	Data 0.0000 (0.1319)	Loss (L1) 8.486 (15.022)
2022-11-30 18:31:09,404 | Epoch: [5][240/297]	Time   0.22 (  0.35)	Data 0.0000 (0.1264)	Loss (L1) 12.740 (15.053)
2022-11-30 18:31:11,642 | Epoch: [5][250/297]	Time   0.23 (  0.35)	Data 0.0000 (0.1214)	Loss (L1) 15.966 (14.966)
2022-11-30 18:31:13,877 | Epoch: [5][260/297]	Time   0.23 (  0.34)	Data 0.0000 (0.1168)	Loss (L1) 10.926 (14.895)
2022-11-30 18:31:16,139 | Epoch: [5][270/297]	Time   0.22 (  0.34)	Data 0.0000 (0.1124)	Loss (L1) 10.140 (14.804)
2022-11-30 18:31:18,392 | Epoch: [5][280/297]	Time   0.22 (  0.33)	Data 0.0000 (0.1085)	Loss (L1) 10.288 (14.940)
2022-11-30 18:31:20,631 | Epoch: [5][290/297]	Time   0.22 (  0.33)	Data 0.0000 (0.1047)	Loss (L1) 12.551 (14.966)
2022-11-30 18:31:50,494 | Val: [ 0/38]	Time 27.774 (27.774)	Loss (MSE) 248.681 (248.681)	

2022-11-30 18:35:13,455 | Epoch: [7][180/297]	Time   0.22 (  0.39)	Data 0.0000 (0.1595)	Loss (L1) 15.869 (14.421)
2022-11-30 18:35:15,768 | Epoch: [7][190/297]	Time   0.23 (  0.38)	Data 0.0000 (0.1512)	Loss (L1) 11.238 (14.378)
2022-11-30 18:35:18,025 | Epoch: [7][200/297]	Time   0.22 (  0.37)	Data 0.0000 (0.1437)	Loss (L1) 7.553 (14.292)
2022-11-30 18:35:20,271 | Epoch: [7][210/297]	Time   0.22 (  0.37)	Data 0.0000 (0.1369)	Loss (L1) 10.605 (14.258)
2022-11-30 18:35:22,513 | Epoch: [7][220/297]	Time   0.23 (  0.36)	Data 0.0000 (0.1307)	Loss (L1) 9.924 (14.178)
2022-11-30 18:35:24,752 | Epoch: [7][230/297]	Time   0.22 (  0.35)	Data 0.0000 (0.1250)	Loss (L1) 19.413 (14.151)
2022-11-30 18:35:26,996 | Epoch: [7][240/297]	Time   0.23 (  0.35)	Data 0.0000 (0.1198)	Loss (L1) 9.357 (14.129)
2022-11-30 18:35:29,237 | Epoch: [7][250/297]	Time   0.22 (  0.34)	Data 0.0000 (0.1151)	Loss (L1) 8.779 (14.064)
2022-11-30 18:35:31,482 | Epoch: [7][260/297]	Time   0.22 (  0.34)	Data 0.0000 (0.1107)	Loss

2022-11-30 18:39:21,858 | Epoch: [9][140/297]	Time   0.22 (  0.43)	Data 0.0000 (0.2011)	Loss (L1) 12.520 (14.511)
2022-11-30 18:39:24,093 | Epoch: [9][150/297]	Time   0.22 (  0.41)	Data 0.0000 (0.1878)	Loss (L1) 15.860 (14.330)
2022-11-30 18:39:26,324 | Epoch: [9][160/297]	Time   0.22 (  0.40)	Data 0.0010 (0.1761)	Loss (L1) 15.817 (14.236)
2022-11-30 18:39:28,557 | Epoch: [9][170/297]	Time   0.22 (  0.39)	Data 0.0010 (0.1658)	Loss (L1) 11.484 (14.055)
2022-11-30 18:39:30,796 | Epoch: [9][180/297]	Time   0.22 (  0.38)	Data 0.0000 (0.1567)	Loss (L1) 10.079 (13.892)
2022-11-30 18:39:33,029 | Epoch: [9][190/297]	Time   0.22 (  0.37)	Data 0.0000 (0.1485)	Loss (L1) 16.351 (13.782)
2022-11-30 18:39:35,263 | Epoch: [9][200/297]	Time   0.22 (  0.37)	Data 0.0000 (0.1411)	Loss (L1) 10.351 (13.684)
2022-11-30 18:39:37,500 | Epoch: [9][210/297]	Time   0.22 (  0.36)	Data 0.0000 (0.1344)	Loss (L1) 29.805 (13.610)
2022-11-30 18:39:39,737 | Epoch: [9][220/297]	Time   0.22 (  0.35)	Data 0.0000 (0.1283)	

2022-11-30 18:43:30,669 | Epoch: [11][100/297]	Time   0.22 (  0.52)	Data 0.0000 (0.2912)	Loss (L1) 8.597 (13.788)
2022-11-30 18:43:32,905 | Epoch: [11][110/297]	Time   0.22 (  0.49)	Data 0.0000 (0.2649)	Loss (L1) 11.564 (13.946)
2022-11-30 18:43:35,142 | Epoch: [11][120/297]	Time   0.22 (  0.47)	Data 0.0010 (0.2431)	Loss (L1) 13.708 (14.122)
2022-11-30 18:43:37,376 | Epoch: [11][130/297]	Time   0.22 (  0.45)	Data 0.0000 (0.2245)	Loss (L1) 8.476 (14.075)
2022-11-30 18:43:39,611 | Epoch: [11][140/297]	Time   0.22 (  0.43)	Data 0.0000 (0.2086)	Loss (L1) 8.428 (13.907)
2022-11-30 18:43:41,867 | Epoch: [11][150/297]	Time   0.24 (  0.42)	Data 0.0000 (0.1948)	Loss (L1) 13.631 (13.718)
2022-11-30 18:43:44,198 | Epoch: [11][160/297]	Time   0.23 (  0.41)	Data 0.0000 (0.1827)	Loss (L1) 24.676 (13.641)
2022-11-30 18:43:46,528 | Epoch: [11][170/297]	Time   0.23 (  0.40)	Data 0.0000 (0.1720)	Loss (L1) 9.260 (13.502)
2022-11-30 18:43:48,857 | Epoch: [11][180/297]	Time   0.23 (  0.39)	Data 0.0000 (0.1

2022-11-30 18:47:39,380 | Epoch: [13][ 60/297]	Time   0.22 (  0.72)	Data 0.0000 (0.4931)	Loss (L1) 12.245 (12.301)
2022-11-30 18:47:41,773 | Epoch: [13][ 70/297]	Time   0.25 (  0.66)	Data 0.0000 (0.4236)	Loss (L1) 11.881 (13.172)
2022-11-30 18:47:44,107 | Epoch: [13][ 80/297]	Time   0.23 (  0.60)	Data 0.0010 (0.3714)	Loss (L1) 11.258 (13.011)
2022-11-30 18:47:46,432 | Epoch: [13][ 90/297]	Time   0.23 (  0.56)	Data 0.0000 (0.3306)	Loss (L1) 30.925 (13.207)
2022-11-30 18:47:48,790 | Epoch: [13][100/297]	Time   0.26 (  0.53)	Data 0.0000 (0.2979)	Loss (L1) 9.554 (13.225)
2022-11-30 18:47:51,160 | Epoch: [13][110/297]	Time   0.22 (  0.50)	Data 0.0000 (0.2710)	Loss (L1) 11.774 (13.348)
2022-11-30 18:47:53,565 | Epoch: [13][120/297]	Time   0.22 (  0.48)	Data 0.0000 (0.2487)	Loss (L1) 9.469 (13.193)
2022-11-30 18:47:55,803 | Epoch: [13][130/297]	Time   0.22 (  0.46)	Data 0.0000 (0.2297)	Loss (L1) 8.927 (13.269)
2022-11-30 18:47:58,045 | Epoch: [13][140/297]	Time   0.23 (  0.45)	Data 0.0000 (0.

2022-11-30 18:51:48,433 | Epoch: [15][ 10/297]	Time   0.22 (  2.82)	Data 0.0000 (2.5856)	Loss (L1) 9.209 (11.505)
2022-11-30 18:51:50,663 | Epoch: [15][ 20/297]	Time   0.22 (  1.58)	Data 0.0000 (1.3545)	Loss (L1) 8.221 (11.401)
2022-11-30 18:51:52,893 | Epoch: [15][ 30/297]	Time   0.22 (  1.14)	Data 0.0000 (0.9176)	Loss (L1) 10.404 (11.335)
2022-11-30 18:51:55,126 | Epoch: [15][ 40/297]	Time   0.21 (  0.92)	Data 0.0000 (0.6938)	Loss (L1) 7.092 (10.933)
2022-11-30 18:51:57,361 | Epoch: [15][ 50/297]	Time   0.23 (  0.78)	Data 0.0000 (0.5578)	Loss (L1) 7.868 (10.717)
2022-11-30 18:51:59,599 | Epoch: [15][ 60/297]	Time   0.22 (  0.69)	Data 0.0000 (0.4664)	Loss (L1) 9.433 (10.680)
2022-11-30 18:52:01,831 | Epoch: [15][ 70/297]	Time   0.22 (  0.63)	Data 0.0000 (0.4007)	Loss (L1) 15.869 (11.005)
2022-11-30 18:52:04,105 | Epoch: [15][ 80/297]	Time   0.22 (  0.58)	Data 0.0000 (0.3513)	Loss (L1) 7.737 (10.838)
2022-11-30 18:52:06,332 | Epoch: [15][ 90/297]	Time   0.22 (  0.54)	Data 0.0000 (0.312

2022-11-30 18:55:36,833 |  * Many: MSE 251.005	L1 12.173	G-Mean 7.750
2022-11-30 18:55:36,834 |  * Median: MSE 670.264	L1 19.878	G-Mean 12.492
2022-11-30 18:55:36,834 |  * Low: MSE 873.420	L1 25.801	G-Mean 19.450
2022-11-30 18:55:36,836 | Best L1 Loss: 10.180
2022-11-30 18:55:37,015 | Epoch #16: Train loss [11.8425]; Val loss: MSE [291.3638], L1 [12.9300], G-Mean [8.1264]
2022-11-30 18:56:07,951 | Epoch: [17][  0/297]	Time  30.93 ( 30.93)	Data 30.7209 (30.7209)	Loss (L1) 7.372 (7.372)
2022-11-30 18:56:10,386 | Epoch: [17][ 10/297]	Time   0.22 (  3.03)	Data 0.0010 (2.7931)	Loss (L1) 9.192 (9.485)
2022-11-30 18:56:12,740 | Epoch: [17][ 20/297]	Time   0.25 (  1.70)	Data 0.0010 (1.4631)	Loss (L1) 11.832 (9.542)
2022-11-30 18:56:15,130 | Epoch: [17][ 30/297]	Time   0.23 (  1.23)	Data 0.0000 (0.9912)	Loss (L1) 11.785 (9.956)
2022-11-30 18:56:17,386 | Epoch: [17][ 40/297]	Time   0.23 (  0.98)	Data 0.0010 (0.7495)	Loss (L1) 7.365 (10.465)
2022-11-30 18:56:19,646 | Epoch: [17][ 50/297]	Time   0

2022-11-30 18:59:54,472 | Val: [10/38]	Time  0.065 ( 2.596)	Loss (MSE) 220.741 (210.431)	Loss (L1) 11.684 (10.798)
2022-11-30 18:59:55,126 | Val: [20/38]	Time  0.066 ( 1.391)	Loss (MSE) 234.000 (209.902)	Loss (L1) 11.094 (10.916)
2022-11-30 18:59:55,781 | Val: [30/38]	Time  0.065 ( 0.963)	Loss (MSE) 162.800 (204.407)	Loss (L1) 9.074 (10.738)
2022-11-30 18:59:56,968 |  * Overall: MSE 199.538	L1 10.638	G-Mean 6.676
2022-11-30 18:59:56,968 |  * Many: MSE 197.827	L1 10.581	G-Mean 6.604
2022-11-30 18:59:56,969 |  * Median: MSE 217.744	L1 11.166	G-Mean 7.345
2022-11-30 18:59:56,970 |  * Low: MSE 198.687	L1 11.465	G-Mean 8.500
2022-11-30 18:59:56,972 | Best L1 Loss: 10.180
2022-11-30 18:59:57,197 | Epoch #18: Train loss [11.5572]; Val loss: MSE [199.5384], L1 [10.6377], G-Mean [6.6764]
2022-11-30 19:00:28,269 | Epoch: [19][  0/297]	Time  31.07 ( 31.07)	Data 30.7664 (30.7664)	Loss (L1) 10.162 (10.162)
2022-11-30 19:00:30,513 | Epoch: [19][ 10/297]	Time   0.22 (  3.03)	Data 0.0010 (2.7974)	Loss

2022-11-30 19:03:35,834 | Epoch: [20][270/297]	Time   0.22 (  0.33)	Data 0.0000 (0.1034)	Loss (L1) 12.302 (10.943)
2022-11-30 19:03:38,070 | Epoch: [20][280/297]	Time   0.22 (  0.32)	Data 0.0000 (0.0997)	Loss (L1) 9.841 (11.006)
2022-11-30 19:03:40,403 | Epoch: [20][290/297]	Time   0.30 (  0.32)	Data 0.0000 (0.0963)	Loss (L1) 9.078 (11.105)
2022-11-30 19:04:11,716 | Val: [ 0/38]	Time 29.201 (29.201)	Loss (MSE) 296.440 (296.440)	Loss (L1) 14.448 (14.448)
2022-11-30 19:04:12,415 | Val: [10/38]	Time  0.063 ( 2.718)	Loss (MSE) 446.120 (352.405)	Loss (L1) 15.906 (14.571)
2022-11-30 19:04:13,061 | Val: [20/38]	Time  0.065 ( 1.455)	Loss (MSE) 525.347 (345.750)	Loss (L1) 17.390 (14.446)
2022-11-30 19:04:13,704 | Val: [30/38]	Time  0.063 ( 1.006)	Loss (MSE) 293.975 (343.783)	Loss (L1) 12.926 (14.333)
2022-11-30 19:04:14,766 |  * Overall: MSE 339.891	L1 14.200	G-Mean 9.179
2022-11-30 19:04:14,766 |  * Many: MSE 337.766	L1 14.176	G-Mean 9.208
2022-11-30 19:04:14,767 |  * Median: MSE 356.602	L1 14

2022-11-30 19:07:42,717 | Epoch: [22][230/297]	Time   0.22 (  0.35)	Data 0.0000 (0.1216)	Loss (L1) 12.281 (10.955)
2022-11-30 19:07:44,953 | Epoch: [22][240/297]	Time   0.22 (  0.34)	Data 0.0000 (0.1166)	Loss (L1) 16.880 (11.084)
2022-11-30 19:07:47,185 | Epoch: [22][250/297]	Time   0.22 (  0.34)	Data 0.0000 (0.1120)	Loss (L1) 12.339 (11.145)
2022-11-30 19:07:49,422 | Epoch: [22][260/297]	Time   0.22 (  0.33)	Data 0.0000 (0.1077)	Loss (L1) 8.472 (11.101)
2022-11-30 19:07:51,662 | Epoch: [22][270/297]	Time   0.22 (  0.33)	Data 0.0000 (0.1037)	Loss (L1) 12.144 (11.232)
2022-11-30 19:07:54,005 | Epoch: [22][280/297]	Time   0.23 (  0.32)	Data 0.0000 (0.1000)	Loss (L1) 11.793 (11.249)
2022-11-30 19:07:56,376 | Epoch: [22][290/297]	Time   0.22 (  0.32)	Data 0.0000 (0.0966)	Loss (L1) 11.440 (11.269)
2022-11-30 19:08:25,861 | Val: [ 0/38]	Time 27.418 (27.418)	Loss (MSE) 160.274 (160.274)	Loss (L1) 9.429 (9.429)
2022-11-30 19:08:26,575 | Val: [10/38]	Time  0.065 ( 2.558)	Loss (MSE) 200.932 (241

2022-11-30 19:11:57,471 | Epoch: [24][190/297]	Time   0.30 (  0.42)	Data 0.0000 (0.1560)	Loss (L1) 9.375 (11.115)
2022-11-30 19:12:00,413 | Epoch: [24][200/297]	Time   0.30 (  0.41)	Data 0.0000 (0.1483)	Loss (L1) 9.121 (11.079)
2022-11-30 19:12:03,344 | Epoch: [24][210/297]	Time   0.29 (  0.41)	Data 0.0000 (0.1413)	Loss (L1) 7.306 (11.031)
2022-11-30 19:12:06,286 | Epoch: [24][220/297]	Time   0.29 (  0.40)	Data 0.0000 (0.1349)	Loss (L1) 7.656 (11.014)
2022-11-30 19:12:09,217 | Epoch: [24][230/297]	Time   0.30 (  0.40)	Data 0.0000 (0.1290)	Loss (L1) 10.346 (11.002)
2022-11-30 19:12:12,152 | Epoch: [24][240/297]	Time   0.29 (  0.39)	Data 0.0000 (0.1237)	Loss (L1) 8.434 (11.057)
2022-11-30 19:12:15,071 | Epoch: [24][250/297]	Time   0.29 (  0.39)	Data 0.0000 (0.1188)	Loss (L1) 7.799 (11.039)
2022-11-30 19:12:17,995 | Epoch: [24][260/297]	Time   0.29 (  0.38)	Data 0.0000 (0.1142)	Loss (L1) 7.518 (11.039)
2022-11-30 19:12:20,919 | Epoch: [24][270/297]	Time   0.29 (  0.38)	Data 0.0000 (0.1100

2022-11-30 19:16:23,580 | Epoch: [26][150/297]	Time   0.24 (  0.44)	Data 0.0000 (0.2001)	Loss (L1) 16.965 (10.893)
2022-11-30 19:16:25,863 | Epoch: [26][160/297]	Time   0.22 (  0.42)	Data 0.0000 (0.1877)	Loss (L1) 10.499 (11.003)
2022-11-30 19:16:28,101 | Epoch: [26][170/297]	Time   0.22 (  0.41)	Data 0.0000 (0.1767)	Loss (L1) 10.981 (11.044)
2022-11-30 19:16:30,334 | Epoch: [26][180/297]	Time   0.22 (  0.40)	Data 0.0010 (0.1670)	Loss (L1) 13.140 (11.057)
2022-11-30 19:16:32,567 | Epoch: [26][190/297]	Time   0.22 (  0.39)	Data 0.0010 (0.1582)	Loss (L1) 11.876 (11.004)
2022-11-30 19:16:34,804 | Epoch: [26][200/297]	Time   0.22 (  0.38)	Data 0.0000 (0.1504)	Loss (L1) 9.249 (11.010)
2022-11-30 19:16:37,040 | Epoch: [26][210/297]	Time   0.23 (  0.38)	Data 0.0000 (0.1433)	Loss (L1) 8.535 (11.049)
2022-11-30 19:16:39,276 | Epoch: [26][220/297]	Time   0.22 (  0.37)	Data 0.0000 (0.1368)	Loss (L1) 10.616 (11.089)
2022-11-30 19:16:41,512 | Epoch: [26][230/297]	Time   0.22 (  0.36)	Data 0.0000 (0

2022-11-30 19:20:30,275 | Epoch: [28][110/297]	Time   0.23 (  0.49)	Data 0.0000 (0.2616)	Loss (L1) 5.043 (10.422)
2022-11-30 19:20:32,676 | Epoch: [28][120/297]	Time   0.23 (  0.47)	Data 0.0000 (0.2400)	Loss (L1) 10.762 (10.370)
2022-11-30 19:20:35,014 | Epoch: [28][130/297]	Time   0.23 (  0.45)	Data 0.0000 (0.2217)	Loss (L1) 8.699 (10.296)
2022-11-30 19:20:37,261 | Epoch: [28][140/297]	Time   0.22 (  0.43)	Data 0.0000 (0.2060)	Loss (L1) 9.454 (10.439)
2022-11-30 19:20:39,532 | Epoch: [28][150/297]	Time   0.22 (  0.42)	Data 0.0000 (0.1924)	Loss (L1) 8.275 (10.400)
2022-11-30 19:20:41,778 | Epoch: [28][160/297]	Time   0.23 (  0.41)	Data 0.0000 (0.1804)	Loss (L1) 8.285 (10.405)
2022-11-30 19:20:44,025 | Epoch: [28][170/297]	Time   0.22 (  0.40)	Data 0.0000 (0.1699)	Loss (L1) 6.396 (10.361)
2022-11-30 19:20:46,265 | Epoch: [28][180/297]	Time   0.22 (  0.39)	Data 0.0000 (0.1605)	Loss (L1) 6.270 (10.245)
2022-11-30 19:20:48,508 | Epoch: [28][190/297]	Time   0.23 (  0.38)	Data 0.0000 (0.1521

ValueError: need at least one array to concatenate

In [None]:
print(torch.cuda.is_available())

In [6]:
#test only
package = None
#package = 'agedb_resnet50_fds_gau_9_1_0_1_0.9_adam_l1_0.001_64'
#package = 'agedb_resnet50_adam_l1_0.001_64'
#package = 'agedb_resnet50_lds_gau_9_1_adam_l1_0.001_64'
if args.gpu is not None:
    print(f"Use GPU: {args.gpu} for training")

# Data
print('=====> Preparing data...')
print("UTKFaces")
df = reload_data()

split_size1=0.6
split_size2=0.4
shuffle_dataset = True
random_seed = 42

df_train = df.sample(frac=split_size1, random_state=random_seed)
test_val_df = df.drop(df_train.index)
df_val = test_val_df.sample(frac=split_size2, random_state=random_seed)
df_test = test_val_df.drop(df_val.index)
df_test_18, df_test_80 = df[df['age'] <18],df[df['age'] >=80]

train_labels = df_train['age']


train_dataset = AgeDB(data_dir=PATH, df=df_train, img_size=args.img_size, split='train',
                      reweight=args.reweight, lds=args.lds, lds_kernel=args.lds_kernel, lds_ks=args.lds_ks, lds_sigma=args.lds_sigma)
val_dataset = AgeDB(data_dir=PATH, df=df_val, img_size=args.img_size, split='val')
test_dataset = AgeDB(data_dir=PATH, df=df_test, img_size=args.img_size, split='test')
test_dataset_18 = AgeDB(data_dir=PATH, df=df_test_18, img_size=args.img_size, split='test')
test_dataset_80 = AgeDB(data_dir=PATH, df=df_test_80, img_size=args.img_size, split='test')

print('worker = ' +str(args.workers) )
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                          num_workers=args.workers, pin_memory=True, drop_last=False)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
                        num_workers=args.workers, pin_memory=True, drop_last=False)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
                         num_workers=args.workers, pin_memory=True, drop_last=False)

test_loader_18 = DataLoader(test_dataset_18, batch_size=args.batch_size, shuffle=False,
                         num_workers=args.workers, pin_memory=True, drop_last=False)
test_loader_80 = DataLoader(test_dataset_80, batch_size=args.batch_size, shuffle=False,
                         num_workers=args.workers, pin_memory=True, drop_last=False)

print(f"Training data size: {len(train_dataset)}")
print(f"Validation data size: {len(val_dataset)}")
print(f"Test data size: {len(test_dataset)}")
print(f"Test data size age < 18: {len(test_dataset_18)}")
print(f"Test data size age >= 80: {len(test_dataset_80)}")

# Model
print('=====> Building model...')

model = resnet50(fds=args.fds, bucket_num=args.bucket_num, bucket_start=args.bucket_start,
                 start_update=args.start_update, start_smooth=args.start_smooth,
                 kernel=args.fds_kernel, ks=args.fds_ks, sigma=args.fds_sigma, momentum=args.fds_mmt)
"""
model=resnet50(fds=False, bucket_num=100, bucket_start=3,
                 start_update=0, start_smooth=1,
                 kernel='gaussian', ks=9, sigma=1, momentum=0.9)
"""
model = torch.nn.DataParallel(model).cuda()

# evaluate only
if args.evaluate:
    assert args.resume, 'Specify a trained model using [args.resume]'
    checkpoint = torch.load(args.resume)
    model.load_state_dict(checkpoint['state_dict'], strict=False)
    print(f"===> Checkpoint '{args.resume}' loaded (epoch [{checkpoint['epoch']}]), testing...")
    validate(test_loader, model, train_labels=train_labels, prefix='Test')


if args.retrain_fc:
    assert args.reweight != 'none' and args.pretrained
    print('===> Retrain last regression layer only!')
    for name, param in model.named_parameters():
        if 'fc' not in name and 'linear' not in name:
            param.requires_grad = False

# Loss and optimizer
if not args.retrain_fc:
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) if args.optimizer == 'adam' else \
        torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
else:
    # optimize only the last linear layer
    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    names = list(filter(lambda k: k is not None, [k if v.requires_grad else None for k, v in model.module.named_parameters()]))
    assert 1 <= len(parameters) <= 2  # fc.weight, fc.bias
    print(f'===> Only optimize parameters: {names}')
    optimizer = torch.optim.Adam(parameters, lr=args.lr) if args.optimizer == 'adam' else \
        torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

if args.pretrained:
    checkpoint = torch.load(args.pretrained, map_location="cpu")
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in checkpoint['state_dict'].items():
        if 'linear' not in k and 'fc' not in k:
            new_state_dict[k] = v
    model.load_state_dict(new_state_dict, strict=False)
    print(f'===> Pretrained weights found in total: [{len(list(new_state_dict.keys()))}]')
    print(f'===> Pre-trained model loaded: {args.pretrained}')

if args.resume:
    if os.path.isfile(args.resume):
        print(f"===> Loading checkpoint '{args.resume}'")
        checkpoint = torch.load(args.resume) if args.gpu is None else \
            torch.load(args.resume, map_location=torch.device(f'cuda:{str(args.gpu)}'))
        args.start_epoch = checkpoint['epoch']
        args.best_loss = checkpoint['best_loss']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print(f"===> Loaded checkpoint '{args.resume}' (Epoch [{checkpoint['epoch']}])")
    else:
        print(f"===> No checkpoint found at '{args.resume}'")

cudnn.benchmark = True



# test with best checkpoint
print("=" * 120)
print("Test best model on testset...")
checkpoint = torch.load(f"{args.store_root}/{package if package!=None else args.store_name}/ckpt.best.pth.tar")
model.load_state_dict(checkpoint['state_dict'])
print(f"Loaded best model, epoch {checkpoint['epoch']}, best val loss {checkpoint['best_loss']:.4f}")
print("TestSet:")
test_loss_mse, test_loss_l1, test_loss_gmean = validate(test_loader, model, train_labels=train_labels, prefix='Test')
print(f"Test loss: MSE [{test_loss_mse:.4f}], L1 [{test_loss_l1:.4f}], G-Mean [{test_loss_gmean:.4f}]\n")
print("TestSet for age < 18:")
test_loss_mse_18, test_loss_l1_18, test_loss_gmean_18 = validate(test_loader_18, model, train_labels=train_labels, prefix='Test',mode = 'low')
print(f"Test loss: MSE [{test_loss_mse_18:.4f}], L1 [{test_loss_l1_18:.4f}], G-Mean [{test_loss_gmean_18:.4f}]\n")
print("TestSet for age >= 80:")
test_loss_mse_80, test_loss_l1_80, test_loss_gmean_80 = validate(test_loader_80, model, train_labels=train_labels, prefix='Test',mode = 'high')
print(f"Test loss: MSE [{test_loss_mse_80:.4f}], L1 [{test_loss_l1_80:.4f}], G-Mean [{test_loss_gmean_80:.4f}]\nDone")

2022-11-30 19:27:00,074 | =====> Preparing data...
2022-11-30 19:27:00,074 | UTKFaces
2022-11-30 19:27:00,154 | Using re-weighting: [INVERSE]
2022-11-30 19:27:00,154 | Using LDS: [GAUSSIAN] (9/1)
2022-11-30 19:27:00,173 | worker = 16
2022-11-30 19:27:00,173 | Training data size: 14225
2022-11-30 19:27:00,174 | Validation data size: 3793
2022-11-30 19:27:00,174 | Test data size: 5690
2022-11-30 19:27:00,175 | Test data size age < 18: 4233
2022-11-30 19:27:00,175 | Test data size age >= 80: 673
2022-11-30 19:27:00,176 | =====> Building model...
2022-11-30 19:27:00,325 | Test best model on testset...
2022-11-30 19:27:00,502 | Loaded best model, epoch 27, best val loss 9.6899
2022-11-30 19:27:00,502 | TestSet:
2022-11-30 19:27:28,553 | Test: [ 0/89]	Time 28.049 (28.049)	Loss (MSE) 220.430 (220.430)	Loss (L1) 9.582 (9.582)
2022-11-30 19:27:29,261 | Test: [10/89]	Time  0.070 ( 2.614)	Loss (MSE) 92.499 (145.710)	Loss (L1) 7.705 (7.570)
2022-11-30 19:27:29,947 | Test: [20/89]	Time  0.069 ( 1.4