From 7dabb0fd422241e6f92ba1e4b26ad896cf7abb75 Mon Sep 17 00:00:00 2001 From: "Pablo R. Mier" Date: Tue, 14 May 2024 15:32:34 +0200 Subject: [PATCH] Many improvements on the backend side to extend the ST method --- .github/workflows/deploy-docs.yml | 3 - corneto/backend/_base.py | 63 +++++- corneto/backend/_cvxpy_backend.py | 15 +- corneto/backend/_picos_backend.py | 6 + corneto/experimental/__init__.py | 0 corneto/methods/steiner.py | 249 +++++++++++++++++++++- docs/guide/intro/prior-knowledge.ipynb | 281 +++++++++++++++++-------- tests/test_backend.py | 30 +++ 8 files changed, 546 insertions(+), 101 deletions(-) create mode 100644 corneto/experimental/__init__.py diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml index 22a2b10..9a11ae5 100644 --- a/.github/workflows/deploy-docs.yml +++ b/.github/workflows/deploy-docs.yml @@ -59,9 +59,6 @@ jobs: git add . git commit -m "Update docs for branch $BRANCH_NAME" --allow-empty git push -u origin gh-pages - - # Optionally, remove the temporary directory (not necessary on CI environments) - # rm -rf "$TEMP_DIR" env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} BRANCH_NAME: ${{ github.head_ref || github.ref_name }} diff --git a/corneto/backend/_base.py b/corneto/backend/_base.py index c550636..601609b 100644 --- a/corneto/backend/_base.py +++ b/corneto/backend/_base.py @@ -109,6 +109,22 @@ def _elementwise_mul(self, other: Any) -> Any: def multiply(self, other: Any) -> "CExpression": return self._elementwise_mul(other) + @abc.abstractmethod + def _hstack(self, other: "CExpression") -> Any: + pass + + @_delegate + def hstack(self, other: "CExpression") -> "CExpression": + return self._hstack(other) + + @abc.abstractmethod + def _vstack(self, other: "CExpression") -> Any: + pass + + @_delegate + def vstack(self, other: "CExpression") -> "CExpression": + return self._vstack(other) + @abc.abstractmethod def _norm(self, p: int = 2) -> Any: pass @@ -468,7 +484,7 @@ def __iadd__(self, other: Any) -> "ProblemDef": def solve( self, solver: Optional[Union[str, Solver]] = None, - max_seconds: int = None, + max_seconds: Optional[int] = None, warm_start: bool = False, verbosity: int = 0, **options, @@ -1049,26 +1065,61 @@ def Xor(self, x: CExpression, y: CExpression, varname="_xor"): [xor >= x - y, xor >= y - x, xor <= x + y, xor <= 2 - x - y] ) - def linear_or(self, x: CSymbol, axis: Optional[int] = None, varname="_linear_or"): - # Check if the variable is binary, otherwise throw an error - if x._vartype != VarType.BINARY: + def linear_or(self, x: CExpression, axis: Optional[int] = None, varname="or"): + # Check if the variable has a vartype and is binary + if hasattr(x, "_vartype") and x._vartype != VarType.BINARY: raise ValueError(f"Variable x has type {x._vartype} instead of BINARY") + else: + for s in x._proxy_symbols: + if s._vartype != VarType.BINARY: + # Show warning only + LOGGER.warn( + f"Variable {s.name} has type {s._vartype}, expression is assumed to be binary" + ) + break + Z = x.sum(axis=axis) Z_norm = Z / x.shape[axis] # between 0-1 # Create a new binary variable to compute linearized or Or = self.Variable(varname, Z.shape, 0, 1, vartype=VarType.BINARY) return self.Problem([Or >= Z_norm, Or <= Z]) - def linear_and(self, x: CSymbol, axis: Optional[int] = None, varname="_linear_and"): + def linear_and(self, x: CExpression, axis: Optional[int] = None, varname="and"): # Check if the variable is binary, otherwise throw an error - if x._vartype != VarType.BINARY: + if hasattr(x, "_vartype") and x._vartype != VarType.BINARY: raise ValueError(f"Variable x has type {x._vartype} instead of BINARY") + else: + for s in x._proxy_symbols: + if s._vartype != VarType.BINARY: + # Show warning only + LOGGER.warn( + f"Variable {s.name} has type {s._vartype}, expression is assumed to be binary" + ) + break Z = x.sum(axis=axis) N = x.shape[axis] Z_norm = Z / N And = self.Variable(varname, Z.shape, 0, 1, vartype=VarType.BINARY) return self.Problem([And <= Z_norm, And >= Z - N + 1]) + def vstack(self, arg_list: Iterable[CSymbol]): + v = None + for a in arg_list: + if v is None: + v = a + else: + v = v.vstack(a) + return v + + def hstack(self, arg_list: Iterable[CSymbol]): + h = None + for a in arg_list: + if h is None: + h = a + else: + h = h.hstack(a) + return h + class NoBackend(Backend): def __init__(self) -> None: diff --git a/corneto/backend/_cvxpy_backend.py b/corneto/backend/_cvxpy_backend.py index 08b199e..f0bc6f1 100644 --- a/corneto/backend/_cvxpy_backend.py +++ b/corneto/backend/_cvxpy_backend.py @@ -40,6 +40,12 @@ def _sum(self, axis: Optional[int] = None) -> Any: def _max(self, axis: Optional[int] = None) -> Any: return cp.max(self._expr, axis=axis) + def _hstack(self, other: CExpression) -> Any: + return cp.hstack([self._expr, other]) + + def _vstack(self, other: CExpression) -> Any: + return cp.vstack([self._expr, other]) + @property def value(self) -> np.ndarray: return self._expr.value @@ -169,10 +175,17 @@ def _solve( cfg = options.get("mosek_params", dict()) cfg.update({"mioMaxTime": float(max_seconds)}) options["mosek_params"] = cfg - elif s == "SCIPY": + elif s == "SCIPY" or s == "HIGHS": cfg = options.get("scipy_options", dict()) cfg.update({"time_limit": float(max_seconds), "disp": verbosity > 0}) options["scipy_options"] = cfg + else: + # Warning that a mapping is not available, check backend documentation + LOGGER.warn(f"""max_seconds parameter mapping for {s} not found. + Please refer to the backend documentation for more + information. For example, using CVXPY with GUROBI solver, + the parameter TimeLimit can be directly passed with + `problem.solve(solver='GUROBI', TimeLimit=max_seconds)`""") P.solve(solver=s, verbose=verbosity > 0, warm_start=warm_start, **options) return P diff --git a/corneto/backend/_picos_backend.py b/corneto/backend/_picos_backend.py index a98e363..75047e4 100644 --- a/corneto/backend/_picos_backend.py +++ b/corneto/backend/_picos_backend.py @@ -40,6 +40,12 @@ def _sum(self, axis: Optional[int] = None) -> Any: def _max(self, axis: Optional[int] = None) -> Any: raise NotImplementedError() + def _hstack(self, other: CExpression) -> Any: + return self._expr & other + + def _vstack(self, other: CExpression) -> Any: + return self._expr // other + @property def value(self) -> np.ndarray: return self._expr.value diff --git a/corneto/experimental/__init__.py b/corneto/experimental/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/corneto/methods/steiner.py b/corneto/methods/steiner.py index 501f71b..70e2e76 100644 --- a/corneto/methods/steiner.py +++ b/corneto/methods/steiner.py @@ -1,10 +1,11 @@ import numpy as np +from corneto._constants import VAR_FLOW from corneto._graph import Attr, BaseGraph, EdgeType from corneto.backend import DEFAULT_BACKEND, Backend -def exact_steiner_tree( +def __exact_steiner_tree( G: BaseGraph, terminals, edge_weights=None, @@ -97,3 +98,249 @@ def exact_steiner_tree( # For vertices, we need to check in which edges they appear. # If any of those edges is selected, then the node was collected return P, Gc + + +def exact_steiner_tree( + G: BaseGraph, + terminals, + edge_weights=None, + root=None, + tolerance=1e-3, + strict_acyclic=False, + flow_name=VAR_FLOW, + backend: Backend = DEFAULT_BACKEND, +): + prized_nodes, prizes = [], [] + if isinstance(terminals, dict): + prized = {k: v for k, v in terminals.items() if v != 0} + if len(prized) > 0: + prized_nodes, prizes = zip(*prized.items()) + # non_zero_prizes = [k for k, v in terminals.items() if v != 0] + terminals = list(terminals.keys()) + + # V = {v: i for i, v in enumerate(G.V)} + dummy_edges = dict() + K = backend + if K is None: + raise ValueError("Invalid backend") + # If root not provided, take the first terminal as a root. + # Note that in undirected graphs, it doesn't matter the root node as long as + # the graph is connected + if root is None: + root = terminals[0] + Gc = G.copy() + # TODO: If graph is directed, edges for inflow/outflow should be reversible + eidx = Gc.add_edge( + (), root, type=EdgeType.UNDIRECTED + ) # () -> root (input flow, source flow node) + dummy_edges[root] = eidx + ids = [] + for v in terminals: + if v != root: + idx = Gc.add_edge(v, (), type=EdgeType.UNDIRECTED) + ids.append(idx) # terminal -> () (sink node, remove flow) + dummy_edges[v] = idx + ids = np.array(ids) + # lower/upper bounds for flow. If directed, lb=0, ub>0, if undirected, lb<0, ub>0 + # NOTE: bounds are arbitrary, but very large/small numbers can introduce issues with integrality tolerances + # TODO: lb/ub could be provided, or taken from the graph + lb = np.array( + [ + 0 if prop.has_attr(Attr.EDGE_TYPE, EdgeType.DIRECTED) else -10 + for prop in Gc.get_attr_edges() + ] + ) + if strict_acyclic: + P = K.AcyclicFlow(Gc, lb=lb, ub=10, varname=flow_name) + else: + P = K.Flow(Gc, lb=lb, ub=10, varname=flow_name) + ids_e = list(set(range(Gc.ne)) - set(ids + [eidx])) + # Indicators for the edges (1=unconstrained, 0=blocked flow) + F = P.expr[flow_name] + P += K.Indicator(F, indexes=ids_e) + Fi = P.expr[f"{flow_name}_i"] + if strict_acyclic: + P += P.expr[f"{flow_name}_ipos"] + P.expr[f"{flow_name}_ineg"] <= Fi + + # TODO: Take as argument, read from graph + if edge_weights is None: + edge_weights = np.array([prop.get("weight", 0) for prop in Gc.get_attr_edges()]) + elif isinstance(edge_weights, (list, tuple)): + edge_weights = np.array(edge_weights) + else: + raise ValueError("Unknown type for edge_weights (list or tuple)") + + P.add_objectives(edge_weights @ Fi) # sum the total cost of selected edges + + if len(prized_nodes) == 0: + # If not prized + P += F[eidx] == 10 # inject non-zero flow + P += F[ids] >= 10 / (len(terminals) + 1) # Force all terminals to be present + else: + id_edge_prized = np.array([dummy_edges[v] for v in prized_nodes]) + P += K.NonZeroIndicator(F, indexes=id_edge_prized, tolerance=tolerance) + I_prized_selected = ( + P.symbols[f"{flow_name}_ipos"] + P.symbols[f"{flow_name}_ineg"] + ) + P.add_objectives(np.array(prizes) @ I_prized_selected, weights=-1) + + # Add an objective for non-zero flow on prized nodes + + # For vertices, we need to check in which edges they appear. + # If any of those edges is selected, then the node was collected + return P, Gc + + +def create_exact_multi_steiner_tree( + G: BaseGraph, + terminal_per_condition, + edge_weights_per_condition=None, + root_vertices=None, + tolerance=1e-3, + strict_acyclic=False, + lam=0.01, + flow_name="flow", + backend: Backend = DEFAULT_BACKEND, +): + if backend is None: + raise ValueError("Invalid backend") + # Detect the number of conditions + if isinstance(terminal_per_condition, list): + # Get all the internal lists (conditions) + conditions = [l for l in terminal_per_condition if isinstance(l, list)] + elif isinstance(terminal_per_condition, dict): + conditions = [d for d in terminal_per_condition.values() if isinstance(d, dict)] + else: + raise ValueError("Invalid terminals format") + # Create the N problems + num_vertices, num_edges = G.shape + big_P = None + for i in range(len(conditions)): + terminals = conditions[i] + P, Gc = exact_steiner_tree( + G, + terminals, + edge_weights=edge_weights_per_condition[i] + if edge_weights_per_condition is not None + else None, + root=root_vertices[i] if root_vertices is not None else None, + tolerance=tolerance, + strict_acyclic=strict_acyclic, + backend=backend, + flow_name=f"{flow_name}{i}", + ) + if big_P is None: + big_P = P + else: + big_P += P + + # We create a linking binary vector computing the or of the selected edges + vars = [big_P.expr[f"{flow_name}{i}_i"][:num_edges] for i in range(len(conditions))] + I = backend.vstack(vars) + big_P += backend.linear_or(I, axis=0, varname="is_unblocked") + # big_P.register("is_unblocked", v_or) + big_P.add_objectives(sum(big_P.expr.is_unblocked), weights=lam) + return big_P + + +def _exact_multi_steiner_tree( + G: BaseGraph, + terminal_per_condition, + edge_weights_per_condition=None, + root_vertices=None, + tolerance=1e-3, + strict_acyclic=False, + backend: Backend = DEFAULT_BACKEND, +): + if backend is None: + raise ValueError("Invalid backend") + # Detect the number of conditions + if isinstance(terminal_per_condition, list): + # Get all the internal lists (conditions) + conditions = [l for l in terminal_per_condition if isinstance(l, list)] + elif isinstance(terminal_per_condition, dict): + conditions = [d for d in terminal_per_condition.values() if isinstance(d, dict)] + else: + raise ValueError("Invalid terminals format") + + Gc = G.copy() + for i in range(len(conditions)): + prized_nodes, prizes = [], [] + terminals = conditions[i] + if isinstance(terminals, dict): + prized = {k: v for k, v in terminals.items() if v != 0} + if len(prized) > 0: + prized_nodes, prizes = zip(*prized.items()) + terminals = list(terminals.keys()) + + dummy_edges = dict() + + if root_vertices is None: + root = terminals[0] + else: + root = root_vertices[i] + # TODO: Check if this root does not have a flow edge + eidx = Gc.add_edge( + (), root, type=EdgeType.UNDIRECTED + ) # () -> root (input flow, source flow node) + dummy_edges[root] = eidx + ids = [] + for v in terminals: + if v != root: + idx = Gc.add_edge(v, (), type=EdgeType.UNDIRECTED) + ids.append(idx) # terminal -> () (sink node, remove flow) + dummy_edges[v] = idx + ids = np.array(ids) + lb = np.array( + [ + 0 if prop.has_attr(Attr.EDGE_TYPE, EdgeType.DIRECTED) else -10 + for prop in Gc.get_attr_edges() + ] + ) + + if strict_acyclic: + raise NotImplementedError( + "Acyclic Flow for multi flows required, or Acyclic component (WIP)" + ) + else: + P = backend.Flow(Gc, lb=lb, ub=10, n_flows=len(conditions)) + + for i in range(len(conditions)): + F = P.expr.flow[:, i] + # Here there are differences, depending on if terminals have prizes or not. + # If they have prizes, terminals are optional. In order to be optimal, in-out flow + # through terminals have to be optional. + ids_e = list(set(range(Gc.ne)) - set([*ids, eidx])) + # Indicators for the edges (1=unconstrained, 0=blocked flow) + P += backend.Indicator(F, indexes=ids_e) + Fi = P.symbols["_flow_i"] + if strict_acyclic: + P += P.symbols["_flow_ipos"] + P.symbols["_flow_ineg"] <= Fi + edge_weights = None + if edge_weights_per_condition is not None: + edge_weights = edge_weights_per_condition[i] + + if edge_weights is None: + edge_weights = np.array( + [prop.get("weight", 0) for prop in Gc.get_attr_edges()] + ) + elif isinstance(edge_weights, (list, tuple)): + edge_weights = np.array(edge_weights) + else: + raise ValueError("Unknown type for edge_weights (list or tuple)") + + P.add_objectives(edge_weights @ Fi) # sum the total cost of selected edges + + if len(prized_nodes) == 0: + # If not prized, force all terminals to be present + # NOTE: This can lead to infeasibilities if a terminal cannot be connected + P += F[eidx] == 10 # inject non-zero flow + P += F[ids] >= 10 / (len(terminals) + 1) + else: + id_edge_prized = np.array([dummy_edges[v] for v in prized_nodes]) + P += backend.NonZeroIndicator( + F, indexes=id_edge_prized, tolerance=tolerance + ) + I_prized_selected = P.symbols["_flow_ipos"] + P.symbols["_flow_ineg"] + # Maximize the selection of nodes with prizes + P.add_objectives(np.array(prizes) @ I_prized_selected, weights=-1) diff --git a/docs/guide/intro/prior-knowledge.ipynb b/docs/guide/intro/prior-knowledge.ipynb index c05524c..ce31ea0 100644 --- a/docs/guide/intro/prior-knowledge.ipynb +++ b/docs/guide/intro/prior-knowledge.ipynb @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 40, "id": "a1684e94", "metadata": {}, "outputs": [ @@ -76,7 +76,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 41, "id": "4512c595", "metadata": {}, "outputs": [ @@ -133,10 +133,10 @@ "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 2, + "execution_count": 41, "metadata": {}, "output_type": "execute_result" } @@ -159,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 42, "id": "ad64bcb7", "metadata": {}, "outputs": [ @@ -234,10 +234,10 @@ "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 3, + "execution_count": 42, "metadata": {}, "output_type": "execute_result" } @@ -258,7 +258,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 43, "id": "d87919e0", "metadata": {}, "outputs": [ @@ -337,10 +337,10 @@ "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 4, + "execution_count": 43, "metadata": {}, "output_type": "execute_result" } @@ -351,6 +351,61 @@ "G.plot()" ] }, + { + "cell_type": "markdown", + "id": "f924788c", + "metadata": {}, + "source": [ + "Order of vertices and edges are preserved in the order of addition. Given an edge (u, v) added to the graph, u is added to the graph only if it's not already present." + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "8768fab0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1, 2, 3, 4)" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "G.V" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "8fd6ecf9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((frozenset({1}), frozenset({2})),\n", + " (frozenset({2}), frozenset({3})),\n", + " (frozenset({1}), frozenset({3})),\n", + " (frozenset({3}), frozenset({4})),\n", + " (frozenset({3}), frozenset({4})),\n", + " (frozenset({3}), frozenset({4})))" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "G.E" + ] + }, { "cell_type": "markdown", "id": "3b793e80", @@ -361,7 +416,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 46, "id": "cfc44e12", "metadata": {}, "outputs": [ @@ -375,7 +430,7 @@ " '__target_attr': {4: {'__value': {}}}}" ] }, - "execution_count": 5, + "execution_count": 46, "metadata": {}, "output_type": "execute_result" } @@ -388,7 +443,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 47, "id": "990951f4", "metadata": {}, "outputs": [ @@ -398,7 +453,7 @@ "0.5" ] }, - "execution_count": 6, + "execution_count": 47, "metadata": {}, "output_type": "execute_result" } @@ -409,7 +464,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 48, "id": "fe14480b", "metadata": {}, "outputs": [ @@ -423,7 +478,7 @@ " '__target_attr': {4: {'__value': {}}}}" ] }, - "execution_count": 7, + "execution_count": 48, "metadata": {}, "output_type": "execute_result" } @@ -434,7 +489,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 49, "id": "215eebcf", "metadata": {}, "outputs": [ @@ -444,7 +499,7 @@ "[6]" ] }, - "execution_count": 8, + "execution_count": 49, "metadata": {}, "output_type": "execute_result" } @@ -455,7 +510,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 50, "id": "58cc1998", "metadata": {}, "outputs": [ @@ -465,7 +520,7 @@ "(frozenset({3}), frozenset({4}))" ] }, - "execution_count": 9, + "execution_count": 50, "metadata": {}, "output_type": "execute_result" } @@ -485,7 +540,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 51, "id": "281142cf", "metadata": {}, "outputs": [ @@ -495,7 +550,7 @@ "{3: {'__value': {}}}" ] }, - "execution_count": 10, + "execution_count": 51, "metadata": {}, "output_type": "execute_result" } @@ -507,7 +562,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 52, "id": "9c1df45f", "metadata": {}, "outputs": [ @@ -517,7 +572,7 @@ "{3: {'__value': {}}}" ] }, - "execution_count": 11, + "execution_count": 52, "metadata": {}, "output_type": "execute_result" } @@ -548,7 +603,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 53, "id": "d5b8730c", "metadata": {}, "outputs": [ @@ -600,10 +655,10 @@ "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 12, + "execution_count": 53, "metadata": {}, "output_type": "execute_result" } @@ -615,7 +670,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 54, "id": "558c4a7e", "metadata": {}, "outputs": [ @@ -632,7 +687,7 @@ " '__target_attr': {'C': {'__value': {}}}}]" ] }, - "execution_count": 13, + "execution_count": 54, "metadata": {}, "output_type": "execute_result" } @@ -643,7 +698,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 55, "id": "c3310b77", "metadata": {}, "outputs": [], @@ -663,7 +718,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 56, "id": "d29ddc6e", "metadata": {}, "outputs": [], @@ -673,7 +728,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 57, "id": "035d04d4", "metadata": {}, "outputs": [ @@ -758,10 +813,10 @@ "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 16, + "execution_count": 57, "metadata": {}, "output_type": "execute_result" } @@ -789,7 +844,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 62, "id": "20ebd58d", "metadata": {}, "outputs": [ @@ -809,143 +864,189 @@ "\n", "\n", "A\n", - "\n", - "A\n", + "\n", + "A\n", "\n", "\n", "\n", "e_0_center\n", - "\n", + "\n", "\n", "\n", "\n", "A->e_0_center\n", - "\n", + "\n", "\n", "\n", "\n", "B\n", - "\n", - "B\n", + "\n", + "B\n", "\n", "\n", "\n", "B->e_0_center\n", - "\n", - "\n", - "\n", - "\n", - "D\n", - "\n", - "D\n", - "\n", - "\n", - "\n", - "e_1_center\n", - "\n", - "\n", - "\n", - "\n", - "D->e_1_center\n", - "\n", + "\n", "\n", "\n", - "\n", + "\n", "C\n", - "\n", - "C\n", + "\n", + "C\n", "\n", "\n", "\n", "e_4_target\n", - "\n", + "\n", "\n", "\n", "\n", "C->e_4_target\n", - "\n", - "\n", + "\n", + "\n", "\n", - "\n", - "\n", - "e_0_center->D\n", - "\n", - "\n", + "\n", + "\n", + "D\n", + "\n", + "D\n", + "\n", + "\n", + "\n", + "e_1_center\n", + "\n", + "\n", + "\n", + "\n", + "D->e_1_center\n", + "\n", "\n", "\n", - "\n", + "\n", "e_0_center->C\n", - "\n", - "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "e_0_center->D\n", + "\n", + "\n", "\n", "\n", "\n", "F\n", - "\n", - "F\n", + "\n", + "F\n", "\n", "\n", "\n", "E\n", - "\n", - "E\n", + "\n", + "E\n", "\n", "\n", "\n", "e_1_center->F\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "e_1_center->E\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "e_2_source\n", - "\n", + "\n", "\n", "\n", "\n", "e_2_source->A\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "e_3_source\n", - "\n", + "\n", "\n", "\n", "\n", "e_3_source->B\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 19, + "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "G = cn.Graph()\n", - "G.add_edge({\"A\", \"B\"}, {\"C\", \"D\"})\n", - "G.add_edge(\"D\", {\"E\", \"F\"})\n", + "G.add_edge((\"A\", \"B\"), (\"C\", \"D\"))\n", + "G.add_edge(\"D\", (\"E\", \"F\"))\n", "G.add_edge((), \"A\")\n", "G.add_edge((), \"B\")\n", "G.add_edge(\"C\", ())\n", "G.plot()" ] }, + { + "cell_type": "code", + "execution_count": 63, + "id": "9646595e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "('A', 'B', 'C', 'D', 'F', 'E')" + ] + }, + "execution_count": 63, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "G.V" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "1aada7ad", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((frozenset({'A', 'B'}), frozenset({'C', 'D'})),\n", + " (frozenset({'D'}), frozenset({'E', 'F'})),\n", + " (frozenset(), frozenset({'A'})),\n", + " (frozenset(), frozenset({'B'})),\n", + " (frozenset({'C'}), frozenset()))" + ] + }, + "execution_count": 64, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "G.E" + ] + }, { "cell_type": "markdown", "id": "dcbd2a10", diff --git a/tests/test_backend.py b/tests/test_backend.py index 6bf1730..8263729 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -135,6 +135,36 @@ def test_opt_delegate_sum_axis1(backend): assert np.isclose(esum.value, 60) +def test_hstack_expression_matrix(backend): + x = backend.Variable("x", (2, 2)) + y = backend.Variable("y", (2, 3)) + z = x.hstack(y) + assert z.shape == (2, 5) + + +def test_vstack_expression_matrix(backend): + x = backend.Variable("x", (2, 2)) + y = backend.Variable("y", (3, 2)) + z = x.vstack(y) + assert z.shape == (5, 2) + + +def test_hstack_backend(backend): + x = backend.Variable("x", (1, 3)) + y = backend.Variable("y", (1, 6)) + z = backend.Variable("z", (1, 1)) + t = backend.hstack([x, y, z]) + assert t.shape == (1, 10) + + +def test_vstack_backend(backend): + x = backend.Variable("x", (3, 1)) + y = backend.Variable("y", (6, 1)) + z = backend.Variable("z", (1, 1)) + t = backend.vstack([x, y, z]) + assert t.shape == (10, 1) + + def test_cexpression_name(backend): x = backend.Variable("x") e = x <= 10