# Working with child classes

There can be scenarios where it makes sense to subclass different methods for a single stage. E.g. testing out dense neural network vs. a CNN which are both defined in a custom class.
Such a scenario can be handled by using a class for each method.

In [1]:
from zntrack import ZnTrackProject, config

config.nb_name = "02_PassingClasses.ipynb"

In [2]:
import os
from zntrack.utils import cwd_temp_dir

temp_dir = cwd_temp_dir()

In [3]:
project = ZnTrackProject()
project.create_dvc_repository()

2022-01-13 19:39:44,413 (INFO): Setting up GIT/DVC repository.


## Creating Operations

Best practice for adding different custom operations or methods is to inherit from a common parent with a method that does the computation.

In [4]:
class Base:
    def compute(self, inp):
        raise NotImplementedError

For simplicity reasons we will look at some very simple functions but they can be of arbitrary complexity.
We apply the `check_signature` decorator which is an optional check that the tests that the keyword arguments are identical to the class attribute names.
This is mandatory for ZnTrack to work in the anticipated way.

In [5]:
from zntrack.utils.decorators import check_signature


class ShiftValues(Base):
    @check_signature
    def __init__(self, shift: float):
        self.shift = shift

    def compute(self, inp):
        return inp + self.shift


class ScaleValues(Base):
    @check_signature
    def __init__(self, factor: float):
        self.factor = factor

    def compute(self, inp):
        return inp * self.factor

The actual Node makes use of the typical ZnTrack functionality beeing extended by `zn.Method()`.

In [6]:
from zntrack import Node, zn

In [7]:
class Calculator(Node):
    operation: Base = zn.Method()
    input_value = zn.params()
    result = zn.outs()

    def __init__(self, input_value=None, operation=None, **kwargs):
        super().__init__(**kwargs)
        self.input_value = input_value
        self.operation = operation

    def run(self):
        self.result = self.operation.compute(self.input_value)

With this definition given, we can pass an instance of our coompute classes to the Node. It will then save the state of the instances and reproduce that state in the `dvc repro`.
Let's start with a simple shift of the given input values.

In [8]:
Calculator(input_value=10, operation=ShiftValues(shift=5)).write_graph(no_exec=True)
!dvc repro

Submit issues to https://github.com/zincware/ZnTrack.
Running stage 'Calculator':
> python -c "from src.Calculator import Calculator; Calculator.load(name='Calculator').run_and_save()" 
Generating lock file 'dvc.lock'
Updating lock file 'dvc.lock'

To track the changes with git, run:

	git add dvc.lock
Use `dvc push` to send your updates to remote storage.


We can now load the Stage and look at the result.

In [9]:
Calculator.load().result

15

Similiarly we can use the other class instance, that has different attributes to it and use that in the same way.

In [10]:
Calculator(input_value=10, operation=ScaleValues(factor=2)).write_graph(no_exec=False)

Submit issues to https://github.com/zincware/ZnTrack.


In [11]:
Calculator.load().result

15

It is also possible to use multiple methods or have multiple arguments to the methods. We can show this by combining shift and scale into a single class.

In [12]:
class ShiftAndScale(Base):
    @check_signature
    def __init__(self, shift, factor):
        self.shift = shift
        self.factor = factor

    def compute(self, inp):
        return self.factor * inp + self.shift

In [13]:
Calculator(input_value=10, operation=ShiftAndScale(shift=5, factor=2)).write_graph(
    no_exec=False
)

Submit issues to https://github.com/zincware/ZnTrack.


In [14]:
Calculator.load().result

15

or we can use both methods inside a single Node.

In [15]:
class CombinedCalculator(Node):
    shift: Base = zn.Method()
    scale: Base = zn.Method()
    input_value = zn.params()
    result = zn.outs()

    def __init__(self, input_value=None, shift=None, scale=None, **kwargs):
        super().__init__(**kwargs)
        self.input_value = input_value
        self.shift = shift
        self.scale = scale

    def run(self):
        tmp = self.scale.compute(self.input_value)
        self.result = self.shift.compute(tmp)

In [16]:
CombinedCalculator(
    input_value=10, shift=ShiftValues(shift=5), scale=ScaleValues(factor=2)
).write_graph(no_exec=False)

Submit issues to https://github.com/zincware/ZnTrack.


In [17]:
CombinedCalculator.load().result

25

In [None]:
os.chdir("..")
temp_dir.cleanup()