In [119]:
import matplotlib.pyplot as plt
%matplotlib inline
import optuna
import pandas as pd
import numpy as np
from graphviz import Digraph

pd.set_option('display.max_columns', 200)
pd.set_option('display.max_rows', 20)

# Cteate study object

In [120]:
%cd /workspace/
study = optuna.create_study(storage='sqlite:///result/optuna.db',
                            study_name='experiment01',
                            direction="minimize",
                            load_if_exists=True)

[I 2024-07-27 14:23:24,405] Using an existing study with name 'experiment01' instead of creating a new one.


/workspace


# Watch

In [183]:
optuna.visualization.plot_intermediate_values(study)

# Graph

In [177]:
num_models = 3
epochs = 10
target_model = "DeiT_Tiny"

In [178]:
df = study.trials_dataframe()

# Select the columns that exist in the DataFrame
available_columns = ["number", "state", "value", "datetime_start", "datetime_complete"]

# Add all columns that start with 'params_' and 'user_attrs_'
params_columns = [col for col in df.columns if col.startswith('params_')]
user_attrs_columns = [col for col in df.columns if col.startswith('user_attrs_')]

selected_columns = available_columns + params_columns + user_attrs_columns

df = df[selected_columns]

# Now you can proceed with your analysis
# complete = df[df["state"] == "COMPLETE"]
complete = df.iloc[122:123]

# Instead of using complete["params"], we'll use all columns that start with 'params_'
params = complete[params_columns]

# Filter for model-related parameters
model = params.loc[:, params.columns.str.contains("model")]

# The rest of your code remains the same
model_acc = []
for id_ in complete["number"]:
    model_acc += [[pd.read_csv(f"./result/{id_:04d}/log/net{i}/epoch_log.csv",
                               index_col="Unnamed: 0").at[epochs, "test_accuracy"] for i in range(num_models)]]
    
model_acc = pd.DataFrame(model_acc, index=model.index, columns=[f"model_{i}_acc" for i in range(num_models)])

sorted_df = complete.sort_values(by="value")
sorted_acc = sorted_df["value"]

sorted_df = pd.concat([sorted_acc, model.loc[sorted_acc.index], model_acc.loc[sorted_acc.index]], axis=1).rename(columns={0:"max_accuracy"})
sorted_df

Unnamed: 0,value,params_model_0_name,params_model_1_name,params_model_2_name,model_0_acc,model_1_acc,model_2_acc
122,,DeiT_Tiny,WRN28_2,WRN28_2,42.24,74.7,41.38


In [179]:
df

