In [None]:
from wandb_utils import (
    make_wandb_filters,
    fetch_runs,
    get_normalized_arch_values_by_edge,
    get_normalized_arch_values_by_op,
    get_cell_grad_norm,
    get_arch_param_grad_norm,
    get_arch_param_grad_norm_by_edge,
    get_skip_connections,
    get_mean_gradient_matching_score,
    get_benchmark_test_acc,
    get_layer_alignment_scores_all_cells,
    get_layer_alignment_scores_first_and_last_cells,
)

import matplotlib.pyplot as plt

In [None]:
filters = make_wandb_filters(
    state="finished",
    meta_info="'DrNAS-Baseline'",
    # lora_rank=0,
    # lora_warmup=16,
    # oles=True,
    # oles_threshold=0.3,
    # seed=0,
)

print(filters)

runs = fetch_runs(filters)

print(f"Found {len(runs)} runs")

runs = sorted(runs, key=lambda run: run.name)

for run in runs:
    print(run.name)

In [None]:
df = run.history()
df.head()

In [None]:
def plot_line_chart(df):
    # Plotting the line chart
    plt.figure(figsize=(10, 6))  # Set the figure size

    # Plot each column as a separate line
    for column in df.columns:
        plt.plot(df.index, df[column], label=column, linewidth=2)  # Customize line width

    # Add chart title and labels
    plt.title("Beautiful Line Chart", fontsize=16)
    plt.xlabel("Index", fontsize=12)
    plt.ylabel("Values", fontsize=12)

    # Customize ticks on the x and y axes
    plt.xticks(fontsize=10)
    plt.yticks(fontsize=10)

    # Add a legend to identify different columns
    plt.legend(title="Columns", fontsize=10, title_fontsize=12)

    # Add a grid for better readability
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)

    # Display the chart
    plt.tight_layout()  # Adjust layout for better spacing
    plt.show()




In [None]:
new_df = get_cell_grad_norm(df, 4)
plot_line_chart(new_df)

In [None]:
dfs = [get_cell_grad_norm(df, idx) for idx in range(8)]
cell_grads_df = pd.concat(dfs, axis=1)
plot_line_chart(cell_grads_df)

In [None]:
dfs = [get_arch_param_grad_norm(df, idx) for idx in ("normal", "reduce")]
cell_grads_df = pd.concat(dfs, axis=1)
plot_line_chart(cell_grads_df)


In [None]:
dfs = [get_skip_connections(df, idx) for idx in ("normal", "reduce")]
cell_grads_df = pd.concat(dfs, axis=1)
plot_line_chart(cell_grads_df)

In [None]:
dfs = [get_mean_gradient_matching_score(df)]
cell_grads_df = pd.concat(dfs, axis=1)
plot_line_chart(cell_grads_df)

In [None]:
# # pd.concat(get_normalized_arch_values_by_op(df, "reduce", 5)), axis=1)
# dfs = [get_normalized_arch_values_by_edge(df, "normal", idx) for idx in range(8)]
# cell_grads_df = pd.concat(dfs, axis=1)
# plot_line_chart(cell_grads_df)

In [None]:
plot_line_chart(get_layer_alignment_scores_all_cells(df, "reduce"))
plot_line_chart(get_layer_alignment_scores_first_and_last_cell(df, "reduce"))
