In [None]:
class widgetContainer:
	pass


import ipywidgets as widgets
from IPython.display import display
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.colors as colors
import plotly.io as pio


pio.renderers.default = "notebook"


def container(cls):
	def new(obj, *args, **kwargs):
		raise NotImplementedError(f"{obj} is a container, __new__ is not defined")

	def init(obj, *args, **kwargs):
		raise NotImplementedError(f"{obj} is a container, __init__ is not defined")

	def call(obj, *args, **kwargs):
		raise NotImplementedError(f"{obj} is a container, __call__ is not defined")

	for name, member in list(cls.__dict__.items()):
		if callable(member):
			setattr(cls, name, staticmethod(member))
	setattr(cls, "__new__", new)
	setattr(cls, "__init__", init)
	setattr(cls, "__call__", call)

	return cls


def _callableSync(obj, update_dict):
	if getattr(obj, "_callable", False):

		# allow assigining params as attributes for simple interfacing
		args = getattr(obj, "_args", ())
		kwds = getattr(obj, "_kwargs", dict()).copy()

		shared = set(kwds) & set(update_dict)
		for k in shared:
			kwds[k] = update_dict[k]
		return obj(*args, **kwds)
	return obj


def getTraceIndex(row, col, n_cols):

	return (row - 1) * n_cols + (col - 1)


class sliderContainer(widgetContainer):
	@staticmethod
	def _idx(val, data):
		return int(round((val - data[0]) / (data[1] - data[0])))

	@staticmethod
	def _closedIdx(data):
		def _inner(val):
			return sliderContainer._idx(val, data)

		return _inner

	@staticmethod
	def createSlider(*, value, _min, _max, step, description, continuous_update, **kwargs):
		slider = widgets.FloatSlider(
			value=value,
			min=_min,
			max=_max,
			step=step,
			description=description,
			continuous_update=continuous_update,
		)

		return slider

	def _createSlider(self, _slider_dict, **kwargs):
		slider_name = list(_slider_dict.keys())[0]
		slider_dict = list(_slider_dict.values())[0]
		slider_data = slider_dict["data"]
		creation_dict = dict(
			value=float(slider_data[0]),
			_min=float(slider_data.min()),
			_max=float(slider_data.max()),
			step=float(slider_data[1] - slider_data[0]),
			description=f"{slider_name}",
			continuous_update=slider_dict.get("continuous_update", False),
		)
		creation_dict.update(**kwargs)

		slider = sliderContainer.createSlider(**creation_dict)

		self.Sliders[slider_name] = slider
		self.Slider_idxFn[slider_name] = sliderContainer._closedIdx(slider_data)
		self.Sliders[slider_name].observe(self.refresh, names="value")

	@property
	def sliders(self):
		return list(self.Sliders.values())

	def __new__(cls, slider_dicts, update_functions, data, **kwargs):
		instance = object.__new__(cls)

		def _closed_init(slider_dicts, update_functions, data, **kwargs):
			def _inner(fig, **kwds):
				kwargs.update(kwds)
				instance.__init__(
					fig=fig,
					slider_dicts=slider_dicts,
					update_functions=update_functions,
					data=data,
					**kwargs,
				)

				return instance

			return _inner

		return _closed_init(slider_dicts, update_functions, data, **kwargs)

	def __init__(self, fig, slider_dicts, update_functions, data, **kwargs):

		self.fig = fig
		self.Sliders = dict()
		self.Slider_idxFn = dict()
		for i in slider_dicts:
			self._createSlider(i)
		for k, v in kwargs.items():
			setattr(self, k, v)
		self.updateFunctions = update_functions	# list probably
		self.data = data
		self.controls = widgets.VBox(self.sliders)	# , layout=widgets.Layout(width="100%")
		self.container = widgets.VBox([self.controls, self.fig])

	def _refreshSliders(self, *args, **kwargs):
		slider_indices = dict()
		Slider_Values = dict()
		for k, v in self.Sliders.items():
			value = v.value
			slider_indices[k] = self.Slider_idxFn[k](value)
			Slider_Values[k] = value
		return slider_indices, Slider_Values

	def _updateFigure(self, slider_indices, values, *args, **kwargs):

		for fn in self.updateFunctions:
			fn(fig=self.fig, data=self.data, **locals())

	def refresh(self, *args, **kwargs):
		self.__call__(*args, **kwargs)

	def __call__(self, *args, **kwargs):
		slider_indices, Slider_Values = self._refreshSliders(*args, **kwargs)
		with self.fig.batch_update():
			self._updateFigure(slider_indices, Slider_Values)


