In [27]:
import os 

from wandb_utils import (
    get_normalized_arch_values_by_edge_df,
    get_normalized_arch_values_by_op_df,
    get_cell_grad_norm_df,
    get_arch_param_grad_norm_df,
    get_arch_param_grad_norm_by_edge_df,
    get_skip_connections_df,
    get_mean_gradient_matching_score_df,
    get_benchmark_test_acc_df,
    get_layer_alignment_scores_all_cells_df,
    get_layer_alignment_scores_first_and_last_cells_df,
    get_wandb_runs_as_dfs,
)

from df_utils import (
    calculate_mean_std,
    concat_dfs_with_column_prefixes,
    clean_dfs,
    get_arch_parameters,
    get_darts_genotype,
)

from plot_utils import (
    plot_everything,
    plot_arch_values_by_edges,
    plot_gradient_matching_scores,
    plot_benchmark_test_acc,
    plot_skip_connections,
    plot_layer_alignment_scores,
    plot_line_chart_with_std_dev,
)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [28]:
meta_infos = {
    # DARTS
    0: "'DARTS-Baseline'",
    1: "DARTS-OLES-Threshold-Ablation",
    2: "DARTS-LoRA-Rank-Ablation",
    3: "DARTS-LoRA-Warmup-Ablation",
    # DARTS with pruning at 50%
    4: "'DARTS-Prune-Baseline'",
    5: "DARTS-Prune-LoRA-Rank-Ablation",
    6: "DARTS-Prune-LoRA-Warmup-Ablation",
    # DrNAS with partial connections (k=6), pruning at 50%, 8 cells, 36 initial channels, regularization term
    7: "'DrNAS-Baseline'",
    8: "DrNAS-OLES-Threshold-Ablation",
    9: "DrNAS-LoRA-Rank-Ablation",
    10: "DrNAS-LoRA-Warmup-Ablation",
    # DrNAS without progressive learning (no partial connections, no pruning, 8 cells, init channels 16, regularization term)
    11: "'DrNAS-Baseline-Basic'",
    12: "DrNAS-OLES-Threshold-Ablation-Basic",
    13: "DrNAS-LoRA-Rank-Ablation-Basic",
    14: "DrNAS-LoRA-Warmup-Ablation-Basic",
    # DARTS on NB201
    15: "NB201-DARTS-Baseline",
    16: "NB201-DARTS-OLES",
    17: "NB201-DARTS-LoRA",
    
}

In [None]:
darts_base_dfs = get_wandb_runs_as_dfs(
    state="finished",
    meta_info=meta_infos[0],
    # lora_rank=1,
    # lora_warmup=8,
    # oles=True,
    # oles_threshold=0.3,
    # seed=0,
)

darts_base_dfs = clean_dfs(darts_base_dfs)
darts_base_mean_df, darts_base_std_df = calculate_mean_std(*darts_base_dfs)

In [None]:
oles_darts_base_dfs = get_wandb_runs_as_dfs(
    state="finished",
    meta_info=meta_infos[1],
    # lora_rank=1,
    # lora_warmup=16,
    oles=True,
    oles_threshold=0.3, # 0.3 for DARTS, 0.7 for NB201
    # seed=0,
)

oles_darts_base_dfs = clean_dfs(oles_darts_base_dfs)
oles_darts_base_mean_df, oles_darts_base_std_df = calculate_mean_std(*oles_darts_base_dfs)

In [None]:
lora_darts_base_dfs = get_wandb_runs_as_dfs(
    state="finished",
    meta_info=meta_infos[2],
    lora_rank=1,
    lora_warmup=16,
    # oles=True,
    # oles_threshold=0.3,
    # seed=0,
)

lora_darts_base_dfs = clean_dfs(lora_darts_base_dfs)
lora_darts_base_mean_df, lora_darts_base_std_df = calculate_mean_std(*lora_darts_base_dfs)
lora_darts_base_mean_df.shape

