In [None]:
'''
Evaluate FID statistic with unbiased data, following [Choi et. al, ICML'20]

References:
    - Dingfan Chen, GS-WGAN, 2020, https://github.com/DingfanChen/GS-WGAN/blob/main/evaluation/eval_fid.py
    - Maximilian Seitzer, Python Package pytroch-fid, 2020, https://github.com/mseitzer/pytorch-fid

Config:
    dataset: 'mnist' or 'fmnist'
    gpu_num: str indicating GPU device to run.
    batch_size: To inference InceptionV3 Net.
    target_path: Folder containing trained generators.
    num_gen_img: Number of generated images to evaluate.
    bias_factor: 'z' or 'y' or 'multi'
    select_runs: Whether to select top FID runs.
'''

import torch
import os
import numpy as np


# config ==============
dataset = 'mnist'
target_model = 'gswgan' # 'gpate' or 'datalens' or 'gswgan'
gen_data_path_list = ['...path to generated .npz data...']
gpu_num = '2'
only_y = False # class bias= True, subgroups bias= False
#========================

data_dir = '../dataset'
random_seed = 0
batch_size = 100


# random seed
torch.manual_seed(random_seed)
np.random.seed(random_seed)

# environment
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_num

if dataset == 'small_celeba':
    dpath = f'{data_dir}/celebA/train_celeba_gender_32x32.npz'
    real_data = np.load(dpath)['data_x']
    real_label = np.load(dpath)['data_y']

    # balance data
    major_idx = np.where(real_label == 1)[0]
    minor_idx = np.where(real_label == 0)[0]

    balanced_num = min(len(major_idx), len(minor_idx))
    balanced_id = np.concatenate([major_idx[:balanced_num], minor_idx[:balanced_num]])

    real_data = real_data[balanced_id].transpose((0, 3, 1, 2))
    real_label = real_label[balanced_id]
    real_z = None
    # minor, major
    digit_list = [0, 1]
    img_size = 32

elif dataset == 'large_celeba':
    dpath = f'{data_dir}/celebA/train_celeba_gender_64x64.npz'
    real_data = np.load(dpath)['data_x']
    real_label = np.load(dpath)['data_y']

    # balance data
    major_idx = np.where(real_label == 1)[0]
    minor_idx = np.where(real_label == 0)[0]

    balanced_num = min(len(major_idx), len(minor_idx))
    balanced_id = np.concatenate([major_idx[:balanced_num], minor_idx[:balanced_num]])

    real_data = real_data[balanced_id].transpose((0, 3, 1, 2))
    real_label = real_label[balanced_id]
    real_z = None
    # minor, major
    digit_list = [0, 1]
    img_size = 64

if dataset == 'mnist':
    dpath = f'{data_dir}/mnist/rotated/unbiased'
    real_data = torch.load(os.path.join(dpath, 'train_data.pt'))
    real_label = torch.load(os.path.join(dpath, 'train_Y.pt'))
    real_z = torch.load(os.path.join(dpath, 'train_A.pt'))
    # minor, major
    digit_list = [1, 3]

elif dataset == 'fmnist':
    dpath = f'{data_dir}/fmnist/rotated/unbiased'
    real_data = torch.load(os.path.join(dpath, 'train_data.pt')).unsqueeze(1).numpy()
    real_label = torch.load(os.path.join(dpath, 'train_Y.pt')).numpy()
    real_z = torch.load(os.path.join(dpath, 'train_A.pt')).numpy()
    # minor, major
    digit_list = [1, 7]
# else:
#     raise NotImplementedError


In [None]:
'''
prepare InceptionV3 model
'''


from pytorch_fid.inception import InceptionV3


block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
model = InceptionV3([block_idx])
load_model = model.cuda()


In [None]:
'''
get statistic of activation
'''

from tqdm import tqdm
from torch.nn.functional import adaptive_avg_pool2d
import sys

STAT_DIR = './stats'

# ========= Functions ====================================
def mkdir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)


def get_act(model, batch_size, gen_data):
    '''
    Given InceptionV3 model, get statistic of gen_data. 
    Note gen_data should have size ( * , 28, 28, 1), type ndarray, and normalized from 0 to 1. (for binary)
    Returns:
        mean, cov
    '''
    model.eval()
    
    if gen_data.shape[0] < batch_size:
        print(f'Group Size({gen_data.shape[0]}) is smaller than batch size({batch_size})')
        n_batches = 1
        n_used_imgs = gen_data.shape[0]
        smaller_flag = True

    else:
        n_batches = gen_data.shape[0] // batch_size
        n_used_imgs = n_batches * batch_size
        smaller_flag = False


    pred_arr = np.empty((n_used_imgs, 2048))
    for i in tqdm(range(n_batches)):
        if smaller_flag:
            start = 0
            end = batch_size = gen_data.shape[0]
            images = gen_data[start:end]
        else:
            start = i * batch_size
            end = start + batch_size
            images = gen_data[start:end]

        if images.shape[1] != 3:
            images = images.transpose((0, 3, 1, 2))
            images = np.tile(images, [1, 3, 1, 1])

        batch = torch.from_numpy(images).type(torch.FloatTensor).cuda()
        pred = model(batch)[0]

        if pred.shape[2] != 1 or pred.shape[3] != 1:
            pred = adaptive_avg_pool2d(pred, output_size=(1, 1))

        pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)

    mu = np.mean(pred_arr, axis=0)
    sigma = np.cov(pred_arr, rowvar=False)

    return mu, sigma
