Skip to content

Commit

Permalink
[rllib] Add test case that we don't have a hard torch dependency (#7926)
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl committed Apr 8, 2020
1 parent 85481d6 commit e8c19ab
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 8 deletions.
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,13 @@ py_test(
srcs = ["tests/test_dependency.py"]
)

py_test(
name = "tests/test_dependency_torch",
tags = ["tests_dir", "tests_dir_D"],
size = "small",
srcs = ["tests/test_dependency_torch.py"]
)

py_test(
name = "tests/test_eager_support",
tags = ["tests_dir", "tests_dir_E"],
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/qmix/qmix_policy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from gym.spaces import Tuple, Discrete, Dict
import logging
import numpy as np
from torch.optim import RMSprop
from torch.distributions import Categorical

import ray
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
Expand Down Expand Up @@ -244,6 +242,7 @@ def __init__(self, obs_space, action_space, config):
self.loss = QMixLoss(self.model, self.target_model, self.mixer,
self.target_mixer, self.n_agents, self.n_actions,
self.config["double_q"], self.config["gamma"])
from torch.optim import RMSprop
self.optimiser = RMSprop(
params=self.params,
lr=config["lr"],
Expand Down Expand Up @@ -283,6 +282,7 @@ def compute_actions(self,
random_numbers = torch.rand_like(q_values[:, :, 0])
pick_random = (random_numbers < (self.cur_epsilon
if explore else 0.0)).long()
from torch.distributions import Categorical
random_actions = Categorical(avail).sample().long()
actions = (pick_random * random_actions +
(1 - pick_random) * masked_q_values.argmax(dim=2))
Expand Down
22 changes: 22 additions & 0 deletions rllib/tests/test_dependency_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env python

import os
import sys

if __name__ == "__main__":
# Do not import torch for testing purposes.
os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"

from ray.rllib.agents.a3c import A2CTrainer
assert "torch" not in sys.modules, \
"Torch initially present, when it shouldn't."

# note: no ray.init(), to test it works without Ray
trainer = A2CTrainer(
env="CartPole-v0", config={
"use_pytorch": False,
"num_workers": 0
})
trainer.train()

assert "torch" not in sys.modules, "Torch should not be imported"
17 changes: 11 additions & 6 deletions rllib/utils/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def try_import_tfp(error=False):

# Fake module for torch.nn.
class NNStub:
pass
def __init__(self, *a, **kw):
# Fake nn.functional module within torch.nn.
self.functional = None
self.Module = ModuleStub


# Fake class for torch.nn.Module to allow it to be inherited from.
Expand All @@ -120,7 +123,7 @@ def try_import_torch(error=False):
"""
if "RLLIB_TEST_NO_TORCH_IMPORT" in os.environ:
logger.warning("Not importing Torch for test purposes.")
return None, None
return _torch_stubs()

try:
import torch
Expand All @@ -129,10 +132,12 @@ def try_import_torch(error=False):
except ImportError as e:
if error:
raise e
return _torch_stubs()


nn = NNStub()
nn.Module = ModuleStub
return None, nn
def _torch_stubs():
nn = NNStub()
return None, nn


def get_variable(value,
Expand Down Expand Up @@ -165,7 +170,7 @@ def get_variable(value,
return tf.compat.v1.get_variable(
tf_name, initializer=value, dtype=dtype, trainable=trainable)
elif framework == "torch" and torch_tensor is True:
import torch
torch, _ = try_import_torch()
var_ = torch.from_numpy(value)
var_.requires_grad = trainable
return var_
Expand Down

0 comments on commit e8c19ab

Please sign in to comment.