Skip to content

Latest commit



385 lines (287 loc) · 15.1 KB


File metadata and controls

385 lines (287 loc) · 15.1 KB

Extending SPINS

Here we discuss how to add custom nodes, actions, and flows.

Custom Nodes

To create a custom node,

  • Create a class that inherits from goos.ProblemGraphNode
  • Add a node_type class field that uniquely identifies the class.
  • Add a type-annotated constructor that marks node inputs.
  • Implement eval and grad.
  • For performance gains, implement eval_const_flags.

A basic example of a node that doubles the value of its input is shown below:

class DoubleNode(goos.Function):
  node_type = "mycustomnodes.double_node"

  def __init__(self, node: goos.Function) -> None:

  def eval(self, input_vals: List[goos.NumericFlow]) -> goos.NumericFlow:
    return goos.NumericFlow(input_vals[0].array * 2)

  def grad(self, input_vals: List[goos.NumericFlow],
           grad_val: goos.NumericFlow.Grad) -> goos.NumericFlow.Grad:
    return goos.NumericFlow.Grad(grad_val.array_grad * 2)

This custom node inherits from goos.Function, which in turn inherits from goos.ProblemGraphNode. By inheriting from goos.Function, we enable users of our node to perform other numeric operations on it:

var = goos.Variable(3)
node = DoubleNode(var)
# The following line is only possible because we inherit from
# `goos.Function` rather than `goos.ProblemGraphNode` directly.
obj = node + 3

Next, we named our node "mycustomnodes.double_node". This needs to be a unique name across all possible nodes. We therefore recommend the format "your_own_unique_identifier.node_name". This name is used to serialize and deserialize the node from disk.

The node constructor is defined and properly type-annotated. The type annotations are used internally to construct a schema for the node, which is used to validate its inputs. For example, because the sole argument node is defined as a goos.Function, the following will raise an error:

# Raises an error because `3` is not a `goos.Function`!
node = DoubleNode(3)

# Instead, we need to wrap the 3 as a constant.
node = DoubleNode(goos.Constant(3))

You may find it annoying for users to have to wrap a constant input as a goos.Constant function node. For the moment, we will live with this fact; we will come back to the topic of usability later.

Because the inputs must be serializable, only certain type annotations are supported at the moment:

  • The native types int, float, str, bool, and complex
  • numpy.ndarray
  • The composite types List and Union
  • Any goos.ProblemGraphNode
  • Any goos.Model

In the constructor, we call __init__ with a single argument that marks all the problem graph node dependencies. The flows generated by all the nodes given as an argument to __init__ will be passed to eval and grad as input_vals. If there is more than one dependency, the order in which they are passed to __init__ indcates the order that they are passed to eval and grad:

class SumAndDoubleNode(goos.Function):
   node_type = "mycustomnodes.sum_and_double_node"

   def __init__(self, node1: goos.Function, node2: goos.Function) -> None:
     super().__init__([node1, node2])

   def eval(self, input_vals: List[goos.NumericFlow]) -> goos.NumericFlow:
     # `input_vals` contains flows for `node1` followed by `node2`.
     node1_val = input_vals[0].array
     node2_val = input_vals[1].array

Finally, we implement the node logic by defining eval and grad. eval is called to evaluate the function and grad is called to evaluate the gradient. Specifically, eval accepts a list of input flows from the nodes marked as dependencies in the constructor and must produce a single flow as output. grad accepts a list of input flows as well as the current backward gradient value and produces the corresponding gradient flow.

Implementing eval and grad

eval and grad form the backbone of the backpropr algorithm to automatically compute objective function values and their gradients. Specifically, if the objective function is given by f, then the grad function for a node g with inputs x_1, x_2, \cdots, x_n should compute the partial derivatives \frac{df}{dx_1}, \frac{df}{dx_2}, \cdots, \frac{df}{dx_n}. The partial derivative \frac{df}{dg} is given as the second argument to grad.

For example, suppose we have a node that takes two inputs and implements the function g(x, y) = x \cdot (y + 1):

class NodeG(goos.Function):
   node_type = "mycustomnodes.node_g"

   def __init__(self, node1: goos.Function, node2: goos.Function) -> None:
     super().__init__([node1, node2])

   def eval(self, input_vals: List[goos.NumericFlow]) -> goos.NumericFlow:
     x = input_vals[0].array
     y = input_vals[1].array

     return goos.NumericFlow(x * (y + 1))

