In [7]:
%load_ext autoreload
%autoreload 2

import pytest
import pandas as pd
import math

import latenta as la

from dataclasses import dataclass
from typing import Dict, List, Any
import torch

@dataclass
class Setting:
    base: Any
    params: Dict
    tests: List

cells = la.Dim(pd.Series(range(3), name = "cell").astype(str))
genes = la.Dim(pd.Series(range(4), name = "gene").astype(str))
clusters = la.Dim(pd.Series(range(2), name = "cluster").astype(str))
axes = la.Dim(["x", "y"], id = "axis")

cell_def = la.Definition([cells])
gene_def = la.Definition([genes])

full_tests = {"init":{}, "dependent_dims":{}, "likelihood":{}, "redefine":{}, "spread":{}}
exponential_tests = {**full_tests, "icdf":{}}

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
latent = la.Latent(
    p = la.distributions.RandomWalk(
        step = la.distributions.Normal(),
        n_knots = 10
    )
)

In [13]:
design = [
    Setting(
        la.distributions.Delta,
        {"loc":la.Fixed(0., definition = cell_def)},
        {"init", "dependent_dims", "likelihood", "redefine", "spread", "icdf"}
    ),
    Setting(
        la.distributions.Normal,
        {"loc":la.Fixed(0., definition = cell_def), "scale":1.},
        {"init", "dependent_dims", "likelihood", "redefine", "spread", "icdf", "latent"}
    ),
    Setting(
        la.distributions.LogNormal,
        {"loc":la.Fixed(0., definition = cell_def), "scale":1.},
        {"init", "dependent_dims", "likelihood", "redefine", "spread", "icdf", "latent"}
    ),
    Setting(
        la.distributions.LogitNormal,
        {"loc":la.Fixed(0., definition = cell_def), "scale":1.},
        {"init", "dependent_dims", "likelihood", "redefine", "spread", "icdf", "latent"}
    ),
    Setting(
        la.distributions.CircularNormal,
        {"loc":la.Fixed(0., definition = cell_def), "scale":la.Fixed(1., definition = la.Definition([axes]))},
        {"init", "dependent_dims", "likelihood", "redefine", "spread", "icdf", "latent"}
    ),
    Setting(
        la.distributions.Uniform,
        {"low":la.Fixed(0., definition = cell_def), "high":1.},
        {"init", "dependent_dims", "likelihood", "redefine", "spread", "icdf", "latent"}
    ),
    Setting(
        la.distributions.VonMises,
        {"loc":la.Fixed(0., definition = cell_def), "concentration":1.},
        {"init", "dependent_dims", "likelihood", "redefine", "latent"}
    ),
    
    
    Setting(
        la.distributions.Beta,
        {"concentration0":la.Fixed(0.5, definition = cell_def), "concentration1":1.},
        {"init", "dependent_dims", "likelihood", "redefine", "spread", "latent"}
    ),
    Setting(
        la.distributions.Dirichlet,
        {"concentration":la.Fixed(1., definition = la.Definition([cells, clusters])), "component_dim":clusters},
        {"init", "dependent_dims", "likelihood", "redefine", "spread", "latent"}
    ),
    
    Setting(
        la.distributions.Bernouilli,
        {"probs":la.Fixed(0.5, definition = cell_def)},
        {"init", "dependent_dims", "likelihood", "redefine"}
    ),
    Setting(
        la.distributions.Binomial,
        {"total_count":la.Fixed(2., definition = cell_def), "probs":0.5},
        {"init", "dependent_dims", "likelihood", "redefine"}
    ),
    Setting(
        la.distributions.NegativeBinomial1,
        {"logits":la.Fixed(math.log(100.), definition = cell_def)},
        {"init", "dependent_dims", "likelihood", "redefine"}
    ),
    
    Setting(
        la.distributions.PairMixture,
        {"distribution0":la.distributions.Normal(), "distribution1":la.distributions.Normal(), "weight":la.Fixed(0.5, definition = cell_def)},
        {"init", "dependent_dims", "likelihood", "redefine", "icdf", "latent"}
    ),
    Setting(
        la.distributions.Mixture, {
            "distributions":{"a":la.distributions.Normal(), "b":la.distributions.Laplace(), "c":la.distributions.Normal()},
            "weight":la.Fixed(0.5, definition = cell_def)
        },
        {"init", "dependent_dims", "likelihood", "redefine", "spread", "icdf", "latent"}
    ),
    
    Setting(
        la.distributions.RandomWalk,
        {"step":la.distributions.Normal(definition = cell_def), "n_knots":10},
        {"init", "dependent_dims", "likelihood", "redefine", "icdf", "latent"}
    ),
    
    Setting(
        la.distributions.GridRandomWalk,
        {
            "steps":{"a":la.distributions.Normal(definition = cell_def), "b":la.distributions.Laplace(definition = cell_def)},
             "n_knots":[5, 10],
            "circular":[False, False]
        },
        {"init", "dependent_dims", "likelihood", "redefine", "icdf", "latent"}
    )
]


