In [None]:
'''
Using heuristics, find sensitive attribute (i.e., rotation) of generated data.
'''


import torch
import numpy as np


# config ===
dataset = 'mnist' # mnist, celeba, fmnist
gen_data_path = '...path to generated .npz data...'
target_model = 'gswgan'
# ==========

data_dir = '../dataset/'
print("Dataset: ", dataset)

# load real data
if dataset == 'mnist':
    unbiased_x = torch.load(f'{data_dir}/mnist/rotated/unbiased/train_data.pt') / 1.0 
    unbiased_y = torch.load(f'{data_dir}/mnist/rotated/unbiased/train_Y.pt')
    unbiased_A = torch.load(f'{data_dir}/mnist/rotated/unbiased/train_A.pt')
    group_to_digit = {'minor': 1, 'major': 3}
    digit_to_group = {1: 'minor', 3: 'major'}
    
elif dataset == 'fmnist':
    unbiased_x = torch.load(f'{data_dir}/fmnist/rotated/unbiased/train_data.pt') / 1.0
    unbiased_y = torch.load(f'{data_dir}/fmnist/rotated/unbiased/train_Y.pt')
    unbiased_A = torch.load(f'{data_dir}/fmnist/rotated/unbiased/train_A.pt')
    group_to_digit = {'minor': 1, 'major': 7}
    digit_to_group = {1: 'minor', 7: 'major'}


# data group
cln_1 = unbiased_x[(unbiased_y == group_to_digit['minor']) & (unbiased_A == 1)]
rot_1 = unbiased_x[(unbiased_y == group_to_digit['minor']) & (unbiased_A == 0)]
cln_3 = unbiased_x[(unbiased_y == group_to_digit['major']) & (unbiased_A == 1)]
rot_3 = unbiased_x[(unbiased_y == group_to_digit['major']) & (unbiased_A == 0)]


# get centroids
centroid_dict = {}
centroid_dict['cln_minor'] = torch.mean(torch.mean(cln_1, dim = 0), dim=0)
centroid_dict['rot_minor'] = torch.mean(torch.mean(rot_1, dim = 0), dim=0)
centroid_dict['cln_major'] = torch.mean(torch.mean(cln_3, dim = 0), dim=0)
centroid_dict['rot_major'] = torch.mean(torch.mean(rot_3, dim = 0), dim=0)

In [None]:
# load gen data

import joblib
import os
import numpy as np
import torch


save_dir = os.path.dirname(gen_data_path)
savename = os.path.basename(gen_data_path) + '_labeled'


# G-PATE ver.
if target_model == 'gpate':
    gen_data = joblib.load(gen_data_path)
    gen_data_x = gen_data[:, :-10] * 255
    gen_data_y = np.argmax(gen_data[:, -10:], axis=1)

    gen_data_x = torch.tensor(gen_data_x).float()
    gen_data_y = torch.tensor(gen_data_y).long()

elif target_model == 'datalens':
    gen_data_x = torch.tensor([])
    gen_data_y = torch.tensor([])
    for i in range(10):
        gen_data = joblib.load(gen_data_path)
        curr_gen_data_x = gen_data[:, :-10] * 255
        curr_gen_data_y = np.argmax(gen_data[:, -10:], axis=1)

        curr_gen_data_x = torch.tensor(curr_gen_data_x).float()
        curr_gen_data_y = torch.tensor(curr_gen_data_y)

        gen_data_x = torch.cat([gen_data_x, curr_gen_data_x], dim=0)
        gen_data_y = torch.cat([gen_data_y, curr_gen_data_y], dim=0)

        gen_data_path = gen_data_path.replace(f"-{i}.pkl", f"-{i+1}.pkl")

else:
    gen_data_x = np.load(gen_data_path)['data_x'] * 255
    gen_data_y = np.load(gen_data_path)['data_y']   

    gen_data_x = torch.tensor(gen_data_x).float()
    gen_data_y = torch.tensor(gen_data_y).long()



gen_data_y.unique(return_counts=True)

In [None]:
# plot test with z

import matplotlib.pyplot as plt
%matplotlib inline

def plot_img(sample):
    plt.imshow(sample.view(28, 28).cpu().numpy(), cmap='gray')
    plt.show()
    

num_samples = 5
for idx in range(num_samples):   
    plot_img(gen_data_x[idx])
    print(f'label: {gen_data_y[idx]}')
    print("=====================================")

In [None]:
indices = np.where(gen_data_y != 0)[0]

gen_data_x = gen_data_x[indices]
gen_data_y = gen_data_y[indices]

gen_data_y.unique(return_counts=True)

In [None]:
# get z label and save data


z = np.zeros_like(gen_data_y)


for idx in range(0, gen_data_x.shape[0]):
    sample = torch.mean(torch.mean(gen_data_x[idx], dim=-1), 0)
    y = gen_data_y[idx]

    dist_dict = {}
    for k, v in centroid_dict.items():
        if digit_to_group[int(gen_data_y[idx].item())] not in k:
            continue
        else:
            dist = torch.dist(sample, v, 2)
            dist_dict[k] = dist

    pred = min(dist_dict, key=dist_dict.get)

    if 'cln' in pred:
        z[idx] = 1
    
    
# save data to npz
print(np.unique(z, return_counts=True))
with open(os.path.join(save_dir, savename + '_z.txt'), 'w') as f:
    f.write(str(np.unique(z, return_counts=True)))
np.savez(os.path.join(save_dir, savename), data_z = z, data_x =gen_data_x / 255.0, data_y =gen_data_y)

In [None]:
# plot test with z

import matplotlib.pyplot as plt
%matplotlib inline
    

num_samples = 10
for idx in np.random.choice(gen_data_x.size(0), num_samples):
    plot_img(gen_data_x[idx])
    print(f'label: {gen_data_y[idx]}')
    print(f'z: {z[idx]}')
    print("=====================================")

In [None]:
# Reliability test (Real to Real)


cnt = 0

# find closest centroid
for idx in range(0, unbiased_x.shape[0]):
    sample = unbiased_x[idx].view(-1, 1).squeeze()

    dist_list = {}
    for k, v in centroid_dict.items():
        if digit_to_group[int(unbiased_y[idx])] not in k:
            continue
        else:
            dist = torch.dist(sample, v, 2)
            dist_dict[k] = dist

    pred = min(dist_list, key=dist_list.get)

    A_pred = 1 if "cln" in pred else 0
    y_pred = 3 if "3" in pred else 1

    if y_pred == int(unbiased_y[idx].item()) and A_pred == int(unbiased_A[idx].item()):
        cnt+=1

print("Acc: ", cnt / unbiased_x.shape[0] * 100)

