# Discrete distributions

In [156]:
from IPython import get_ipython
if get_ipython():
    get_ipython().run_line_magic("load_ext", "autoreload")
    get_ipython().run_line_magic("autoreload", "2")

import numpy as np
import pandas as pd
import torch
import math

import xarray as xr

import matplotlib.pyplot as plt
import seaborn as sns

import collections

import latenta as la
la.logger.setLevel("INFO")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [157]:
cells = la.Dim(pd.Series(range(3), name = "cell").astype(str))
genes = la.Dim(pd.Series(range(5), name = "gene").astype(str))
celltypes = la.Dim(pd.Series(range(4), name = "celltype").astype(str))

In [158]:
probs_value = pd.DataFrame(0.1, index = cells.index, columns = celltypes.index)
probs_value = probs_value / probs_value.values.sum(0, keepdims = True)

In [159]:
dist = la.distributions.OneHotCategorical(0.1, definition = la.Definition([cells, celltypes]))

In [160]:
dist.reset_recursive()
dist.run_recursive()
dist.value

tensor([[0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [0., 1., 0., 0.]])

Different component dim (not the end)

In [161]:
probs_value = pd.DataFrame(0.1, index = cells.index, columns = celltypes.index)
probs_value = probs_value / probs_value.values.sum(0, keepdims = True)

In [162]:
dist = la.distributions.OneHotCategorical(0.1, definition = la.Definition([cells, celltypes]), component_dim=cells)

In [163]:
dist.reset_recursive()
dist.run_recursive()
dist.value

tensor([[0., 1., 1., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.]])

More than 2 dimensions

In [164]:
probs_value = np.ones((cells.size, celltypes.size, genes.size))
probs_value = probs_value / probs_value.sum(0, keepdims = True)

In [165]:
dist = la.distributions.OneHotCategorical(0.1, definition = la.Definition([cells, genes, celltypes]), component_dim=cells)

In [178]:
dist.reset_recursive()
dist.run_recursive()
print(dist.value)
print(dist.likelihood)

tensor([[[0., 1., 0., 0.],
         [0., 1., 0., 0.],
         [0., 1., 0., 1.],
         [0., 1., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 1.],
         [1., 0., 1., 1.],
         [1., 0., 0., 0.],
         [1., 0., 1., 1.],
         [1., 1., 0., 0.]],

        [[1., 0., 1., 0.],
         [0., 0., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 0.],
         [0., 0., 1., 1.]]])
tensor([[-2.3026, -2.3026, -2.3026, -2.3026],
        [-2.3026, -2.3026, -2.3026, -2.3026],
        [-2.3026, -2.3026, -2.3026, -2.3026],
        [-2.3026, -2.3026, -2.3026, -2.3026],
        [-2.3026, -2.3026, -2.3026, -2.3026]])
