In [None]:
# default_exp stage
# hide
_FNAME='stage'

import unittest
from unittest import mock
from nbdev.export import notebook2script
import os
TESTCASE = unittest.TestCase()
_nbpath = os.path.join(_dh[0], _FNAME+'.ipynb')

In [None]:
#export
import sys
import yaml
import atexit

from dvcrecord.params import Params
from dvcrecord.deps import Dependency, make_parser, DO_NOT_INCLUDE_IN_PIPELINE
from dvcrecord.output import Output
from dvcrecord.utils import maybe_yaml, write_yaml, PIPELINE_FILE_DEFAULT

class PipelineStage:
    def __init__(self, name, params=None, outputs=None, deps=None, parser=None):
        self.name = name
        self.parser = parser or make_parser()
        
        self.outputs = outputs or Output()
        self.params = params or Params()
        self.deps = deps or Dependency(namespace=self.parse_args())
        
        self.rendering_funcs = {
            'params': self.params.render,
            'deps': self.deps.render,
            'outs': self.outputs.render
            }
        self.atexit_actions()
        
    def atexit_actions(self):
        ns = self.parse_args()
        if ns is None:
            return None
        if ns.dvc_dryrun:
            atexit.register(self.show_render)
        elif ns.dvc_record:
            atexit.register(self.write)
        
    def parse_args(self, *args, **kwargs):
        if self.parser is None:
            return None
        else:
            known, unknown = self.parser.parse_known_args(*args, **kwargs)
            return known
        
    def render_cmd(self, cli_args=None):
        cli_args = cli_args or sys.argv[:]
        cli_args = ['python'] + cli_args
        return ' '.join([arg for arg in cli_args if arg not in DO_NOT_INCLUDE_IN_PIPELINE])
        
    def render(self, as_yaml=False):
        dvc_config = {}
        
        self.deps.register_sourcecode()
        #self.deps.register_param(self.params)
        
        for key, render_func in self.rendering_funcs.items():
            this_yaml = render_func(as_yaml=False)
            if this_yaml:
                dvc_config[key] = this_yaml
            
        dvc_config['cmd'] = self.render_cmd()
        return maybe_yaml(dvc_config, as_yaml=as_yaml)
    
    def show_render(self):
        print(self.render(as_yaml=True))
    
    def write(self, pipefile=None):
        pipefile = pipefile or PIPELINE_FILE_DEFAULT
        try:
            with open(pipefile, 'r') as f:
                pipeline = yaml.safe_load(f)
        except FileNotFoundError:
            pipeline = {'stages': {}}
            
        pipeline['stages'][self.name] = self.render(as_yaml=False)
        write_yaml(pipeline, fname=pipefile)
        return pipeline

In [None]:
import os
from tempfile import TemporaryDirectory
from dvcrecord.utils import write_yaml

def test_stage():
    ps = PipelineStage(name='unittest')
    with TemporaryDirectory() as tempdir:
        
        #set up param files
        params = {'myval': 1, 'stagename': {'otherval': 2}}
        param_file_1 = write_yaml(params, folder=tempdir, fname='params.yaml')
        params2 = {'epochs': 1000}
        param_file_2 = write_yaml(params2, folder=tempdir, fname='moreparams.yaml')

        TESTCASE.assertEqual(ps.params.load(param_file_1+":myval"), 1)
        TESTCASE.assertEqual(ps.params.load(param_file_2+":epochs"), 1000)
        
        #use
        infile_path = ps.deps.register(write_yaml({"input": "data"}, folder=tempdir, fname='input.data'))
        with open(infile_path) as infile:
            infile.read()
       
        outfile_path = write_yaml({"output": "data"}, folder=tempdir, fname='output.data')
        with open(ps.outputs.register(outfile_path)) as outfile:
            outfile.read()

        pipefile=os.path.join(tempdir, 'dvc.yaml')
        ps.write(pipefile)
        
test_stage()

In [None]:
notebook2script(_nbpath)

Converted stage.ipynb.