@container
class Figure_Methods:
	EMPTY = [BaseTraceType("empty")]

	def getTraceIndex(row, col, n_cols):

		return (row - 1) * n_cols + (col - 1)

	def inverseTraceIndex(index, n_cols):

		_row = index // n_cols	# [NOTE] actual row is _row +1
		_col = index % n_cols	# [NOTE] actual col is _col +1
		return _row, _col

	def formatHetrogenous(traces, removeNone=True):

		arr = np.array(traces, dtype=object)
		if np.ndim(arr) == 0:
			return None
		elif np.ndim(arr) == 1:
			pass
		else:
			for r in range(np.ndim(arr)):
				arr = np.hstack(arr)
		if removeNone:
			return arr[arr != None].tolist()

		return arr.tolist()

	def flattenHetrogenous(traces):
		if np.ndim(np.array(traces, dtype=object)) <= 1:
			return traces
		else:
			return np.concatenate(np.array(traces, dtype=object))

	def getTracesSize(traces):
		arr = np.array(traces, dtype=object)
		if np.ndim(arr) <= 1:
			return 1
		return np.size(arr[arr != None])

	def getSuffix(idx):
		axis_idx = ""
		if idx == 0:
			return axis_idx
		return str(idx + 1)

	def _updateFigure(fig, trace, idx, *args, **kwargs):
		suffix = Figure_Methods.getSuffix(idx)

		if "z" in trace:
			trace._orphan_props["scene"] = f"scene{suffix}"

			if fig._grid_ref:

				try:
					n_cols = len(fig._grid_ref[0])
					row, col = Figure_Methods.inverseTraceIndex(idx, n_cols)
					if not fig._grid_ref[row][col][0].subplot_type == "scene":
						# [NOTE] using NotImplementedError as a low probability of intercept raise
						raise NotImplementedError

				except NotImplementedError as e:

					y_domain = fig.layout[f"yaxis{suffix}"]["domain"]
					x_domain = fig.layout[f"xaxis{suffix}"]["domain"]
					scene = {"domain": {"x": x_domain, "y": y_domain}}
					fig.layout[f"scene{suffix}"] = scene

					# [NOTE]  The figure attributes `yaxis{suffix}` and `xaxis{suffix}` don't need to be kept to preserve indexing
					fig.layout.pop(f"yaxis{suffix}", None)
					fig.layout.pop(f"xaxis{suffix}", None)

		fig._data[idx] = trace._orphan_props

		return fig

	def processOrphan(orphan, idx):
		suffix = Figure_Methods.getSuffix(idx)
		orphan_update = dict()

		if "z" in orphan:
			orphan_update.update({"scene": f"scene{suffix}"})
			orphan_update.update({"zaxis": f"z{suffix}"})

		orphan_update = {"xaxis": f"x{suffix}", "yaxis": f"y{suffix}"}

		orphan._orphan_props.update(orphan_update)
		return orphan

	def _appendData(fig, data, idx, *args, **kwargs):
		suffix = Figure_Methods.getSuffix(idx)

		for k, v in data.items():
			if k != "uid":
				_data = getattr(fig.data[idx], k, [])
				if _data is None:
					_data = []
				_data = list(_data)

				_data.extend(v)
				fig.data[idx][k] = _data
			if k == "uid":
				fig.data[idx][k] = v

		return fig

	def _modifyFigure(fig, trace, idx, *args, _modification_type="append", **kwargs):
		target_trace = fig.data[idx]
		if target_trace.type != trace.type:
			fig = Figure_Methods._updateFigure(fig, trace, idx, *args, **kwargs)
		else:
			data = {"x": trace.x, "y": trace.y}
			if "z" in trace:
				data["z"] = trace["z"]
			if "uid" in trace:

				data["uid"] = trace["uid"]
			if _modification_type == "append":

				fig = Figure_Methods._appendData(fig, data, idx, *args, **kwargs)

		return fig

	def modifyFigure(fig, flat_traces, idx=0, *args, **kwargs):
		orphaned = []
		if isinstance(flat_traces, BaseTraceType):
			fig = Figure_Methods._modifyFigure(fig, flat_traces, idx=idx)
			return fig, orphaned
		elif isinstance(flat_traces, list):
			trace = flat_traces[0]
			fig, orphans = Figure_Methods.modifyFigure(fig, trace, idx=idx)
			orphaned.append(orphans)

			_ophans = flat_traces[1:]	# [NOTE] unformatted hence _ prefix
			for o in list(_ophans):
				orphaned.append(Figure_Methods.processOrphan(o, idx=idx))
		elif isinstance(flat_traces, np.ndarray):
			for i, trace in enumerate(flat_traces):
				fig, orphans = Figure_Methods.modifyFigure(fig, trace, idx=i + idx)
			orphaned.append(orphans)
		elif flat_traces is None:
			pass
		else:
			raise TypeError(f"Unknown Trace Type: {type(flat_traces)} | Trace: {flat_traces}")

		orphaned = Figure_Methods.flattenHetrogenous(orphaned)
		return fig, orphaned

	def initialiseFigure(traces, *, fig_parameters, fig_type, **kwargs):
		size = Figure_Methods.getTracesSize(traces)
		if size > 1:

			subplot_parameters = fig_parameters
			fig_parameters = dict()
		fig = go.Figure(data=[None] * size, skip_invalid=True, **fig_parameters)

		if size > 1:
			dimensions = [len(traces), len(traces[0])]

			rows = (dimensions[0:1] or [1])[0]
			cols = (dimensions[1:2] or [1])[0]
			fig = make_subplots(figure=fig, rows=rows, cols=cols, **subplot_parameters)

		return fig

	def addOrphans(fig, orphans):
		flat_traces = Figure_Methods.formatHetrogenous(orphans)
		if np.size(np.array(flat_traces, dtype=object)) == 0:
			return fig
		if orphans is None:
			return fig

		fig.add_traces(orphans)
		return fig


class Plots_New:

	DEFAULT_COLORS = colors.DEFAULT_PLOTLY_COLORS
	LEN_DEFAULT_COLORS = len(DEFAULT_COLORS)

	@staticmethod
	def createGraph(graph_parameters, display_graph=True, **kwargs):

		traces = graph_parameters["traces"]
		fig_parameters = graph_parameters.get("fig_parameters", dict())
		fig_type = graph_parameters.get("fig_type", None)
		layout = graph_parameters.get("layout", dict())

		fig = Figure_Methods.initialiseFigure(
			traces, fig_parameters=fig_parameters, fig_type=fig_type, **kwargs
		)

		flat_traces = Figure_Methods.flattenHetrogenous(traces)
		with fig.batch_update():

			fig, orphans = Figure_Methods.modifyFigure(fig, flat_traces, **kwargs)
			fig = Figure_Methods.addOrphans(fig, orphans)

			if fig_type == "Widget":
				fig = go.FigureWidget(fig)
		fig.update_layout(layout)

		functions = graph_parameters.get("functions", None)
		fig_functions = graph_parameters.get("fig_functions", None)
		if fig_functions:
			for k, v in fig_functions.items():
				func = getattr(fig, k)	# [1.XXX] Cant remember why we can pull it from the fig
				func = _callableSync(func, locals())

				func(v, **kwargs)

		if functions:
			for k, v in functions.items():
				func = getattr(fig, k)	# [1.XXX]
				# func = _callableSync(func,locals())
				func(fig, v, **kwargs)

		if display_graph:
			if fig_type == "Widget":
				# [TODO] add superclass figureContainer, ensure _container is related to this
				_container = graph_parameters.get("container", None)
				_container = _container(fig)
				if isinstance(_container, widgetContainer):
					display(_container.container)
				else:
					display(_container)

			else:
				fig.show()

		return fig

	@staticmethod
	def graphScatter(data, *, normalise_x_axis=False, **kwargs):
		Y = data[0]
		try:
			X = data[1]
		except:
			X = list(range(len(Y)))
		if X is None:
			X = list(range(len(Y)))

		trace = go.Scatter(x=X, y=Y, **kwargs)

		layout = dict(
			barmode="overlay",
			bargap=0,
		)
		graph_parameters = {
			"traces": trace,
			"layout": layout,
		}

		return graph_parameters

	@staticmethod
	def graphVariational(data, **kwargs):

		def getKwargVars(**kwargs):

			alpha_name = kwargs["alpha_name"] if "alpha_name" in kwargs else "alpha"

			beta_name = kwargs["beta_name"] if "beta_name" in kwargs else "beta"

			function_name = kwargs["function_name"] if "function_name" in kwargs else "function"

			alpha_range = (
				kwargs["alpha_range"] if "alpha_range" in kwargs else (0, alpha_len - 1)
			)
			beta_range = kwargs["beta_range"] if "beta_range" in kwargs else (0, beta_len - 1)
			return alpha_name, beta_name, alpha_range, beta_range, function_name

		dimensions = np.shape(data)
		if dimensions[0] > 2:
			_matrix = data
		else:
			raise

		alpha_len, beta_len = len(_matrix), len(_matrix[0])
		alpha_name, beta_name, alpha_range, beta_range, function_name = getKwargVars(**kwargs)

		matrix = np.stack(_matrix.tolist())	# beta × alpha × T

		ymax = np.max(_matrix.tolist())
		X = matrix[0][0].shape[0]

		x = np.arange(X)

		Alpha = np.linspace(*alpha_range, alpha_len)

		Beta = np.linspace(*beta_range, beta_len)

		Xalpha, Yalpha = np.meshgrid(x, Alpha)	# left surface  (beta fixed)
		Xbeta, Ybeta = np.meshgrid(x, Beta)	# right surface (alpha fixed)

		beta_idx, alpha_idx = 0, 0

		alpha_surface = go.Surface(
			z=matrix[beta_idx],
			x=Xalpha,
			y=Yalpha,
			colorscale="Viridis",
			cmin=0,
			cmax=ymax,
			showscale=True,
		)

		beta_surface = go.Surface(
			z=matrix[:, alpha_idx],
			x=Xbeta,
			y=Ybeta,
			colorscale="Viridis",
			cmin=0,
			cmax=ymax,
			showscale=True,
		)

		scatter = go.Scatter(
			x=x, y=matrix[beta_idx, alpha_idx], mode="lines", uid="variational_scatter"
		)

		camera = dict(
			eye=dict(x=-1.8, y=-1.8, z=1.0),
			up=dict(x=0.0, y=0.0, z=1.0),
			center=dict(x=0.0, y=0.0, z=0.0),
		)
		fig_parameters = dict(
			specs=[	# [TODO] Remove Specs since it is infereable now
				[{"type": "surface"}, {"type": "surface"}],
				[{"colspan": 2, "type": "xy"}, None],
			],
			vertical_spacing=0.08,
			row_heights=[0.75, 0.25],
		)

		layout = dict(
			width=1400,
			height=850,
			scene=dict(
				xaxis_title="x",
				yaxis_title=f"{alpha_name}",
				zaxis_title=f"{function_name}",
				camera=camera,
			),
			scene2=dict(
				xaxis_title="x",
				yaxis_title=f"{beta_name}",
				zaxis_title=f"{function_name}",
				camera=camera,
			),
		)

		def beta_update(*, fig, slider_indices, data, **kwargs):
			i = slider_indices[f"{beta_name}"]
			fig.data[0].z = data[i]

		def alpha_update(*, fig, slider_indices, data, **kwargs):
			j = slider_indices[f"{alpha_name}"]

			fig.data[0].z = data[:, j]

		def timeseries_update(*, fig, slider_indices, data, **kwargs):
			i = slider_indices[f"{beta_name}"]
			j = slider_indices[f"{alpha_name}"]

			fig.data[2].y = data[i, j]

		def title_update(*, fig, slider_indices, data, values, **kwargs):
			i = slider_indices[f"{beta_name}"]
			j = slider_indices[f"{alpha_name}"]
			beta_val = values[f"{beta_name}"]
			alpha_val = values[f"{alpha_name}"]

			_function_name = kwargs.get("function_name", f"{function_name}")

			fig.layout.title.text = (
				f"{_function_name} –  {alpha_name} = {alpha_val:.2f}, " f"{beta_name} = {beta_val:.2f}"
			)

		update_fns = [beta_update, alpha_update, timeseries_update, title_update]
		alpha_slider_dict = {f"{alpha_name}": {"data": Alpha}}
		beta_slider_dict = {f"{beta_name}": {"data": Beta}}
		slider_dicts = [beta_slider_dict, alpha_slider_dict]
		wContainer = sliderContainer(slider_dicts, update_fns, matrix)

		graph_parameters = {
			"traces": [[alpha_surface, beta_surface], [scatter, None]],
			"layout": layout,
			"fig_type": "Widget",
			"fig_parameters": fig_parameters,	# Not strictly necessary for specs
			"container": wContainer,
		}
		return graph_parameters

