Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Issue 35440: JSON output writer should include INFOs. #39632

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions rllib/offline/json_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
smart_open = None

from ray.air._internal.json import SafeFallbackEncoder
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.output_writer import OutputWriter
from ray.rllib.utils.annotations import override, PublicAPI
Expand All @@ -25,8 +25,8 @@
WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)]


# TODO(jungong) : use DatasetWriter to back JsonWriter, so we reduce
# codebase complexity without losing existing functionality.
# TODO(jungong): use DatasetWriter to back JsonWriter, so we reduce codebase complexity
# without losing existing functionality.
@PublicAPI
class JsonWriter(OutputWriter):
"""Writer object that saves experiences in JSON file chunks."""
Expand Down Expand Up @@ -129,10 +129,6 @@ def _to_json_dict(batch: SampleBatchType, compress_columns: List[str]) -> Dict:
policy_batches[policy_id][k] = _to_jsonable(
v, compress=k in compress_columns
)
# INFOS aren't compatible with Arrow since they are dicts with non-string
# keys.
if SampleBatch.INFOS in policy_batches[policy_id]:
del policy_batches[policy_id][SampleBatch.INFOS]
out["policy_batches"] = policy_batches
else:
out["type"] = "SampleBatch"
Expand Down
15 changes: 14 additions & 1 deletion rllib/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def write_outputs(self, output, fw, output_config=None):
)
)
algo = config.build()
algo.train()
print(algo.train())
return algo

def test_agent_output_ok(self):
Expand All @@ -80,6 +80,19 @@ def test_agent_output_logdir(self):
agent = self.write_outputs("logdir", fw)
self.assertEqual(len(glob.glob(agent.logdir + "/output-*.json")), 1)

def test_agent_output_infos(self):
"""Verify that the infos dictionary is written to the output files.

Note, with torch this is always the case."""
output_config = {"store_infos": True}
for fw in framework_iterator(frameworks=("torch", "tf")):
self.write_outputs(self.test_dir, fw, output_config=output_config)
self.assertEqual(len(os.listdir(self.test_dir + fw)), 1)
reader = JsonReader(self.test_dir + fw + "/*.json")
data = reader.next()
data = convert_ma_batch_to_sample_batch(data)
self.assertTrue("infos" in data)

def test_agent_input_dir(self):
config = (
PGConfig()
Expand Down
Loading