In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from confopt.searchspace import (
    NASBench1Shot1SearchSpace,
)
from confopt.benchmarks import NB101Benchmark
from confopt.oneshot.archsampler.drnas.sampler import DRNASSampler

In [None]:
supernet = NASBench1Shot1SearchSpace("S3")
alphas = supernet.arch_parameters

benchmark_api = NB101Benchmark("full")

In [3]:
def sample_alphas_and_query(n_samples, benchmark_api):
    drnas_sampler = DRNASSampler(alphas)
    results = []

    for i in range(n_samples):
        new_alphas = drnas_sampler.sample(alphas)
        supernet.set_arch_parameters(new_alphas)
        genotype = supernet.get_genotype()
        result = benchmark_api.query(genotype)
        results.append(result)

    df = pd.DataFrame(results)
    return df

In [4]:
df1 = sample_alphas_and_query(500, benchmark_api)
df2 = sample_alphas_and_query(500, benchmark_api)

In [5]:
def is_pareto_efficient(costs):
    """
    Find the pareto-efficient points
    :param costs: An (n_points, n_costs) array
    :return: A (n_points, ) boolean array, indicating whether each point is Pareto efficient
    """
    is_efficient = np.ones(costs.shape[0], dtype = bool)
    for i, c in enumerate(costs):
        if is_efficient[i]:
            is_efficient[is_efficient] = np.any(costs[is_efficient] > c, axis=1)  # Keep any point with a lower cost
            is_efficient[i] = True  # And keep self
    return is_efficient

def plot_pareto_front(df, col1, col2, minimize_col1, minimize_col2):
    """
    Plot the Pareto front for two columns in a DataFrame
    :param df: pandas DataFrame
    :param col1: Name of the first column
    :param col2: Name of the second column
    """
    # Extract the two columns
    costs = df[[col1, col2]].values

    if minimize_col1:
        costs[:, 0] = -costs[:, 0]
    
    if minimize_col2:
        costs[:, 1] = -costs[:, 1]

    # Find the Pareto efficient points
    pareto_efficient_mask = is_pareto_efficient(costs)
    
    # Plot all points
    plt.scatter(df[col1], df[col2], c='blue', alpha=0.2, label='All points')
    
    # Plot Pareto front points
    pareto_front = df[pareto_efficient_mask]
    plt.scatter(pareto_front[col1], pareto_front[col2], c='red', label='Pareto front', alpha=1)
    
    # Connect Pareto front points
    pareto_front_sorted = pareto_front.sort_values(by=[col1])
    plt.plot(pareto_front_sorted[col1], pareto_front_sorted[col2], c='red', linestyle='--')
    
    plt.xlabel(col1)
    plt.ylabel(col2)
    plt.title(f'Pareto Front: {col1} vs {col2}')
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_compare_boxplots(df1, df2, column_name, df1_name='DataFrame 1', df2_name='DataFrame 2'):
    """
    Create boxplots for a specified column from two DataFrames.
    
    :param df1: First pandas DataFrame
    :param df2: Second pandas DataFrame
    :param column_name: Name of the column to plot
    :param df1_name: Name to label the first DataFrame (default: 'DataFrame 1')
    :param df2_name: Name to label the second DataFrame (default: 'DataFrame 2')
    """
    # Set the style for the plot
    sns.set_style("whitegrid")
    
    # Create the figure and axis objects
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Prepare data for plotting
    data = pd.concat([
        df1[column_name].rename(df1_name),
        df2[column_name].rename(df2_name)
    ], axis=1)
    
    # Create the boxplot
    sns.boxplot(data=data, ax=ax)
    
    # Set the title and labels
    plt.title(f"Comparison of '{column_name}' between two DataFrames", fontsize=16)
    plt.xlabel("DataFrames", fontsize=12)
    plt.ylabel(column_name, fontsize=12)
    
    # Rotate x-axis labels if they are too long
    plt.xticks(rotation=45, ha='right')

    ax.set_ylim(0.7, 1.0)
    
    # Adjust the layout and display the plot
    plt.tight_layout()
    plt.show()

In [None]:
plot_pareto_front(df1, "benchmark/trainable_parameters", "benchmark/test_top1", minimize_col1=True, minimize_col2=False)
plot_pareto_front(df2, "benchmark/trainable_parameters", "benchmark/test_top1", minimize_col1=True, minimize_col2=False)
plot_compare_boxplots(df1, df2, 'benchmark/test_top1', 'Trial 1', 'Trial 2')