In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from genjax import Pytree
from condorgmm.condor.utils import MyPytree, find_first_above
from dataclasses import dataclass
import genjax
genjax.pretty()

In [None]:
@dataclass
class _Domain:
    values : jnp.ndarray
    _numpy_values : np.ndarray

    def __init__(self, values):
        self.values = values
        self._numpy_values = np.array(values)

    def __eq__(self, other):
        return bool(np.all(self._numpy_values == other._numpy_values))

    def __hash__(self):
        return hash(tuple(self._numpy_values))
    

@Pytree.dataclass
class Domain(MyPytree):
    _dom: _Domain = Pytree.static()

    def __init__(self, values):
        self._dom = _Domain(values)

    def __len__(self):
        return len(self._dom.values)

    @property
    def values(self):
        return self._dom.values

    @property
    def discrete_float_values(self):
        """
        A batched `FloatFromDiscreteSet` containing
        each element in this domain.
        """
        return jax.vmap(lambda idx: FloatFromDiscreteSet(idx=idx, domain=self))(
            jnp.arange(self.values.shape[0])
        )

    def first_value_above(self, val) -> "FloatFromDiscreteSet":
        """
        Return a `FloatFromDiscreteSet` for the smallest value
        greater than or equal `val` in the domain.

        If no such value exists, returns FloatFromDiscreteSet(-1, domain).
        """
        idx = find_first_above(self.values, val)
        return FloatFromDiscreteSet(idx=idx, domain=self)



Domain(jnp.array([1., 2., 3.]))

In [None]:
@Pytree.dataclass
class FloatFromDiscreteSet(MyPytree):
    idx: int
    domain: Domain = Pytree.static()

    @property
    def value(self):
        return self.domain.values[self.idx]

    @property
    def shape(self):
        return self.idx.shape

    def tile(self, *tile_args, **tile_kwargs):
        return FloatFromDiscreteSet(
            idx=jnp.tile(self.idx, *tile_args, **tile_kwargs), domain=self.domain
        )

    def __eq__(self, other):
        return self.domain == other.domain and jnp.all(
            jnp.array(self.idx) == jnp.array(other.idx)
        )

In [None]:
@genjax.Pytree.dataclass
class UniformFromDomain(MyPytree, genjax.ExactDensity):
    def sample(self, key, domain: Domain) -> FloatFromDiscreteSet:
        idx = jax.random.randint(key, (), 0, len(domain))
        return FloatFromDiscreteSet(idx=idx, domain=domain)
    
    def logpdf(self, val: FloatFromDiscreteSet, domain: Domain):
        assert val.domain == domain
        return -jnp.log(len(domain))
    
uniform_from_domain = UniformFromDomain()

In [None]:
values = jnp.array([1.0, 2.0, 3.0, 4.0])
dom = Domain(values)
dom2 = Domain(values + 1)
jax.jit(lambda x, y: x == y)(dom, dom2)

In [None]:
sample = jax.jit(uniform_from_domain.sample)(jax.random.key(0), dom)
jitted_logpdf = jax.jit(uniform_from_domain.logpdf)

In [None]:
jitted_logpdf(sample, dom2)