Skip to content

Commit

Permalink
added ConvProd2DV2 that uses ProductsLayer to reduce the amount of re…
Browse files Browse the repository at this point in the history
…dundant computation
  • Loading branch information
jostosh committed Jul 2, 2018
1 parent 827d278 commit 40e2369
Show file tree
Hide file tree
Showing 5 changed files with 683 additions and 24 deletions.
4 changes: 2 additions & 2 deletions libspn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from libspn.graph.permproducts import PermProducts
from libspn.graph.products import Products
from libspn.graph.productslayer import ProductsLayer
from libspn.graph.convprod2d import ConvProd2D
from libspn.graph.convprod2d import ConvProd2D, ConvProd2DV2
from libspn.graph.weights import Weights
from libspn.graph.weights import assign_weights
from libspn.graph.weights import initialize_weights
Expand Down Expand Up @@ -117,7 +117,7 @@
'Scope', 'Input', 'Node', 'ParamNode', 'OpNode', 'VarNode',
'Concat', 'IVs', 'ContVars',
'Sum', 'ParSums', 'Sums', 'SumsLayer', 'ConvSum',
'Product', 'PermProducts', 'Products', 'ProductsLayer', 'ConvProd2D',
'Product', 'PermProducts', 'Products', 'ProductsLayer', 'ConvProd2D', 'ConvProd2DV2',
'GaussianLeaf',
'Weights', 'assign_weights', 'initialize_weights',
'serialize_graph', 'deserialize_graph',
Expand Down
25 changes: 18 additions & 7 deletions libspn/generation/spatial.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from collections import defaultdict
from libspn.graph.convprod2d import ConvProd2D
from libspn.graph.convprod2d import ConvProd2D, ConvProd2DV2
from libspn.graph.convsum import ConvSum
from libspn.graph.localsum import LocalSum
from libspn.graph.concat import Concat
from libspn.exceptions import StructureError


class ConvSPN:

def __init__(self):
def __init__(self, convprod_version='v1'):
self.level_at = 0
self.nodes_per_level = defaultdict(list)
self.last_nodes = None
self.node_level = dict()
self._convprod_version = convprod_version