In [1]:
import ipywidgets as widgets
from IPython.display import display
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.colors as colors
import plotly.io as pio

# from Helpers import container

pio.renderers.default = "notebook"
import numpy as np

In [2]:
def isIterable(obj):
	if isinstance(obj, str):
		return False
	try:
		iter(obj)
		obj[0]
		return True
	except Exception:
		return False

In [3]:
flat_traces = [object, object]
np.ndim(flat_traces)

1

In [None]:
# Case 1&2 simple fig, fig with overlays
# case 3&4 subplots, subplots with overlays


def flatten_nested(arr):
	result = []
	if not isinstance(arr, (list, np.ndarray)):
		return arr
	for item in arr:
		if item is None:
			continue
		if isinstance(item, (list, np.ndarray)):

			result.extend(flatten_nested(item))
		else:
			result.append(item)
	return np.array(result, dtype=object)


def flatten(arr):
	result = []
	if not isinstance(arr, (list, np.ndarray)):
		return arr

	for item in arr:
		if item is None:
			continue
		if isinstance(item, np.ndarray):

			result.extend(flatten(item))
		else:
			result.append(item)
	return result	# np.array(result, dtype=object)


def container(cls):
	def new(obj, *args, **kwargs):
		raise NotImplementedError(f"{obj} is a container, __new__ is not defined")

	def init(obj, *args, **kwargs):
		raise NotImplementedError(f"{obj} is a container, __init__ is not defined")

	def call(obj, *args, **kwargs):
		raise NotImplementedError(f"{obj} is a container, __call__ is not defined")

	for name, member in list(cls.__dict__.items()):
		if callable(member):
			setattr(cls, name, staticmethod(member))
	setattr(cls, "__new__", new)
	setattr(cls, "__init__", init)
	setattr(cls, "__call__", call)

	return cls


def _callableSync(obj, update_dict):
	if getattr(obj, "_callable", False):

		# allow assigining params as attributes for simple interfacing
		args = getattr(obj, "_args", ())
		kwds = getattr(obj, "_kwargs", dict()).copy()

		shared = set(kwds) & set(update_dict)
		for k in shared:
			kwds[k] = update_dict[k]
		return obj(args, **kwds)
	return obj


def getTraceIndex(row, col, n_cols):

	return (row - 1) * n_cols + (col - 1)


def interface_required(func):
	def _wrapper(*args, **kwargs):
		try:
			import Interface

		except ImportError:
			raise ImportError("Interface module is required for this functionality.")
		return func(*args, **kwargs)

	return _wrapper


def _callableSync(obj, update_dict):
	if getattr(obj, "_callable", False):

		# allow assigining params as attributes for simple interfacing
		args = getattr(obj, "_args", ())
		kwds = getattr(obj, "_kwargs", dict()).copy()

		shared = set(kwds) & set(update_dict)
		for k in shared:
			kwds[k] = update_dict[k]
		return obj(args, **kwds)
	return obj


def getTraceIndex(row, col, n_cols):

	return (row - 1) * n_cols + (col - 1)


