Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numba speedup for wiring + log potentials #133

Merged
merged 19 commits into from Apr 8, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/pmp_binary_deconvolution.py
Expand Up @@ -222,7 +222,7 @@ def plot_images(images, display=True, nr=None):

# %%
pW = 0.25
pS = 1e-72
pS = 1e-100
pX = 1e-100

# Sparsity inducing priors for W and S
Expand All @@ -240,7 +240,7 @@ def plot_images(images, display=True, nr=None):
# We draw a batch of samples from the posterior in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap`

# %%
np.random.seed(seed=42)
np.random.seed(seed=40)
n_samples = 4

bp_arrays = jax.vmap(functools.partial(run_bp, damping=0.5), in_axes=0, out_axes=0)(
Expand Down
4 changes: 2 additions & 2 deletions examples/rbm.py
Expand Up @@ -14,7 +14,7 @@
# ---

# %% [markdown]
# [Restricted Boltzmann Machine (RBM)](https://en.wikipedia.org/wiki/Restricted_Boltzmann_machine) is a well-known and widely used PGM for learning probabilistic distributions over binary data. we demonstrate how we can easily implement [perturb-and-max-product (PMP)](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) sampling from an RBM trained on MNIST digits using PGMax. PMP is a recently proposed method for approximately sampling from a PGM by computing the maximum-a-posteriori (MAP) configuration (using max-product LBP) of a perturbed version of the model.
# [Restricted Boltzmann Machine (RBM)](https://en.wikipedia.org/wiki/Restricted_Boltzmann_machine) is a well-known and widely used PGM for learning probabilistic distributions over binary data. We demonstrate how we can easily implement [perturb-and-max-product (PMP)](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) sampling from an RBM trained on MNIST digits using PGMax. PMP is a recently proposed method for approximately sampling from a PGM by computing the maximum-a-posteriori (MAP) configuration (using max-product LBP) of a perturbed version of the model.
#
# We start by making some necessary imports.

Expand Down Expand Up @@ -56,7 +56,7 @@
# %% [markdown]
# [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup) (e.g. an [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)), or a list/dictionary of [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed.
#
# After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph). We can add the unary and pairwise factors one at a time to `fg` by
# After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph). We can add the unary and pairwise factors by grouping them using

# %%
# Add unary factors
Expand Down
219 changes: 122 additions & 97 deletions pgmax/factors/enumeration.py
Expand Up @@ -6,6 +6,7 @@

import jax
import jax.numpy as jnp
import numba as nb
import numpy as np

from pgmax.bp import bp_utils
Expand All @@ -23,106 +24,81 @@ class EnumerationWiring(nodes.Wiring):
factor_configs_edge_states[ii, 0] contains the global EnumerationFactor config index,
factor_configs_edge_states[ii, 1] contains the corresponding global edge_state index.
Both indices only take into account the EnumerationFactors of the FactorGraph

Attributes:
num_val_configs: Number of valid configurations for this wiring
"""

factor_configs_edge_states: Union[np.ndarray, jnp.ndarray]

@property
def inference_arguments(self) -> Mapping[str, Union[np.ndarray, int]]:
"""
Returns:
A dictionnary of elements used to run belief propagation.
"""
def __post_init__(self):
super().__post_init__()

if self.factor_configs_edge_states.shape[0] == 0:
num_val_configs = 0
else:
num_val_configs = int(self.factor_configs_edge_states[-1, 0]) + 1

return {
"factor_configs_edge_states": self.factor_configs_edge_states,
"num_val_configs": num_val_configs,
}
object.__setattr__(self, "num_val_configs", num_val_configs)


@dataclass(frozen=True, eq=False)
class EnumerationFactor(nodes.Factor):
"""An enumeration factor

Args:
configs: Array of shape (num_val_configs, num_variables)
factor_configs: Array of shape (num_val_configs, num_variables)
An array containing an explicit enumeration of all valid configurations
log_potentials: Array of shape (num_val_configs,)
An array containing the log of the potential value for each valid configuration

Raises:
ValueError: If:
(1) The dtype of the configs array is not int
(1) The dtype of the factor_configs array is not int
(2) The dtype of the potential array is not float
(3) Configs does not have the correct shape
(3) factor_configs does not have the correct shape
(4) The potential array does not have the correct shape
(5) The configs array contains invalid values
(5) The factor_configs array contains invalid values
"""

configs: np.ndarray
factor_configs: np.ndarray
log_potentials: np.ndarray

def __post_init__(self):
self.configs.flags.writeable = False
if not np.issubdtype(self.configs.dtype, np.integer):
self.factor_configs.flags.writeable = False
if not np.issubdtype(self.factor_configs.dtype, np.integer):
raise ValueError(
f"Configurations should be integers. Got {self.configs.dtype}."
f"Configurations should be integers. Got {self.factor_configs.dtype}."
)

if not np.issubdtype(self.log_potentials.dtype, np.floating):
raise ValueError(
f"Potential should be floats. Got {self.log_potentials.dtype}."
)

if self.configs.ndim != 2:
if self.factor_configs.ndim != 2:
raise ValueError(
"configs should be a 2D array containing a list of valid configurations for "
f"EnumerationFactor. Got a configs array of shape {self.configs.shape}."
"factor_configs should be a 2D array containing a list of valid configurations for "
f"EnumerationFactor. Got a factor_configs array of shape {self.factor_configs.shape}."
)

if len(self.variables) != self.configs.shape[1]:
if len(self.variables) != self.factor_configs.shape[1]:
raise ValueError(
f"Number of variables {len(self.variables)} doesn't match given configurations {self.configs.shape}"
f"Number of variables {len(self.variables)} doesn't match given configurations {self.factor_configs.shape}"
)

if self.log_potentials.shape != (self.configs.shape[0],):
if self.log_potentials.shape != (self.factor_configs.shape[0],):
raise ValueError(
f"Expected log potentials of shape {(self.configs.shape[0],)} for "
f"({self.configs.shape[0]}) valid configurations. Got log potentials of "
f"Expected log potentials of shape {(self.factor_configs.shape[0],)} for "
f"({self.factor_configs.shape[0]}) valid configurations. Got log potentials of "
f"shape {self.log_potentials.shape}."
)

vars_num_states = np.array([variable.num_states for variable in self.variables])
if not np.logical_and(
self.configs >= 0, self.configs < vars_num_states[None]
self.factor_configs >= 0, self.factor_configs < vars_num_states[None]
).all():
raise ValueError("Invalid configurations for given variables")

def compile_wiring(
self, vars_to_starts: Mapping[nodes.Variable, int]
) -> EnumerationWiring:
"""Compile EnumerationWiring for the EnumerationFactor

Args:
vars_to_starts: A dictionary that maps variables to their global starting indices
For an n-state variable, a global start index of m means the global indices
of its n variable states are m, m + 1, ..., m + n - 1

Returns:
EnumerationWiring for the EnumerationFactor
"""
return compile_enumeration_wiring(
factor_edges_num_states=self.edges_num_states,
variables_for_factors=tuple([self.variables]),
factor_configs=self.configs,
vars_to_starts=vars_to_starts,
num_factors=1,
)

@staticmethod
def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiring:
"""Concatenate a list of EnumerationWirings
Expand Down Expand Up @@ -176,60 +152,109 @@ def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiri
),
)

