-
Notifications
You must be signed in to change notification settings - Fork 5.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RLlib] Discussion 4351: Conv2d default filter tests and add default …
…setting for 96x96 image obs space. (#21560)
- Loading branch information
Showing
3 changed files
with
74 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import gym | ||
import unittest | ||
|
||
from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS | ||
from ray.rllib.models.tf.visionnet import VisionNetwork | ||
from ray.rllib.models.torch.visionnet import VisionNetwork as TorchVision | ||
from ray.rllib.utils.framework import try_import_torch, try_import_tf | ||
from ray.rllib.utils.test_utils import framework_iterator | ||
|
||
torch, nn = try_import_torch() | ||
tf1, tf, tfv = try_import_tf() | ||
|
||
|
||
class TestConv2DDefaultStacks(unittest.TestCase): | ||
"""Tests our ConvTranspose2D Stack modules/layers.""" | ||
|
||
def test_conv2d_default_stacks(self): | ||
"""Tests, whether conv2d defaults are available for img obs spaces. | ||
""" | ||
action_space = gym.spaces.Discrete(2) | ||
|
||
shapes = [ | ||
(480, 640, 3), | ||
(240, 320, 3), | ||
(96, 96, 3), | ||
(84, 84, 3), | ||
(42, 42, 3), | ||
(10, 10, 3), | ||
] | ||
for shape in shapes: | ||
print(f"shape={shape}") | ||
obs_space = gym.spaces.Box(-1.0, 1.0, shape=shape) | ||
for fw in framework_iterator(): | ||
model = ModelCatalog.get_model_v2( | ||
obs_space, | ||
action_space, | ||
2, | ||
MODEL_DEFAULTS.copy(), | ||
framework=fw) | ||
self.assertTrue( | ||
isinstance(model, (VisionNetwork, TorchVision))) | ||
if fw == "torch": | ||
output, _ = model({ | ||
"obs": torch.from_numpy(obs_space.sample()[None]) | ||
}) | ||
else: | ||
output, _ = model({"obs": obs_space.sample()[None]}) | ||
# B x [action logits] | ||
self.assertTrue(output.shape == (1, 2)) | ||
print("ok") | ||
|
||
|
||
if __name__ == "__main__": | ||
import pytest | ||
import sys | ||
sys.exit(pytest.main(["-v", __file__])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters