In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import optuna
import pandas as pd
import numpy as np
from graphviz import Digraph

pd.set_option('display.max_columns', 200)
pd.set_option('display.max_rows', 20)

# Cteate study object

In [None]:
study = optuna.create_study(storage='sqlite:///result/optuna.db',
                            study_name='experiment01',
                            #pruner=pruner,
                            direction="minimize",
                            load_if_exists=True)

# Watch

In [None]:
optuna.visualization.plot_intermediate_values(study)

# Graph

In [None]:
num_models = 3
epochs = 200
target_model = "ResNet32"

In [None]:
df = study.trials_dataframe()
df = df[["number", "state", "value", "datetime_start","datetime_complete","params","user_attrs"]]

complete = df[df["state"] == optuna.structs.TrialState.COMPLETE]
params = complete["params"]

model = params.loc[:,params.columns.str.contains("model")]

model_acc = []
for id_ in complete["number"]:
    model_acc += [[pd.read_csv(f"./result/{id_:04d}/log/net{i}/epoch_log.csv",
                               index_col="Unnamed: 0").at[epochs, "test_accuracy"] for i in range(num_models)]]
    
model_acc = pd.DataFrame(model_acc, index=model.index, columns=[f"model_{i}_acc" for i in range(num_models)])

sorted_df = complete.sort_values(by="value")
sorted_acc = sorted_df["value"]

sorted_df = pd.concat([sorted_acc, model.loc[sorted_acc.index], model_acc.loc[sorted_acc.index]], axis=1).rename(columns={0:"max_accuracy"})
sorted_df

In [None]:
top = 0

In [None]:
### Loss
loss = params.loc[:, params.columns.str.contains("loss")].loc[sorted_df.index[top]]
model_name = params.loc[:, params.columns.str.contains("model")].loc[sorted_df.index[top]]

wh = int(np.sqrt(len(loss)))
df_loss = pd.DataFrame(loss.values.reshape((wh,wh)), columns=model_name, index=model_name)
df_loss

### Gate
gate = params.loc[:, params.columns.str.contains("gate")].loc[sorted_df.index[top]]
df_gate = pd.DataFrame(df_loss.copy()).applymap(lambda x :None)

for gate_name, val in gate.to_dict().items():
    source, target, _ = gate_name.split("_")
    df_gate.iloc[int(source), int(target)] = val

edge_color = {
    "ThroughGate": "3", 
    "LinearGate": "1", 
    "CorrectGate": "2"
}

G = Digraph(format="pdf", engine="dot")

acc = model_acc.loc[sorted_acc.index].iloc[top]
for target in range(len(df_loss)):    
    G.node(f"{target+1}. "+df_loss.index[target]+f" ({acc[target]}%)",
           color='gray90', fillcolor='gray90', style='filled')
G.node(f"{1}. "+df_loss.index[0]+f" ({acc[0]}%)", color='pink', fillcolor='pink', style='radial')
    
for target in range(len(df_loss)):
    for source in range(len(df_loss)):
        gate = df_gate.iloc[target,source]
        if gate != "CutoffGate":
            label = df_gate.iloc[target, source].replace("Gate","")                
            if source == target:
                if gate == "CorrectGate":
                    gate = "ThroughGate"
                    label = "Through"
                G.edge(f"{target}",
                       f"{target+1}. "+df_loss.index[target]+f" ({acc[target]}%)",
                       label=label, fontsize="13", fontcolor=edge_color[gate],
                       color=edge_color[gate], colorscheme="dark28")
                G.node(f"{target}", label="Label", color='white', style='filled')
            else:
                G.edge(f"{source+1}. "+df_loss.columns[source]+f" ({acc[source]}%)",
                       f"{target+1}. "+df_loss.index[target]+f" ({acc[target]}%)",
                       label=label, fontsize="13", fontcolor=edge_color[gate],
                       color=edge_color[gate], colorscheme="dark28")

                
for target in range(len(df_loss)):
    if (df_gate.iloc[target] == "CutoffGate").all():
        G.node(f"{target+1}. "+df_loss.columns[target]+f" ({acc[target]}%)",
               color='lightblue',fillcolor='lightblue', style='radial')

G.render(filename=f"{top}", directory="./topn_graph", cleanup=True, format="pdf")
G