# Working with child classes

There can be scenarios where it makes sense to subclass different methods for a single stage. E.g. testing out dense neural network vs. a CNN which are both defined in a custom class.
Such a scenario can be handled by using a class for each method.

In [1]:
from zntrack import ZnTrackProject, config

config.nb_name = "02_PassingClasses.ipynb"

In [2]:
import os
from zntrack.utils import cwd_temp_dir

temp_dir = cwd_temp_dir()

In [3]:
project = ZnTrackProject()
project.create_dvc_repository()

## Inheriting from Base-Nodes

The next part of the documentation will show how you can pass a Python class to a Node to enable different methods.
Whilst this can be very useful it is often easier to create a Base-Node and define custom methods as subclass of this Base.


In [4]:
from zntrack import Node, zn

In [5]:
class NumberManipulationBase(Node):
    node_name = "NumberManipulationBase"
    # define the node_name for all child classes. Otherwise, child classes can coexist.
    input_number = zn.params()
    output_number = zn.outs()


class MultiplyNumber(NumberManipulationBase):
    factor = zn.params()

    def run(self):
        self.output_number = self.input_number * self.factor


class DivideNumber(NumberManipulationBase):
    divider = zn.params()

    def run(self):
        self.output_number = self.input_number / self.divider

In [6]:
MultiplyNumber(input_number=10, factor=3).write_graph(run=True)
print(MultiplyNumber.load().output_number)

DivideNumber(input_number=10, divider=2).write_graph(run=True)
print(DivideNumber.load().output_number)

Submit issues to https://github.com/zincware/ZnTrack.
30
5.0


Due to lazy-loading you might be able to access the output of `DivideNumber` also through `NumberManipulationBase` and `MultiplyNumber`.
This is only possible for shared ZnTrackOptions between the Nodes.
If you try to access e.g. the `factor` you will get an Error because `factor` is not an attribute of `DivideNumber`.

## Creating Operations

Best practice for adding different custom operations or methods is to inherit from a common parent with a method that does the computation.

In [7]:
class Base:
    def compute(self, inp):
        raise NotImplementedError

For simplicity reasons we will look at some very simple functions but they can be of arbitrary complexity.
We apply the `check_signature` decorator which is an optional check that the tests that the keyword arguments are identical to the class attribute names.
This is mandatory for ZnTrack to work in the anticipated way.

In [8]:
from zntrack.utils.decorators import check_signature


class ShiftValues(Base):
    @check_signature
    def __init__(self, shift: float):
        self.shift = shift

    def compute(self, inp):
        return inp + self.shift


class ScaleValues(Base):
    @check_signature
    def __init__(self, factor: float):
        self.factor = factor

    def compute(self, inp):
        return inp * self.factor

The actual Node makes use of the typical ZnTrack functionality beeing extended by `zn.Method()`.

In [9]:
class Calculator(Node):
    operation: Base = zn.Method()
    input_value = zn.params()
    result = zn.outs()

    def __init__(self, input_value=None, operation=None, **kwargs):
        super().__init__(**kwargs)
        self.input_value = input_value
        self.operation = operation

    def run(self):
        self.result = self.operation.compute(self.input_value)

With this definition given, we can pass an instance of our coompute classes to the Node. It will then save the state of the instances and reproduce that state in the `dvc repro`.
Let's start with a simple shift of the given input values.

In [10]:
Calculator(input_value=10, operation=ShiftValues(shift=5)).write_graph(run=True)



We can now load the Stage and look at the result.

In [11]:
Calculator.load().result

15

Similiarly we can use the other class instance, that has different attributes to it and use that in the same way.

In [12]:
Calculator(input_value=10, operation=ScaleValues(factor=2)).write_graph(no_exec=False)



In [13]:
Calculator.load().result

20

It is also possible to use multiple methods or have multiple arguments to the methods. We can show this by combining shift and scale into a single class.

In [14]:
class ShiftAndScale(Base):
    @check_signature
    def __init__(self, shift, factor):
        self.shift = shift
        self.factor = factor

    def compute(self, inp):
        return self.factor * inp + self.shift

In [15]:
Calculator(input_value=10, operation=ShiftAndScale(shift=5, factor=2)).write_graph(
    no_exec=False
)



In [16]:
Calculator.load().result

25

or we can use both methods inside a single Node.

In [17]:
class CombinedCalculator(Node):
    shift: Base = zn.Method()
    scale: Base = zn.Method()
    input_value = zn.params()
    result = zn.outs()

    def __init__(self, input_value=None, shift=None, scale=None, **kwargs):
        super().__init__(**kwargs)
        self.input_value = input_value
        self.shift = shift
        self.scale = scale

    def run(self):
        tmp = self.scale.compute(self.input_value)
        self.result = self.shift.compute(tmp)

In [18]:
CombinedCalculator(
    input_value=10, shift=ShiftValues(shift=5), scale=ScaleValues(factor=2)
).write_graph(no_exec=False)



In [19]:
CombinedCalculator.load().result

25

In [None]:
os.chdir("..")
temp_dir.cleanup()