Skip to content

[Tune] how to import customized trainable function in another file in a customazied path #1639

@haoyangz

Description

@haoyangz

System information

  • Linux Ubuntu 16.04:
  • Ray installed from source:
  • Ray version: 0.3.1:
  • Python version: 2.7:
  • Exact command to reproduce:

Describe the problem

I have several model files, each of which has a different trainable function (that corresponds to a different neural network structure). In my main script main.py, I tried to import the trainable function from a given model file (the path of which is at mydir/model.py for instance) by inserting its path to sys.path, and pass the trainable function to run_experiment. The importing works in the main script, but Ray failed the following error:

Traceback (most recent call last):
  File "/cluster/zeng/code/research/software/miniconda/lib/python2.7/site-packages/ray-0.3.1-py2.7-linux-x86_64.egg/ray/actor.py", line 276, in fetch_and_register_actor
    unpickled_class = pickle.loads(pickled_class)
  File "/cluster/zeng/code/research/software/miniconda/lib/python2.7/pickle.py", line 1388, in loads
    return Unpickler(file).load()
  File "/cluster/zeng/code/research/software/miniconda/lib/python2.7/pickle.py", line 864, in load
    dispatch[key](self)
  File "/cluster/zeng/code/research/software/miniconda/lib/python2.7/pickle.py", line 1096, in load_global
    klass = self.find_class(module, name)
  File "/cluster/zeng/code/research/software/miniconda/lib/python2.7/pickle.py", line 1130, in find_class
    __import__(module)
ImportError: No module named model

I can only run with model.py being in the same directory as the main.py.

My question

What's the right way to import a trainable function from a customized path?

To reproduce

My model file (saved to 'mydir/model.py'):

import numpy as np
import os, json
from ray.tune import Trainable, TrainingResult, register_trainable

class MyTrainableClass(Trainable):
    def _setup(self):
        self.timestep = 0

    def _train(self):
        self.timestep += 1
        v = np.tanh(float(self.timestep) / self.config["width"])
        v *= self.config["height"]

        # Here we use `episode_reward_mean`, but you can also report other
        # objectives such as loss or accuracy (see tune/result.py).
        return TrainingResult(episode_reward_mean=v, timesteps_this_iter=1)

    def _save(self, checkpoint_dir):
        path = os.path.join(checkpoint_dir, "checkpoint")
        with open(path, "w") as f:
            f.write(json.dumps({"timestep": self.timestep}))
        return path

    def _restore(self, checkpoint_path):
        with open(checkpoint_path) as f:
            self.timestep = json.loads(f.read())["timestep"]

My main.py file:

#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json, sys
import os
import random
import numpy as np
import ray
from ray.tune import Trainable, TrainingResult, register_trainable, \
    run_experiments
from ray.tune.hyperband import HyperBandScheduler

sys.path.append('mydir')
from model import *

register_trainable("my_class", MyTrainableClass)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--smoke-test", action="store_true", help="Finish quickly for testing")
    args, _ = parser.parse_known_args()
    ray.init()

    # Hyperband early stopping, configured with `episode_reward_mean` as the
    # objective and `timesteps_total` as the time unit.
    hyperband = HyperBandScheduler(
        time_attr="timesteps_total", reward_attr="episode_reward_mean",
        max_t=100)

    run_experiments({
        "hyperband_test": {
            "run": "my_class",
            "stop": {"training_iteration": 1 if args.smoke_test else 99999},
            "repeat": 20,
            "resources": {"cpu": 1, "gpu": 0},
            "config": {
                "width": lambda spec: 10 + int(90 * random.random()),
                "height": lambda spec: int(100 * random.random()),
            },
        }
    }, scheduler=hyperband)

To reproduce the error:

python main.py

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions