In [None]:
##############################################################################
#	Project		:	Age Estimation
#	Pipeline	:	E2ePipeline3
#	Date		:	1.11.2023
# 	Description	: 	Bias Analysis
##############################################################################

import shutil

# importing the sys module
import sys        
 
# appending the directory of mod.py
# in the sys.path list
sys.path.append('../')   

import json
import pickle
import os
import random
from collections import defaultdict

import torch
from PIL import Image
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import models
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torch.optim import lr_scheduler
from torch.optim import Adam
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import utils
from datetime import datetime
from Common.Datasets.Morph2.data_parser import DataParser
from tqdm import tqdm

from Common.Optimizers.RangerLars import RangerLars
from Common.Schedulers.GradualWarmupScheduler import GradualWarmupScheduler
from Common.Analysis.GeneralMethods import get_statistics
from Common.Datasets.CACD.CacdDataParser import CacdDataParser
from Common.Datasets.Morph2.dataset_utils import *

import ep3_config as cfg
from ep3_dataset import QueryAndMultiAgeRefsDataset
from ep3_model import DiffBasedAgeDetectionModel
from ep3_train import train




#####################################################
#           Preparations
#####################################################

torch.manual_seed(cfg.RANDOM_SEED)
np.random.seed(cfg.RANDOM_SEED)
random.seed(cfg.RANDOM_SEED)

if cfg.USE_GPU:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")

print(device)

torch.cuda.empty_cache()

#####################################################
#           Data Loading
#####################################################

# Load data
print(f"Dataset: {cfg.DATASET_SELECT}")
print("Reading dataset...")

# Load data
if cfg.DATASET_SELECT == "Morph2":
	data_parser = DataParser(cfg.MORPH2_DATASET_PATH, small_data=cfg.SMALL_DATA)
	data_parser.initialize_data()
	x_train, y_train, x_test, y_test, chosen_idxs_trn, chosen_idxs_tst = data_parser.x_train,	data_parser.y_train, data_parser.x_test, data_parser.y_test, data_parser.chosen_idxs_trn, data_parser.chosen_idxs_tst
elif cfg.DATASET_SELECT == "CACD":
	data_parser = CacdDataParser(cfg.CACD_DATASET_PATH)
	data_parser.initialize_data()
	x_train, y_train, x_test, y_test = data_parser.x_train,	data_parser.y_train, data_parser.x_test, data_parser.y_test

if cfg.RANDOM_SPLIT:
    all_images = np.concatenate((x_train, x_test), axis=0)
    all_labels = np.concatenate((y_train, y_test), axis=0)

    x_train, x_test, y_train, y_test = train_test_split(all_images, all_labels, test_size=cfg.TEST_SIZE_FOR_RS, random_state=cfg.RANDOM_SEED)

#####################################################
#           Metadata Loading
#####################################################

# Emebeddings
face2emb_arr_trn_r = np.load(f'{cfg.DATASET_SELECT}_face2emb_arr_trn_recog.npy', allow_pickle=True)
face2emb_arr_vld_r = np.load(f'{cfg.DATASET_SELECT}_face2emb_arr_vld_recog.npy', allow_pickle=True)

if cfg.SMALL_DATA:
	if cfg.APPLY_TRAIN_SET_SPLIT_FOR_DIST_AND_ISOL or cfg.APPLY_TEST_SET_SPLIT_FOR_DIST_AND_ISOL:
		print("Unsupported modes from small data. Please cancel these modes and rerun. Aborting")
		exit()

	if cfg.DATASET_SELECT == "Morph2":
		face2emb_arr_trn_r = face2emb_arr_trn_r[chosen_idxs_trn]
		face2emb_arr_vld_r = face2emb_arr_vld_r[chosen_idxs_tst]

# Base model inference results loading
with open(cfg.INPUT_ESTIMATION_FILE_NAME_TEST, 'r') as im2age_map_test_f:
	im2age_map_test = json.load(im2age_map_test_f)
                  
with open(cfg.INPUT_ESTIMATION_FILE_NAME_TRAIN, 'r') as im2age_map_train_and_dist_f:
	im2age_map_train_and_dist = json.load(im2age_map_train_and_dist_f)
                  

"""
Creatig now these for next stages:

train actual
test actual 
dist
embeddings train actual
embeddings test actual
map test
base_model_err_dist_on_non_trained_set
"""


if cfg.APPLY_TRAIN_SET_SPLIT_FOR_DIST_AND_ISOL:
	print("applying dist and isol train sets split")
	# load dist and isol test indexes
	with open(f'{cfg.INDEXES_SAVE_DIR}/{cfg.DATASET_SELECT}_dist_indexes.pkl', 'rb') as f_dist_indexes:
		dist_indexes = pickle.load(f_dist_indexes)
	with open(f'{cfg.INDEXES_SAVE_DIR}/{cfg.DATASET_SELECT}_isolated_train_indexes.pkl', 'rb') as f_isolated_train_indexes:
		isolated_train_indexes = pickle.load(f_isolated_train_indexes)


	print(f"Original train set size: {len(data_parser.y_train)}")

	x_train_dist, y_train_dist, im2age_map_train_dist, x_train_isol, y_train_isol, im2age_map_train_isol = gen_dist_and_isol_test_sets(x_src_dataset=data_parser.x_train, 
																										y_src_dataset=data_parser.y_train, 
																										im2age_map_src_dataset_orig=im2age_map_train_and_dist, 
																										dist_indexes=dist_indexes, 
																										isolated_src_dataset_indexed=isolated_train_indexes)

	print(f"Actual train set size: {len(y_train_isol)}")

	# Outcome to next stages
	face2emb_arr_trn_r_actual = face2emb_arr_trn_r[isolated_train_indexes]
	face2emb_arr_vld_r_actual = face2emb_arr_vld_r
	
	base_model_err_dist_on_non_trained_set = get_statistics(dataset_metadata=y_train_dist,
										dataset_indexes=[i for i in range(len(y_train_dist))], 
										im2age_map_batst=im2age_map_train_dist)
	
	#print(f"""MAE (dist): {np.mean(np.abs(base_model_err_dist_on_non_trained_set["data"]))}""")
	
	# train_set_stats = get_statistics(dataset_metadata=y_train_isol,
	# 									dataset_indexes=[i for i in range(len(y_train_isol))], 
	# 									im2age_map_batst=im2age_map_train_isol)
	
	# print(f"""MAE (train): {np.mean(np.abs(train_set_stats["data"]))}""")
	
	# test_set_stats = get_statistics(dataset_metadata=y_test,
	# 									dataset_indexes=[i for i in range(len(y_test))], 
	# 									im2age_map_batst=im2age_map_test)
	
	# print(f"""MAE (test): {np.mean(np.abs(test_set_stats["data"]))}""")
	

	x_test_actual = x_test
	y_test_actual = y_test
	x_train_actual = x_train_isol
	y_train_actual = y_train_isol
	im2age_map_test_actual = im2age_map_test
elif cfg.APPLY_TEST_SET_SPLIT_FOR_DIST_AND_ISOL:
	print("applying dist and isol test sets split")
	# load dist and isol test indexes
	with open(f'{cfg.DATASET_SELECT}_dist_indexes.pkl', 'rb') as f_dist_indexes:
			dist_indexes = pickle.load(f_dist_indexes)
	with open(f'{cfg.DATASET_SELECT}_isolated_test_indexed.pkl', 'rb') as f_isolated_test_indexed:
			isolated_test_indexed = pickle.load(f_isolated_test_indexed)

	x_test_dist, y_test_dist, im2age_map_dist, x_test_isol, y_test_isol, im2age_map_isol = gen_dist_and_isol_test_sets(x_test, y_test, im2age_map_test, dist_indexes, isolated_test_indexed)


	face2emb_arr_vld_r_actual = face2emb_arr_vld_r[isolated_test_indexed]

	test_err_distribution = get_statistics(dataset_metadata=y_test_dist,
										dataset_indexes=[i for i in range(len(y_test_dist))],#chosen_idxs_tst, 
										im2age_map_batst=im2age_map_dist)
	x_test_actual = x_test_isol
	y_test_actual = y_test_isol
	im2age_map_test_actual = im2age_map_isol
else:
	print("NOT applying dist and isol test sets split")
	face2emb_arr_vld_r_actual = face2emb_arr_vld_r
    
	test_err_distribution = get_statistics(dataset_metadata=y_test,
										dataset_indexes=[i for i in range(len(y_test))],#chosen_idxs_tst, 
										im2age_map_batst=im2age_map_test)
	x_test_actual = x_test
	y_test_actual = y_test
	im2age_map_test_actual = im2age_map_test



#####################################################
#           Dataset Creation
#####################################################


# Test - Transforms
transf_tst = transforms.Compose([
			transforms.Resize(224),
			transforms.ToTensor(),
			transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
		])


# Train set
if cfg.APPLY_TRAIN_SET_SPLIT_FOR_DIST_AND_ISOL:
	# The original train set is composed from the actual train set and dist set.
	# The dist is isolated from train set completely, hence it is kind of "test set".
	# So we use same settings used with test set definition (e.g. we don't use the 
	# embeddings of the dist set as references). I.e. we run on dist set as if it 
	# was "test set"

	train_ds = QueryAndMultiAgeRefsDataset(
		min_age=cfg.MIN_AGE,
		max_age=cfg.MAX_AGE,
		age_interval=cfg.AGE_INTERVAL,
		transform=transf_tst,
		num_references=cfg.NUM_REFERENCES,
		embeddings_knn=cfg.EMBEDDINGS_KNN,
		base_data_set_images=x_train,                
		base_data_set_metadata=y_train,   
		base_data_set_embeddings=face2emb_arr_trn_r,            
		ref_data_set_images=x_train_actual,                
		ref_data_set_metadata=y_train_actual,              
		ref_data_set_embeddings=face2emb_arr_trn_r_actual,
		dataset_size_factor=cfg.DATASET_SIZE_FACTOR,
		base_set_is_ref_set=False,
		disable_same_ref_being_query=False,
		knn_reduced_pool_size=cfg.KNN_REDUCED_POOL_SIZE,
		sample_knn_reduced_pool=True,
		base_model_distribution=None,
		im2age_map=im2age_map_train_and_dist,
		mode_select="apply_map"
		)
	
	print("Train+Dist (q vld, r trn) set size: " + str(len(train_ds)))
	

# Test set
test_ds = QueryAndMultiAgeRefsDataset(
	min_age=cfg.MIN_AGE,
	max_age=cfg.MAX_AGE,
	age_interval=cfg.AGE_INTERVAL,
	transform=transf_tst,
	num_references=cfg.NUM_REFERENCES,
	embeddings_knn=cfg.EMBEDDINGS_KNN,
	base_data_set_images=x_test,                
	base_data_set_metadata=y_test,   
	base_data_set_embeddings=face2emb_arr_vld_r,            
	ref_data_set_images=x_train,                
	ref_data_set_metadata=y_train,              
	ref_data_set_embeddings=face2emb_arr_trn_r,
	dataset_size_factor=cfg.DATASET_SIZE_FACTOR,
	base_set_is_ref_set=False,
	disable_same_ref_being_query=False,
	knn_reduced_pool_size=cfg.KNN_REDUCED_POOL_SIZE,
	sample_knn_reduced_pool=True,
    base_model_distribution=None,
	im2age_map=im2age_map_test,
	mode_select="apply_map"
    )

print("Testing (q vld, r trn) set size: " + str(len(test_ds)))

if cfg.APPLY_TEST_SET_SPLIT_FOR_DIST_AND_ISOL:
	test_isol_ds = QueryAndMultiAgeRefsDataset(
		min_age=cfg.MIN_AGE,
		max_age=cfg.MAX_AGE,
		age_interval=cfg.AGE_INTERVAL,
		transform=transf_tst,
		num_references=cfg.NUM_REFERENCES,
		embeddings_knn=cfg.EMBEDDINGS_KNN,
		base_data_set_images=x_test_actual,                
		base_data_set_metadata=y_test_actual,   
		base_data_set_embeddings=face2emb_arr_vld_r_actual,            
		ref_data_set_images=x_train,                
		ref_data_set_metadata=y_train,              
		ref_data_set_embeddings=face2emb_arr_trn_r,
		dataset_size_factor=cfg.DATASET_SIZE_FACTOR,
		base_set_is_ref_set=False,
		disable_same_ref_being_query=False,
		knn_reduced_pool_size=cfg.KNN_REDUCED_POOL_SIZE,
		sample_knn_reduced_pool=True,
		base_model_distribution=None,
		im2age_map=im2age_map_test_actual,
		mode_select="apply_map"
		)

	print("Testing isolated (q vld, r trn) set size: " + str(len(test_isol_ds)))


# full test 
image_datasets = {
    'val_full' : test_ds
}

dataset_sizes = {x: len(image_datasets[x]) for x in ['val_full']}

data_loaders = {
    'val_full': DataLoader(test_ds, batch_size=1, num_workers=cfg.NUM_OF_WORKERS_DATALOADER, pin_memory=True, shuffle=False, drop_last=False),
}