In [63]:
for setting in design:
    print(setting.base)
    dist = setting.base(**setting.params)
    dist.run_recursive()
    dist.value

    if "likelihood" in setting.tests:
        dist.likelihood

    if "dependent_dims" in setting.tests:
        dist = setting.base(**setting.params, dependent_dims = {cells})
        dist.run_recursive()
        dist.value
        dist.likelihood

    if "redefine" in setting.tests:
        dist = setting.base(**setting.params, dependent_dims = {cells})
        dist = dist.redefine(definition = dist.value_definition.clean.expand_right(genes))
        dist.run_recursive()
        dist.value
        dist.likelihood
        assert dist.value.shape == dist.value_definition.shape
        assert dist.likelihood.ndim == dist.value_definition.ndim

    if "icdf" in setting.tests:
        assert dist.has_icdf

        icdf = dist.icdf(torch.tensor(0.5))
        assert icdf.ndim == dist.ndim
        assert icdf.shape == dist.value_definition.shape

        icdf = dist.icdf(torch.tensor([0.1, 0.5]))
        assert icdf.ndim == dist.ndim + 1
        assert icdf.shape == (2, *dist.value_definition.shape)

    if "spread" in setting.tests:
        spread = dist.spread(10)
        assert spread.ndim == dist.ndim + 1

    if "latent" in setting.tests:
        latent = la.Latent(dist)
        assert dist.prior().shape == latent.q.prior().shape

<class 'latenta.distributions.delta.Delta'>
<class 'latenta.distributions.normal.Normal'>
<class 'latenta.distributions.normal.LogNormal'>
<class 'latenta.distributions.normal.LogitNormal'>
<class 'latenta.distributions.normal.CircularNormal'>
<class 'latenta.distributions.uniform.Uniform'>
<class 'latenta.distributions.vonmises.VonMises'>
<class 'latenta.distributions.beta.Beta'>
<class 'latenta.distributions.dirichlet.Dirichlet'>
<class 'latenta.distributions.bernouilli.Bernouilli'>
<class 'latenta.distributions.bernouilli.Binomial'>
<class 'latenta.distributions.negative_binomial.NegativeBinomial1'>
<class 'latenta.distributions.mixture.PairMixture'>
<class 'latenta.distributions.mixture.Mixture'>
<class 'latenta.distributions.random_walk.RandomWalk'>
<class 'latenta.distributions.random_walk.GridRandomWalk'>


In [61]:
print(setting.base)
dist = setting.base(**setting.params)
dist.run_recursive()
dist.value

if "likelihood" in setting.tests:
    dist.likelihood

if "dependent_dims" in setting.tests:
    dist = setting.base(**setting.params, dependent_dims = {cells})
    dist.run_recursive()
    dist.value
    dist.likelihood

if "redefine" in setting.tests:
    dist = setting.base(**setting.params, dependent_dims = {cells})
    dist = dist.redefine(definition = dist.value_definition.clean.expand_right(genes))
    dist.run_recursive()
    dist.value
    dist.likelihood
    assert dist.value.shape == dist.value_definition.shape
    assert dist.likelihood.ndim == dist.value_definition.ndim

if "icdf" in setting.tests:
    assert dist.has_icdf
    
    icdf = dist.icdf(torch.tensor(0.5))
    assert icdf.ndim == dist.ndim
    assert icdf.shape == dist.value_definition.shape

    icdf = dist.icdf(torch.tensor([0.1, 0.5]))
    assert icdf.ndim == dist.ndim + 1
    assert icdf.shape == (2, *dist.value_definition.shape)

if "spread" in setting.tests:
    spread = dist.spread(10)
    assert spread.ndim == dist.ndim + 1

if "latent" in setting.tests:
    latent = la.Latent(dist)
    assert dist.prior().shape == latent.q.prior().shape

<class 'latenta.distributions.random_walk.GridRandomWalk'>