Unnamed: 0,number,state,value,datetime_start,datetime_complete,params_00_00_gate,params_00_00_loss,params_00_01_gate,params_00_01_loss,params_00_02_gate,params_00_02_loss,params_01_00_gate,params_01_00_loss,params_01_01_gate,params_01_01_loss,params_01_02_gate,params_01_02_loss,params_02_00_gate,params_02_00_loss,params_02_01_gate,params_02_01_loss,params_02_02_gate,params_02_02_loss,params_1_is_pretrained,params_2_is_pretrained,params_model_0_name,params_model_1_name,params_model_2_name,user_attrs_seed
0,0,PRUNED,91.28,2024-07-27 14:20:51.121296,2024-07-27 14:24:01.049416,ThroughGate,IndepLoss,CorrectGate,KLLoss,NegativeLinearGate,KLLoss,CorrectGate,KLLoss,CorrectGate,IndepLoss,ThroughGate,KLLoss,LinearGate,KLLoss,NegativeLinearGate,KLLoss,NegativeLinearGate,IndepLoss,0,1,DeiT_Tiny,DeiT_Small,ResNet32,0
1,1,PRUNED,74.76,2024-07-27 14:21:01.204812,2024-07-27 14:30:07.788592,ThroughGate,IndepLoss,ThroughGate,KLLoss,LinearGate,KLLoss,CutoffGate,KLLoss,CutoffGate,IndepLoss,ThroughGate,KLLoss,ThroughGate,KLLoss,ThroughGate,KLLoss,NegativeLinearGate,IndepLoss,1,0,DeiT_Tiny,DeiT_Tiny,ResNet32,0
2,2,PRUNED,90.88,2024-07-27 14:21:12.359759,2024-07-27 14:23:50.268436,NegativeLinearGate,IndepLoss,CutoffGate,KLLoss,ThroughGate,KLLoss,ThroughGate,KLLoss,ThroughGate,IndepLoss,ThroughGate,KLLoss,ThroughGate,KLLoss,NegativeLinearGate,KLLoss,NegativeLinearGate,IndepLoss,1,0,DeiT_Tiny,ResNet110,ResNet32,0
3,3,PRUNED,83.94,2024-07-27 14:21:21.067865,2024-07-27 14:27:49.837370,ThroughGate,IndepLoss,NegativeLinearGate,KLLoss,LinearGate,KLLoss,NegativeLinearGate,KLLoss,ThroughGate,IndepLoss,LinearGate,KLLoss,CutoffGate,KLLoss,NegativeLinearGate,KLLoss,NegativeLinearGate,IndepLoss,0,1,DeiT_Tiny,DeiT_Small,DeiT_Tiny,0
4,4,RUNNING,,2024-07-27 14:21:31.233980,NaT,CorrectGate,IndepLoss,CutoffGate,KLLoss,ThroughGate,KLLoss,CutoffGate,KLLoss,LinearGate,IndepLoss,CorrectGate,KLLoss,NegativeLinearGate,KLLoss,CorrectGate,KLLoss,CutoffGate,IndepLoss,1,1,DeiT_Tiny,ResNet32,DeiT_Tiny,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
154,154,RUNNING,,2024-07-27 15:24:49.809385,NaT,NegativeLinearGate,IndepLoss,ThroughGate,KLLoss,NegativeLinearGate,KLLoss,CorrectGate,KLLoss,CorrectGate,IndepLoss,CorrectGate,KLLoss,CorrectGate,KLLoss,NegativeLinearGate,KLLoss,LinearGate,IndepLoss,1,1,DeiT_Tiny,ResNet32,DeiT_Tiny,0
155,155,FAIL,,2024-07-27 15:26:15.835555,2024-07-27 15:26:19.161836,CutoffGate,IndepLoss,CutoffGate,KLLoss,CutoffGate,KLLoss,CorrectGate,KLLoss,ThroughGate,IndepLoss,NegativeLinearGate,KLLoss,CutoffGate,KLLoss,ThroughGate,KLLoss,NegativeLinearGate,IndepLoss,0,0,DeiT_Tiny,WRN28_2,WRN28_2,0
156,156,RUNNING,,2024-07-27 15:26:38.060902,NaT,ThroughGate,IndepLoss,NegativeLinearGate,KLLoss,CorrectGate,KLLoss,LinearGate,KLLoss,CorrectGate,IndepLoss,CutoffGate,KLLoss,CutoffGate,KLLoss,CorrectGate,KLLoss,LinearGate,IndepLoss,1,0,DeiT_Tiny,DeiT_Small,ResNet110,0
157,157,RUNNING,,2024-07-27 15:27:06.539494,NaT,LinearGate,IndepLoss,NegativeLinearGate,KLLoss,ThroughGate,KLLoss,ThroughGate,KLLoss,CutoffGate,IndepLoss,ThroughGate,KLLoss,ThroughGate,KLLoss,LinearGate,KLLoss,ThroughGate,IndepLoss,0,0,DeiT_Tiny,DeiT_Small,WRN28_2,0


In [180]:
import pandas as pd
import numpy as np
from graphviz import Digraph

def generate_graph(params, sorted_df, model_acc, sorted_acc, top=0):
    # Loss
    loss = params.loc[:, params.columns.str.contains("loss")].loc[sorted_df.index[top]]
    model_name = params.loc[:, params.columns.str.contains("model")].loc[sorted_df.index[top]]

    wh = int(np.sqrt(len(loss)))
    df_loss = pd.DataFrame(loss.values.reshape((wh,wh)), columns=model_name, index=model_name)

    # Gate
    gate = params.loc[:, params.columns.str.contains("gate")].loc[sorted_df.index[top]]
    df_gate = pd.DataFrame(df_loss.copy()).applymap(lambda x: None)
    
    # is_pretrained
    df_is_pretrained = params.loc[:, params.columns.str.contains("is_pretrained")].loc[sorted_df.index[top]]

    for gate_name, val in gate.to_dict().items():
        _, source, target, _ = gate_name.split("_")
        df_gate.iloc[int(source), int(target)] = val

    edge_color = {
        "ThroughGate": "3",
        "LinearGate": "1",
        "CorrectGate": "2",
        "NegativeLinearGate": "#4682B4",
    }

    G = Digraph(format="pdf", engine="dot")

    acc = model_acc.loc[sorted_acc.index].iloc[top]
    for target in range(len(df_loss)):    
        G.node(f"{target+1}. "+df_loss.index[target]+f" ({acc[target]}%)",
               color='gray90', fillcolor='gray90', style='filled')
    G.node(f"{1}. "+df_loss.index[0]+f" ({acc[0]}%)", color='pink', fillcolor='pink', style='radial')
        
    for target in range(len(df_loss)):
        if target != 0:
            is_pretrained = df_is_pretrained.iloc[target-1]
        else:
            is_pretrained = 0
        for source in range(len(df_loss)):
            if is_pretrained:
                gate = "CutoffGate"
            else:
                gate = df_gate.iloc[target,source]
            if gate != "CutoffGate":
                label = df_gate.iloc[target, source].replace("Gate","")                
                if source == target:
                    if gate == "CorrectGate":
                        gate = "ThroughGate"
                        label = "Through"
                    G.edge(f"{target}",
                           f"{target+1}. "+df_loss.index[target]+f" ({acc[target]}%)",
                           label=label, fontsize="13", fontcolor=edge_color[gate],
                           color=edge_color[gate], colorscheme="dark28")
                    G.node(f"{target}", label="Label", color='white', style='filled')
                else:
                    G.edge(f"{source+1}. "+df_loss.columns[source]+f" ({acc[source]}%)",
                           f"{target+1}. "+df_loss.index[target]+f" ({acc[target]}%)",
                           label=label, fontsize="13", fontcolor=edge_color[gate],
                           color=edge_color[gate], colorscheme="dark28")

    for target in range(len(df_loss)):
        if target != 0:
            is_pretrained = df_is_pretrained.iloc[target-1]
        else:
            is_pretrained = 0
        if (df_gate.iloc[target] == "CutoffGate").all() or is_pretrained:
            G.node(f"{target+1}. "+df_loss.columns[target]+f" ({acc[target]}%)",
                   color='lightblue',fillcolor='lightblue', style='radial')

    G.render(filename=f"{top}", directory="./topn_graph", cleanup=True, format="pdf")
    return G