In [30]:
def fetch_genotypes(
        expected_last_epoch,
        meta_info,
        lora_rank=None,
        lora_warmup=None,
        oles=None,
        oles_threshold=None,
        seed=None,
):
    dfs = get_wandb_runs_as_dfs(
        state="finished",
        meta_info=meta_info,
        lora_rank=lora_rank,
        lora_warmup=lora_warmup,
        oles=oles,
        oles_threshold=oles_threshold,
        seed=seed,
    )

    genotypes = {}
    for idx, df in enumerate(dfs):
        arch_params = get_arch_parameters(df, has_reduction_cell=True, expected_last_epoch=expected_last_epoch)
        genotype = get_darts_genotype(*arch_params)
        genotypes[idx] = genotype
        print(f"Genotype {idx}: {genotype}")

    # make a directory with the name of the meta_info
    folder = f"genotypes/{meta_info}"

    os.makedirs(folder)
    print(f"Writing genotypes to {folder}")
    for idx, genotype in genotypes.items():
        os.makedirs(f"{folder}/{idx}")
        with open(f"{folder}/{idx}/genotype.txt", "w") as f:
            f.write(genotype)


In [None]:
fetch_genotypes(
    49,
    meta_info=meta_infos[2],
    lora_rank=1,
    lora_warmup=16,
    # oles=True,
    # oles_threshold=0.3,
    # seed=0,
)

In [None]:
fetch_genotypes(
    49,
    meta_info=meta_infos[1],
    # lora_rank=1,
    # lora_warmup=16,
    oles=True,
    oles_threshold=0.3, # 0.3 for DARTS, 0.7 for NB201
    # seed=0,
)


In [None]:
fetch_genotypes(
    49,
    meta_info=meta_infos[0],
    # lora_rank=1,
    # lora_warmup=8,
    # oles=True,
    # oles_threshold=0.3,
    # seed=0,
)

In [None]:
# zero_df = pd.DataFrame(np.zeros(darts_base_std_df.shape), columns=darts_base_std_df.columns, index=darts_base_std_df.index)

# seed = 0
# edge = 0
# for edge in range(14):
#     plot_arch_values_by_edges(darts_base_dfs[seed], zero_df, [edge], f"DARTS-Prune-Baseline-Edge-{edge}", cell_types=("normal",), ignore_ops=[0])
#     # plot_arch_values_by_edges(oles_darts_base_dfs[seed], zero_df, [edge], f"OLES-DARTS-Edge-{edge}", cell_types=("normal",),  ignore_ops=[0])
#     plot_arch_values_by_edges(lora_darts_base_dfs[seed], zero_df, [edge], f"LoRA-DARTS-Prune-Edge-{edge}", cell_types=("normal",),  ignore_ops=[0])

In [None]:
m = concat_dfs_with_column_prefixes([darts_base_mean_df, oles_darts_base_mean_df, lora_darts_base_mean_df], ["DARTS-Baseline", "OLES-DARTS", "LoRA-DARTS"])
s = concat_dfs_with_column_prefixes([darts_base_std_df, oles_darts_base_std_df, lora_darts_base_std_df], ["DARTS-Baseline", "OLES-DARTS", "LoRA-DARTS"])

In [None]:
# plot_gradient_matching_scores(darts_base_mean_df, darts_base_std_df, "DARTS-Baseline")
# plot_gradient_matching_scores(oles_darts_base_mean_df, oles_darts_base_std_df, "OLES-DARTS")
# plot_gradient_matching_scores(lora_darts_base_mean_df, lora_darts_base_std_df, "LoRA-DARTS")
plot_gradient_matching_scores(m, s, "Gradient Matching Scores")

In [None]:
# plot_skip_connections(darts_base_mean_df, darts_base_std_df, "DARTS-Baseline")
# plot_skip_connections(oles_darts_base_mean_df, oles_darts_base_std_df, "OLES-DARTS")
# plot_skip_connections(lora_darts_base_mean_df, lora_darts_base_std_df, "LoRA-DARTS")
plot_skip_connections(m, s, "Skip Connections")

In [None]:
# plot_benchmark_test_acc(darts_base_mean_df, darts_base_std_df, "DARTS-Baseline")
# plot_benchmark_test_acc(lora_darts_base_mean_df, lora_darts_base_std_df, "LoRA-DARTS")
# plot_benchmark_test_acc(oles_darts_base_mean_df, oles_darts_base_std_df, "OLES-LoRA-DARTS")
plot_benchmark_test_acc(m, s, "Benchmark Test Accuracy")

In [None]:
plot_layer_alignment_scores(m, s, "Layer Alignment Scores")