The grad function needs to compute \frac{dg}{dx} = y + 1 and \frac{dg}{dy} = x given \frac{df}{dg}, which is passed as the second argument in grad:

class NodeG(goos.Function):

   def grad(self, input_vals: List[goos.NumericFlow],
                  grad_val: goos.NumericFlow.Grad)
                  -> List[goos.NumericFlow.Grad]:
     x = input_vals[0].array
     y = input_vals[1].array

     df_dx = (y + 1) * grad_val.array_grad
     df_dy = x * grad_val.array_grad

     return [goos.NumericFlow.Grad(dg_dx), goos.NumericFlow.Grad(dg_dy)]

In order to ensure the correctness and reproduciblity of SPINS, the following rules must hold true for eval and grad:

  • Flows should NOT be modified. If you want to modify a flow, make a copy first.
  • Flow values should only depend on values computed from the input flows or parameters passed in through the constructor.

Note that in SPINS, the flow system uses duck typing: Anything that has the appropriate properties of a flow is considered a flow of that type. Furthermore, a flow type may be considered as more than one type of flow. For example, the PixelatedContShapeFlow can be considered a NumericFlow because it has an array property as well as a ShapeFlow as it has all the requisite ShapeFlow properties. Consequently, for single input, single output nodes, it may be advisable to clone the flow rather than creating a new one:

def eval(self, input_vals: List[goos.NumericFlow]) -> goos.NumericFlow:
  out_flow = copy.deepycopy(input_vals[0])
  out_flow.array = ...

  return out_flow

This way, this node can be used for NumericFlow and PixelatedContShapeFLow: If a NumericFlow is passed as input, then the output is a NumericFlow. If a PixelatedContShapeFlow is passed as input, then the output is a PixelatedContShapeFlow.


Because nodes must be serializable, we cannot pass arbitrary objects into the constructor of a node. However, it may be beneficial to pass more complex data objects. Currently, the mechanism for doing this is through the schematics Python library. For convenience, we have aliased schematics.models.Model to goos.Model and schematics.types to goos.types and slightly modified the functionality.

For convenience, we have defined a complex number type and a NumPy array type (see goos.optplan.schema_types). We also have implemented a few common schema types in goos.common_schemas.


Sometimes, it may be clumsy to define a node directly in code. In the above DoubleNode example, for instance, a user must wrap a constant number like 3 as a goos.Constant before passing it into DoubleNode. To mitigate these convenience and usability issues, we recommend defining functions that create nodes on behalf of the user. For example,

 def double_node(node: Union[goos.Function, float]):
   if not isinstance(node, goos.Function):
     node = goos.Constant(node)
   return DoubleNode(node)

node = double_node(3)
node2 = double_node(goos.Variable(5))

You see this kind of node creation functions throughout the codebase in order to simplify node creation.

Custom Actions

A custom action must do the following:

  • Inherit from goos.Action
  • Add a node_type class field that uniquely identifies the class.
  • Add a type-annotated constructor that marks node inputs.
  • Inherit run method that accepts an optimization plan as an argument.

From a structural point of view, defining an action is similar to defining a node except that an action inherits from goos.Action instead of goos.ProblemGraphNode and that an action implements run instead of eval and grad. The above discussion about defining node types and a type-annotated constructor remains the same with the following main exception: It is unnecessary to declare any dependencies through super().__init__. Below we show an example of an action that adds one to a variable.

class AddOne(goos.Action):
   node_type = "myactions.add_one"

   def __init__(self, var: goos.Variable) -> None:
     self._var = var

   def run(self, plan: goos.OptimizationPlan) -> None:
     val = plan.get_var_value(self._var)
     plan.set_var_value(self._var, val + 1)

The run method accepts an optimization plan as input and changes the plan state. However, the run method in the following ways:

  • Call eval_nodes and eval_grad to evaluate the values and gradients of nodes. A plan should NOT call node.get() and node.get_grad().
  • Call set_var_value and get_var_value to get and set the values of a node. Again, this should be done in lieu of node.get(). Note that setting the value of a frozen variable may raise an exception.
  • Call set_var_bounds and get_var_bounds to change the bounds of a variable.

