In [None]:
'''
Given an .npz file, calculates the fairness discrepancy and KL divergence to uniform distribution.
'''


import torch
import numpy as np
import torch.nn.functional as F
import os
from collections import defaultdict


# config
model = 'gswgan' # 'gpate' or 'datalens' or 'gswgan'
gen_data_path_list = ['...path to npz file...']


# ideal case
result_dict = defaultdict(list)
unif = torch.Tensor([0.25, 0.25, 0.25, 0.25])

# for each gen_data, evaluate FIDs
for data_path in gen_data_path_list:
    print("Current gen_data: ", data_path)

    # load data
    data_x = np.load(data_path)['data_x']
    data_y = np.load(data_path)['data_y']
    data_z = np.load(data_path)['data_z']

    # categorize groups
    pairs = [ str(y)+str(z) for y,z in zip(data_y, data_z)]
    groups, counts = np.unique(pairs, return_counts=True)

    group_dict = dict(zip(groups, counts))
    data_distrib = torch.Tensor(counts) / len(data_y)
    data_distrib = torch.cat((data_distrib, torch.Tensor([0])), dim=0)

    kl_base = F.kl_div(unif.log(), data_distrib, None, None, 'sum')
    result_dict['kl_to_uniform'].append(np.round(kl_base, 3))
    print(f'kl_to_uniform: {kl_base:3f}')

    fd_base = torch.dist(unif, data_distrib, p=2)
    result_dict['fairness_discrepancy'].append(np.round(fd_base, 3))
    print(f'fariness discrepancy: {fd_base:3f}')


# result folder
result_file_folder = os.path.join(f'/home/soyeon/nas/pfgan_hub/evaluation/diversity/{model}')
os.makedirs(result_file_folder, exist_ok = True)

# save results
with open(os.path.join(result_file_folder, f'fairness_result.txt'), 'w') as f:
    for k, v in result_dict.items():
        f.write(f'Result for {k}: {v}\n')
        f.write(f'\tmean: {np.mean(v):.3f}\n')
        f.write(f'\tstd: {np.std(v):.3f}\n')