def add_dilate_stride(
self, *input_nodes, kernel_size=2, strides=(1, 4), dilation_rate=(2, 1),
Expand Down Expand Up @@ -64,11 +66,20 @@ def add_stack(
strides, dilation_rate, kernel_size, prod_num_channels, padding_algorithm,
pad_left, pad_right, pad_top, pad_bottom, sum_num_channels, name_prefixes,
name_suffixes, sum_node_type):
next_node = ConvProd2D(
*input_nodes, grid_dim_sizes=spatial_dims, pad_bottom=pad_b, pad_top=pad_t,
pad_left=pad_l, pad_right=pad_r, num_channels=prod_nc,
name="{}Prod{}".format(name_pref, name_suff), dilation_rate=dilation_r,
kernel_size=kernel_s, padding_algorithm=pad_algo, strides=stride)
if self._convprod_version == 'v2':
if len(*input_nodes) > 1:
input_concat = Concat(*input_nodes, axis=3)
next_node = ConvProd2DV2(
input_concat, grid_dim_sizes=spatial_dims, pad_bottom=pad_b, pad_top=pad_t,
pad_left=pad_l, pad_right=pad_r, num_channels=prod_nc,
name="{}Prod{}".format(name_pref, name_suff), dilation_rate=dilation_r,
kernel_size=kernel_s, padding_algorithm=pad_algo, strides=stride)
else:
next_node = ConvProd2D(
*input_nodes, grid_dim_sizes=spatial_dims, pad_bottom=pad_b, pad_top=pad_t,
pad_left=pad_l, pad_right=pad_r, num_channels=prod_nc,
name="{}Prod{}".format(name_pref, name_suff), dilation_rate=dilation_r,
kernel_size=kernel_s, padding_algorithm=pad_algo, strides=stride)
spatial_dims = next_node.output_shape_spatial[:2]
input_nodes = [next_node]
print("Built node {}: {} x {} x {}".format(next_node, *next_node.output_shape_spatial))
Expand Down
60 changes: 57 additions & 3 deletions libspn/graph/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
from libspn.inference.type import InferenceType
from libspn.exceptions import StructureError
from libspn.utils.serialization import register_serializable
from libspn.graph.convsum import ConvSum
from libspn.graph.localsum import LocalSum
from libspn.graph.convprod2d import ConvProd2D, ConvProd2DV2
import tensorflow as tf
import numpy as np


@register_serializable
Expand All @@ -23,9 +28,10 @@ class Concat(OpNode):
name (str): Name of the node.
"""

def __init__(self, *inputs, name="Concat"):
def __init__(self, *inputs, name="Concat", axis=1):
super().__init__(InferenceType.MARGINAL, name)
self.set_inputs(*inputs)
self._axis = axis

def serialize(self):
data = super().serialize()
Expand Down Expand Up @@ -99,8 +105,45 @@ def _compute_value(self, *input_tensors):
raise StructureError("%s is missing inputs." % self)
# Concatenate inputs
input_tensors = self._gather_input_tensors(*input_tensors)
return utils.concat_maybe(input_tensors, 1)
input_shapes = self._gather_input_shapes()
reshaped_tensors = [tf.reshape(t, (-1,) + s) for t, s in zip(input_tensors, input_shapes)]
out = utils.concat_maybe(reshaped_tensors, axis=self._axis)
if self.is_spatial:
out = tf.reshape(out, (-1, int(np.prod(self.output_shape_spatial))))
return out

@property
def output_shape_spatial(self):
if self._axis != 3:
raise AttributeError("Requested spatial output shape of a Concat node that is "
"not spatial.")
shapes = self._gather_input_shapes()
concat_axis_sum = sum(s[self._axis - 1] for s in shapes)
return shapes[0][:self._axis-1] + (concat_axis_sum,)

def _gather_input_shapes(self):
shapes = []
for inp in self.inputs:
if isinstance(inp.node, (ConvProd2D, ConvProd2DV2, ConvSum, LocalSum)):
shapes.append(inp.node.output_shape_spatial)
else:
shapes.append((inp.node.get_out_size(),))

if any(len(shapes[0]) != len(s) for s in shapes):
raise StructureError("All shapes must be of same dimension, now have: {}".format(
[len(s) for s in shapes]
))
if any(shapes[0][:self._axis - 1] != s[:self._axis - 1] for s in shapes):
raise StructureError("All non-concatenation axes must be identical.")
return shapes

def _num_channels_per_input(self):
if not self.is_spatial:
raise AttributeError("Requested number of channels per input while this Concat node "
"is not spatial.")
shapes = self._gather_input_shapes()
return [s[self._axis - 1] for s in shapes]

@utils.docinherit(OpNode)
def _compute_log_value(self, *input_tensors):
return self._compute_value(*input_tensors)
Expand All @@ -112,14 +155,25 @@ def _compute_mpe_value(self, *input_tensors):
@utils.docinherit(OpNode)
def _compute_log_mpe_value(self, *input_tensors):
return self._compute_value(*input_tensors)

@property
def is_spatial(self):
return self._axis == 3

def _compute_mpe_path(self, counts, *input_values, add_random=False, use_unweighted=False):
# Check inputs
if not self._inputs:
raise StructureError("%s is missing inputs." % self)
# Split counts for each input
input_sizes = self.get_input_sizes(*input_values)
split = utils.split_maybe(counts, input_sizes, 1)
# input_shapes = self._gather_input_shapes()
if self.is_spatial:
input_shapes = self._gather_input_shapes()
counts = tf.reshape(counts, (-1,) + self.output_shape_spatial)
split = utils.split_maybe(counts, self._num_channels_per_input(), axis=self._axis)
split = [tf.reshape(t, (-1, int(np.prod(s)))) for t, s in zip(split, input_shapes)]
else:
split = utils.split_maybe(counts, input_sizes, 1)
return self._scatter_to_input_tensors(*[(t, v) for t, v in
zip(split, input_values)])

Expand Down

0 comments on commit 40e2369

Please sign in to comment.