As a general rule of thumb, an action may always request information about the plan state but may not be able to change the state. The run method should NOT add or remove nodes from the graph as this has the potential to break the reproducibility of the system. The run method should also NOT call or any method that would invoke (e.g. node.get(run=True)).


As with nodes, we recommend defining creation function for actions. A typical creation function is as follows:

def add_one(*args, **kwargs) -> AddOne:
   action = AddOne(*args, **kwargs)
   return action

In this case, this function simply forwards all the arguments to the action class though more preprocessing can be done. Additionally, the action is automatically added to the default plan, obviated the need for the user to explicitly call add_action.

Custom Flows

You may wish to define a custom flow if none of the existing flows capture the correct description of the object you wish to define. This is often true for new descriptions of shapes.

Flows must follow a few rules:

  • Must inherit from goos.Flow.
  • Must be pickable (or more specifically, dillable).
  • Must provide a constructor that accepts no arguments. Note that this implies that all fields should have default values.
  • Must define an inner class called Grad. This is automatically generated if not explicitly defined.
  • Must define an inner class called ConstFlags. This is automatically generated if not explicitly defined.

Flows can do the following though:

  • Inherit from more than one type of flow, though you should carefully consider the ramifications.
  • Contain other flows.

To define a flow, simply create a class that inherits from Flow:

class MyFlow(goos.Flow):
    myfield: bool = False
    myfield2: np.ndarray = goos.np_zero_field(3)
    myfield3: float = 3

By default, flows are converted into Python dataclasses and thus the dataclass syntax can be used. Default values are provided so that the automatically defined constructor requires no arguments. goos.np_zero_field is a utility function that creates a field with numpy array zeros (it is short for dataclasses.field(default=factory=lambda: np.zeros(n)). This is necessary because best coding practices dictate that we should not default initialize with an object.

This flow can now be used like so:

flow = MyFlow()
flow.myfield = True

flow = MyFlow(myfield2=np.array([3,4,5]))

Because we did not explicitly define a Grad and ConstFlags class, they were automatically generated. In auto-generated Grad classes, every numeric field with name fieldname will have an associated field fieldname_grad in the Grad class. In auto-generated ConstFlags, all the fields in the flow will exist in ConstFlags except that they all become booleans:

grad_flow = MyFlow.Grad()
grad_flow.myfield2_grad = np.array([3,4,5])

const_flags = MyFlow.ConstFlags(myfield=False, myfield2=True)

Note that automatic generation works assuming that the dataflow model is used for the class. If you choose not to define flow fields this way, you should declare your own Grad and ConstFlags.

Gradient Flow

Every flow has an associated gradient flow that represents the flow containing gradient information. Consequently, during forward evaluation of the nodes, flows are passed as inputs whereas during the backward evaluation, gradient flows are passed as inputs. The gradient flow associated with a flow is simply the flow name plus .Grad. For example, a flow called MyFlow would have a gradient flow named MyFlow.Grad.

By default, if no inner Grad class is defined, a gradient flow class will be automatically constructed based on the defined fields of the Flow. Note that the gradient flow class autogeneration assumes that the Flow operates as a normal dataclass. Therefore, if you do not rely on the dataclass operation of a Flow, you should define your own Grad class.

Constant Flags

In order to optimize evaluation of the computational graph, additional flags known as const flags for each input are passed to eval and grad. For example, a simulation node may use the fact that a Shape is constant to speed up the process of drawing the permittivity distribution. Specifically, every flow must have a ConstFlags inner class. It is automatically generated if not defined. This class has a field for every non-constant field of the flow.

The const flags are used in the following ways:

  • Marking constant flow fields. Constant flow fields are those that cannot change (i.e. do not depend in any way on a Variable).
  • Marking frozen flow fields. Frozen flow fields are those that do not depend on any thawed Variable.

Thus, by definition, all constant flow fields are also frozen flow fields, but frozen flow fields need not be constant. During function or gradient evaluation, the constant flow fields and frozen flow fields are computed and stored in a separate instance of ConstFlags. In other words, multiple ConstFlags classes will be instantiated but will server different purposes.