# Extending our work

In [12]:
from genjax import vi

This notebook illustrates a number of ways to build upon our work and implementation. Here are a few:

* (**Extending ADEV, the automatic differentiation algorithm, with new samplers equipped with gradient strategies.**) After implementing the ADEV interfaces for these objects, they can be freely lifted into the `Distribution` type of our language, and can be used in modeling and guide code. We illustrate this process by implementing `beta_implicit`, and using it in a model and guide program from the Pyro tutorials.
* (**Using a standard loss function (like `genjax.vi.elbo`) with new models and guides.**) By virtue of the programmability of our system, this is a standard means of extending our work. This extension is covered in the tutorial for the first case, above.
* (**Implementing new loss functions, by utilizing the modeling interfaces in our language.**) We illustrate this process by implementing [SDOS](https://arxiv.org/abs/2103.01030), an estimator for a symmetric KL divergence, using our language and automated the derivation of gradients for a guide program.

We cover each of these possible extensions in turn below.

## Implementing new samplers for ADEV

ADEV is an extensible AD algorithm: users can implement new samplers equipped with gradient strategies, and use them in ADEV programs.

In [13]:
from adevjax import ADEVPrimitive

### Implementing a `beta_implicit` sampler

In [ADEV appendix B.7](https://arxiv.org/pdf/2212.06386.pdf), the author's outline a gradient strategy for distribution samplers when the CDF is available. In the literature, this is called implicit differentiation.

Several libraries take advantage of this strategy already, including the `distributions` module of [TensorFlow Probability](https://www.tensorflow.org/probability/api_docs/python/tfp/distributions), for distributions like [Beta](https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/Beta).

Our system enables extenders to take advantage of these strategies directly (when 3rd party libraries _already implement differentiation rules_ via JAX's native JVP rule system).

In [9]:
# Defining a new primitive.
@dataclass
class BetaIMPLICIT(ADEVPrimitive):
    # `flatten` is a method which is required to register this type as
    # a JAX PyTree type.
    def flatten(self):
        return (), ()

    # New primitives require a `sample` implementation, whose signature is:
    # sample(self, key: PRNGKey, *args)
    # where `PRNGKey` is the type of JAX PRNG keys.
    def sample(self, key, alpha, beta):
        v = tfd.Beta(concentration1=alpha, concentration0=beta).sample(seed=key)
        return v

    # New primitives require an implementation for their gradient strategy
    # in the `jvp_estimate` method.
    #
    # This method is called by the ADEV interpreter, and gets access
    # to primals, tangents, and two continuations for the rest of the computation.
    def jvp_estimate(self, key, primals, tangents, konts):
        kpure, kdual = konts

        # Because TFP already overloads their Beta sampler with implicit
        # differentiation rules for JVP, we directly utilize their rules.
        def _inner(alpha, beta):
            # Invoking TFP's Implicit reparametrization:
            # https://github.com/tensorflow/probability/blob/v0.23.0/tensorflow_probability/python/distributions/beta.py#L292-L306
            x = tfd.Beta(concentration1=alpha, concentration0=beta).sample(seed=key)
            return x

        # We invoke JAX's JVP (which utilizes TFP's registered implicit differentiation
        # rule for Beta) to get a primal and tanget out.
        primal_out, tangent_out = jax.jvp(_inner, primals, tangents)

        # Then, we give the result to the ADEV dual'd continuation, to continue
        # ADEV's forward mode.
        return kdual((primal_out,), (tangent_out,))

# Creating an instance, to be exported and used as a sampler.
beta_implicit = BetaIMPLICIT()

Now, with a new ADEV sampler in hand, we lift it to a `genjax.vi.ADEVDistribution` - a type of distribution which provides compatibility with Gen's generative computations.

In [15]:
beta_implicit = vi.ADEVDistribution(
    beta_implicit, 
    lambda v: v
)

This object can now be used in guide code, as part of variational inference learning.

## Implementing new models and guides

In this section, we'll illustrate how to use our system with new model and guide programs. We'll directly use our `beta_implicit` from above to implement a tutorial from Pyro's documentation.

## Implementing new loss functions