@staticmethod
def compile_wiring(
factor_edges_num_states: np.ndarray,
variables_for_factors: Tuple[nodes.Variable, ...],
factor_configs: np.ndarray,
vars_to_starts: Mapping[nodes.Variable, int],
num_factors: int,
) -> EnumerationWiring:
"""Compile an EnumerationWiring for an EnumerationFactor or a FactorGroup with EnumerationFactors.
Internally calls _compile_var_states_numba and _compile_enumeration_wiring_numba for speed.

def compile_enumeration_wiring(
factor_edges_num_states: np.ndarray,
variables_for_factors: Tuple[Tuple[nodes.Variable, ...], ...],
factor_configs: np.ndarray,
vars_to_starts: Mapping[nodes.Variable, int],
num_factors: int,
) -> EnumerationWiring:
"""Compile an EnumerationWiring for an EnumerationFactor or a FactorGroup with EnumerationFactors.
Args:
factor_edges_num_states: An array concatenating the number of states for the variables connected to each
Factor of the FactorGroup. Each variable will appear once for each Factor it connects to.
variables_for_factors: A tuple of tuples containing variables connected to each Factor of the FactorGroup.
Each variable will appear once for each Factor it connects to.
factor_configs: Array of shape (num_val_configs, num_variables) containing an explicit enumeration
of all valid configurations.
vars_to_starts: A dictionary that maps variables to their global starting indices
For an n-state variable, a global start index of m means the global indices
of its n variable states are m, m + 1, ..., m + n - 1
num_factors: Number of Factors in the FactorGroup.