# other sets
if cfg.APPLY_TEST_SET_SPLIT_FOR_DIST_AND_ISOL:
	image_datasets['val_isol'] = test_isol_ds
	dataset_sizes['val_isol'] = len(image_datasets['val_isol'])
	data_loaders['val_isol'] = DataLoader(test_isol_ds, batch_size=1, num_workers=cfg.NUM_OF_WORKERS_DATALOADER, pin_memory=True, shuffle=False, drop_last=False)

if cfg.APPLY_TRAIN_SET_SPLIT_FOR_DIST_AND_ISOL:
	image_datasets['train'] = train_ds
	dataset_sizes['train'] = len(image_datasets['train'])
	data_loaders['train'] = DataLoader(train_ds, batch_size=1, num_workers=cfg.NUM_OF_WORKERS_DATALOADER, pin_memory=True, shuffle=False, drop_last=False)


if cfg.DIST_APPROX_METHOD == "kde_based_saturated":
	min_age_diff = cfg.ERROR_SAT_RANGE_MIN
	max_age_diff = cfg.ERROR_SAT_RANGE_MAX
else:
	min_age_diff = cfg.MIN_AGE - cfg.MAX_AGE 
	max_age_diff = cfg.MAX_AGE - cfg.MIN_AGE 
	
num_classes_diff = max_age_diff - min_age_diff + 1
print(f"num of diff classes: {num_classes_diff}")


#####################################################
#           Model
#####################################################

model = DiffBasedAgeDetectionModel(
    device=device,
	min_age=cfg.MIN_AGE,
	max_age=cfg.MAX_AGE,
	age_interval=cfg.AGE_INTERVAL,
	num_references=cfg.NUM_REFERENCES,
	pretrained_model_path=cfg.PRETRAINED_MODEL_PATH,
	pretrained_model_file_name=cfg.PRETRAINED_MODEL_FILE_NAME,
	load_pretrained=cfg.LOAD_PRETRAINED_RECOG,
	dropout_p=cfg.DROPOUT_P,
    num_of_fc_layers=cfg.NUM_OF_FC_LAYERS,
    is_ordinal=cfg.IS_ORDINAL,
    min_age_diff=min_age_diff,
	max_age_diff=max_age_diff,
	num_classes_diff=num_classes_diff,
    regressors_diff_head=cfg.REGRESSORS_DIFF_HEAD,
	fc_head_base_layer_size=cfg.FC_HEAD_BASE_LAYER_SIZE,
	use_vit=cfg.USE_VIT,
	use_convnext=cfg.USE_CONVNEXT,
	use_efficientnet=cfg.USE_EFFICIENTNET,
	use_resnet51q=cfg.USE_RESNET51Q
)

model.to(device)

# if cfg.UNFREEZE_FEATURE_EXT_ON_RLVNT_EPOCH:
#     model.freeze_base_cnn(True)

if cfg.USE_GPU and cfg.MULTI_GPU:
    if torch.cuda.device_count() > 1:
        print("Using multiple GPUs (" + str(torch.cuda.device_count()) + ")")
        model = torch.nn.DataParallel(model)


# test_err_distribution = get_statistics(dataset_metadata=y_test,
#                                        dataset_indexes=[i for i in range(len(y_test))],#chosen_idxs_tst, 
#                                        im2age_map_batst=im2age_map_test)

# mae_age = np.mean(np.abs(test_err_distribution["data"]))
# print(f"MAE : {mae_age}")

model.eval()

print("loading weights...")
#loaded = torch.load("/home/eng/workspace/AgeEstimationMultiProject/E2ePipeline3/weights/Morph2Diff/unified/iter/time_24_12_2023_16_20_38/weights_54_2.5240.pt")
#loaded = torch.load("/home/eng/workspace/AgeEstimationMultiProject/E2ePipeline3/weights/Morph2Diff/unified/iter/time_08_01_2024_22_18_11/weights_59_2.5057.pt")
#loaded = torch.load("/home/eng/workspace/AgeEstimationMultiProject/E2ePipeline3/weights/Morph2Diff/unified/iter/time_17_01_2024_01_18_28/weights_19_2.4980.pt")
#loaded = torch.load("/home/eng/workspace/AgeEstimationMultiProject/E2ePipeline3/weights/Morph2Diff/unified/iter/time_19_01_2024_23_38_06/weights_19_2.4949.pt")
#loaded = torch.load("/home/eng/workspace/AgeEstimationMultiProject/E2ePipeline3/weights/Morph2Diff/unified/iter/time_20_01_2024_22_06_08/weights_1_2.4824.pt")
loaded = torch.load(cfg.INFERENCE_MODEL_WEIGHTS_PATH)
model.load_state_dict(loaded['model_state_dict'], strict=True)#, map_location=torch.device('cuda:0')))#, strict=False)
			