# Cannonical method via indexes
@container
class Figure_Methods:
	EMPTY = [{}]

	def appendToTrace(fig, trace, idx=None, *args, **kwargs):
		if trace is None:
			return fig
		# print(idx)

		if ("row" in kwargs and "col" in kwargs) and idx is None:

			n_cols = len(fig._grid_ref[0])	#
			idx = getTraceIndex(kwargs["row"], kwargs["cols"], n_cols)
		if idx is None:
			idx = 0

		fnc = [list.append, list.extend]
		_trace = fig.data[idx]
		# for k, v in trace._orphan_props.items():
		# 	if not k.startswith("x") or not k.startswith("y") or not k.startswith("x"):
		# 		setattr(_trace, k, v)
		if "x" not in kwargs:
			x = trace.x

		if "y" not in kwargs:
			y = trace.y

		if hasattr(_trace, "z"):
			if "z" not in kwargs:
				z = trace.z
		else:
			z = None

		func = fnc[hasattr(x, "__len__")]

		t = list(_trace.x or list())
		func(t, x)
		_trace.x = t

		t = list(_trace.y or list())
		func(t, y)
		_trace.y = t

		if z:
			t = list(_trace.z or list())
			func(t, z)
			_trace.z = t

		for k, v in kwargs.items():
			setattr(_trace, k, v)

		return fig

	#
	def processOrphan(orphan, idx):

		axis_idx = ""

		if idx == 1:
			axis_idx = str(axis_idx + 1)
		orphan_update = {"xaxis": f"x{axis_idx}", "yaxis": f"y{axis_idx}"}
		if hasattr(orphan, "z"):
			orphan_update.update({"zaxis": f"z{axis_idx}"})

		orphan._orphan_props.update(orphan_update)	# uses builting _orphan_probs attribute
		return orphan

	def addTraces(fig, flat_traces):

		if flat_traces is None:
			return fig

		flat_traces = flatten_nested(flat_traces)

		fig.add_traces(list(flat_traces))
		return fig

	def updateTraces(fig, flat_traces, idx=0):
		# flat_traces = np.array(flat_traces, dtype=object)
		orphaned = []
		if np.ndim(np.array(flat_traces, dtype=object)) == 0:

			fig = Figure_Methods.appendToTrace(fig, flat_traces, idx=idx)
			return fig, orphaned

		if np.ndim(np.array(flat_traces, dtype=object)) == 1:

			if isinstance(flat_traces, list):
				_ophans = flat_traces[1:]
				trace = flat_traces[0]
				for o in list(_ophans):
					orphaned.append(Figure_Methods.processOrphan(o, idx=idx))
				fig, orphans = Figure_Methods.updateTraces(fig, trace, idx=idx)

				orphans.append(orphans)
				orphaned = flatten_nested(orphaned)
				return fig, orphaned

			for i, trace in enumerate(flat_traces):

				fig, orphans = Figure_Methods.updateTraces(fig, trace, idx=i + idx)
				if orphans is not None:
					orphaned.append(orphans)
			orphaned = flatten_nested(orphaned)
			return fig, orphaned

			return fig, np.array(orphaned, dtype=object)
		if np.ndim(flat_traces) == 2:

			for i, trace in enumerate(flat_traces):
				fig, orphans = Figure_Methods.updateTraces(fig, trace, idx=i + idx)
				if orphans is not None:
					orphaned.append(orphans)
			orphaned = flatten_nested(orphaned)
			return fig, orphaned
		orphaned = flatten_nested(orphaned)

		raise NotImplementedError(f"{flat_traces}")
		return fig, orphaned

	def populateEmpty(fig, *args, **kwargs):
		if fig._grid_ref is None:
			fig.add_trace(*Figure_Methods.EMPTY)
		else:
			_rows = len(fig._grid_ref)
			_cols = len(fig._grid_ref[0])
			empty = Figure_Methods.EMPTY * _rows * _cols
			rows = list(range(1, _rows + 1)) if _rows > 1 else 1
			cols = list(range(1, _cols + 1)) if _cols > 1 else 1

			fig.add_traces(empty, rows=rows, cols=cols)
		return fig


