In [None]:
import os
import pickle
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider, FloatLogSlider, FloatText
import numpy as np

In [None]:
def unpack_pickles(directory):
    hyperparam_hist = []
    test_loss_hist = []
    evaluation_hist = []
    killed_param_hist = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith('.pickle'):
                filepath = os.path.join(root, file)
                with open(filepath, 'rb') as f:
                    data = pickle.load(f)

                evaluation_hist.append(data['evaluation_hist'])
                hyperparam_hist.append(data['hp'])
                test_loss_hist.append(data['test_loss_hist'])
                killed_param_hist.append(data['killed_param_hist'])
            # From the bootstrap file
            # pickle.dump({'hp': hp, 'test_loss_hist': test_loss_hist,
            # 'evaluation_hist': evaluation_hist,
            # 'killed_param_hist': killed_param_hist}, handle)
                        
    return hyperparam_hist, test_loss_hist, evaluation_hist, killed_param_hist


In [None]:
search_dir = 'output/'
hyperparam_list, test_loss_hist, evaluation_hist, killed_param_hist = unpack_pickles(search_dir)
# print(all_lists)
test_loss_hist = np.array(test_loss_hist)

# per_cutoff_losses is a list: [[losses_at_cutoff_i], [losses_at_cutoff_i+1], ...]
ensemble_size = len(test_loss_hist)
no_cutoffs = len(test_loss_hist[0])

# Cutoffs:
# IMPORTANT NOTE: I've added a 0 to the beginning here as there's one eval with no cutoff
cutoffs = np.array([0, 1e-4, 5e-3, 7.5e-3, 1e-2, 5e-2, 7.5e-2,
                    1e-1, 5e-1, 7.5e-1, 1, 5, 7.5, 10, 50, 75,
                    100, 5e2, 7.5e2, 1e3, 5e3, 7.5e3, 1e4,
                    5e4, 7.5e4, 1e5])

### Helper function

In [None]:
def find_outlier_indices(data, top_percentile=95, bottom_percentile=5):
    """
    Find the indices of outliers in the top and bottom percentiles of the data, and indices of non-outliers.

    Parameters:
        data (array-like): The input data.
        top_percentile (float): The percentile above which values are considered outliers. Default is 95.
        bottom_percentile (float): The percentile below which values are considered outliers. Default is 5.

    Returns:
        tuple: A tuple containing three arrays:
            - Indices of top outliers
            - Indices of bottom outliers
            - Indices of non-outliers
    """
    top_threshold = np.percentile(data, top_percentile)
    bottom_threshold = np.percentile(data, bottom_percentile)

    top_outliers_indices = np.where(data > top_threshold)[0]
    bottom_outliers_indices = np.where(data < bottom_threshold)[0]

    all_indices = np.arange(len(data))
    non_outliers_indices = np.setdiff1d(all_indices, np.concatenate([top_outliers_indices, bottom_outliers_indices]))

    return top_outliers_indices, bottom_outliers_indices, non_outliers_indices

## Losses

In [None]:
def plot_losses_var_cutoff(idx, feature):
    plt.figure(figsize=(8, 6))
    # Feature is either a list of lists or a list of floats
    plt.scatter(np.arange(len(feature[:, idx])), feature[:, idx], marker='o', color='b')
    # plt.title(f'Losses for cutoff {cutoffs[idx]}')
    plt.xlabel('Ensemble member')
    plt.ylabel('MSE Loss')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f'plots/relu_act/act-relu_cutoff-{cutoffs[idx]}_first_ts.pdf')
    plt.show()

In [None]:
print('Ensemble size: ', ensemble_size)
print('No cutoffs: ', no_cutoffs)

print('Pre-average shape: ', test_loss_hist.shape)
test_loss_hist_averaged = np.average(test_loss_hist, axis=2)
print('Post-average shape: ', test_loss_hist_averaged.shape)
mask = test_loss_hist_averaged[:, 0] < 1e3
test_loss_hist_averaged = test_loss_hist_averaged[mask]
print(test_loss_hist_averaged.shape)

In [None]:
plot_elt_losses = lambda idx: plot_losses_var_cutoff(idx, test_loss_hist_averaged)

# Slider widget
slider = IntSlider(min=0, max=no_cutoffs-1, value=0, description='Cutoff')

# Interactive plot
interact(plot_elt_losses, idx=slider)