In [None]:
# for k in im2age_map_test_actual:
# 	im2age_map_test_actual[k]

## For later bias analysis

In [None]:
running_mae_age = 0.0

im2age_map_next = dict()
print("running inference...")

age_preds = []
age_labels = []
#class_preds = []
#class_labels = []
genders = []
races = []


if cfg.APPLY_TEST_SET_SPLIT_FOR_DIST_AND_ISOL:
	running_mae_age_isol = 0.0

	print("####################################################################")
	print("### Isolated test set (evaluation)")
	i = 0
	for batch in tqdm(data_loaders['val_isol']):
		image_vec = batch['image_vec'].to(device) #batch['image_vec'][:,:2,:,:,:].to(device)
		query_age = batch['query_age'].to(device).float()
		query_age_noised = batch['query_age_noised'].to(device).long()
		age_diff = batch['age_diffs_for_reg'].to(device).float() #torch.stack([batch['age_diffs_for_reg'][i].to(device).float() for i in range(num_references)])
		age_refs = batch['age_refs'].to(device).long() #torch.stack([batch['age_refs'][i].to(device).float() for i in range(num_references)])
		idxs = batch['actual_query_idx'].to(device)

		#class_label = batch['age_diffs_for_cls'].to(device).float()
		age_label = list(batch['query_age'].cpu().numpy())
		#print(age_label)
		race = batch['race']
		gender = batch['gender_raw']

		with torch.no_grad():
			# age_pred, age_diff_preds = model(input_images=image_vec, input_ref_ages=age_refs)
			# age_loss = 	criterion_age(age_pred.reshape(age_pred.shape[0]), query_age)
			# age_diff_loss = criterion_age_diff(age_diff_preds, age_diff)
			if cfg.USE_GENDER:
				age_pred_f, age_pred_r, age_diff_preds_f, age_diff_preds_r, classification_logits, classification_logits_main_diff, classification_logits_main_diff_minus, gender_head_cls_pre_sigmoid = model(input_images=image_vec, query_noisy_age=query_age_noised, input_ref_ages=age_refs) 
			else:
				age_pred_f, age_pred_r, age_diff_preds_f, age_diff_preds_r, classification_logits, classification_logits_main_diff, classification_logits_main_diff_minus = model(input_images=image_vec, query_noisy_age=query_age_noised, input_ref_ages=age_refs) 
			
			if cfg.INFERENCE_BASED_ON_F:
				running_mae_age_isol += torch.nn.L1Loss()(age_pred_f.reshape(age_pred_f.shape[0]), query_age) * image_vec.size(0)
				#print(f"age_pred : {age_pred_f.view(-1)}, age actual :{query_age}")
				age_preds.append(list(age_pred_f.cpu().numpy())[0][0])
			else:
				running_mae_age_isol += torch.nn.L1Loss()(age_pred_r.reshape(age_pred_r.shape[0]), query_age) * image_vec.size(0)
				#print(f"age_pred : {age_pred_r.view(-1)}, age actual :{query_age}")
				age_preds.append(list(age_pred_r.cpu().numpy())[0][0])

		#print(list(age_pred_f.cpu().numpy()), age_label, gender, race)

		#_, class_pred = torch.max(classification_logits, 1)
		
		age_labels.append(age_label[0])
		#class_preds.append(class_pred)
		#class_labels.append(class_label)
		genders.append(gender[0])
		races.append(race[0])

		i += 1

		# if i == 10:
		# 	break

	mae_age = running_mae_age_isol / dataset_sizes['val_isol']
	print(f"MAE (isolated test set): {mae_age}")
