In [None]:
import sys
sys.path.insert(0, "./original/")

import json
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
import numpy as np

from argparse import Namespace
from pathlib import Path

from original.benchmark.comm import preprocess

# Figure 1

In [None]:
%%capture
policy_results_d = Path("logs/cifar100-ResNet20-4/augmentations")

# Load validation dataset for image visualization
validset = torchvision.datasets.CIFAR100(root="data/", train=False, download=True, transform=transforms.ToTensor())

In [None]:
spri = []
all_policies = [list(map(int, p.stem.split("-"))) for p in policy_results_d.glob("*[0-9].json")]
for policy in all_policies:
    policy_results_f = policy_results_d / f"{'-'.join(map(str, policy))}.json"
    with open(policy_results_f) as f:
        policy_results = json.load(f)
    spri.append(policy_results["S_pri"])
num_images = len(spri[0])
spri = np.mean(spri, axis=1)
the_best_policy = all_policies[spri.argmin()]

In [None]:
image_ids = list(range(num_images))
policies = [the_best_policy]
# policies = all_policies

grad_sim_t = []
grad_sim_not = []

no_policy_results_f = policy_results_d / ".json"
with open(no_policy_results_f) as f:
    no_policy_results = json.load(f)
grad_sim_not = [no_policy_results["grad_sim"][idx] for idx in image_ids]

for policy in policies:
    policy_results_f = policy_results_d / f"{'-'.join(map(str, policy))}.json"
    with open(policy_results_f) as f:
        policy_results = json.load(f)
    grad_sim_t.extend([policy_results["grad_sim"][idx] for idx in image_ids])

grad_sim_t = np.mean(grad_sim_t, axis=0)
grad_sim_not = np.mean(grad_sim_not, axis=0)

if len(image_ids) == 1:
    sample_list = [200 + i * 5 for i in range(100)]
    img = validset[sample_list[image_ids[0]]][0].numpy().transpose(1, 2, 0)
    plt.imshow(img)
    plt.show()

def plot_figure1(grad_sim_t, grad_sim_not, bins=20):
    plt.plot(grad_sim_t, label="w/. privacy-aware transform")
    plt.plot(grad_sim_not, "-.", label="w/o any transform")

    plt.xlabel('i')
    plt.ylabel('GradSim')
    ticks_x = list(range(0, bins+1, 5))
    labels_x = [f"{i}/{bins}" for i in ticks_x]
    plt.xticks(ticks_x, labels_x)
    plt.xlim(ticks_x[0], ticks_x[-1])
    plt.ylim(0.00, 1.00)
    plt.tight_layout()
    plt.legend()
    plt.show()

plot_figure1(grad_sim_t, grad_sim_not)

# Figure 2

# Figure 3

# Figure 4

# Figure 5

In [None]:
single_policies_idx = [i for i, p in enumerate(all_policies) if len(p) == 1]
transformation_spri = [[] for _ in range(len(single_policies_idx))]
for i in single_policies_idx:
    policy = all_policies[i]
    policy_spri = spri[i]
    for t in policy:
        transformation_spri[t].append(policy_spri)
transformation_spri = [np.mean(t) for t in transformation_spri]

def plot_figur5(transformation_spri):
    num_trans = len(transformation_spri)
    transformation_spri = np.array(transformation_spri)
    best_5 = transformation_spri.argsort()[:5]
    colors = [("r" if i in best_5 else "gray") for i in range(num_trans)]
    plt.bar(range(num_trans), transformation_spri, color=colors)
    plt.xlabel('Transformation index')
    plt.ylabel('S_pri')
    plt.ylim(0.25, 0.45)
    plt.yticks(np.linspace(0.25, 0.45, 5))
    plt.show()
    print("5 best transformations:", best_5)

plot_figur5(transformation_spri)

# Figure 6

# Table 1

# Table 2

# Table 3

# Table 4