Skip to content

Commit

Permalink
Many improvements on the backend side to extend the ST method
Browse files Browse the repository at this point in the history
  • Loading branch information
pablormier committed May 14, 2024
1 parent 1910439 commit 7dabb0f
Show file tree
Hide file tree
Showing 8 changed files with 546 additions and 101 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/deploy-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@ jobs:
git add .
git commit -m "Update docs for branch $BRANCH_NAME" --allow-empty
git push -u origin gh-pages

# Optionally, remove the temporary directory (not necessary on CI environments)
# rm -rf "$TEMP_DIR"
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
63 changes: 57 additions & 6 deletions corneto/backend/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,22 @@ def _elementwise_mul(self, other: Any) -> Any:
def multiply(self, other: Any) -> "CExpression":
return self._elementwise_mul(other)

@abc.abstractmethod
def _hstack(self, other: "CExpression") -> Any:
pass

@_delegate
def hstack(self, other: "CExpression") -> "CExpression":
return self._hstack(other)

@abc.abstractmethod
def _vstack(self, other: "CExpression") -> Any:
pass

@_delegate
def vstack(self, other: "CExpression") -> "CExpression":
return self._vstack(other)

@abc.abstractmethod
def _norm(self, p: int = 2) -> Any:
pass
Expand Down Expand Up @@ -468,7 +484,7 @@ def __iadd__(self, other: Any) -> "ProblemDef":
def solve(
self,
solver: Optional[Union[str, Solver]] = None,
max_seconds: int = None,
max_seconds: Optional[int] = None,
warm_start: bool = False,
verbosity: int = 0,
**options,
Expand Down Expand Up @@ -1049,26 +1065,61 @@ def Xor(self, x: CExpression, y: CExpression, varname="_xor"):
[xor >= x - y, xor >= y - x, xor <= x + y, xor <= 2 - x - y]
)

def linear_or(self, x: CSymbol, axis: Optional[int] = None, varname="_linear_or"):
# Check if the variable is binary, otherwise throw an error
if x._vartype != VarType.BINARY:
def linear_or(self, x: CExpression, axis: Optional[int] = None, varname="or"):
# Check if the variable has a vartype and is binary
if hasattr(x, "_vartype") and x._vartype != VarType.BINARY:
raise ValueError(f"Variable x has type {x._vartype} instead of BINARY")
else:
for s in x._proxy_symbols:
if s._vartype != VarType.BINARY:
# Show warning only
LOGGER.warn(
f"Variable {s.name} has type {s._vartype}, expression is assumed to be binary"
)
break

Z = x.sum(axis=axis)
Z_norm = Z / x.shape[axis] # between 0-1
# Create a new binary variable to compute linearized or
Or = self.Variable(varname, Z.shape, 0, 1, vartype=VarType.BINARY)
return self.Problem([Or >= Z_norm, Or <= Z])

def linear_and(self, x: CSymbol, axis: Optional[int] = None, varname="_linear_and"):
def linear_and(self, x: CExpression, axis: Optional[int] = None, varname="and"):
# Check if the variable is binary, otherwise throw an error
if x._vartype != VarType.BINARY:
if hasattr(x, "_vartype") and x._vartype != VarType.BINARY:
raise ValueError(f"Variable x has type {x._vartype} instead of BINARY")
else:
for s in x._proxy_symbols:
if s._vartype != VarType.BINARY:
# Show warning only
LOGGER.warn(
f"Variable {s.name} has type {s._vartype}, expression is assumed to be binary"
)
break
Z = x.sum(axis=axis)
N = x.shape[axis]
Z_norm = Z / N
And = self.Variable(varname, Z.shape, 0, 1, vartype=VarType.BINARY)
return self.Problem([And <= Z_norm, And >= Z - N + 1])

def vstack(self, arg_list: Iterable[CSymbol]):
v = None
for a in arg_list:
if v is None:
v = a
else:
v = v.vstack(a)
return v

def hstack(self, arg_list: Iterable[CSymbol]):
h = None
for a in arg_list:
if h is None:
h = a
else:
h = h.hstack(a)
return h


class NoBackend(Backend):
def __init__(self) -> None:
Expand Down
15 changes: 14 additions & 1 deletion corneto/backend/_cvxpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def _sum(self, axis: Optional[int] = None) -> Any:
def _max(self, axis: Optional[int] = None) -> Any:
return cp.max(self._expr, axis=axis)

def _hstack(self, other: CExpression) -> Any:
return cp.hstack([self._expr, other])

def _vstack(self, other: CExpression) -> Any:
return cp.vstack([self._expr, other])

@property
def value(self) -> np.ndarray:
return self._expr.value
Expand Down Expand Up @@ -169,10 +175,17 @@ def _solve(
cfg = options.get("mosek_params", dict())
cfg.update({"mioMaxTime": float(max_seconds)})
options["mosek_params"] = cfg
elif s == "SCIPY":
elif s == "SCIPY" or s == "HIGHS":
cfg = options.get("scipy_options", dict())
cfg.update({"time_limit": float(max_seconds), "disp": verbosity > 0})
options["scipy_options"] = cfg
else:
# Warning that a mapping is not available, check backend documentation
LOGGER.warn(f"""max_seconds parameter mapping for {s} not found.
Please refer to the backend documentation for more
information. For example, using CVXPY with GUROBI solver,
the parameter TimeLimit can be directly passed with
`problem.solve(solver='GUROBI', TimeLimit=max_seconds)`""")

P.solve(solver=s, verbose=verbosity > 0, warm_start=warm_start, **options)
return P
6 changes: 6 additions & 0 deletions corneto/backend/_picos_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def _sum(self, axis: Optional[int] = None) -> Any:
def _max(self, axis: Optional[int] = None) -> Any:
raise NotImplementedError()

def _hstack(self, other: CExpression) -> Any:
return self._expr & other

def _vstack(self, other: CExpression) -> Any:
return self._expr // other

@property
def value(self) -> np.ndarray:
return self._expr.value
Expand Down
Empty file.
Loading

0 comments on commit 7dabb0f

Please sign in to comment.