class Plots:

	DEFAULT_COLORS = colors.DEFAULT_PLOTLY_COLORS
	LEN_DEFAULT_COLORS = len(DEFAULT_COLORS)

	@staticmethod
	def flattenTraces(traces):

		if np.ndim(traces) < 2:
			return traces
		else:
			return np.hstack(traces)

	@staticmethod
	def histogram(func):
		func.plot = 2	# hello
		return func

	@staticmethod
	def scatter(func):
		func.plot = 0
		return func

	def _callableSync(obj, update_dict):
		if getattr(obj, "_callable", False):

			# allow assigining params as attributes for simple interfacing
			args = getattr(obj, "_args", ())
			kwds = getattr(obj, "_kwargs", dict()).copy()

			shared = set(kwds) & set(update_dict)
			for k in shared:
				kwds[k] = update_dict[k]
			return obj(args, **kwds)
		return obj

	@staticmethod
	def processTraces(traces, fig_type=None, fig_parameters=fig_parameters, **kwargs):
		subplots = np.ndim(np.array(traces, dtype=object)) > 1
		empty_len = None

		if subplots:
			traces = np.array(traces, dtype=object)
			dimensions = [len(traces), len(traces[0])]

			rows = (dimensions[0:1] or [1])[0]
			cols = (dimensions[1:2] or [1])[0]
			placeholder_traces = Figure_Methods.EMPTY * rows * cols
			fig = make_subplots(rows=rows, cols=cols, **fig_parameters)

		else:

			fig = go.Figure(**fig_parameters)
		fig = Figure_Methods.populateEmpty(fig)

		if fig_type == "Widget":
			fig = go.FigureWidget(fig)

		fig, orphaned = Figure_Methods.updateTraces(fig, traces)
		if orphaned is not None:
			orphaned = list(flatten_nested(orphaned))
		fig = Figure_Methods.addTraces(fig, orphaned)
		return fig

	@staticmethod
	def createGraph(graph_parameters, display_graph=True, **kwargs):
		"""
		graph_parameters={
		        "traces":trace_list,# trace_list.shape = rows,cols
		        "layout":layout_dict,
		        "fig_type":fig_type_str,
		        "fig_parameters":dict|None
		        "fig_functions":{"function_name":"parameters"}
		        **optional_parameters
		}

		"""
		graph_parameters.update(**kwargs)
		traces = graph_parameters["traces"]

		fig_type = graph_parameters.get("fig_type", None)

		fig = Plots.processTraces(traces, fig_type=fig_type, **kwargs)
		functions = graph_parameters.get("functions", None)
		fig_functions = graph_parameters.get("fig_functions", None)
		fig_parameters = graph_parameters.get("fig_parameters", dict())

		fig.update_layout(graph_parameters["layout"])

		if fig_functions:
			for k, v in fig_functions.items():
				func = getattr(fig, k)	# [REVIEW]
				func = Plots._callableSync(func, locals())

				func(v, **kwargs)

		if functions:
			for k, v in functions.items():
				func = getattr(fig, k)
				# func = _callableSync(func,locals())
				func(fig, v, **kwargs)
		if display_graph:
			if fig_type == "Widget":
				container = graph_parameters.get("container", None)
				if container is None:
					raise ValueError("Widget expects a base")
				container = _callableSync(container, locals())(fig)
				display(container)
			else:
				fig.show()
		return fig

	@staticmethod
	def graphHistogram(
		data,
		*,
		mode="bar",
		normalise_x_axis=False,
		density=False,
		**kwargs,
	):

		counts = data[0]
		midpoints = data[1]
		if len(midpoints) == len(counts) + 1:
			# Assume data[1[ is bin edges
			midpoints = (midpoints[:-1] + midpoints[1:]) / 2
			# midpoints = midpoints[:-1]
		if normalise_x_axis:
			Data = globals().get("Data", None)
			if Data is None:
				print(
					"Data module not imported. Please import Data module to use normalise function."
				)
				midpoints = (midpoints - midpoints.min()) / (midpoints.max() - midpoints.min())
			else:
				midpoints = Data.normalise(midpoints)
		if density:
			counts = counts / np.sum(counts)
		if mode == "bar":
			trace = go.Bar(x=midpoints, y=counts)
		if mode == "scatter":
			trace = go.Scatter(
				x=midpoints,
				y=counts,
				mode="lines",
				line={"shape": "hv"},
			)

		layout = dict(
			barmode="overlay",
			bargap=0,
		)
		graph_parameters = {
			"traces": trace,
			"layout": layout,
		}
		return graph_parameters

	@staticmethod
	def graphScatter(data, *, normalise_x_axis=False, **kwargs):
		Y = data[0]
		try:
			X = data[1]
		except:
			X = list(range(len(Y)))
		if X is None:
			X = list(range(len(Y)))

		trace = go.Scatter(x=X, y=Y, **kwargs)

		layout = dict(
			barmode="overlay",
			bargap=0,
		)
		graph_parameters = {
			"traces": trace,
			"layout": layout,
		}

		return graph_parameters

	@staticmethod
	def pulseLocation(y_data):
		y = y_data
		y = np.array(y)
		x = np.arange(len(y), dtype=np.float32)
		dy = np.diff(y, prepend=y[0])
		a = (np.cumsum(x * dy) + x * dy) / len(y)
		b = np.diff(a, prepend=a[0] - (a[1] - a[0]) / 2)
		indices = np.where(np.abs(b) > 0.66 * np.abs(y))

		return indices[0], b, a

	@staticmethod
	def graphVariational(data, **kwargs):

		def getKwargVars(**kwargs):

			alpha_name = kwargs["alpha_name"] if "alpha_name" in kwargs else "alpha"

			beta_name = kwargs["beta_name"] if "beta_name" in kwargs else "beta"

			function_name = kwargs["function_name"] if "function_name" in kwargs else "function"

			alpha_range = (
				kwargs["alpha_range"] if "alpha_range" in kwargs else (0, alpha_len - 1)
			)
			beta_range = kwargs["beta_range"] if "beta_range" in kwargs else (0, beta_len - 1)
			return alpha_name, beta_name, alpha_range, beta_range, function_name

		dimensions = np.shape(data)
		# parameters=parameters[0][0]
		if dimensions[0] > 2:
			_matrix = data
		else:
			raise

		alpha_len, beta_len = len(_matrix), len(_matrix[0])
		alpha_name, beta_name, alpha_range, beta_range, function_name = getKwargVars(**kwargs)

		matrix = np.stack(_matrix.tolist())	# beta × alpha × T

		ymax = np.max(_matrix.tolist())
		X = matrix[0][0].shape[0]

		x = np.arange(X)

		Alpha = np.linspace(*alpha_range, alpha_len)

		Beta = np.linspace(*beta_range, beta_len)

		Xalpha, Yalpha = np.meshgrid(x, Alpha)	# left surface  (beta fixed)
		Xbeta, Ybeta = np.meshgrid(x, Beta)	# right surface (alpha fixed)

		beta_idx, alpha_idx = 0, 0

		alpha_surface = go.Surface(
			z=matrix[beta_idx],
			x=Xalpha,
			y=Yalpha,
			colorscale="Viridis",
			cmin=0,
			cmax=ymax,
			showscale=True,
		)

		beta_surface = go.Surface(
			z=matrix[:, alpha_idx],
			x=Xbeta,
			y=Ybeta,
			colorscale="Viridis",
			cmin=0,
			cmax=ymax,
			showscale=True,
		)

		scatter = go.Scatter(x=x, y=matrix[beta_idx, alpha_idx], mode="lines")

		camera = dict(
			eye=dict(x=-1.8, y=-1.8, z=1.0),
			up=dict(x=0.0, y=0.0, z=1.0),
			center=dict(x=0.0, y=0.0, z=0.0),
		)
		fig_parameters = dict(
			specs=[
				[{"type": "surface"}, {"type": "surface"}],
				[{"colspan": 2, "type": "xy"}, None],
			],
			vertical_spacing=0.08,
			row_heights=[0.75, 0.25],
		)

		layout = dict(
			width=1400,
			height=850,
			scene=dict(
				xaxis_title="x",
				yaxis_title=f"{alpha_name}",
				zaxis_title=f"{function_name}",
				camera=camera,
			),
			scene2=dict(
				xaxis_title="x",
				yaxis_title=f"{beta_name}",
				zaxis_title=f"{function_name}",
				camera=camera,
			),
		)

		def container(fig, **kwargs):

			# alpha_name, beta_name, func_name = updateSliderNames(**kwargs)

			def _idx(val, grid):
				return int(round((val - grid[0]) / (grid[1] - grid[0])))

			alpha_slider = widgets.FloatSlider(
				value=float(Alpha[alpha_idx]),
				min=float(Alpha.min()),
				max=float(Alpha.max()),
				step=float(Alpha[1] - Alpha[0]),
				description=f"{alpha_name}",
				continuous_update=False,
			)

			beta_slider = widgets.FloatSlider(
				value=float(Beta[beta_idx]),
				min=float(Beta.min()),
				max=float(Beta.max()),
				step=float(Beta[1] - Beta[0]),
				description=f"{beta_name}",
				continuous_update=False,
			)

			def refresh(_=None):
				i = _idx(beta_slider.value, Beta)	# current beta index
				j = _idx(alpha_slider.value, Alpha)	# current alpha index

				with fig.batch_update():
					fig.data[0].z = matrix[i]

					fig.data[1].z = matrix[:, j]

					# 1‑D line
					fig.data[2].y = matrix[i, j]

					fig.layout.title.text = (
						f"{function_name} –  {alpha_name} = {Alpha[j]:.2f}, " f"{function_name} = {Beta[i]:.2f}"
					)

			alpha_slider.observe(refresh, names="value")
			beta_slider.observe(refresh, names="value")
			controls = widgets.VBox(
				[alpha_slider, beta_slider], layout=widgets.Layout(width="100%")
			)
			container = widgets.VBox([controls, fig], layout=widgets.Layout(width="100%"))

			return container

		graph_parameters = {
			"traces": [[alpha_surface, beta_surface], [scatter, None]],
			"layout": layout,
			"fig_type": "Widget",
			"fig_parameters": fig_parameters,
			"container": container,
		}
		return graph_parameters

	PLOT_MAPPING = {
		0: graphScatter,
		1: NotImplemented,
		2: graphHistogram,
	}

