Skip to content

Commit

Permalink
added shape checking to arrays in Brownian return types
Browse files Browse the repository at this point in the history
  • Loading branch information
andyElking committed May 9, 2024
1 parent 53acc09 commit 4f36bb3
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions diffrax/_custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,41 +47,43 @@
Control = PyTree[Shaped[ArrayLike, "?*control"], "C"]
Args = PyTree[Any]

BM = PyTree[Shaped[ArrayLike, "?*bm"], "BM"]

DenseInfo = dict[str, PyTree[Array]]
DenseInfos = dict[str, PyTree[Shaped[Array, "times-1 ..."]]]
BufferDenseInfos = dict[str, PyTree[eqxi.MaybeBuffer[Shaped[Array, "times ..."]]]]
sentinel: Any = eqxi.doc_repr(object(), "sentinel")


class AbstractBrownianIncrement(eqx.Module):
dt: eqx.AbstractVar[PyTree[FloatScalarLike]]
W: eqx.AbstractVar[PyTree[Array]]
dt: eqx.AbstractVar[PyTree[FloatScalarLike, "BM"]]
W: eqx.AbstractVar[BM]


class AbstractSpaceTimeLevyArea(AbstractBrownianIncrement):
H: eqx.AbstractVar[PyTree[Array]]
H: eqx.AbstractVar[BM]


class AbstractSpaceTimeTimeLevyArea(AbstractSpaceTimeLevyArea):
K: eqx.AbstractVar[PyTree[Array]]
K: eqx.AbstractVar[BM]


class BrownianIncrement(AbstractBrownianIncrement):
dt: PyTree[FloatScalarLike]
W: PyTree[Array]
dt: PyTree[FloatScalarLike, "BM"]
W: BM


class SpaceTimeLevyArea(AbstractSpaceTimeLevyArea):
dt: PyTree[FloatScalarLike]
W: PyTree[Array]
H: PyTree[Array]
dt: PyTree[FloatScalarLike, "BM"]
W: BM
H: BM


class SpaceTimeTimeLevyArea(AbstractSpaceTimeTimeLevyArea):
dt: PyTree[FloatScalarLike]
W: PyTree[Array]
H: PyTree[Array]
K: PyTree[Array]
dt: PyTree[FloatScalarLike, "BM"]
W: BM
H: BM
K: BM


def levy_tree_transpose(
Expand Down

0 comments on commit 4f36bb3

Please sign in to comment.