Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/apply scale by marginals #86

Merged
merged 4 commits into from
Feb 11, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions moscot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
import moscot.costs
import moscot.solvers
import moscot.backends
import moscot.problems
13 changes: 9 additions & 4 deletions moscot/backends/ott/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,17 @@ def transport_matrix(self) -> npt.ArrayLike:
def cost(self) -> float:
return float(self._output.reg_ot_cost)

def _ones(self, n: int) -> jnp.ndarray:
return jnp.ones((n,))


class SinkhornOutput(OTTBaseOutput):
def _apply(self, x: npt.ArrayLike, *, forward: bool) -> npt.ArrayLike:
axis = int(not forward)
if x.ndim == 1:
return self._output.apply(x, axis=axis)
return self._output.apply(x, axis=1 - forward)
if x.ndim == 2:
# convert to batch first
return self._output.apply(x.T, axis=axis).T
return self._output.apply(x.T, axis=1 - forward).T
raise ValueError("TODO - dim error")

@property
Expand Down Expand Up @@ -66,7 +68,7 @@ def shape(self) -> Tuple[int, int]:

@property
def converged(self) -> bool:
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
costs, tol = self._output.costs, self._threshold
costs = self._output.costs
costs = costs[costs != -1]
# TODO(michalk8): is this correct?
# modified the condition from:
Expand All @@ -89,3 +91,6 @@ def cost(self) -> float:
@property
def converged(self) -> bool:
return self._converged

def _ones(self, n: int) -> jnp.ndarray:
return jnp.ones((n,))
2 changes: 1 addition & 1 deletion moscot/backends/ott/_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from moscot.solvers._output import BaseSolverOutput
from moscot.backends.ott._output import GWOutput, SinkhornOutput, LRSinkhornOutput
from moscot.solvers._base_solver import BaseSolver
from moscot.solvers._tagged_arry import TaggedArray
from moscot.solvers._tagged_array import TaggedArray

__all__ = ("Cost", "SinkhornSolver", "GWSolver", "FGWSolver")

Expand Down
2 changes: 1 addition & 1 deletion moscot/problems/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from moscot._utils import _get_backend_losses
from moscot.costs._costs import __all__ as moscot_losses, BaseLoss
from moscot.solvers._tagged_arry import Tag, TaggedArray
from moscot.solvers._tagged_array import Tag, TaggedArray


@dataclass(frozen=True)
Expand Down
8 changes: 5 additions & 3 deletions moscot/problems/_base_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from moscot.solvers._output import BaseSolverOutput
from moscot.problems._anndata import AnnDataPointer
from moscot.solvers._base_solver import BaseSolver
from moscot.solvers._tagged_arry import Tag, TaggedArray
from moscot.solvers._tagged_array import Tag, TaggedArray


class BaseProblem(ABC):
Expand Down Expand Up @@ -219,21 +219,23 @@ def push(
data: Optional[Union[str, npt.ArrayLike]] = None,
subset: Optional[Sequence[Any]] = None,
normalize: bool = True,
**kwargs: Any,
) -> npt.ArrayLike:
# TODO: check if solved - decorator?
data = self._get_mass(self.adata, data=data, subset=subset, normalize=normalize)
return self.solution.push(data)
return self.solution.push(data, **kwargs)

def pull(
self,
data: Optional[Union[str, npt.ArrayLike]] = None,
subset: Optional[Sequence[Any]] = None,
normalize: bool = True,
**kwargs: Any,
) -> npt.ArrayLike:
# TODO: check if solved - decorator?
adata = self.adata if self._adata_y is None else self._adata_y
data = self._get_mass(adata, data=data, subset=subset, normalize=normalize)
return self.solution.pull(data)
return self.solution.pull(data, **kwargs)

@property
def solution(self) -> Optional[BaseSolverOutput]:
Expand Down
4 changes: 3 additions & 1 deletion moscot/problems/_compound_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def _apply(
normalize: bool = True,
forward: bool = True,
return_all: bool = False,
scale_by_marginals: bool = False,
**kwargs: Any,
) -> Dict[Tuple[Any, Any], npt.ArrayLike]:
def get_data(plan: Tuple[Any, Any]) -> Optional[npt.ArrayLike]:
Expand Down Expand Up @@ -108,7 +109,8 @@ def get_data(plan: Tuple[Any, Any]) -> Optional[npt.ArrayLike]:
ds = [get_data(plan)]
for step in steps:
problem = self._problems[step]
ds.append((problem.push if forward else problem.pull)(ds[-1], subset=subset, normalize=normalize))
fun = problem.push if forward else problem.pull
ds.append(fun(ds[-1], subset=subset, normalize=normalize, scale_by_marginals=scale_by_marginals))

# TODO(michalk8): shall we include initial input? or add as option?
res[plan] = ds[1:] if return_all else ds[-1]
Expand Down
2 changes: 1 addition & 1 deletion moscot/solvers/_base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy.typing as npt

from moscot.solvers._output import BaseSolverOutput
from moscot.solvers._tagged_arry import Tag, TaggedArray
from moscot.solvers._tagged_array import Tag, TaggedArray

ArrayLike = Union[npt.ArrayLike, TaggedArray]

Expand Down
21 changes: 19 additions & 2 deletions moscot/solvers/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import numpy.typing as npt


# TODO(michalk8):
# 1. mb. use more contrained type hints
# 2. consider always returning 2-dim array, even if 1-dim is passed (not sure which convenient for user)
class BaseSolverOutput(ABC):
@abstractmethod
def _apply(self, x: npt.ArrayLike, *, forward: bool) -> npt.ArrayLike:
Expand All @@ -29,16 +32,30 @@ def cost(self) -> float:
def converged(self) -> bool:
pass

def push(self, x: npt.ArrayLike) -> npt.ArrayLike:
# TODO(michalk8): mention in docs it needs to be broadcastable
@abstractmethod
def _ones(self, n: int) -> npt.ArrayLike:
pass

def push(self, x: npt.ArrayLike, scale_by_marginals: bool = False) -> npt.ArrayLike:
if x.shape[0] != self.shape[0]:
raise ValueError("TODO: wrong shape")
x = self._scale_by_marginals(x, forward=True) if scale_by_marginals else x
return self._apply(x, forward=True)

def pull(self, x: npt.ArrayLike) -> npt.ArrayLike:
def pull(self, x: npt.ArrayLike, scale_by_marginals: bool = False) -> npt.ArrayLike:
if x.shape[0] != self.shape[1]:
raise ValueError("TODO: wrong shape")
x = self._scale_by_marginals(x, forward=False) if scale_by_marginals else x
return self._apply(x, forward=False)

def _scale_by_marginals(self, x: npt.ArrayLike, *, forward: bool) -> npt.ArrayLike:
# alt. we could use the public push/pull
scale = self._apply(self._ones(self.shape[forward]), forward=not forward)
if x.ndim == 2:
scale = scale[:, None]
return x / scale
michalk8 marked this conversation as resolved.
Show resolved Hide resolved

def _format_params(self, fmt: Callable[[Any], str]) -> str:
params = {"shape": self.shape, "cost": round(self.cost, 4), "converged": self.converged}
return ", ".join(f"{name}={fmt(val)}" for name, val in params.items())
Expand Down
File renamed without changes.