else:
	running_mae_age_isol = 0.0

	print("####################################################################")
	print("### Full test set (evaluation)")
	i = 0
	for batch in tqdm(data_loaders['val_full']):
		image_vec = batch['image_vec'].to(device) #batch['image_vec'][:,:2,:,:,:].to(device)
		query_age = batch['query_age'].to(device).float()
		query_age_noised = batch['query_age_noised'].to(device).long()
		age_diff = batch['age_diffs_for_reg'].to(device).float() #torch.stack([batch['age_diffs_for_reg'][i].to(device).float() for i in range(num_references)])
		age_refs = batch['age_refs'].to(device).long() #torch.stack([batch['age_refs'][i].to(device).float() for i in range(num_references)])
		idxs = batch['actual_query_idx'].to(device)

		#class_label = batch['age_diffs_for_cls'].to(device).float()
		age_label = list(batch['query_age'].cpu().numpy())
		#print(age_label)
		race = batch['race']
		gender = batch['gender_raw']

		with torch.no_grad():
			# age_pred, age_diff_preds = model(input_images=image_vec, input_ref_ages=age_refs)
			# age_loss = 	criterion_age(age_pred.reshape(age_pred.shape[0]), query_age)
			# age_diff_loss = criterion_age_diff(age_diff_preds, age_diff)
			if cfg.USE_GENDER:
				age_pred_f, age_pred_r, age_diff_preds_f, age_diff_preds_r, classification_logits, classification_logits_main_diff, classification_logits_main_diff_minus, gender_head_cls_pre_sigmoid = model(input_images=image_vec, query_noisy_age=query_age_noised, input_ref_ages=age_refs) 
			else:
				age_pred_f, age_pred_r, age_diff_preds_f, age_diff_preds_r, classification_logits, classification_logits_main_diff, classification_logits_main_diff_minus = model(input_images=image_vec, query_noisy_age=query_age_noised, input_ref_ages=age_refs) 
			

			if cfg.INFERENCE_BASED_ON_F:
				running_mae_age_isol += torch.nn.L1Loss()(age_pred_f.reshape(age_pred_f.shape[0]), query_age) * image_vec.size(0)
				#print(f"age_pred : {age_pred_f.view(-1)}, age actual :{query_age}")
				age_preds.append(list(age_pred_f.cpu().numpy())[0][0])
			else:
				running_mae_age_isol += torch.nn.L1Loss()(age_pred_r.reshape(age_pred_r.shape[0]), query_age) * image_vec.size(0)
				#print(f"age_pred : {age_pred_r.view(-1)}, age actual :{query_age}")
				age_preds.append(list(age_pred_r.cpu().numpy())[0][0])

		#print(list(age_pred_f.cpu().numpy()), age_label, gender, race)

		#_, class_pred = torch.max(classification_logits, 1)
		
		age_labels.append(age_label[0])
		#class_preds.append(class_pred)
		#class_labels.append(class_label)
		genders.append(gender[0])
		races.append(race[0])

		i += 1

		# if i == 10:
		# 	break

	mae_age = running_mae_age_isol / dataset_sizes['val_full']
	print(f"MAE (full test set): {mae_age}")

In [None]:
import pandas as pd
import seaborn as sns
sns.set_theme(style="whitegrid", palette="pastel")

#data = {'age_preds': age_preds, 'age_labels': age_labels, 'class_labels': class_labels, 'class_preds': class_preds, 'genders': genders, 'races': races}
data = {'age_preds': age_preds, 'age_labels': age_labels, 'genders': genders, 'races': races}
df = pd.DataFrame(data=data)

race_dict = {'W':'White', 'B':'Black', 'H':'Hispanic', 'O':'Other', 'A':'Asian'}
gender_dict = {'M':'Male', 'F':'Female'}

# df['age_preds'] = df['age_preds'].apply(lambda x: x.cpu().detach().numpy()[0])
# df['age_labels'] = df['age_labels'].apply(lambda x: x.cpu().detach().numpy()[0])
# #df['class_preds'] = df['class_preds'].apply(lambda x: x.cpu().detach().numpy()[0])
# #df['class_labels'] = df['class_labels'].apply(lambda x: x.cpu().detach().numpy()[0])

df['genders'] = df['genders'].apply(lambda x: gender_dict[x])
df['races'] = df['races'].apply(lambda x: race_dict[x])

df = df[df['races'] != 'Other']


df.head(10)

In [None]:
df['abs_age_diff'] = np.abs(df['age_preds'] - df['age_labels'])

In [None]:
df.to_csv('results_good_eval_2_8_2025_morph2_effcientnetv2_best_2.45.csv')

In [None]:
df

In [None]:
# print(data['age_labels'][0])
# print(data["age_preds"][0])
# print(data["genders"][0])
# print(data["races"][0])

In [None]:
#cfg.INPUT_ESTIMATION_FILE_NAME

## Age Bias Analysis

In [None]:
NUM_OF_BINS = 13

ages = []
for i in range(len(y_train)):
	metadata = json.loads(y_train[i])
	age = int(metadata["age"])
	ages.append(age)

# for i in range(len(y_test)):
# 	metadata = json.loads(y_test[i])
# 	age = int(metadata["age"])
# 	ages.append(age)