In [None]:
def plot_loss_hist(cutoff_idx, bins, bottom_percentile, top_percentile, feature):
    plt.figure(figsize=(8, 6))
    # Feature is either a list of lists or a list of floats
    feat_indices = find_outlier_indices(feature[:, cutoff_idx], top_percentile=top_percentile, bottom_percentile=bottom_percentile)[2]
    plt.hist(feature[feat_indices, cutoff_idx], color='b', bins=bins)
    # plt.title(f'Ensemble evaluations for cutoff $\Lambda:${cutoffs[cutoff_idx]: .2f}')
    plt.xlabel('Rosenbrock evaluation')
    plt.ylabel('Relative frequency')
    # plt.axhline(float(y_true), color='orange', label=f'True y: {float(y_true): .2f}')
    # average_learnt_y = np.mean(feature[:, idx])
    # plt.axhline(average_learnt_y, color='pink', label=f'Ensemble avg learnt y: {average_learnt_y: .2f}')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f'plots/relu_act/act-relu_ensemble_eval_ts-1_cutoff-{cutoffs[cutoff_idx]}_outlier-t-{top_percentile}-b-{bottom_percentile}.pdf')
    # plt.legend()
    plt.show()

In [None]:
plot_loss_evals_hist = lambda cutoff_idx, bins, bottom_percentile, top_percentile: plot_loss_hist(cutoff_idx, bins, bottom_percentile, top_percentile, test_loss_hist_averaged)

# Slider widget
cutoff_slider = IntSlider(min=0, max=no_cutoffs-1, value=0, description='Cutoff')
bins = IntSlider(min=1, max=100, value=50, description='Histogram bins')
bottom_percentile = FloatText(min=0, max=100, value=1, description='Bottom acceptance percentile')
top_percentile = FloatText(min=0, max=100, value=99, description='Top acceptance percentile')

# Interactive plot
interact(plot_loss_evals_hist, cutoff_idx=cutoff_slider, bins=bins, bottom_percentile=bottom_percentile, top_percentile=top_percentile)

## Average cutoff losses for the test samples ensemble 
*(averaged over ensemble and test samples for each ensemble member)*

In [None]:
# Perform the averaging
test_loss_hist_twice_averaged = np.average(test_loss_hist_averaged, axis=0)
test_loss_hist.shape

In [None]:
plt.figure(figsize=(8, 6))
# Feature is either a list of lists or a list of floats
plt.scatter(cutoffs, test_loss_hist_twice_averaged, marker='o', color='b')
# plt.title('Average test loss/cutoff')
plt.xlabel('Cutoff')
plt.xscale('log')
plt.ylabel('MSE Loss')
plt.grid(True)
plt.tight_layout()
plt.savefig('plots/relu_act/average_loss_per_cutoff.pdf')
plt.show()

## Network output

We'll choose just a single test evaulation (wlog the first in each ensemble member) and evaluate the flow as the cutoff is increased.

In [None]:
from datasets import MultivariateGaussianDataset

In [None]:
# Check the test data seed is the same for a couple of ensemble members
num_test_samples = hyperparam_list[0].num_test_samples
means_x = hyperparam_list[0].means_x
cov_x = hyperparam_list[0].cov_x
test_seed = hyperparam_list[0].test_seed
learnable_func = hyperparam_list[0].learnable_func

test_x_dataset = MultivariateGaussianDataset(num_samples=num_test_samples,
                                             means=means_x, cov=cov_x,
                                             seed=test_seed)
examination_data = test_x_dataset[40]
y_true = learnable_func(examination_data)

In [None]:
import torch

In [None]:
single_test_sample_evaluation_hist = torch.tensor(evaluation_hist)[:, :, 40].numpy()
single_test_sample_evaluation_hist = single_test_sample_evaluation_hist[mask]

# Remove the poorly trained models
single_test_sample_evaluation_hist.shape

In [None]:
def plot_eval_cutoff(idx, feature):
    plt.figure(figsize=(8, 6))
    # Feature is either a list of lists or a list of floats
    plt.scatter(np.arange(len(feature[:, idx])), feature[:, idx], marker='o', color='b')
    plt.title(f'Ensemble evaluations for {cutoffs[idx]: .2f}')
    plt.xlabel('Ensemble member index')
    plt.ylabel('Test function evaluation')
    plt.axhline(float(y_true), color='orange', label=f'True y: {float(y_true): .2f}')
    average_learnt_y = np.mean(feature[:, idx])
    plt.axhline(average_learnt_y, color='pink', label=f'Ensemble avg learnt y: {average_learnt_y: .2f}')
    plt.grid(True)
    plt.legend()
    plt.show()

In [None]:
plot_elt_evals = lambda idx: plot_eval_cutoff(idx, single_test_sample_evaluation_hist)