In [182]:
from IPython.display import display, HTML
from graphviz import Source

def display_multiple_graphs(params, sorted_df, model_acc, sorted_acc, num_graphs=3):
    graphs = []
    for top in range(num_graphs):
        graph = generate_graph(params, sorted_df, model_acc, sorted_acc, top=top)
        graphs.append(graph)

    # CSS styles with adjustments for full graph visibility
    styles = """
    <style>
        @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;600&display=swap');

        body {
            background-color: #0f0e17;
            color: #fffffe;
            font-family: 'Poppins', sans-serif;
        }
        .graph-container {
            display: grid;
            grid-template-columns: repeat(auto-fit, minmax(400px, 1fr));
            gap: 40px;
            padding: 40px;
            background: linear-gradient(135deg, #0f0e17 0%, #232946 100%);
            border-radius: 20px;
            box-shadow: 0 20px 40px rgba(0, 0, 0, 0.4);
        }
        .graph-item {
            background: linear-gradient(45deg, #232946, #2e3a5c);
            border-radius: 15px;
            box-shadow: 0 10px 20px rgba(0, 0, 0, 0.2);
            padding: 25px;
            transition: all 0.4s cubic-bezier(0.175, 0.885, 0.32, 1.275);
            border: 2px solid #4a4e69;
            overflow: hidden;
            position: relative;
        }
        .graph-item::before {
            content: '';
            position: absolute;
            top: -50%;
            left: -50%;
            width: 200%;
            height: 200%;
            background: radial-gradient(circle, rgba(255,255,255,0.1) 0%, rgba(255,255,255,0) 70%);
            transform: scale(0);
            transition: transform 0.6s ease-out;
        }
        .graph-item:hover {
            transform: translateY(-15px) scale(1.03);
            box-shadow: 0 15px 30px rgba(0, 0, 0, 0.3);
            border-color: #ff8906;
        }
        .graph-item:hover::before {
            transform: scale(1);
        }
        .graph-title {
            font-family: 'Poppins', sans-serif;
            color: #ff8906;
            text-align: center;
            margin-bottom: 20px;
            font-size: 24px;
            font-weight: 600;
            letter-spacing: 1.5px;
            text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
        }
        .graph-svg {
            width: 100%;
            height: auto;
            display: flex;
            justify-content: center;
            align-items: center;
        }
        .graph-svg svg {
            max-width: 100%;
            height: auto;
            background-color: #fffffe;
            border-radius: 10px;
            padding: 15px;
            box-shadow: 0 5px 15px rgba(0, 0, 0, 0.1);
        }
    </style>
    """

    html_content = styles + '<div class="graph-container">'
    for i, graph in enumerate(graphs):
        svg_content = Source(graph.source).pipe(format='svg').decode('utf-8')
        svg_content = svg_content.replace('width="0"', '').replace('height="0"', '')
        html_content += f'<div class="graph-item">'
        html_content += f'<div class="graph-title">Top {i+1} KTG</div>'
        html_content += f'<div class="graph-svg">{svg_content}</div>'
        html_content += '</div>'
    html_content += '</div>'

    display(HTML(html_content))

display_multiple_graphs(params, sorted_df, model_acc, sorted_acc, num_graphs=len(sorted_df))