Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLI-881][CLI-880][CLI-451] Improve wandb sync to handle errors #2199

Merged
merged 19 commits into from Jun 4, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/settings.json
Expand Up @@ -6,7 +6,7 @@
},
"git.ignoreLimitWarning": true,

"editor.formatOnSave": false,
"editor.formatOnSave": true,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was getting sick of forgetting to format my docs with black. I know there are a few files that aren't formatted but I figure the vast majority are and you can always save without formatting if needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But make sure you are pinned to the right version of black. otherwise you will get problems. but im fine if other people dont mind.


"python.linting.enabled": true,
"python.linting.flake8Enabled": true,
Expand Down
29 changes: 24 additions & 5 deletions tests/fixtures/train.py
Expand Up @@ -2,21 +2,40 @@
import time
import random
import wandb
import numpy as np
import os
import signal
import sys

parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=2)
parser.add_argument("--epochs", type=int, default=2)
parser.add_argument("--heavy", action="store_true", default=False)
parser.add_argument("--sleep_every", type=int, default=0)
args = parser.parse_args()
print("Calling init")
print("Calling init with args: {}", format(args))
print("Environ: {}".format({k: v for k, v in os.environ.items() if k.startswith("WANDB")}))
wandb.init(config=args)
print("Init called")
print("Init called with config {}".format(wandb.config))

#raise ValueError()
#os.kill(os.getpid(), signal.SIGINT)
# raise ValueError()
# os.kill(os.getpid(), signal.SIGINT)
for i in range(0, wandb.config.epochs):
loss = random.uniform(0, wandb.config.epochs - i)
print("loss: %s" % loss)
wandb.log({"loss": loss}, commit=False)
if wandb.config.heavy:
for x in range(50):
wandb.log(
{
"hist_{}".format(x): wandb.Histogram(
np.random.randint(255, size=(1000))
)
},
commit=False,
)
wandb.log({"cool": True})
if wandb.config.sleep_every > 0 and i % wandb.config.sleep_every == 0:
print("sleeping")
time.sleep(random.random() + 1)
sys.stdout.flush()
print("Finished")
71 changes: 71 additions & 0 deletions tests/test_offline_sync.py
@@ -0,0 +1,71 @@
import os
import subprocess
import sys
import time
import glob
from .utils import fixture_open


def test_sync_in_progress(live_mock_server, test_dir):
with open("train.py", "w") as f:
f.write(fixture_open("train.py").read())
env = dict(os.environ)
env["WANDB_MODE"] = "offline"
env["WANDB_DIR"] = test_dir
env["WANDB_CONSOLE"] = "off"
stdout = open("stdout.log", "w+")
offline_run = subprocess.Popen(
[
sys.executable,
"train.py",
"--epochs",
"50",
"--sleep_every",
"15",
"--heavy",
],
stdout=stdout,
stderr=subprocess.STDOUT,
bufsize=1,
close_fds=True,
env=env,
)
attempts = 0
latest_run = os.path.join(test_dir, "wandb", "latest-run")
while not os.path.exists(latest_run) and attempts < 100:
time.sleep(0.1)
attempts += 1
if attempts == 100:
print("cur dir contents: ", os.listdir(test_dir))
print("wandb dir contents: ", os.listdir(os.path.join(test_dir, "wandb")))
stdout.seek(0)
print("STDOUT")
print(stdout.read())
debug = os.path.join("wandb", "debug.log")
debug_int = os.path.join("wandb", "debug-internal.log")
if os.path.exists(debug):
print("DEBUG")
print(open(debug).read())
if os.path.exists(debug_int):
print("DEBUG INTERNAL")
print(open(debug).read())
assert False, "train.py failed to launch :("
else:
"Starting live syncing after {} seconds".format(attempts * 0.1)
sync_file = ".wandb"
for i in range(3):
# Generally, the first sync will fail because the .wandb file is empty
sync = subprocess.Popen(["wandb", "sync", latest_run], env=os.environ)
assert sync.wait() == 0
matches = glob.glob(os.path.join(latest_run, "*.wandb"))
if len(matches) > 0:
sync_file = matches[0]
# Only confirm we don't have a .synced file if our offline run is still running
if offline_run.poll() is None:
assert not os.path.exists(os.path.join(sync_file + ".synced"))
assert offline_run.wait() == 0
sync = subprocess.Popen(["wandb", "sync", latest_run], env=os.environ)
assert sync.wait() == 0
assert os.path.exists(os.path.join(sync_file + ".synced"))
print("Number of upserts: ", live_mock_server.get_ctx()["upsert_bucket_count"])
assert live_mock_server.get_ctx()["upsert_bucket_count"] >= 3
60 changes: 45 additions & 15 deletions wandb/sdk/internal/datastore.py
Expand Up @@ -65,12 +65,15 @@ def __init__(self):
self._opened_for_scan = False
self._fp = None
self._index = 0
self._size_bytes = 0

self._crc = [0] * (LEVELDBLOG_LAST + 1)
for x in range(1, LEVELDBLOG_LAST + 1):
self._crc[x] = zlib.crc32(strtobytes(chr(x))) & 0xFFFFFFFF

assert wandb._assert_is_internal_process
assert (
wandb._assert_is_internal_process
), "DataStore can only be used in the internal process"

def open_for_write(self, fname):
self._fname = fname
Expand All @@ -95,24 +98,36 @@ def open_for_scan(self, fname):
logger.info("open for scan: %s", fname)
self._fp = open(fname, "rb")
self._index = 0
self._size_bytes = os.stat(fname).st_size
self._opened_for_scan = True
self._read_header()