# Slider widget
slider = IntSlider(min=0, max=no_cutoffs-1, value=0, description='Cutoff')

# Interactive plot
interact(plot_elt_evals, idx=slider)

In [None]:
def plot_eval_hist(cutoff_idx, bins, bottom_percentile, top_percentile, feature):
    plt.figure(figsize=(8, 6))
    # Feature is either a list of lists or a list of floats
    feat_indices = find_outlier_indices(feature[:, cutoff_idx], top_percentile=top_percentile, bottom_percentile=bottom_percentile)[2]
    plt.hist(feature[feat_indices, cutoff_idx], color='b', bins=bins)
    plt.title(f'Ensemble evaluations for cutoff $\Lambda:${cutoffs[cutoff_idx]: .2f}')
    plt.xlabel('Rosenbrock evaluation')
    plt.ylabel('Relative frequency')
    # plt.axhline(float(y_true), color='orange', label=f'True y: {float(y_true): .2f}')
    # average_learnt_y = np.mean(feature[:, idx])
    # plt.axhline(average_learnt_y, color='pink', label=f'Ensemble avg learnt y: {average_learnt_y: .2f}')
    plt.grid(True)
    # plt.legend()
    plt.show()

In [None]:
plot_elt_evals = lambda cutoff_idx, bins, bottom_percentile, top_percentile: plot_eval_hist(cutoff_idx, bins, bottom_percentile, top_percentile, single_test_sample_evaluation_hist)

# Slider widget
cutoff_slider = IntSlider(min=0, max=no_cutoffs-1, value=0, description='Cutoff')
bins = IntSlider(min=1, max=100, value=50, description='Histogram bins')
bottom_percentile = FloatText(min=0, max=100, value=1, description='Bottom acceptance percentile')
top_percentile = FloatText(min=0, max=100, value=99, description='Top acceptance percentile')

# Interactive plot
interact(plot_elt_evals, cutoff_idx=cutoff_slider, bins=bins, bottom_percentile=bottom_percentile, top_percentile=top_percentile)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm

def plot_final_eval_hist_fit(cutoff_idx, bins, bottom_percentile, top_percentile, feature):
    plt.figure(figsize=(8, 6))
    # Feature is either a list of lists or a list of floats
    feat_indices = find_outlier_indices(feature[:, cutoff_idx], top_percentile=top_percentile, bottom_percentile=bottom_percentile)[2]
    hist_values, bin_edges, _ = plt.hist(feature[feat_indices, cutoff_idx], color='b', bins=bins, density=False)

    # Fit a Gaussian to the histogram
    mu, sigma = norm.fit(feature[feat_indices, cutoff_idx])
    # Symmetric axis code
    # get y-axis limits of the plot
    low, high = plt.xlim()
    # find the new limits
    bound = max(abs(low), abs(high))
    # set new limits
    plt.xlim(-bound, bound)
    x = np.linspace(-bound, bound, 5000)
    fitted_curve = norm.pdf(x, mu, sigma) * (bin_edges[1] - bin_edges[0]) * len(feature[feat_indices, cutoff_idx])
    plt.plot(x, fitted_curve, 'r-', color='orange', label=f'Fitted Gaussian (μ={mu:.2f}, σ={sigma:.2f})')
    plt.axvline(y_true[0], color='pink', linewidth=3, label='True value')

    # plt.title(f'Ensemble evaluations for cutoff $\Lambda:${cutoffs[cutoff_idx]: .2f}')
    plt.xlabel('Test function evaluation')
    plt.ylabel('Relative frequency')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    # plt.xlim((-100, 100))
    


    plt.savefig(f'plots/relu_quadratic/relu_quadratic_cutoff_{cutoffs[cutoff_idx]}.pdf')
    plt.show()


In [None]:
plot_elt_evals = lambda cutoff_idx, bins, bottom_percentile, top_percentile: plot_final_eval_hist_fit(cutoff_idx, bins, bottom_percentile, top_percentile, single_test_sample_evaluation_hist)

# Slider widget
cutoff_slider = IntSlider(min=0, max=no_cutoffs-1, value=0, description='Cutoff')
bins = IntSlider(min=1, max=1000, value=50, description='Histogram bins')
bottom_percentile = FloatText(min=0, max=100, value=0, description='Bottom acceptance percentile')
top_percentile = FloatText(min=0, max=100, value=100, description='Top acceptance percentile')

# Interactive plot
interact(plot_elt_evals, cutoff_idx=cutoff_slider, bins=bins, bottom_percentile=bottom_percentile, top_percentile=top_percentile)