In [1]:
import typing

from zntrack import SpawnNode, Node, config, zn
import random

In [2]:
config.nb_name = "07_spawn_nodes.ipynb"

In [3]:
from zntrack.utils import cwd_temp_dir

temp_dir = cwd_temp_dir()

In [4]:
!git init
!dvc init

Initialized empty Git repository in C:/Users/fabia/AppData/Local/Temp/tmpcqenvy_w/.git/
Initialized DVC repository.

You can now commit the changes to git.

+---------------------------------------------------------------------+
|                                                                     |
|        DVC has enabled anonymous aggregate usage analytics.         |
|     Read the analytics documentation (and how to opt-out) here:     |
|             <https://dvc.org/doc/user-guide/analytics>              |
|                                                                     |
+---------------------------------------------------------------------+

What's next?
------------
- Check out the documentation: <https://dvc.org/doc>
- Get help and share ideas: <https://dvc.org/chat>
- Star us on GitHub: <https://github.com/iterative/dvc>


# Spawn Nodes

There can be many scenarios where a single Node can be run multiple times and the results will be pooled of some sort.
This can be achieved by using a `zntrack.SpawnNode`. The key features of a SpawnNode are:
- one or many `zn.Iterable()` parameters
- in case of multiple `zn.Iterable()` a `spawn_filter` to only run certain combinations of parameters
- a pooling Node to gather the results

In [5]:
class SimpleSpawn(SpawnNode):
    start = zn.iterable([-100, -10, -1, 0])
    stop = zn.iterable([1, 10, 100])
    step = zn.iterable([1, 10])

    number = zn.outs()

    def run(self):
        self.number = random.randrange(self.start, self.stop, self.step)

In [6]:
SimpleSpawn().write_graph(no_exec=False)

Submit issues to https://github.com/zincware/ZnTrack.
2022-01-15 18:38:22,968 (INFO): Running stage 'SimpleSpawn_1148361016337595827':
> python -c "from src.SimpleSpawn import SimpleSpawn; SimpleSpawn.load(name='SimpleSpawn_1148361016337595827').run_and_save()" 
Creating 'dvc.yaml'
Adding stage 'SimpleSpawn_1148361016337595827' in 'dvc.yaml'
Generating lock file 'dvc.lock'
Updating lock file 'dvc.lock'

To track the changes with git, run:

	git add dvc.yaml dvc.lock 'nodes\SimpleSpawn_1148361016337595827\.gitignore'

Submit issues to https://github.com/zincware/ZnTrack.
2022-01-15 18:38:27,441 (INFO): Running stage 'SimpleSpawn_60154902242571096':
> python -c "from src.SimpleSpawn import SimpleSpawn; SimpleSpawn.load(name='SimpleSpawn_60154902242571096').run_and_save()" 
Adding stage 'SimpleSpawn_60154902242571096' in 'dvc.yaml'
Updating lock file 'dvc.lock'

To track the changes with git, run:

	git add dvc.yaml 'nodes\SimpleSpawn_60154902242571096\.gitignore' dvc.loc

In [7]:
class SimplePool(Node):
    spawned_nodes: typing.List[SimpleSpawn] = zn.deps([x for x in SimpleSpawn.load()])

    def run(self):
        for node in self.spawned_nodes:
            print(f"Found {node.number} with start: {node.start}, stop: {node.stop}, step: {node.step}")

In [8]:
# Write the Graph for a better DVC Dag visual
SimplePool().write_graph()

Submit issues to https://github.com/zincware/ZnTrack.
2022-01-15 18:40:19,485 (INFO): Adding stage 'SimplePool' in 'dvc.yaml'

To track the changes with git, run:

	git add dvc.yaml



In [9]:
!dvc dag

