Skip to content

Commit

Permalink
[CLI-881][CLI-880][CLI-451] Improve wandb sync to handle errors (#2199)
Browse files Browse the repository at this point in the history
  • Loading branch information
vanpelt committed Jun 4, 2021
1 parent c4b6ef3 commit 68bdfc0
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 62 deletions.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Expand Up @@ -6,7 +6,7 @@
},
"git.ignoreLimitWarning": true,

"editor.formatOnSave": false,
"editor.formatOnSave": true,

"python.linting.enabled": true,
"python.linting.flake8Enabled": true,
Expand Down
29 changes: 24 additions & 5 deletions tests/fixtures/train.py
Expand Up @@ -2,21 +2,40 @@
import time
import random
import wandb
import numpy as np
import os
import signal
import sys

parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=2)
parser.add_argument("--epochs", type=int, default=2)
parser.add_argument("--heavy", action="store_true", default=False)
parser.add_argument("--sleep_every", type=int, default=0)
args = parser.parse_args()
print("Calling init")
print("Calling init with args: {}", format(args))
print("Environ: {}".format({k: v for k, v in os.environ.items() if k.startswith("WANDB")}))
wandb.init(config=args)
print("Init called")
print("Init called with config {}".format(wandb.config))

#raise ValueError()
#os.kill(os.getpid(), signal.SIGINT)
# raise ValueError()
# os.kill(os.getpid(), signal.SIGINT)
for i in range(0, wandb.config.epochs):
loss = random.uniform(0, wandb.config.epochs - i)
print("loss: %s" % loss)
wandb.log({"loss": loss}, commit=False)
if wandb.config.heavy:
for x in range(50):
wandb.log(
{
"hist_{}".format(x): wandb.Histogram(
np.random.randint(255, size=(1000))
)
},
commit=False,
)
wandb.log({"cool": True})
if wandb.config.sleep_every > 0 and i % wandb.config.sleep_every == 0:
print("sleeping")
time.sleep(random.random() + 1)
sys.stdout.flush()
print("Finished")
77 changes: 77 additions & 0 deletions tests/test_offline_sync.py
@@ -0,0 +1,77 @@
import os
import subprocess
import sys
import time
import glob
from .utils import fixture_open


def test_sync_in_progress(live_mock_server, test_dir):
with open("train.py", "w") as f:
f.write(fixture_open("train.py").read())
env = dict(os.environ)
env["WANDB_MODE"] = "offline"
env["WANDB_DIR"] = test_dir
env["WANDB_CONSOLE"] = "off"
stdout = open("stdout.log", "w+")
offline_run = subprocess.Popen(
[
sys.executable,
"train.py",
"--epochs",
"50",
"--sleep_every",
"15",
"--heavy",
],
stdout=stdout,
stderr=subprocess.STDOUT,
bufsize=1,
close_fds=True,
env=env,
)
attempts = 0
latest_run = os.path.join(test_dir, "wandb", "latest-run")
while not os.path.exists(latest_run) and attempts < 50:
time.sleep(0.1)
# On windows we have no symlinks, so we grab the run dir
if attempts > 0 and attempts % 10 == 0:
if os.path.exists(os.path.join(test_dir, "wandb")):
run_dir = os.listdir(os.path.join(test_dir, "wandb"))
if len(run_dir) > 0:
latest_run = os.path.join(test_dir, "wandb", run_dir[0])
attempts += 1
if attempts == 50:
print("cur dir contents: ", os.listdir(test_dir))
print("wandb dir contents: ", os.listdir(os.path.join(test_dir, "wandb")))
stdout.seek(0)
print("STDOUT")
print(stdout.read())
debug = os.path.join("wandb", "debug.log")
debug_int = os.path.join("wandb", "debug-internal.log")
if os.path.exists(debug):
print("DEBUG")
print(open(debug).read())
if os.path.exists(debug_int):
print("DEBUG INTERNAL")
print(open(debug).read())
assert False, "train.py failed to launch :("
else:
print(
"Starting live syncing after {} seconds from: {}".format(
attempts * 0.1, latest_run
)
)
for i in range(3):
# Generally, the first sync will fail because the .wandb file is empty
sync = subprocess.Popen(["wandb", "sync", latest_run], env=os.environ)
assert sync.wait() == 0
# Only confirm we don't have a .synced file if our offline run is still running
if offline_run.poll() is None:
assert len(glob.glob(os.path.join(latest_run, "*.synced"))) == 0
assert offline_run.wait() == 0
sync = subprocess.Popen(["wandb", "sync", latest_run], env=os.environ)
assert sync.wait() == 0
assert len(glob.glob(os.path.join(latest_run, "*.synced"))) == 1
print("Number of upserts: ", live_mock_server.get_ctx()["upsert_bucket_count"])
assert live_mock_server.get_ctx()["upsert_bucket_count"] >= 3
1 change: 1 addition & 0 deletions tests/wandb_artifacts_test.py
Expand Up @@ -515,6 +515,7 @@ def test_add_table_from_dataframe(live_mock_server, test_settings):
run.finish()


@pytest.mark.timeout(120)
def test_artifact_log_with_network_error(live_mock_server, test_settings):
run = wandb.init(settings=test_settings)
artifact = wandb.Artifact("table-example", "dataset")
Expand Down
3 changes: 2 additions & 1 deletion tests/wandb_integration_test.py
Expand Up @@ -65,7 +65,8 @@ def test_resume_allow_success(live_mock_server, test_settings):


@pytest.mark.skipif(
platform.system() == "Windows", reason="File syncing is somewhat busted in windows"
platform.system() == "Windows" or sys.version_info < (3, 6),
reason="File syncing is somewhat busted in windows and python 2",
)
# TODO: Sometimes wandb-summary.json didn't exists, other times requirements.txt in windows
def test_parallel_runs(request, live_mock_server, test_settings, test_name):
Expand Down
60 changes: 45 additions & 15 deletions wandb/sdk/internal/datastore.py
Expand Up @@ -65,12 +65,15 @@ def __init__(self):
self._opened_for_scan = False
self._fp = None
self._index = 0
self._size_bytes = 0

self._crc = [0] * (LEVELDBLOG_LAST + 1)
for x in range(1, LEVELDBLOG_LAST + 1):
self._crc[x] = zlib.crc32(strtobytes(chr(x))) & 0xFFFFFFFF

assert wandb._assert_is_internal_process
assert (
wandb._assert_is_internal_process
), "DataStore can only be used in the internal process"

def open_for_write(self, fname):
self._fname = fname
Expand All @@ -95,24 +98,36 @@ def open_for_scan(self, fname):
logger.info("open for scan: %s", fname)
self._fp = open(fname, "rb")
self._index = 0
self._size_bytes = os.stat(fname).st_size
self._opened_for_scan = True
self._read_header()

def in_last_block(self):
"""When reading, we want to know if we're in the last block to
handle in progress writes"""
return self._index > self._size_bytes - LEVELDBLOG_DATA_LEN

def scan_record(self):
assert self._opened_for_scan
assert self._opened_for_scan, "file not open for scanning"
# TODO(jhr): handle some assertions as file corruption issues
# assume we have enough room to read header, checked by caller?
header = self._fp.read(LEVELDBLOG_HEADER_LEN)
if len(header) == 0:
return None
assert len(header) == LEVELDBLOG_HEADER_LEN
assert (
len(header) == LEVELDBLOG_HEADER_LEN
), "record header is {} bytes instead of the expected {}".format(
len(header), LEVELDBLOG_HEADER_LEN
)
fields = struct.unpack("<IHB", header)
checksum, dlength, dtype = fields
# check len, better fit in the block
self._index += LEVELDBLOG_HEADER_LEN
data = self._fp.read(dlength)
checksum_computed = zlib.crc32(data, self._crc[dtype]) & 0xFFFFFFFF
assert checksum == checksum_computed
assert (
checksum == checksum_computed
), "record checksum is invalid, data may be corrupt"
self._index += dlength
return dtype, data

Expand All @@ -125,7 +140,7 @@ def scan_data(self):
pad_check = strtobytes("\x00" * space_left)
pad = self._fp.read(space_left)
# verify they are zero
assert pad == pad_check
assert pad == pad_check, "invald padding"
self._index += space_left

record = self.scan_record()
Expand All @@ -135,7 +150,9 @@ def scan_data(self):
if dtype == LEVELDBLOG_FULL:
return data

assert dtype == LEVELDBLOG_FIRST
assert (
dtype == LEVELDBLOG_FIRST
), "expected record to be type {} but found {}".format(LEVELDBLOG_FIRST, dtype)
while True:
offset = self._index % LEVELDBLOG_BLOCK_LEN
record = self.scan_record()
Expand All @@ -145,7 +162,11 @@ def scan_data(self):
if dtype == LEVELDBLOG_LAST:
data += new_data
break
assert dtype == LEVELDBLOG_MIDDLE
assert (
dtype == LEVELDBLOG_MIDDLE
), "expected record to be type {} but found {}".format(
LEVELDBLOG_MIDDLE, dtype
)
data += new_data
return data

Expand All @@ -156,21 +177,28 @@ def _write_header(self):
LEVELDBLOG_HEADER_MAGIC,
LEVELDBLOG_HEADER_VERSION,
)
assert len(data) == 7
assert (
len(data) == LEVELDBLOG_HEADER_LEN
), "header size is {} bytes, expected {}".format(
len(data), LEVELDBLOG_HEADER_LEN
)
self._fp.write(data)
self._index += len(data)

def _read_header(self):
header_length = 7
header = self._fp.read(header_length)
header = self._fp.read(LEVELDBLOG_HEADER_LEN)
assert (
len(header) == LEVELDBLOG_HEADER_LEN
), "header is {} bytes instead of the expected {}".format(
len(header), LEVELDBLOG_HEADER_LEN
)
ident, magic, version = struct.unpack("<4sHB", header)
if ident != strtobytes(LEVELDBLOG_HEADER_IDENT):
raise Exception("Invalid header")
if magic != LEVELDBLOG_HEADER_MAGIC:
raise Exception("Invalid header")
if version != LEVELDBLOG_HEADER_VERSION:
raise Exception("Invalid header")
assert len(header) == header_length
self._index += len(header)

def _write_record(self, s, dtype=None):
Expand All @@ -179,7 +207,7 @@ def _write_record(self, s, dtype=None):
# (this is a precondition to calling this method)
assert len(s) + LEVELDBLOG_HEADER_LEN <= (
LEVELDBLOG_BLOCK_LEN - self._index % LEVELDBLOG_BLOCK_LEN
)
), "not enough space to write new records"

dlength = len(s)
dtype = dtype or LEVELDBLOG_FULL
Expand Down Expand Up @@ -220,7 +248,7 @@ def _write_data(self, s):
self._write_record(s[:data_room], LEVELDBLOG_FIRST)
data_used += data_room
data_left -= data_room
assert data_left
assert data_left, "data_left should be non-zero"

# write middles (if any)
while data_left > LEVELDBLOG_DATA_LEN:
Expand All @@ -231,8 +259,10 @@ def _write_data(self, s):
data_used += LEVELDBLOG_DATA_LEN
data_left -= LEVELDBLOG_DATA_LEN

# write last
# write last and flush the entire block to disk
self._write_record(s[data_used:], LEVELDBLOG_LAST)
self._fp.flush()
os.fsync(self._fp.fileno())

return file_offset, self._index - file_offset, flush_index, flush_offset

Expand All @@ -249,7 +279,7 @@ def write(self, obj):
"""
raw_size = obj.ByteSize()
s = obj.SerializeToString()
assert len(s) == raw_size
assert len(s) == raw_size, "invalid serialization"
ret = self._write_data(s)
return ret

Expand Down

0 comments on commit 68bdfc0

Please sign in to comment.