Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
125 lines (110 sloc) 4.23 KB
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import json
import logging
import numpy as np
import os
from six.moves.urllib.parse import urlparse
import time
try:
from smart_open import smart_open
except ImportError:
smart_open = None
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.output_writer import OutputWriter
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.compression import pack, compression_supported
logger = logging.getLogger(__name__)
@PublicAPI
class JsonWriter(OutputWriter):
"""Writer object that saves experiences in JSON file chunks."""
@PublicAPI
def __init__(self,
path,
ioctx=None,
max_file_size=64 * 1024 * 1024,
compress_columns=frozenset(["obs", "new_obs"])):
"""Initialize a JsonWriter.
Arguments:
path (str): a path/URI of the output directory to save files in.
ioctx (IOContext): current IO context object.
max_file_size (int): max size of single files before rolling over.
compress_columns (list): list of sample batch columns to compress.
"""
self.ioctx = ioctx or IOContext()
self.max_file_size = max_file_size
self.compress_columns = compress_columns
if urlparse(path).scheme:
self.path_is_uri = True
else:
path = os.path.abspath(os.path.expanduser(path))
# Try to create local dirs if they don't exist
try:
os.makedirs(path)
except OSError:
pass # already exists
assert os.path.exists(path), "Failed to create {}".format(path)
self.path_is_uri = False
self.path = path
self.file_index = 0
self.bytes_written = 0
self.cur_file = None
@override(OutputWriter)
def write(self, sample_batch):
start = time.time()
data = _to_json(sample_batch, self.compress_columns)
f = self._get_file()
f.write(data)
f.write("\n")
if hasattr(f, "flush"): # legacy smart_open impls
f.flush()
self.bytes_written += len(data)
logger.debug("Wrote {} bytes to {} in {}s".format(
len(data), f,
time.time() - start))
def _get_file(self):
if not self.cur_file or self.bytes_written >= self.max_file_size:
if self.cur_file:
self.cur_file.close()
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
path = os.path.join(
self.path, "output-{}_worker-{}_{}.json".format(
timestr, self.ioctx.worker_index, self.file_index))
if self.path_is_uri:
if smart_open is None:
raise ValueError(
"You must install the `smart_open` module to write "
"to URIs like {}".format(path))
self.cur_file = smart_open(path, "w")
else:
self.cur_file = open(path, "w")
self.file_index += 1
self.bytes_written = 0
logger.info("Writing to new output file {}".format(self.cur_file))
return self.cur_file
def _to_jsonable(v, compress):
if compress and compression_supported():
return str(pack(v))
elif isinstance(v, np.ndarray):
return v.tolist()
return v
def _to_json(batch, compress_columns):
out = {}
if isinstance(batch, MultiAgentBatch):
out["type"] = "MultiAgentBatch"
out["count"] = batch.count
policy_batches = {}
for policy_id, sub_batch in batch.policy_batches.items():
policy_batches[policy_id] = {}
for k, v in sub_batch.data.items():
policy_batches[policy_id][k] = _to_jsonable(
v, compress=k in compress_columns)
out["policy_batches"] = policy_batches
else:
out["type"] = "SampleBatch"
for k, v in batch.data.items():
out[k] = _to_jsonable(v, compress=k in compress_columns)
return json.dumps(out)
You can’t perform that action at this time.