Skip to content

Commit

Permalink
[RLlib] Discussion 4351: Conv2d default filter tests and add default …
Browse files Browse the repository at this point in the history
…setting for 96x96 image obs space. (#21560)
  • Loading branch information
sven1977 committed Jan 13, 2022
1 parent a3442df commit 3ac4dab
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 1 deletion.
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,13 @@ py_test(
srcs = ["models/tests/test_attention_nets.py"]
)

py_test(
name = "test_conv2d_default_stacks",
tags = ["team:ml", "models"],
size = "medium",
srcs = ["models/tests/test_conv2d_default_stacks.py"]
)

py_test(
name = "test_convtranspose2d_stack",
tags = ["team:ml", "models"],
Expand Down
56 changes: 56 additions & 0 deletions rllib/models/tests/test_conv2d_default_stacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import gym
import unittest

from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS
from ray.rllib.models.tf.visionnet import VisionNetwork
from ray.rllib.models.torch.visionnet import VisionNetwork as TorchVision
from ray.rllib.utils.framework import try_import_torch, try_import_tf
from ray.rllib.utils.test_utils import framework_iterator

torch, nn = try_import_torch()
tf1, tf, tfv = try_import_tf()


class TestConv2DDefaultStacks(unittest.TestCase):
"""Tests our ConvTranspose2D Stack modules/layers."""

def test_conv2d_default_stacks(self):
"""Tests, whether conv2d defaults are available for img obs spaces.
"""
action_space = gym.spaces.Discrete(2)

shapes = [
(480, 640, 3),
(240, 320, 3),
(96, 96, 3),
(84, 84, 3),
(42, 42, 3),
(10, 10, 3),
]
for shape in shapes:
print(f"shape={shape}")
obs_space = gym.spaces.Box(-1.0, 1.0, shape=shape)
for fw in framework_iterator():
model = ModelCatalog.get_model_v2(
obs_space,
action_space,
2,
MODEL_DEFAULTS.copy(),
framework=fw)
self.assertTrue(
isinstance(model, (VisionNetwork, TorchVision)))
if fw == "torch":
output, _ = model({
"obs": torch.from_numpy(obs_space.sample()[None])
})
else:
output, _ = model({"obs": obs_space.sample()[None]})
# B x [action logits]
self.assertTrue(output.shape == (1, 2))
print("ok")


if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
12 changes: 11 additions & 1 deletion rllib/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def get_filter_config(shape):
List[list]: The Conv2D filter configuration usable as `conv_filters`
inside a model config dict.
"""
shape = list(shape)
# VizdoomGym (large 480x640).
filters_480x640 = [
[16, [24, 32], [14, 18]],
Expand All @@ -86,6 +85,12 @@ def get_filter_config(shape):
[32, [6, 6], 4],
[256, [9, 9], 1],
]
# 96x96x3 (e.g. CarRacing-v0).
filters_96x96 = [
[16, [8, 8], 4],
[32, [4, 4], 2],
[256, [11, 11], 2],
]
# Atari.
filters_84x84 = [
[16, [8, 8], 4],
Expand All @@ -103,12 +108,17 @@ def get_filter_config(shape):
[16, [5, 5], 2],
[32, [5, 5], 2],
]

shape = list(shape)
if len(shape) in [2, 3] and (shape[:2] == [480, 640]
or shape[1:] == [480, 640]):
return filters_480x640
elif len(shape) in [2, 3] and (shape[:2] == [240, 320]
or shape[1:] == [240, 320]):
return filters_240x320
elif len(shape) in [2, 3] and (shape[:2] == [96, 96]
or shape[1:] == [96, 96]):
return filters_96x96
elif len(shape) in [2, 3] and (shape[:2] == [84, 84]
or shape[1:] == [84, 84]):
return filters_84x84
Expand Down

0 comments on commit 3ac4dab

Please sign in to comment.