In [None]:
from wandb_utils import (
    get_normalized_arch_values_by_edge_df,
    get_normalized_arch_values_by_op_df,
    get_cell_grad_norm_df,
    get_arch_param_grad_norm_df,
    get_arch_param_grad_norm_by_edge_df,
    get_skip_connections_df,
    get_mean_gradient_matching_score_df,
    get_benchmark_test_acc_df,
    get_layer_alignment_scores_all_cells_df,
    get_layer_alignment_scores_first_and_last_cells_df,
    get_wandb_runs_as_dfs,
)

from df_utils import (
    calculate_mean_std,
    clean_dfs,
)

from plot_utils import (
    plot_everything,
    plot_line_chart_with_std_dev,
)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
meta_infos = {
    # DARTS
    0:"'DARTS-Baseline'",
    1:"DARTS-OLES-Threshold-Ablation",
    2:"DARTS-LoRA-Rank-Ablation",
    3:"DARTS-LoRA-Warmup-Ablation",

    # DARTS with pruning at 50%
    4:"'DARTS-Prune-Baseline'",
    5:"DARTS-Prune-LoRA-Rank-Ablation",
    6:"DARTS-Prune-LoRA-Warmup-Ablation",

    # DrNAS with partial connections (k=6), pruning at 50%, 8 cells, 36 initial channels, regularization term
    7:"'DrNAS-Baseline'",
    8:"DrNAS-OLES-Threshold-Ablation",
    9:"DrNAS-LoRA-Rank-Ablation",
    10:"DrNAS-LoRA-Warmup-Ablation",

    # DrNAS without progressive learning (no partial connections, no pruning, 8 cells, init channels 16, regularization term)
    11:"'DrNAS-Baseline-Basic'",
    12:"DrNAS-OLES-Threshold-Ablation",
    13:"DrNAS-LoRA-Rank-Ablation-Basic",
    14:"DrNAS-LoRA-Warmup-Ablation-Baic",
}

In [None]:
# Get the runs using the wansb filters
# The most convenient filter to use is "meta_info", which is shown in Wandb as extra:meta-info.
# See the experiments sheet for the meta_info values
dfs = get_wandb_runs_as_dfs(
    state="finished",
    meta_info=meta_infos[3],
    lora_rank=1,
    lora_warmup=8,
    # oles=True,
    # oles_threshold=0.3,
    # seed=0,
)

# Every df summarizes one run
# Clean them first - remove columns with all NaNs and other non-numerical values
dfs = clean_dfs(dfs)

# Now we can calculate the mean and std of the dfs
mean_df, std_df = calculate_mean_std(*dfs)
mean_df.shape, std_df.shape

# Now we can plot the data
plot_everything(mean_df, std_df, meta_infos[3], 15)

In [None]:
def plot_everything_for_meta_info(meta_info, lora_rank=None, lora_warmup=None, oles=None, oles_threshold=None, seed=None, start_epoch_for_plots=0):
    """
        Helper function to plot everything for a given meta_info.
        Remember that if you provide only the meta-info, it will average
        over all the ablation experiments of that type, and not just different seeds of the same experiment.
        E.g., DARTS-LoRA-Rank-Ablation will average over all the seeds of the DARTS-LoRA-Rank-Ablation experiments for all the different ranks.
    """
    dfs = get_wandb_runs_as_dfs(
        state="finished",
        meta_info=meta_info,
        lora_rank=lora_rank,
        lora_warmup=lora_warmup,
        oles=oles,
        oles_threshold=oles_threshold,
        seed=seed,
    )
    dfs = clean_dfs(dfs)
    mean_df, std_df = calculate_mean_std(*dfs)
    plot_everything(mean_df, std_df, meta_info, start_epoch_for_plots)


In [None]:
plot_everything_for_meta_info(meta_infos[7])

In [None]:
plot_everything_for_meta_info(meta_infos[10], lora_warmup=4)