Navigation Menu

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rllib] Allow for JAX framework #8732

Closed
2 tasks done
KristianHolsheimer opened this issue Jun 2, 2020 · 10 comments · Fixed by #8748
Closed
2 tasks done

[rllib] Allow for JAX framework #8732

KristianHolsheimer opened this issue Jun 2, 2020 · 10 comments · Fixed by #8748
Assignees
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks

Comments

@KristianHolsheimer
Copy link
Contributor

What is the problem?

I've been using JAX as my framework for a little while now. I just upgraded to the nightly build (due to some unrelated issues) and now RLlib is telling me I need to install TensorFlow or Torch.

I tried setting {'framework': 'jax', ...} in my trainer config, but this results in another error. Basically, not recognizing any framework other than one of: [tf|tfe|torch|auto]

Ray version: ray-0.9.0.dev0, Python 3.8, Ubuntu Linux 20.04 LTS

Script to reproduce:

import ray
from ray.rllib.policy.policy import Policy as BasePolicy
from ray.rllib.agents.trainer_template import build_trainer


class Policy(BasePolicy):
    def compute_actions(self, obs_batch, **kwargs):
        actions = [self.action_space.sample() for _ in obs_batch]
        return actions, [], {}

    def get_weights(self):
        pass
    
    def set_weights(self, weights):
        pass

    def learn_on_batch(self, sample_batch):
        pass


trainer = build_trainer(
    name='foo',
    default_policy=Policy,
)

ray.init()
ray.tune.run(
    trainer,
    config={
        'framework': 'jax',
        'env': 'FrozenLake-v0',
    },
    stop={'training_iteration': 1}
)

If we cannot run your script, we cannot fix your issue.

  • I have verified my script runs in a clean environment and reproduces the issue.
  • I have verified the issue also occurs with the latest wheels.
@KristianHolsheimer KristianHolsheimer added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Jun 2, 2020
@ericl ericl added P1 Issue that should be fixed within a few weeks and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Jun 2, 2020
@ericl
Copy link
Collaborator

ericl commented Jun 2, 2020

@sven1977 should we support framework=None or something like this?

@KristianHolsheimer
Copy link
Contributor Author

Yes, that would probably be more sustainable.

Am I right in thinking that as long as my Policy-derived class implements all required methods it doesn't depend on any specific framework?

And if that's the case, why pick a default framework at all?

@KristianHolsheimer
Copy link
Contributor Author

Would RLlib allow for mixing of frameworks?

What would go wrong if I use tensorflow to collect experience, pytorch to learn and jax/numpy to run trails?

Of course, we can think of many reasons not to use such a setup, but to what extent does RLlib's functionality depend on a choice of framework?

@ericl
Copy link
Collaborator

ericl commented Jun 3, 2020

Yeah I think the framework parameter might have been a bit heavy handed. We do allow mixing of policies with different frameworks.

I made a PR to remove the framework import checking, which should allow any kind of policy to be used no matter what the setting is.

@sven1977
Copy link
Contributor

sven1977 commented Jun 3, 2020

Yeah, let's allow framework=None as well, in which case, RLlib shouldn't check anything.

@KristianHolsheimer
Copy link
Contributor Author

KristianHolsheimer commented Jun 3, 2020

@sven1977 This might be silly question, but why would we want to check the framework at all?

I feel that with the latest changes that @ericl made in #8748 is a better setup, i.e. drop the framework checks altogether. This means that the framework config setting is just a hint that allows you to write some conditional logic if a specific value is set.

For instance, I might want to implement some logic if config['framework'] == 'jax', which shouldn't cause any Exceptions elsewhere in the codebase.

@rkooo567 rkooo567 added the rllib label Jun 3, 2020
@sven1977
Copy link
Contributor

sven1977 commented Jun 5, 2020

Ok, this works now. Just explicitly use None as your framework in the config.
@KristianHolsheimer your point is valid, but we do want to apply type checking here as we internally really only support tf|torch|tfe|None. If you want, just create a new key in your Trainer's config: e.g. jax=True for now and check the value of that. We will look into adding JAX very soon and probably have some rudimentary support for this in the near future (generic default Policy/Models).

config={
        'framework': None,
        'env': 'FrozenLake-v0',
    },

Closing this issue.

@sven1977 sven1977 closed this as completed Jun 5, 2020
@KristianHolsheimer
Copy link
Contributor Author

Good to hear that JAX is on your road map. Let me know if I can help.

I shared a couple testing scripts in #8776

@dynamicwebpaige
Copy link
Contributor

dynamicwebpaige commented Mar 17, 2022

@sven1977 Is there a good issue to keep an eye on, RE: JAX support?

Poking around in the ray-project/ray source code, it seems like there have been at least a few recent JAX-affiliate additions, even if the documentation is a bit sparse. 🙂

@JiahaoYao
Copy link
Contributor

Hi @sven and @dynamicwebpaige, #8776 is fixed, hopefully, by adding ray.get() before the remote task.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks
Projects
None yet
6 participants