In [None]:
class Plots:

    import ipywidgets as widgets
    from IPython.display import display
    import plotly.offline as pyoff
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    import plotly.colors as colors
    import plotly.io as pio
    import numpy as np

    pio.renderers.default = "notebook"
    PLOT_MAPPING = {
        0: "Scatter",
        1: "Live Scatter",
        2: "Histogram",
    }

    DEFAULT_COLORS = colors.DEFAULT_PLOTLY_COLORS
    LEN_DEFAULT_COLORS = len(DEFAULT_COLORS)

    @staticmethod
    def histogram(func):
        func.plot = 1
        return func

    @staticmethod
    def scatter(func):
        func.plot = 0
        return func



    @staticmethod
    def updateLiveScatter(fig, x, y):
        if not isinstance(fig, Plots.go.FigureWidget):
            fig = Plots.go.FigureWidget(fig)
            display(fig)

        # ys = np.concatenate(fig.data[0].y,y)

        xs = [*fig.data[0].x, x] if isinstance(x,
                                               int) else [*fig.data[0].x, *x]
        ys = [*fig.data[0].y, y] if isinstance(x,
                                               int) else [*fig.data[0].y, *y]

        # ys = fig.data[0].y
        # xs.append(x)
        # ys.append(y)
        with fig.batch_update():
            scatter = fig.data[0]
            scatter.x = xs
            scatter.y = ys

        return fig
        # fig.show()

    def _callableSync(obj, update_dict):
        if getattr(obj, "_callable", False):

            args = getattr(obj, "_args", ())
            kwds = getattr(obj, "_kwargs", dict()).copy()

            shared = set(kwds) & set(update_dict)
            for k in shared:
                kwds[k] = update_dict[k]
            return obj(args, **kwds)
        return obj

    def createGraph(graph_parameters, display=True, **kwargs):
        """
        graph_parameters={
                        "traces":trace_list,# trace_list.shape = rows,cols
                        "layout":layout_dict,
                        "fig_type":fig_type_str,
                        "fig_parameters":dict|None
                        "fig_functions":{"function_name":"parameters"}
                        **optional_parameters
                        }

        """
        graph_parameters.update(**kwargs)

        traces = np.array(graph_parameters["traces"], dtype=object)

        functions = graph_parameters.get("functions", None)
        fig_functions = graph_parameters.get("fig_functions", None)
        fig_type = graph_parameters.get("fig_type", None)
        fig_parameters = graph_parameters.get("fig_parameters", dict())
        # print(np.shape(traces)!= tuple())
        subplots = np.shape(traces) != tuple() and np.shape(traces)[0] > 1
        # print(np.shape(traces))
        # subplots = len(dimensions) > 0

        if subplots:
            dimensions = [len(traces), len(traces[0])]

            rows = (dimensions[0:1] or [1])[0]
            cols = (dimensions[1:2] or [1])[0]
            fig = Plots.make_subplots(rows=rows, cols=cols, **fig_parameters)

            if fig_type == "Widget":
                # base = graph_parameters.get("base", None)
                # if base is None:
                #     raise ValueError("Widget expects a base")
                container = graph_parameters.get("container", None)
                if container is None:
                    raise ValueError("Widget expects a base")

                fig = Plots.go.FigureWidget(fig)

            for i in range(rows):
                for j in range(cols):
                    trace = graph_parameters["traces"][i][j]
                    if trace is not None:
                        if isinstance(trace, list):
                            [
                                fig.add_trace(t, row=i + 1, col=j + 1)
                                for t in trace if t is not None
                            ]
                        else:
                            # print(i,j)

                            # print(trace)
                            fig.add_trace(trace, row=i + 1, col=j + 1)

        else:
            fig = Plots.go.Figure(**fig_parameters)
            if graph_parameters["traces"] is not None:

                fig.add_trace(graph_parameters["traces"])

        fig.update_layout(graph_parameters["layout"])

        if fig_functions:
            for k, v in fig_functions.items():
                func = getattr(fig, k)
                func = Plots._callableSync(func, locals())

                func(v)

        if functions:
            for k, v in functions.items():
                func = getattr(fig, k)
                # func = _callableSync(func,locals())
                func(fig, v)
        if display:
            if fig_type == "Widget":
                container = Plots._callableSync(container, locals())(fig)
                display(container)
            else:
                fig.show()
        return fig

    def plotGraph(function_data, **kwargs):
        graph_params = Plots._plotGraph(function_data)
        Plots.createGraph(graph_params)
        return graph_params

    def _plotGraph(interface, **kwargs):
        # _data_dict= {f"{MODEL_NAME}": {f"{function_name}":{f"{NTrials}":{"Data":data,"parameters":{parameters},"Graphical_Data":graphical_data}}}}
        model_name = interface.MODEL_NAME
        function_name = interface.FUNCTION_NAME
        parameters = interface["parameters"]
        NTrials = interface["NTrials"]
        data = interface["data"]

        func = interface["function"]

        plot = getattr(func, "plot", None)
        title_layout = dict(title=f"{model_name}: {function_name}")
        if plot is not None:
            title_layout = dict(
                title=
                f"{model_name}: {Plots.PLOT_MAPPING[plot]} for {function_name}"
            )

            if plot == 0:
                graph_params = Plots.graphScatter(data, **kwargs)
            if plot == 1:
                graph_params = Plots.graphHistogram(data, **kwargs)
            if plot == 2:

                graph_params = Plots.graphVariational(data, parameters,
                                                      function_name)

        graph_params["layout"].update(title_layout)

        return graph_params

    @staticmethod
    def graphHistogram(
        data,
        *,
        mode="bar",
        normalise_x_axis=False,
        density=False,
        **kwargs,
    ):

        counts = data[0]
        midpoints = data[1]
        if len(midpoints) == len(counts) + 1:
            # Assume data[1[ is bin edges
            midpoints = (midpoints[:-1] + midpoints[1:]) / 2
            # midpoints = midpoints[:-1]
        if normalise_x_axis:
            if globals().get("Data", None) is None:
                print(
                    "Data module not imported. Please import Data module to use normalise function."
                )
                midpoints=(midpoints - midpoints.min()) / (midpoints.max() - midpoints.min())
            else:
                midpoints = Data.normalise(midpoints)
        if density:
            counts = counts / Plots.np.sum(counts)
        if mode == "bar":
            trace = Plots.go.Bar(x=midpoints, y=counts)
        if mode == "scatter":
            trace = Plots.go.Scatter(
                x=midpoints,
                y=counts,
                mode="lines",
                line={"shape": "hv"},
            )

        layout = dict(
            barmode="overlay",
            bargap=0,
        )
        graph_parameters = {
            "traces": trace,
            "layout": layout,
        }
        return graph_parameters

    @staticmethod
    def graphScatter(data, *, normalise_x_axis=False, **kwargs):
        Y = data[0]
        try:
            X = data[1]
        except:
            X = list(range(len(Y)))
        if X is None:
            X = list(range(len(Y)))

        trace = Plots.go.Scatter(x=X, y=Y, **kwargs)

        layout = dict(
            barmode="overlay",
            bargap=0,
        )
        graph_parameters = {
            "traces": trace,
            "layout": layout,
        }

        return graph_parameters

In [None]:
# Demo of live plotting
import time

x = []
y=[]
data =[x,y]

trace = Plots.graphScatter(data)
fig = Plots.createGraph(trace, display=False)


import time
for i in range(100):
    # update(i, np.sin(i / 10))
    fig=Plots.updateLiveScatter(fig, i, Plots.np.sin(i / 10))
    time.sleep(0.05)