def in_last_block(self):
"""When reading, we want to know if we're in the last block to
handle in progress writes"""
return self._index > self._size_bytes - LEVELDBLOG_DATA_LEN

def scan_record(self):
assert self._opened_for_scan
assert self._opened_for_scan, "file not open for scanning"
# TODO(jhr): handle some assertions as file corruption issues
# assume we have enough room to read header, checked by caller?
header = self._fp.read(LEVELDBLOG_HEADER_LEN)
if len(header) == 0:
return None
assert len(header) == LEVELDBLOG_HEADER_LEN
assert (
len(header) == LEVELDBLOG_HEADER_LEN
), "record header is {} bytes instead of the expected {}".format(
len(header), LEVELDBLOG_HEADER_LEN
)
fields = struct.unpack("<IHB", header)
checksum, dlength, dtype = fields
# check len, better fit in the block
self._index += LEVELDBLOG_HEADER_LEN
data = self._fp.read(dlength)
checksum_computed = zlib.crc32(data, self._crc[dtype]) & 0xFFFFFFFF
assert checksum == checksum_computed
assert (
checksum == checksum_computed
), "record checksum is invalid, data may be corrupt"
self._index += dlength
return dtype, data

Expand All @@ -125,7 +140,7 @@ def scan_data(self):
pad_check = strtobytes("\x00" * space_left)
pad = self._fp.read(space_left)
# verify they are zero
assert pad == pad_check
assert pad == pad_check, "invald padding"
self._index += space_left

record = self.scan_record()
Expand All @@ -135,7 +150,9 @@ def scan_data(self):
if dtype == LEVELDBLOG_FULL:
return data

assert dtype == LEVELDBLOG_FIRST
assert (
dtype == LEVELDBLOG_FIRST
), "expected record to be type {} but found {}".format(LEVELDBLOG_FIRST, dtype)
while True:
offset = self._index % LEVELDBLOG_BLOCK_LEN
record = self.scan_record()
Expand All @@ -145,7 +162,11 @@ def scan_data(self):
if dtype == LEVELDBLOG_LAST:
data += new_data
break
assert dtype == LEVELDBLOG_MIDDLE
assert (
dtype == LEVELDBLOG_MIDDLE
), "expected record to be type {} but found {}".format(
LEVELDBLOG_MIDDLE, dtype
)
data += new_data
return data

Expand All @@ -156,21 +177,28 @@ def _write_header(self):
LEVELDBLOG_HEADER_MAGIC,
LEVELDBLOG_HEADER_VERSION,
)
assert len(data) == 7
assert (
len(data) == LEVELDBLOG_HEADER_LEN
), "header size is {} bytes, expected {}".format(
len(data), LEVELDBLOG_HEADER_LEN
)
self._fp.write(data)
self._index += len(data)

def _read_header(self):
header_length = 7
header = self._fp.read(header_length)
header = self._fp.read(LEVELDBLOG_HEADER_LEN)
assert (
len(header) == LEVELDBLOG_HEADER_LEN
), "header is {} bytes instead of the expected {}".format(
len(header), LEVELDBLOG_HEADER_LEN
)
ident, magic, version = struct.unpack("<4sHB", header)
if ident != strtobytes(LEVELDBLOG_HEADER_IDENT):
raise Exception("Invalid header")
if magic != LEVELDBLOG_HEADER_MAGIC:
raise Exception("Invalid header")
if version != LEVELDBLOG_HEADER_VERSION:
raise Exception("Invalid header")
assert len(header) == header_length
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was one bug where we didn't verify we had a valid header before attempt to unpack it.

self._index += len(header)

def _write_record(self, s, dtype=None):
Expand All @@ -179,7 +207,7 @@ def _write_record(self, s, dtype=None):
# (this is a precondition to calling this method)
assert len(s) + LEVELDBLOG_HEADER_LEN <= (
LEVELDBLOG_BLOCK_LEN - self._index % LEVELDBLOG_BLOCK_LEN
)
), "not enough space to write new records"

dlength = len(s)
dtype = dtype or LEVELDBLOG_FULL
Expand Down Expand Up @@ -220,7 +248,7 @@ def _write_data(self, s):
self._write_record(s[:data_room], LEVELDBLOG_FIRST)
data_used += data_room
data_left -= data_room
assert data_left
assert data_left, "data_left should be non-zero"

# write middles (if any)
while data_left > LEVELDBLOG_DATA_LEN:
Expand All @@ -231,8 +259,10 @@ def _write_data(self, s):
data_used += LEVELDBLOG_DATA_LEN
data_left -= LEVELDBLOG_DATA_LEN

# write last
# write last and flush the entire block to disk
self._write_record(s[data_used:], LEVELDBLOG_LAST)
self._fp.flush()
os.fsync(self._fp.fileno())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raubitsj not positive this is the right thing to do, but figured we should at least flush every block to disk explicitly. I don't think this actually fixes any issues with syncing and I can take it out if you think it will have negative performance consequences.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isnt quite doing what you think it is doing.
I had planned on making flushes happen time based that is why i left it out. Writing on every block is fine, but this is only syncing on a record that spanned multiple blocks. It will be fine, but it wont actually guarantee that things are flushed regularly -- but in most cases it will happen.
lets keep it for now if it was tested to work ok.


return file_offset, self._index - file_offset, flush_index, flush_offset

Expand All @@ -249,7 +279,7 @@ def write(self, obj):
"""
raw_size = obj.ByteSize()
s = obj.SerializeToString()
assert len(s) == raw_size
assert len(s) == raw_size, "invalid serialization"
ret = self._write_data(s)
return ret

Expand Down