In [5]:
"""
quick_run_graph_variational.py
One-off smoke-test – **no pytest** – for
Plots.graphVariational + Plots.createGraph
"""

import numpy as np
import plotly.graph_objects as go

# ----- replace this import path with the real one in your project

# -----------------------------------------------------------------


def main():
	# --------------------------------------------------------------
	# 1. Create demo data  (square so z & y match: alpha_len == beta_len)
	# --------------------------------------------------------------
	rng = np.random.default_rng(2025)
	alpha_len = beta_len = 5	#  <-- equal lengths avoids shape mismatch
	t_len = 40
	data = rng.random((alpha_len, beta_len, t_len))

	# --------------------------------------------------------------
	# 2. Build the graph parameters and the FigureWidget
	# --------------------------------------------------------------
	gp = Plots.graphVariational(
		data,
		alpha_name="α",
		beta_name="β",
		function_name="ℒ",
	)

	fig = Plots.createGraph(gp, display_graph=False)	# no auto-display

	# --------------------------------------------------------------
	# 3. Console sanity checks
	# --------------------------------------------------------------
	print("✓ got FigureWidget :", isinstance(fig, go.FigureWidget))
	print("  trace types      :", [tr.type for tr in fig.data])
	# Surface-0 grid shape consistency
	surf0 = fig.data[0]
	print("  surface-0 z-shape:", np.shape(surf0.z), " | Y-grid shape:", np.shape(surf0.y))
	# Custom axis labels
	print("  scene-1 Y label  :", fig.layout.scene.yaxis.title.text)
	print("  scene-2 Y label  :", fig.layout.scene2.yaxis.title.text)

	# --------------------------------------------------------------
	# 4. Show the widget if you’re in a GUI / notebook
	# --------------------------------------------------------------
	# fig.show()


if __name__ == "__main__":
	main()

NameError: name 'fig_parameters' is not defined

In [None]:
# params = Plots.graphScatter([[3, 4, 5], [0, 1, 2]])
# fig = Plots.createGraph(params)

In [None]:
import numpy as np
import plotly.graph_objects as go


# Assume the user's provided code for Plots, Figure_Methods, etc. is defined above this line.
#
# ... (user's code from the previous prompt) ...
def testSimplePlot():
	"""Verifies creation of a single scatter plot without subplots."""
	print("Running: testSimplePlot")

	# 1. Arrange
	x_data = [1, 2, 3, 4]
	y_data = [10, 11, 12, 13]

	graph_params = Plots.graphScatter(data=[y_data, x_data])

	# 2. Act
	fig = Plots.createGraph(graph_params, display_graph=False)

	# 3. Assert
	# Corrected assertion: Convert tuple from fig.data to list for comparison
	display(fig)
	print("Passed: testSimplePlot\n")
	return fig


def testOverlayPlot():
	"""Verifies overlaying two traces on a single non-subplot figure."""
	print("Running: testOverlayPlot")

	# 1. Arrange
	trace_one = go.Scatter(x=[1, 2, 3], y=[1, 2, 3], name="Trace 1")
	trace_two = go.Scatter(x=[1, 2, 3], y=[3, 2, 1], name="Trace 2")

	graph_params = {"traces": [trace_one, trace_two], "layout": {}}

	# 2. Act
	fig = Plots.createGraph(graph_params, display_graph=False)

	# 3. Assert
	# assert len(fig.data) == 2, f"Expected 2 traces, but found {len(fig.data)}"
	# assert fig.data[0].name == "Trace 1"
	# assert fig.data[1].name == "Trace 2"
	# Corrected assertion: Convert tuple from fig.data to list for comparison
	# assert list(fig.data[1].y) == [3, 2, 1]
	display(fig)
	print("Passed: testOverlayPlot\n")
	return fig