#import matplotlib.pyplot as plt

ages_a = np.array(ages)

num_of_samples_train = dict()

for bin_idx in range(NUM_OF_BINS):
	num_of_samples_train[f"{5*(bin_idx+3)}to{5*(bin_idx+3)+4}"] = len(ages_a[(ages_a >= 5*(bin_idx+3)) & (ages_a <= 5*(bin_idx+3)+4)])

#plt.hist(ages, bins=50)


# age_ranges = {
# 	"15to19": (15, 19),
# 	"20to24": (20, 24),
# 	"25to29": (25, 29),
# 	"30to34": (30, 34),
# 	"35to39": (35, 39),
# 	"40to44": (40, 44),
# 	"45to49": (45, 49),
# 	"50to54": (50, 54),
# 	"55to59": (55, 59),
# 	"60to64": (60, 64),
# 	"65to70": (65, 70)
# }
# ages = []
# for i in range(len(y_train)):
# 	metadata = json.loads(y_train[i])
# 	age = float(metadata["age"])
# 	ages.append(age)
# 	for key, (lower, upper) in age_ranges.items():
# 		if lower <= age <= upper:
# 			num_of_samples[f"{lower}to{upper}"] += 1
# 			print(lower, upper, age)
# 			break

# for i in range(len(y_test)):
# 	metadata = json.loads(y_test[i])
# 	age = float(metadata["age"])
# 	for key, (lower, upper) in age_ranges.items():
# 		if lower <= age <= upper:
# 			num_of_samples[f"{lower}to{upper}"] += 1
# 			break




In [None]:
num_of_samples_train

In [None]:

import math
print("running inference...")

error_bins = dict()
for bin_idx in range(NUM_OF_BINS):
	error_bins[f"{5*(bin_idx+3)}to{5*(bin_idx+3)+4}"] = []



if cfg.APPLY_TEST_SET_SPLIT_FOR_DIST_AND_ISOL:
	running_mae_age_isol = 0.0

	print("####################################################################")
	print("### Isolated test set (evaluation)")
	i = 0
	for batch in tqdm(data_loaders['val_isol']):
		image_vec = batch['image_vec'].to(device) #batch['image_vec'][:,:2,:,:,:].to(device)
		query_age = batch['query_age'].to(device).float()
		query_age_noised = batch['query_age_noised'].to(device).long()
		age_diff = batch['age_diffs_for_reg'].to(device).float() #torch.stack([batch['age_diffs_for_reg'][i].to(device).float() for i in range(num_references)])
		age_refs = batch['age_refs'].to(device).long() #torch.stack([batch['age_refs'][i].to(device).float() for i in range(num_references)])
		idxs = batch['actual_query_idx'].to(device)

		#class_label = batch['age_diffs_for_cls'].to(device).float()
		age_label = list(batch['query_age'].cpu().numpy())
		#print(age_label)

		with torch.no_grad():
			# age_pred, age_diff_preds = model(input_images=image_vec, input_ref_ages=age_refs)
			# age_loss = 	criterion_age(age_pred.reshape(age_pred.shape[0]), query_age)
			# age_diff_loss = criterion_age_diff(age_diff_preds, age_diff)
			if cfg.USE_GENDER:
				age_pred_f, age_pred_r, age_diff_preds_f, age_diff_preds_r, classification_logits, classification_logits_main_diff, gender_head_cls_pre_sigmoid = model(input_images=image_vec, query_noisy_age=query_age_noised, input_ref_ages=age_refs) 
			else:
				age_pred_f, age_pred_r, age_diff_preds_f, age_diff_preds_r, classification_logits, classification_logits_main_diff = model(input_images=image_vec, query_noisy_age=query_age_noised, input_ref_ages=age_refs)
			
			if cfg.INFERENCE_BASED_ON_F:
				running_mae_age_isol += torch.nn.L1Loss()(age_pred_f.reshape(age_pred_f.shape[0]), query_age) * image_vec.size(0)
				#print(f"age_pred : {age_pred_f.view(-1)}, age actual :{query_age}")
				age_pred_used = list(age_pred_f.cpu().numpy())[0][0] 
				#print(age_pred_f)
			else:
				running_mae_age_isol += torch.nn.L1Loss()(age_pred_r.reshape(age_pred_r.shape[0]), query_age) * image_vec.size(0)
				#print(f"age_pred : {age_pred_r.view(-1)}, age actual :{query_age}")
				age_pred_used = list(age_pred_r.cpu().numpy())[0][0]
				#print(age_pred_r)
			
			
			#print(age_label)
			bin_idx = int(math.floor(age_label[0] / 5)) - 3
			#print(bin_idx)

			error_bins[f"{5*(bin_idx+3)}to{5*(bin_idx+3)+4}"].append(age_pred_used-age_label[0])


		# i += 1

		# if i == 10:
		# 	break

	mae_age = running_mae_age_isol / dataset_sizes['val_isol']
	print(f"MAE (isolated test set): {mae_age}")
