# Inheritance

ZnTrack allows inheritance from a Node base class.
This can e.g. be useful if you want to test out different methods of the same kind.
In the following example, we will show this by using different functions in the run method with the same inputs and outputs.

In [1]:
from zntrack import config

config.nb_name = "02_Inheritance.ipynb"

In [2]:
# Work in a temporary directory
from zntrack.utils import cwd_temp_dir
temp_dir = cwd_temp_dir()

!git init
!dvc init

Initialized empty Git repository in C:/Users/fabia/AppData/Local/Temp/tmpc5a1k84s/.git/
Initialized DVC repository.

You can now commit the changes to git.

+---------------------------------------------------------------------+
|                                                                     |
|        DVC has enabled anonymous aggregate usage analytics.         |
|     Read the analytics documentation (and how to opt-out) here:     |
|             <https://dvc.org/doc/user-guide/analytics>              |
|                                                                     |
+---------------------------------------------------------------------+

What's next?
------------
- Check out the documentation: <https://dvc.org/doc>
- Get help and share ideas: <https://dvc.org/chat>
- Star us on GitHub: <https://github.com/iterative/dvc>


In [3]:
from zntrack import Node, zn

In [4]:
class NodeBase(Node):
    node_name = "basic_number"

    inputs: float = zn.params()
    output: float = zn.outs()

In [5]:
class AddNumber(NodeBase):
    """Shift input by an offset"""
    offset: float = zn.params()

    def run(self):
        self.output = self.inputs + self.offset

class MultiplyNumber(NodeBase):
    """Multiply input by a factor"""
    factor: float = zn.params()

    def run(self):
        self.output = self.inputs * self.factor

In [6]:
add_number = AddNumber(inputs=10.0, offset=15.0)
add_number.write_graph(run=True)

Submit issues to https://github.com/zincware/ZnTrack.


Because the Nodes inherit from each other and we defined the `node_name` in the parent class, we can use all classes to load the outputs (as long as they are shared).
This is important to keep in mind when working with inheritance, that the output might not necessarily be created by the Node it was loaded by.
On the other hand, this can be handy for dependency handling.
A subsequent Node can e.g. depend on the parent Node and does not need to know where the values actually come from.
I.e. an ML Model might implement a predict function in the parent node but can have an entirely different structure.
An evaluation node might only need the predict method and can therefore be used with all children of the model class.

In [7]:
NodeBase.load().output

25.0

In [8]:
!dvc dag

+--------------+ 
| basic_number | 
+--------------+ 


In [9]:
multiply_number = MultiplyNumber(inputs=6.0, factor=6.0)
multiply_number.write_graph(run=True)



In [10]:
NodeBase.load().output

36.0

As expected the node name remains the same and therefore, the Node is replaced with the new one.

In [11]:
!dvc dag

+--------------+ 
| basic_number | 
+--------------+ 


## Nodes as parameters

Sometimes it can be useful to have a Node as a parameter or use the run method of the given Node but storing the outputs somewhere else.
For example an active learning cycle might use the model and evaluation class but the outputs are stored in the active learning Node.
You might still want to use the other Nodes to avoid overhead though.

In the following we will use the run method of a `NodeBase` Node and also have a dataclass Node just for storing parameters.
Internally, ZnTrack disables all outputs of the given Node.
To keep the DAG working, a `_hash = zn.Hash()` is introduced.
This value is computed from the parameters as well as the current timestamp and only serves as a file dependency for DVC.
Adding `zn.Hash()` to any Node will add an output file but won't have any additional effect.

In [12]:
class DivideNumber(NodeBase):
    """Multiply input by a factor"""
    divider: float = zn.params()
    _hash = zn.Hash()

    def run(self):
        self.output = self.inputs * self.divider


class Polynomial(Node):
    a0: float = zn.params()
    a1: float = zn.params()
    _hash = zn.Hash()

class ManipulateNumber(Node):
    inputs: float = zn.params()
    output: float = zn.outs()
    value_handler: NodeBase = zn.Nodes()
    polynomial: Polynomial = zn.Nodes()

    def run(self):
        # use the passed method
        self.value_handler.inputs = self.inputs
        self.value_handler.run()
        self.output = self.value_handler.output
        # polynomials
        self.output = self.polynomial.a0 + self.polynomial.a1 * self.output

In [13]:
manipulate_number = ManipulateNumber(
    inputs=10.0,
    value_handler=DivideNumber(divider=3.0, inputs=None),
    polynomial=Polynomial(a0=60.0, a1=10.0),
)

In [14]:
manipulate_number.write_graph(run=True)



In [15]:
manipulate_number = manipulate_number.load()

In [16]:
manipulate_number.output

360.0

In [17]:
!dvc dag

+--------------+ 
| basic_number | 
+--------------+ 
+-----------------------------+                    +--------------------------------+
| ManipulateNumber-polynomial |                    | ManipulateNumber-value_handler |
+-----------------------------+                    +--------------------------------+
                           ****                      *****                     
                               ****              ****                          
                                   ***        ***                              
                                +------------------+                           
                                | ManipulateNumber |                           
                                +------------------+                           


In [None]:
temp_dir.cleanup()