Args:
factor_edges_num_states: An array concatenating the number of states for the variables connected to each
Factor of the FactorGroup. Each variable will appear once for each Factor it connects to.
variables_for_factors: A tuple of tuples containing variables connected to each Factor of the FactorGroup.
Each variable will appear once for each Factor it connects to.
factor_configs: Array of shape (num_val_configs, num_variables) containing an explicit enumeration
of all valid configurations.
vars_to_starts: A dictionary that maps variables to their global starting indices
For an n-state variable, a global start index of m means the global indices
of its n variable states are m, m + 1, ..., m + n - 1
num_factors: Number of Factors in the FactorGroup.
Returns:
The EnumerationWiring
"""
var_states = np.array(
[vars_to_starts[variable] for variable in variables_for_factors]
)
num_states = np.array(
[variable.num_states for variable in variables_for_factors]
)
num_states_cumsum = np.insert(np.cumsum(num_states), 0, 0)
var_states_for_edges = np.empty(shape=(num_states_cumsum[-1],), dtype=int)
_compile_var_states_numba(var_states_for_edges, num_states_cumsum, var_states)

Returns:
The EnumerationWiring
num_configs, num_variables = factor_configs.shape
factor_configs_edge_states = np.empty(
(num_factors * num_configs * num_variables, 2), dtype=int
)
assert factor_edges_num_states.shape[0] == num_factors * num_variables
antoine-dedieu marked this conversation as resolved.
Show resolved Hide resolved
factor_edges_starts = np.insert(np.cumsum(factor_edges_num_states), 0, 0)
_compile_enumeration_wiring_numba(
factor_configs_edge_states, factor_configs, factor_edges_starts, num_factors
)

return EnumerationWiring(
edges_num_states=factor_edges_num_states,
var_states_for_edges=var_states_for_edges,
factor_configs_edge_states=factor_configs_edge_states,
)


@nb.jit(parallel=False, cache=True, fastmath=True, nopython=True)
def _compile_var_states_numba(
Copy link
Contributor

@wlehrach wlehrach Apr 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you make the caller allocate these arrays? In general t's cleaner and less likely to result in error to allocate return arrays inside numba rather that mutating a passed in array. You can get a very small optimization by re-using arrays between calls (so highly performance sensitive code it can be useful), but you're not doing that here. You can refer to dtype of incoming arrays as well and copy that.

var_states_for_edges: np.ndarray,
num_states_cumsum: np.ndarray,
var_states: np.ndarray,
) -> np.ndarray:
"""Fast numba computation of the var_states_for_edges of a Wiring.
var_states_for_edges is updated in-place.
"""
var_states_for_edges = []
for variables_for_factor in variables_for_factors:
for variable in variables_for_factor:
num_states = variable.num_states
this_var_states_for_edges = np.arange(
vars_to_starts[variable], vars_to_starts[variable] + num_states
)
var_states_for_edges.append(this_var_states_for_edges)

# Note: edges_starts corresponds to the factor_to_msgs_start for the LogicalFactors
edges_starts = np.insert(factor_edges_num_states.cumsum(), 0, 0)[:-1].reshape(
-1, factor_configs.shape[1]
)
for variable_idx in nb.prange(num_states_cumsum.shape[0] - 1):
start_variable, end_variable = (
num_states_cumsum[variable_idx],
num_states_cumsum[variable_idx + 1],
)
var_states_for_edges[start_variable:end_variable] = var_states[
variable_idx
] + np.arange(end_variable - start_variable)

factor_configs_edge_states = np.stack(
[
np.repeat(
np.arange(factor_configs.shape[0] * num_factors),
factor_configs.shape[1],
),
(factor_configs[None] + edges_starts[:, None, :]).flatten(),
],
axis=1,
)
return EnumerationWiring(
edges_num_states=factor_edges_num_states,
var_states_for_edges=np.concatenate(var_states_for_edges),
factor_configs_edge_states=factor_configs_edge_states,
)

@nb.jit(parallel=False, cache=True, fastmath=True, nopython=True)
def _compile_enumeration_wiring_numba(
factor_configs_edge_states: np.ndarray,
factor_configs: np.ndarray,
factor_edges_starts: np.ndarray,
num_factors: int,
) -> np.ndarray:
"""Fast numba computation of the factor_configs_edge_states of an EnumerationWiring.
factor_edges_starts is updated in-place.
"""

num_configs, num_variables = factor_configs.shape

for factor_idx in nb.prange(num_factors):
for config_idx in range(num_configs):
factor_config_idx = num_configs * factor_idx + config_idx
factor_configs_edge_states[
num_variables
* factor_config_idx : num_variables
* (factor_config_idx + 1),
0,
] = factor_config_idx

for var_idx in range(num_variables):
factor_configs_edge_states[
num_variables * factor_config_idx + var_idx, 1
] = (
factor_edges_starts[num_variables * factor_idx + var_idx]
+ factor_configs[config_idx, var_idx]
)


@functools.partial(jax.jit, static_argnames=("num_val_configs", "temperature"))
Expand Down