In [None]:
import sys
import os
import json

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

  
# append the path of the parent directory
sys.path.append("..")
from nonsmooth_implicit_diff import plot_utils
from nonsmooth_implicit_diff import utils
from collections import defaultdict


In [None]:
def load_json_files(folder_path):
    # Iterate through all subfolders in the given folder
    exps = []
    for root, dirs, files in os.walk(folder_path):
        for subdir in dirs:
            subfolder_path = os.path.join(root, subdir)
            
            # Load results.json
            results_file_path = os.path.join(subfolder_path, 'results.json')
            if os.path.exists(results_file_path):
                with open(results_file_path, 'r') as results_file:
                    results = json.load(results_file)
                    print(f"Loaded results.json from {subfolder_path}: {results}")

            # Load config.json
            config_file_path = os.path.join(subfolder_path, 'conf.json')
            if os.path.exists(config_file_path):
                with open(config_file_path, 'r') as config_file:
                    config = json.load(config_file)
                    print(f"Loaded conf.json from {subfolder_path}: {config}")
            exps.append(dict(config=config, results=results))
    return exps

folder_path = '../exps/data_poisoning_stochastic/'
exps = load_json_files(folder_path)

In [None]:
exps[0]['results']['hg_results']['fixed'].keys()

In [None]:
def assert_equal_except_key(dict1, dict2, key_to_ignore):
    def check_equal(v1, v2):
        if isinstance(v1, list):
            return True
        if isinstance(v1, float):
            return abs(v1-v2)/max(abs(v1),abs(v2)) < 1
        else:
            return v1==v2

    # assert dict1.keys() == dict2.keys()
    for k, v in dict1.items():
        if k in dict2:
            if k != key_to_ignore:
                assert check_equal(v,dict2[k]), f"Values for key {k} not equal! {v}, {dict2[k]}"


# Preprocess: group the seeds
e1 = exps[0]
conf = e1['config']

for e in exps:
    c, r = e['config'], e['results']
    assert_equal_except_key(conf, c, "random_state") 

hg_results = dict()
for method_name, metrics_dict in e1['results']['hg_results'].items():
    hg_results[method_name] = defaultdict(list)
    for metric_name in metrics_dict:
        for e in exps:
            c, r = e['config'], e['results']
            hg_results[method_name][metric_name].append(r['hg_results'][method_name][metric_name])
        
        hg_results[method_name][metric_name] = np.array(hg_results[method_name][metric_name])
hg_results.keys()


In [None]:
hg_results['fixed'].keys()

In [None]:
# Sample data for the two line plots
config = exps[0]['config']
hparams = dict(alpha_l1=config['alpha_l1'], alpha_l2=config['alpha_l2'])
print(hparams)

lines_to_plot = (
    # ('ITD', hg_results['reverse'], 'solid'),
    # ('AID-CG', hg_results['CG'], 'solid'),
    ('AID-FP', hg_results['fixed'], 'solid'),
    ('NSID dec', hg_results['fixed_stoch_dec'], 'dashed'),
    ('NSID const', hg_results['fixed_stoch_const'], 'dotted'),
    # ('SID dec', hg_results['fixed_stoch_dec_no_g'], 'dotted'),
    # ('SID const', hg_results['fixed_stoch_const_no_g'], 'dashed'),
)

metrics = (
    ('norm_diff', 'Approximation error'),
    ('norm_diff_norm', 'Normalized approx. error'),
)


mult_size=1.2
for metric, metric_name in metrics:
    fig, ax = plt.subplots(figsize=(3*mult_size, 3*mult_size))
    
    for (name, res, style) in lines_to_plot:
        t = res['t'].mean(axis=0)
        n_samples = res['n_samples'].mean(axis=0)
        n_epochs = n_samples/int(config['n_samples']*(1-config['val_size']))
        # n_epochs = n_samples/10000
        
        line = res[metric]
        line_mean = line.mean(axis=0)
        
        
        # Calculate geometric standard deviation
        y_gstd = np.exp(np.std(np.log(line), axis=0))

        y_10th = np.percentile(line, 30, axis=0)
        y_90th = np.percentile(line, 70, axis=0)
        ax.plot(n_epochs, line_mean, label=name, linestyle=style, marker="o")
        ax.fill_between(n_epochs, line_mean/y_gstd, line_mean*y_gstd, alpha=0.2)


    ax.set_yscale('log')
    ax.set_title(f"Data Poisoning. $\lambda = ({config['alpha_l1']}, {config['alpha_l2']})$")
    
    ax.set_xlabel('# of epochs')
    # ax.set_ylabel(metric_name)
    ax.set_xlim((-1, 200))
    if metric=="norm_diff_norm":
        ax.set_ylim((7e-2, 1.1e0))
    elif metric=="norm_diff":
        ax.set_ylim((5e-5,7e-4))

    

    ax.legend()
    plt.tight_layout()
    #plt.savefig(f'poisoning_al1_{hparams["alpha_l1"]}_al2_{hparams["alpha_l2"]}_{metric}.pdf')

    plt.show()

In [None]:
config['n_samples']*(1-config['val_size'] - config['test_size'])

In [None]:
print(len(t), len(n_samples), len(line_mean))