In [1]:
import itertools
import time

import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots


In [2]:
# from IPython import display
def calc_rows_and_columns(
    n_plots:int,
    nrows:int|None,
    ncols:int|None,
) -> tuple[int, int]:
    is_valid_nrows = nrows is not None and nrows>0
    is_valid_ncols = ncols is not None and ncols>0
    
    if not is_valid_nrows and not is_valid_ncols:
        nrows = ncols = int(np.ceil(np.sqrt(n_plots)))
    else:
        nrows = nrows if is_valid_nrows else int(np.ceil(n_plots/ncols))
        ncols = ncols if is_valid_ncols else int(np.ceil(n_plots/nrows))
        
    return nrows, ncols

    
def setup_figure_layout(
    fig, 
    nrows:int,
    ncols:int,
    trace_names:list[str|list[str]],
)-> None:
    for i, ((row, col), names) in enumerate(zip(itertools.product(range(1,nrows+1),range(1,ncols+1)), trace_names, strict=False), start=1):
        for name in [names,] if not isinstance(names, list) else names:
            fig.add_trace(go.Scatter(x=[], y=[], name=name), row=row, col=col)


        legend_name = f"legend{i+2}"
        axis_num = str(i) if i > 1 else ""
        fig.update_traces(row=row, col=col, legend=legend_name)
        fig.update_layout({
            legend_name: dict(
                x=fig.layout["xaxis"+axis_num].domain[0], 
                y=fig.layout["yaxis"+axis_num].domain[1], 
                xanchor='left', 
                yanchor="top", 
                bgcolor='rgba(0,0,0,0)'
            )
        })


class PlotlyLogger:
    
    def __init__(
        self, 
        observable_plots:list[str|list[str]],
        nrows:int|None = None, 
        ncols:int|None = None,
        width_px:int|None=None,
        height_px:int|None=None,
    ):
        nrows, ncols = calc_rows_and_columns(len(observable_plots), nrows, ncols)
        self.fig = go.FigureWidget(make_subplots(rows=nrows, cols=ncols))
        if width_px is not None or height_px is not None:
            self.fig.update_layout(
                autosize=False,
                width=width_px,
                height=height_px,
            )
        
        self.observable_plots = observable_plots
        setup_figure_layout(self.fig, nrows, ncols, observable_plots)
    
    def log_metric(self, name:str, value:float, step:int) -> None:
        # get the correct object, update x and y
        graph_obj = next(filter(lambda f: f.name==name, self.fig.data))
        graph_obj.x += (step,)
        graph_obj.y += (value,)


    def change_size(
        self,
        width_px:int|None = None,
        height_px:int|None = None,
    ):
        self.fig.update_layout(
            autosize=False,
            width=width_px,
            height=height_px,
        )

    def show(self):
        return self.fig

In [3]:
logger = PlotlyLogger(
    ["q0", "q1", ["train loss", "test loss"]],
)
logger.change_size(width_px=600, height_px=800)
logger.show()

['q0', 'q1', ['train loss', 'test loss']]
1 1 q0
1 2 q1
2 1 train loss
2 1 test loss


