Skip to content
This repository has been archived by the owner on Sep 24, 2020. It is now read-only.

Commit

Permalink
Add output.log
Browse files Browse the repository at this point in the history
  • Loading branch information
raubitsj committed Aug 25, 2020
1 parent 4ee7a43 commit e4048d7
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 13 deletions.
10 changes: 7 additions & 3 deletions wandb/lib/redirect.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __getattr__(self, attr):
return getattr(self.stream, attr)


def _pipe_relay(stopped, fd, name, cb, tee):
def _pipe_relay(stopped, fd, name, cb, tee, output_writer):
while True:
try:
data = os.read(fd, 4096)
Expand All @@ -46,6 +46,8 @@ def _pipe_relay(stopped, fd, name, cb, tee):
break
if tee:
os.write(tee, data)
if output_writer:
output_writer.write(data)
if cb:
try:
cb(name, data)
Expand Down Expand Up @@ -107,10 +109,11 @@ def uninstall(self):


class Capture(object):
def __init__(self, name, cb):
def __init__(self, name, cb, output_writer):
self._started = False
self._name = name
self._cb = cb
self._output_writer = output_writer
self._stopped = None
self._thread = None
self._tee = None
Expand Down Expand Up @@ -138,7 +141,8 @@ def _start(self):
read_thread = threading.Thread(
name=self._name,
target=_pipe_relay,
args=(self._stopped, self._pipe_rd, self._name, self._cb, self._tee),
args=(self._stopped, self._pipe_rd, self._name, self._cb, self._tee,
self._output_writer),
)
read_thread.daemon = True
read_thread.start()
Expand Down
36 changes: 31 additions & 5 deletions wandb/sdk/wandb_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
from wandb.data_types import _datatypes_set_callback
from wandb.errors import Error
from wandb.interface.summary_record import SummaryRecord
from wandb.lib import module, proto_util, redirect, sparkline
from wandb.lib.filenames import JOBSPEC_FNAME
from wandb.lib import filenames, module, proto_util, redirect, sparkline
from wandb.util import sentry_set_scope, to_forward_slash_path
from wandb.viz import Visualize

Expand Down Expand Up @@ -173,6 +172,8 @@ def __init__(self, config=None, settings=None):
self._final_summary = None
self._sampled_history = None

self._output_writer = None

# Pull info from settings
self._init_from_settings(settings)

Expand Down Expand Up @@ -837,11 +838,15 @@ def _redirect(self, stdout_slave_fd, stderr_slave_fd):

if console == "redirect":
logger.info("redirect1")
out_cap = redirect.Capture(name="stdout", cb=self._redirect_cb)
out_cap = redirect.Capture(
name="stdout", cb=self._redirect_cb, output_writer=self._output_writer
)
out_redir = redirect.Redirect(
src="stdout", dest=out_cap, unbuffered=True, tee=True
)
err_cap = redirect.Capture(name="stderr", cb=self._redirect_cb)
err_cap = redirect.Capture(
name="stderr", cb=self._redirect_cb, output_writer=self._output_writer
)
err_redir = redirect.Redirect(
src="stderr", dest=err_cap, unbuffered=True, tee=True
)
Expand Down Expand Up @@ -938,10 +943,14 @@ def _console_start(self):
# setup fake callback
self._redirect_cb = self._console_callback

output_log_path = os.path.join(self.dir, filenames.OUTPUT_FNAME)
self._output_writer = WriteSerializingFile(open(output_log_path, "wb"))
self._redirect(self._stdout_slave_fd, self._stderr_slave_fd)

def _console_stop(self):
self._restore()
self._output_writer.f.close()
self._output_writer = None

def _on_start(self):
wandb.termlog("Tracking run with wandb version {}".format(wandb.__version__))
Expand Down Expand Up @@ -1160,7 +1169,7 @@ def _save_job_spec(self):
}

s = json.dumps(job_spec, indent=4)
spec_filename = JOBSPEC_FNAME
spec_filename = filenames.JOBSPEC_FNAME
with open(spec_filename, "w") as f:
print(s, file=f)
self.save(spec_filename)
Expand Down Expand Up @@ -1298,3 +1307,20 @@ def huggingface_version():
if hasattr(trans, "__version__"):
return trans.__version__
return None


class WriteSerializingFile(object):
"""Wrapper for a file object that serializes writes.
"""

def __init__(self, f):
self.lock = threading.Lock()
self.f = f

def write(self, *args, **kargs):
self.lock.acquire()
try:
self.f.write(*args, **kargs)
self.f.flush()
finally:
self.lock.release()
36 changes: 31 additions & 5 deletions wandb/sdk_py27/wandb_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
from wandb.data_types import _datatypes_set_callback
from wandb.errors import Error
from wandb.interface.summary_record import SummaryRecord
from wandb.lib import module, proto_util, redirect, sparkline
from wandb.lib.filenames import JOBSPEC_FNAME
from wandb.lib import filenames, module, proto_util, redirect, sparkline
from wandb.util import sentry_set_scope, to_forward_slash_path
from wandb.viz import Visualize

Expand Down Expand Up @@ -173,6 +172,8 @@ def __init__(self, config=None, settings=None):
self._final_summary = None
self._sampled_history = None

self._output_writer = None

# Pull info from settings
self._init_from_settings(settings)

Expand Down Expand Up @@ -837,11 +838,15 @@ def _redirect(self, stdout_slave_fd, stderr_slave_fd):

if console == "redirect":
logger.info("redirect1")
out_cap = redirect.Capture(name="stdout", cb=self._redirect_cb)
out_cap = redirect.Capture(
name="stdout", cb=self._redirect_cb, output_writer=self._output_writer
)
out_redir = redirect.Redirect(
src="stdout", dest=out_cap, unbuffered=True, tee=True
)
err_cap = redirect.Capture(name="stderr", cb=self._redirect_cb)
err_cap = redirect.Capture(
name="stderr", cb=self._redirect_cb, output_writer=self._output_writer
)
err_redir = redirect.Redirect(
src="stderr", dest=err_cap, unbuffered=True, tee=True
)
Expand Down Expand Up @@ -938,10 +943,14 @@ def _console_start(self):
# setup fake callback
self._redirect_cb = self._console_callback

output_log_path = os.path.join(self.dir, filenames.OUTPUT_FNAME)
self._output_writer = WriteSerializingFile(open(output_log_path, "wb"))
self._redirect(self._stdout_slave_fd, self._stderr_slave_fd)

def _console_stop(self):
self._restore()
self._output_writer.f.close()
self._output_writer = None

def _on_start(self):
wandb.termlog("Tracking run with wandb version {}".format(wandb.__version__))
Expand Down Expand Up @@ -1160,7 +1169,7 @@ def _save_job_spec(self):
}

s = json.dumps(job_spec, indent=4)
spec_filename = JOBSPEC_FNAME
spec_filename = filenames.JOBSPEC_FNAME
with open(spec_filename, "w") as f:
print(s, file=f)
self.save(spec_filename)
Expand Down Expand Up @@ -1298,3 +1307,20 @@ def huggingface_version():
if hasattr(trans, "__version__"):
return trans.__version__
return None


class WriteSerializingFile(object):
"""Wrapper for a file object that serializes writes.
"""

def __init__(self, f):
self.lock = threading.Lock()
self.f = f

def write(self, *args, **kargs):
self.lock.acquire()
try:
self.f.write(*args, **kargs)
self.f.flush()
finally:
self.lock.release()

0 comments on commit e4048d7

Please sign in to comment.