def testOverlayPlot():
	"""Verifies overlaying two traces on a single non-subplot figure."""
	print("Running: testOverlayPlot")

	# 1. Arrange
	trace_one = go.Scatter(x=[1, 2, 3], y=[1, 2, 3], name="Trace 1")
	trace_two = go.Scatter(x=[1, 2, 3], y=[3, 2, 1], name="Trace 2")

	graph_params = {
		"traces": [trace_one, trace_two],
		"layout": {},
	}

	# 2. Act
	fig = Plots.createGraph(graph_params, display_graph=False)

	# 3. Assert
	# assert len(fig.data) == 2, f"Expected 2 traces, but found {len(fig.data)}"
	# assert fig.data[0].name == "Trace 1"
	# assert fig.data[1].name == "Trace 2"
	# assert all(fig.data[1].y == [3, 2, 1])
	display(fig)

	print("Passed: testOverlayPlot\n")
	return fig


def testSubplots():
	"""Verifies creation of a 2x1 subplot figure with one trace per plot."""
	print("Running: testSubplots")

	# 1. Arrange
	trace_one = go.Scatter(x=[1, 2], y=[1, 1], name="Subplot 1")
	trace_two = go.Scatter(x=[1, 2], y=[2, 2], name="Subplot 2")

	# Note: The structure needs to be a 2D array-like for subplots
	traces_grid = [[trace_one], [trace_two]]

	graph_params = {
		"traces": traces_grid,
		"layout": {},
		"fig_parameters": {"subplot_titles": ["Top", "Bottom"]},
	}

	# 2. Act
	fig = Plots.createGraph(graph_params, display_graph=False)

	# 3. Assert
	# assert len(fig.data) == 2, f"Expected 2 traces, but found {len(fig.data)}"
	# assert fig.data[0].name == "Subplot 1"
	# assert fig.data[1].name == "Subplot 2"
	# assert fig.data[0].yaxis == "y"	# Belongs to first subplot
	# assert fig.data[1].yaxis == "y2"	# Belongs to second subplot
	display(fig)

	print("Passed: testSubplots\n")


def testSubplotsWithOverlay():
	"""Verifies an overlay on one subplot within a multi-subplot figure."""
	print("Running: testSubplotsWithOverlay")

	# 1. Arrange
	overlay_one = go.Scatter(x=[1, 2], y=[1, 1], name="Overlay 1")
	overlay_two = go.Scatter(x=[1, 2], y=[2, 2], name="Overlay 2")
	single_trace = go.Scatter(x=[3, 4], y=[3, 3], name="Single")

	# Traces for a 1x2 grid. The first subplot gets a list of traces.
	traces_grid = [[[overlay_one, overlay_two], single_trace]]

	graph_params = {"traces": traces_grid, "layout": {}, "fig_parameters": {}}

	# 2. Act
	fig = Plots.createGraph(graph_params, display_graph=False)

	# 3. Assert
	# assert len(fig.data) == 3, f"Expected 3 traces, but found {len(fig.data)}"

	# Find traces by subplot assignment
	# subplot_1_traces = [t for t in fig.data if t.xaxis == "x" and t.yaxis == "y"]
	# subplot_2_traces = [t for t in fig.data if t.xaxis == "x2" and t.yaxis == "y2"]

	# assert len(subplot_1_traces) == 2, "Expected 2 traces in the first subplot"
	# assert len(subplot_2_traces) == 1, "Expected 1 trace in the second subplot"

	# The single trace should be in the second subplot
	# assert subplot_2_traces[0].name == "Single"

	# The two overlay traces should be in the first subplot
	# names_in_subplot_1 = {t.name for t in subplot_1_traces}
	# assert names_in_subplot_1 == {"Overlay 1", "Overlay 2"}
	display(fig)
	print("Passed: testSubplotsWithOverlay\n")

In [None]:
testSimplePlot()
f = testOverlayPlot()
# f.data
testSubplots()
testSubplotsWithOverlay()
# print("All tests completed.")

In [None]:
a = [[1, 2, 3, [4, 5]]]
a = np.array(a, dtype=object)
np.ndim(a)

In [None]:
import numpy as np

# Example inhomogeneous array
arr = np.array([[5, [6, 7]], [9, [6, 7]]], dtype=object)

# Flatten while keeping inner lists intact
# flattened = [item for sublist in arr for item in (sublist if isinstance(sublist, (list, tuple)) else [sublist])]
np.hstack(arr)
# arr
# print(flattened)

In [None]:
# fig = make_subplots(rows=1, cols=2)
# trace = go.Scatter(x=[1, 2, 3], y=[1, 2, 3])
# # trace.update_layout(dict(xaxis="x2", yaxis="y2"))
# trace._orphan_props.update({"xaxis": "x2", "yaxis": "y2"})
# fig.add_traces(
# 	[
# 		go.Scatter(x=[1, 2, 3], y=[1, 2, 3], name="t"),
# 		go.Scatter(x=[1, 2, 3], y=[1, 2, 3]),
# 		trace,
# 	]
# )