FigureWidget({
    'data': [{'legend': 'legend3',
              'name': 'q0',
              'type': 'scatter',
              'uid': '5ba5884b-4123-4df3-bd0f-d6d5e8b72fba',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'legend': 'legend4',
              'name': 'q1',
              'type': 'scatter',
              'uid': '0390e73d-3843-40b9-88c6-7da01ff82740',
              'x': [],
              'xaxis': 'x2',
              'y': [],
              'yaxis': 'y2'},
             {'legend': 'legend5',
              'name': 'train loss',
              'type': 'scatter',
              'uid': 'db018544-3e0b-4963-8d1b-4de19d86e268',
              'x': [],
              'xaxis': 'x3',
              'y': [],
              'yaxis': 'y3'},
             {'legend': 'legend5',
              'name': 'test loss',
              'type': 'scatter',
              'uid': '06916fd8-b270-4014-a003-30bb92ea33bd',
              'x': [],
      

In [4]:
for i in range(30):
    for name in ["q0", "q1", "train loss", "test loss"]:
        if name.startswith("test") and i%5>0:
            continue
        logger.log_metric(name, np.random.normal(), i)
    time.sleep(1)

In [15]:
# adding labels

In [5]:
import ipywidgets as widgets

In [6]:
oxdna_sim = widgets.Button(description="oxDNA-1", button_style="warning", icon="hourglass-half")
prog_bar = widgets.IntProgress(value=1, description="Optimizing", bar_style="warning", orientation="horizontal")

In [7]:
def update(btn):
    if btn.button_style=="success":
        btn.button_style="warning"
        btn.icon="hourglass-half"
    else:
        btn.button_style="success"
        btn.icon="check"

oxdna_sim.on_click(update)
oxdna_sim.__dir__()

['_trait_values',
 '_trait_notifiers',
 '_trait_validators',
 '_cross_validation_lock',
 '_model_id',
 '_click_handlers',
 '__module__',
 '__doc__',
 '_view_name',
 '_model_name',
 'description',
 'disabled',
 'icon',
 'button_style',
 'style',
 '__init__',
 '_validate_icon',
 'on_click',
 'click',
 '_handle_button_msg',
 '_trait_default_generators',
 '_all_trait_default_generators',
 '_traits',
 '_static_immutable_initial_values',
 '_descriptors',
 '_instance_inits',
 '__annotations__',
 '_dom_classes',
 'tabbable',
 'tooltip',
 'layout',
 'add_class',
 'remove_class',
 'focus',
 'blur',
 '_repr_keys',
 '_widget_construction_callback',
 '_control_comm',
 'widgets',
 '_active_widgets',
 '_widget_types',
 'widget_types',
 'close_all',
 'on_widget_constructed',
 '_call_widget_constructed',
 'handle_control_comm_opened',
 '_handle_control_comm_msg',
 'handle_comm_opened',
 'get_manager_state',
 '_get_embed_state',
 'get_view_spec',
 '_model_module',
 '_model_module_version',
 '_view_modul

In [96]:
import functools
import random
from enum import Enum
from pathlib import Path

import ipywidgets as widgets

LBL_TOP_HEADER = "Optimization Status"
LBL_PROG_BAR = "Optimizing"
LBL_SIM_HEADER = "Simulators"
LBL_OBS_HEADER = "Observables"
LBL_OBJ_HEADER = "Objectives"

class Status(Enum):
    STARTED = 0
    RUNNING = 1
    COMPLETE = 2
    ERROR = 3


class JupyterLogger:

    STATUS_STYLE = {
        Status.STARTED: {"button_style":"primary", "icon":""},
        Status.RUNNING: {"button_style":"info", "icon":"hourglass-half"},
        Status.COMPLETE: {"button_style":"success", "icon":"check"},
        Status.ERROR: {"button_style":"danger", "icon":"exclamation"},
    }
    
    def __init__(
        self,
        simulators: list[str],
        observables: list[str],
        objectives: list[str],
        metrics_to_log: list[list[str]|str],
        max_opt_steps:int,
        plots_size_px:tuple[int,int]|None = None,
        plots_nrows_ncols:tuple[int,int]|None = None,
        log_dir:str|Path|None = None,
    ):
        #super().__init__(log_dir)

        self.prog_bar = widgets.IntProgress(
            min=0, 
            max=max_opt_steps,
            description=LBL_PROG_BAR,
            bar_style="info", 
            orientation="horizontal"
        )

        btn_f = functools.partial(
            widgets.Button, 
            disabled=True,
            **JupyterLogger.STATUS_STYLE[Status.STARTED],
        )
        
        self.sim_btns = [btn_f(description=sim, ) for sim in simulators]
        self.obs_btns = [btn_f(description=obs) for obs in observables]
        self.obj_btns = [btn_f(description=obj) for obj in objectives]

        nrows, ncols = plots_nrows_ncols if plots_nrows_ncols else (None, None)
        width_px, height_px = plots_size_px if plots_size_px else (None, None)

        self.plots = PlotlyLogger(
            metrics_to_log,
            nrows = nrows,
            ncols = ncols,
            width_px = width_px,
            height_px = height_px,
        )

        self.dashboard = widgets.VBox([
            widgets.Label(value=LBL_TOP_HEADER),
            self.prog_bar,
            widgets.HBox([
                widgets.VBox([
                    widgets.Label(value=LBL_SIM_HEADER),
                    *self.sim_btns,
                ]),
                widgets.VBox([
                    widgets.Label(value=LBL_OBS_HEADER),
                    *self.obs_btns,
                ]),
                widgets.VBox([
                    widgets.Label(value=LBL_OBJ_HEADER),
                    *self.obj_btns,
                ]),
            ]),
            self.plots.show(),
        ])


    def show(self) -> widgets.DOMWidget:
        return self.dashboard

    def increment_prog_bar(self, value:int=1) -> None:
        self.prog_bar.value += value

    def log_metric(self, name:str, value:float, step:int) -> None:
        self.plots.log_metric(name, value, step)


    def _update_status(self, btns:list[widgets.Button], name:str, status:Status) -> None:
        next(filter(lambda btn:btn.description==name, btns)).set_state(JupyterLogger.STATUS_STYLE[status])

    def update_simulator_status(self, name:str, status:Status) -> None:
        self._update_status(self.sim_btns, name, status)
    
    def set_simulator_started(self, name:str) -> None:
        self._update_simulator_status(name, Status.STARTED)
    
    def set_simulator_running(self, name:str) -> None:
        self._update_simulator_status(name, Status.RUNNING)

    def set_simulator_complete(self, name:str) -> None:
        self._update_simulator_status(name, Status.COMPELTE)

    def set_simulator_running(self, name:str) -> None:
        self._update_simulator_status(name, Status.ERROR)

    
    def update_objective_status(self, name:str, status:Status) -> None:
        self._update_status(self.obj_btns, name, status)

    def set_objective_started(self, name:str) -> None:
        self._update_objective_status(name, Status.STARTED)
    
    def set_objective_running(self, name:str) -> None:
        self._update_objective_status(name, Status.RUNNING)

    def set_objective_complete(self, name:str) -> None:
        self._update_objective_status(name, Status.COMPELTE)

    def set_objective_running(self, name:str) -> None:
        self._update_objective_status(name, Status.ERROR)

    
    def update_observable_status(self, name:str, status:Status) -> None:
        self._update_status(self.obs_btns, name, status)

    def set_observable_started(self, name:str) -> None:
        self._update_objective_status(name, Status.STARTED)
    
    def set_observable_running(self, name:str) -> None:
        self._update_objective_status(name, Status.RUNNING)

    def set_observable_complete(self, name:str) -> None:
        self._update_objective_status(name, Status.COMPELTE)

    def set_observable_running(self, name:str) -> None:
        self._update_objective_status(name, Status.ERROR)



        

In [97]:
logger = JupyterLogger(
    simulators = ["oxDNA-1", "oxDNA-2", "jaxmd-1"],
    observables = ["oxDNA-1-traj", "oxDNA-2-traj", "ptwist", "dptwist_dopt"],
    objectives = ["ptwist", ],
    metrics_to_log = ["obs-1", "obs-2", ["loss-1", "test-loss-1"], ["param-3", "param-4"]],
    max_opt_steps = 100,
    plots_size_px = None,
    plots_nrows_ncols = None,
    log_dir = None,
)

logger.show()

VBox(children=(Label(value='Optimization Status'), IntProgress(value=0, bar_style='info', description='Optimiz…

In [98]:
import random
statuses = [
    Status.RUNNING,
    Status.COMPLETE,
    Status.ERROR,
]

for i in range(1, 50):
    
    logger.update_simulator_status(random.choice(["oxDNA-1", "oxDNA-2", "jaxmd-1"]), random.choice(statuses))
    logger.update_objective_status(random.choice(["ptwist",]), random.choice(statuses))
    logger.update_observable_status(random.choice(["oxDNA-1-traj", "oxDNA-2-traj", "ptwist", "dptwist_dopt"]), random.choice(statuses))
    
    for metric in ["obs-1", "obs-2", "loss-1", "test-loss-1", "param-3", "param-4"]:
        if name.startswith("test") and i%5>0:
            continue
        logger.log_metric(metric, np.random.normal(), i)
    logger.increment_prog_bar()
    
    time.sleep(0.5)

In [8]:
widgets.VBox([
    widgets.Label(value="Optimization Status"),
    prog_bar,
    widgets.HBox([
        widgets.VBox([
            widgets.Label(value="Simulators"),
            oxdna_sim,
        ]),
        widgets.VBox([
            widgets.Label(value="Observables"),
            widgets.Label(value="ptwist"),            
            widgets.Label(value="trajectory"),
        ]),
        widgets.VBox([
            widgets.Label(value="Objectives"),
            widgets.Label(value="ptwist"),            
        ]),
    ]),
    logger.show(),
])



In [12]:
oxdna_sim.set_state(dict(button_style="warning"))