+---------------------------------+  
| SimpleSpawn_1148361016337595827 |  
+---------------------------------+  
+-------------------------------+  
| SimpleSpawn_60154902242571096 |  
+-------------------------------+  
+-------------------------------+  
| SimpleSpawn_20602384826158361 |  
+-------------------------------+  
+---------------------------------+  
| SimpleSpawn_3320054484246673420 |  
+---------------------------------+  
+---------------------------------+  
| SimpleSpawn_7716239215721179738 |  
+---------------------------------+  
+--------------------------------+ 
| SimpleSpawn_234859725327373250 | 
+--------------------------------+ 
+---------------------------------+  
| SimpleSpawn_1432771092565796107 |  
+---------------------------------+  
+--------------------------------+ 
| SimpleSpawn_124358062171880444 | 
+--------------------------------+ 
+---------------------------------+  
| SimpleSpawn_5009040046883217039 |  
+---------------------------------+ 

In [10]:
SimplePool().run()

Found -79 with start: -100, stop: 1, step: 1
Found -10 with start: -100, stop: 1, step: 10
Found -4 with start: -100, stop: 10, step: 1
Found -10 with start: -100, stop: 10, step: 10
Found -13 with start: -100, stop: 100, step: 1
Found -40 with start: -100, stop: 100, step: 10
Found -1 with start: -10, stop: 1, step: 1
Found 0 with start: -10, stop: 1, step: 10
Found 3 with start: -10, stop: 10, step: 1
Found 0 with start: -10, stop: 10, step: 10
Found 51 with start: -10, stop: 100, step: 1
Found 80 with start: -10, stop: 100, step: 10
Found 0 with start: -1, stop: 1, step: 1
Found -1 with start: -1, stop: 1, step: 10
Found 8 with start: -1, stop: 10, step: 1
Found 9 with start: -1, stop: 10, step: 10
Found 8 with start: -1, stop: 100, step: 1
Found 69 with start: -1, stop: 100, step: 10
Found 0 with start: 0, stop: 1, step: 1
Found 0 with start: 0, stop: 1, step: 10
Found 0 with start: 0, stop: 10, step: 1
Found 0 with start: 0, stop: 10, step: 10
Found 59 with start: 0, stop: 100, st

We can modify this full gridsearch by adding `spawn_filter` which takes all `zn.Iterables()` as arguments:

In [11]:
class FilteredSpawn(SpawnNode):
    param1 = zn.iterable([1, 2, 3])
    param2 = zn.iterable([1, 2, 3])
    param3 = zn.iterable([1, 2, 3])

    numbers = zn.outs()

    def spawn_filter(self, param1, param2, param3) -> bool:
        return (param1 != param2) and (param1 != param3) and (param2 != param3)

    def run(self):
        self.numbers = [self.param1, self.param2, self.param3]

In [12]:
[x.run() for x in FilteredSpawn()]

[None, None, None, None, None, None]

In [13]:
FilteredSpawn().write_graph(no_exec=False)

Submit issues to https://github.com/zincware/ZnTrack.
2022-01-15 18:40:26,506 (INFO): Running stage 'FilteredSpawn_5958485563372509555':
> python -c "from src.FilteredSpawn import FilteredSpawn; FilteredSpawn.load(name='FilteredSpawn_5958485563372509555').run_and_save()" 
Adding stage 'FilteredSpawn_5958485563372509555' in 'dvc.yaml'
Updating lock file 'dvc.lock'

To track the changes with git, run:

	git add dvc.lock dvc.yaml 'nodes\FilteredSpawn_5958485563372509555\.gitignore'

Submit issues to https://github.com/zincware/ZnTrack.
2022-01-15 18:40:31,977 (INFO): Running stage 'FilteredSpawn_1483594487064820420':
> python -c "from src.FilteredSpawn import FilteredSpawn; FilteredSpawn.load(name='FilteredSpawn_1483594487064820420').run_and_save()" 
Adding stage 'FilteredSpawn_1483594487064820420' in 'dvc.yaml'
Updating lock file 'dvc.lock'

To track the changes with git, run:

	git add dvc.yaml dvc.lock 'nodes\FilteredSpawn_1483594487064820420\.gitignore'

Submit issues 

temp_dir.cleanup()

In [14]:
[x.numbers for x in FilteredSpawn.load()]

[[1, 2, 3], [1, 3, 2], [2, 1, 3], [2, 3, 1], [3, 1, 2], [3, 2, 1]]