# =======================================================


# check whether stats are pre-exist
# stat_file = os.path.join(STAT_DIR, dataset, bias_factor, 'stat.npz')

if not only_y:
    stat_file = os.path.join(STAT_DIR, dataset, 'stat.npz')
    if not os.path.exists(stat_file):
        print('Computing statistic.')

        ## Save real statistics
        mkdir(os.path.join(STAT_DIR, dataset))

        # note real data has shape [bs, 1, 28, 28], while gen data has [bs, 28, 28, 1]
        real_data = real_data.view(-1, 28, 28, 1)
        real_data = real_data / 255.0
        real_data = real_data.numpy()

        # get stats of all groups
        minor, major = digit_list

        idx_cln_3 = (real_label == major) & (real_z == 1)
        idx_rot_3 = (real_label == major) & (real_z == 0)
        idx_cln_1 = (real_label == minor) & (real_z == 1)
        idx_rot_1 = (real_label == minor) & (real_z == 0)

        m_real_all, s_real_all = get_act(model, batch_size, real_data)
        m_real_cln_3, s_real_cln_3 = get_act(model, batch_size, real_data[idx_cln_3])
        m_real_rot_3, s_real_rot_3 = get_act(model, batch_size, real_data[idx_rot_3])
        m_real_cln_1, s_real_cln_1 = get_act(model, batch_size, real_data[idx_cln_1])
        m_real_rot_1, s_real_rot_1 = get_act(model, batch_size, real_data[idx_rot_1])

        np.savez(stat_file, mu_all= m_real_all, sigma_all=s_real_all, \
            mu_cln_3 = m_real_cln_3, sigma_cln_3 = s_real_cln_3, 
            mu_rot_3 = m_real_rot_3, sigma_rot_3 = s_real_rot_3,
            mu_cln_1 = m_real_cln_1, sigma_cln_1 = s_real_cln_1,
            mu_rot_1 = m_real_rot_1, sigma_rot_1 = s_real_rot_1)
        

    else:
        ## Load pre-computed statistics
        print('Loaded pre-computed statistic.')
        f = np.load(stat_file)
        

        m_real_all, s_real_all = f['mu_all'][:], f['sigma_all'][:]
        m_real_cln_3, s_real_cln_3 = f['mu_cln_3'][:], f['sigma_cln_3'][:]
        m_real_rot_3, s_real_rot_3 = f['mu_rot_3'][:], f['sigma_rot_3'][:]
        m_real_cln_1, s_real_cln_1 = f['mu_cln_1'][:], f['sigma_cln_1'][:]
        m_real_rot_1, s_real_rot_1 = f['mu_rot_1'][:], f['sigma_rot_1'][:]

else:
    stat_file = os.path.join(STAT_DIR, dataset, 'stat.npz')

    if not os.path.exists(stat_file):
        print('Computing statistic.')

        ## Save real statistics
        mkdir(os.path.join(STAT_DIR, dataset))

        # note real data has shape [bs, 1, 28, 28], while gen data has [bs, 28, 28, 1]

        # get stats of all groups
        minor, major = digit_list

        idx_major = (real_label == major).squeeze()
        idx_minor = (real_label == major).squeeze()

        m_real_all, s_real_all = get_act(model, batch_size, real_data)
        m_real_major, s_real_major = get_act(model, batch_size, real_data[idx_major])
        m_real_minor, s_real_minor = get_act(model, batch_size, real_data[idx_minor])

        np.savez(stat_file, mu_all= m_real_all, sigma_all=s_real_all, \
            mu_major = m_real_major, sigma_major = s_real_major, 
            mu_minor = m_real_minor, sigma_minor = s_real_minor)
        
    else:
        ## Load pre-computed statistics
        print('Loaded pre-computed statistic.')
        f = np.load(stat_file)

        m_real_all, s_real_all = f['mu_all'][:], f['sigma_all'][:]
        m_real_major, s_real_major = f['mu_major'][:], f['sigma_major'][:]
        m_real_minor, s_real_minor = f['mu_minor'][:], f['sigma_minor'][:]

    


In [None]:
'''
compute fid of gen_data_list
'''

from pytorch_fid.fid_score import calculate_frechet_distance
import numpy as np
from collections import defaultdict