else:
	running_mae_age_isol = 0.0

	print("####################################################################")
	print("### Full test set (evaluation)")
	i = 0
	for batch in tqdm(data_loaders['val_full']):
		image_vec = batch['image_vec'].to(device) #batch['image_vec'][:,:2,:,:,:].to(device)
		query_age = batch['query_age'].to(device).float()
		query_age_noised = batch['query_age_noised'].to(device).long()
		age_diff = batch['age_diffs_for_reg'].to(device).float() #torch.stack([batch['age_diffs_for_reg'][i].to(device).float() for i in range(num_references)])
		age_refs = batch['age_refs'].to(device).long() #torch.stack([batch['age_refs'][i].to(device).float() for i in range(num_references)])
		idxs = batch['actual_query_idx'].to(device)

		#class_label = batch['age_diffs_for_cls'].to(device).float()
		age_label = list(batch['query_age'].cpu().numpy())
		#print(age_label)

		with torch.no_grad():
			# age_pred, age_diff_preds = model(input_images=image_vec, input_ref_ages=age_refs)
			# age_loss = 	criterion_age(age_pred.reshape(age_pred.shape[0]), query_age)
			# age_diff_loss = criterion_age_diff(age_diff_preds, age_diff)
			if cfg.USE_GENDER:
				age_pred_f, age_pred_r, age_diff_preds_f, age_diff_preds_r, classification_logits, classification_logits_main_diff, classification_logits_main_diff_minus, gender_head_cls_pre_sigmoid = model(input_images=image_vec, query_noisy_age=query_age_noised, input_ref_ages=age_refs) 
			else:
				age_pred_f, age_pred_r, age_diff_preds_f, age_diff_preds_r, classification_logits, classification_logits_main_diff, classification_logits_main_diff_minus = model(input_images=image_vec, query_noisy_age=query_age_noised, input_ref_ages=age_refs) 
				
			if cfg.INFERENCE_BASED_ON_F:
				running_mae_age_isol += torch.nn.L1Loss()(age_pred_f.reshape(age_pred_f.shape[0]), query_age) * image_vec.size(0)
				#print(f"age_pred : {age_pred_f.view(-1)}, age actual :{query_age}")
				age_pred_used = list(age_pred_f.cpu().numpy())[0][0] 
				#print(age_pred_f)
			else:
				running_mae_age_isol += torch.nn.L1Loss()(age_pred_r.reshape(age_pred_r.shape[0]), query_age) * image_vec.size(0)
				#print(f"age_pred : {age_pred_r.view(-1)}, age actual :{query_age}")
				age_pred_used = list(age_pred_r.cpu().numpy())[0][0]
				#print(age_pred_r)
			
			
			#print(age_label)
			bin_idx = int(math.floor(age_label[0] / 5)) - 3
			#print(bin_idx)

			error_bins[f"{5*(bin_idx+3)}to{5*(bin_idx+3)+4}"].append(age_pred_used-age_label[0])


		# i += 1

		# if i == 10:
		# 	break

	mae_age = running_mae_age_isol / dataset_sizes['val_full']
	print(f"MAE (Full test set): {mae_age}")

In [None]:
mae_bins = dict()
std_bins = dict()
for bin in error_bins:
	if len(error_bins[bin]) > 0:
		err_arr = np.array(error_bins[bin])
		mae_bins[bin] = np.mean(np.abs(err_arr))
		std_bins[bin] = np.std(err_arr)

In [None]:
err_arr

In [None]:
mae_bins

In [None]:
std_bins

In [None]:
print("age  |  #Samples  |  MAE  |  STD  ")
for bin in mae_bins:
	print(f"{bin} & {num_of_samples_train[bin]} & {mae_bins[bin]} & {std_bins[bin]} \\\\")

In [None]:
errs = []
for bin in error_bins:
	if len(error_bins[bin]) > 0:
		errs += error_bins[bin]

In [None]:
np.mean(np.abs(errs))