fid_values = defaultdict(list)
# for each gen_data, evaluate FIDs
for gen_data_path in gen_data_path_list:
    print("Current gen_data: ", gen_data_path)

    if not only_y:
        # load gen_data, gen_data_y
        if target_model == 'gpate' or target_model == 'datalens':
            gen_data = np.load(gen_data_path)

            gen_data_x = gen_data['data_x'][:60000] / 255.0
            gen_data_x = gen_data_x.reshape(-1, 28, 28, 1)
            gen_data_y = gen_data['data_y'][:60000]
            gen_data_z = gen_data['data_z'][:60000]
        
        else:
            gen_data = np.load(gen_data_path)

            gen_data_x = gen_data['data_x'][:10000] 
            gen_data_y = gen_data['data_y'][:10000]
            gen_data_z = gen_data['data_z'][:10000]

        # overall fid
        m_gen_all, s_gen_all = get_act(model, batch_size, gen_data_x)
        fid_value_all = calculate_frechet_distance(m_real_all, s_real_all, m_gen_all, s_gen_all)
        print("fid_value_all: ", fid_value_all)
        fid_values['overall'].append(np.round(fid_value_all, 3))

        minor, major = digit_list

        idx_cln_3 = (gen_data_y == major) & (gen_data_z == 1)
        idx_rot_3 = (gen_data_y == major) & (gen_data_z == 0)
        idx_cln_1 = (gen_data_y == minor) & (gen_data_z == 1)
        idx_rot_1 = (gen_data_y == minor) & (gen_data_z == 0)

        m_gen_cln_3, s_gen_cln_3 = get_act(model, batch_size, gen_data_x[idx_cln_3])
        m_gen_rot_3, s_gen_rot_3 = get_act(model, batch_size, gen_data_x[idx_rot_3])
        m_gen_cln_1, s_gen_cln_1 = get_act(model, batch_size, gen_data_x[idx_cln_1])
        m_gen_rot_1, s_gen_rot_1 = get_act(model, batch_size, gen_data_x[idx_rot_1])

        fid_value_cln_3 = calculate_frechet_distance(m_real_cln_3, s_real_cln_3, m_gen_cln_3, s_gen_cln_3)
        fid_value_rot_3 = calculate_frechet_distance(m_real_rot_3, s_real_rot_3, m_gen_rot_3, s_gen_rot_3)
        fid_value_cln_1 = calculate_frechet_distance(m_real_cln_1, s_real_cln_1, m_gen_cln_1, s_gen_cln_1)
        fid_value_rot_1 = calculate_frechet_distance(m_real_rot_1, s_real_rot_1, m_gen_rot_1, s_gen_rot_1)

        print('fid value for major class: {:.3f}(Z=1) {:.3f}(Z=0)'.format(fid_value_cln_3, fid_value_rot_3))
        print('fid value for minor class: {:.3f}(Z=1) {:.3f}(Z=0)'.format(fid_value_cln_1, fid_value_rot_1))

        fid_values['fid_major_z1'].append(fid_value_cln_3)
        fid_values['fid_major_z0'].append(fid_value_rot_3)
        fid_values['fid_minor_z1'].append(fid_value_cln_1)
        fid_values['fid_minor_z0'].append(fid_value_rot_1)



    else:
        import joblib

        for i in range(1):
            gen_data = joblib.load(gen_data_path)
            gen_data_x, gen_data_y = np.hsplit(gen_data, [-2])
    
            random_indices = np.random.choice(len(gen_data_x), 60000, replace=False).tolist()
            gen_data_x = gen_data_x[random_indices]
            gen_data_y = gen_data_y[random_indices]

            gen_data_x = gen_data_x.reshape(-1, 3, img_size, img_size)
            gen_data_y = np.argmax(gen_data_y, axis=1)

            # overall fid
            m_gen_all, s_gen_all = get_act(model, batch_size, gen_data_x)
            fid_value_all = calculate_frechet_distance(m_real_all, s_real_all, m_gen_all, s_gen_all)
            print("fid_value_all: ", fid_value_all)
            fid_values['overall'].append(np.round(fid_value_all, 3))
            
            # get group indices for gen_data
            group_indices = {}
            for group in [0, 1]:
                indices = np.where((gen_data_y == group))[0]
                group_indices[group] = indices.tolist()
                print(group_indices)

            for group, indices in group_indices.items():
                print("Y = ", group, "\tlen: ", len(indices))
                m_gen, s_gen = get_act(model, batch_size, gen_data_x[indices])
                fid_value = calculate_frechet_distance(m_real_all, s_real_all, m_gen, s_gen)
                fid_values[group].append(np.round(fid_value, 2))
                print("group: ", group, "fid_value: ", fid_value)


In [None]:
# result folder
result_file_folder = os.path.join(target_model)
os.makedirs(result_file_folder, exist_ok = True)

# save results
with open(os.path.join(result_file_folder, f'utility_result.txt'), 'w') as f:
    for k, v in fid_values.items():
        f.write(f'fid value for {k}: {v}\n')
        f.write(f'\tmean: {np.mean(v):.3f}\n')
        f.write(f'\tstd: {np.std(v):.3f}\n')