This repository has been archived by the owner on May 22, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 44
/
publish.py
140 lines (126 loc) · 4.53 KB
/
publish.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import json
import logging
import os
import time
import uuid
import asdf
from clint.textui import progress
from dateutil.parser import parse as parse_datetime
from google.cloud.storage import Client
import requests
from ast2vec.meta import extract_index_meta
from ast2vec.model import Model
class FileReadTracker:
"""
Wrapper around Python fileobj which records the file position and updates
the console progressbar.
"""
def __init__(self, file, logger):
self._file = file
self._position = 0
file.seek(0, 2)
self._size = file.tell()
self._enabled = logger.isEnabledFor(logging.INFO)
if self._enabled:
self._progress = progress.Bar(expected_size=self._size)
file.seek(0)
@property
def size(self):
return self._size
def read(self, size=None):
result = self._file.read(size)
self._position += len(result)
if self._enabled:
self._progress.show(self._position)
return result
def tell(self):
return self._position
def done(self):
if self._enabled:
self._progress.done()
def publish_model(args):
"""
Pushes the model to Google Cloud Storage and updates the index file.
:param args: :class:`argparse.Namespace` with "model", "gcs" and "force".
:return: None if successful, 1 otherwise.
"""
log = logging.getLogger("publish")
log.info("Reading %s...", os.path.abspath(args.model))
tree = asdf.open(args.model).tree
meta = tree["meta"]
log.info("Locking the bucket...")
transaction = uuid.uuid4().hex.encode()
if args.credentials:
client = Client.from_service_account_json(args.credentials)
else:
client = Client()
bucket = client.get_bucket(args.gcs)
sentinel = bucket.blob("index.lock")
locked = False
while not locked:
while sentinel.exists():
log.warning("Failed to acquire the lock, waiting...")
time.sleep(1)
# At this step, several agents may think the lockfile does not exist
try:
sentinel.upload_from_string(transaction)
# Only one agent succeeds to check this condition
locked = sentinel.download_as_string() == transaction
except:
# GCS detects the changed-while-reading collision
log.warning("Failed to acquire the lock, retrying...")
try:
blob = bucket.blob("models/%s/%s.asdf" % (meta["model"], meta["uuid"]))
if blob.exists() and not args.force:
log.error("Model %s already exists, aborted.", meta["uuid"])
return 1
log.info("Uploading %s from %s...", meta["model"],
os.path.abspath(args.model))
with open(args.model, "rb") as fin:
tracker = FileReadTracker(fin, log)
try:
blob.upload_from_file(
tracker, content_type="application/x-yaml",
size=tracker.size)
finally:
tracker.done()
blob.make_public()
model_url = blob.public_url
log.info("Uploaded as %s", model_url)
log.info("Updating the models index...")
blob = bucket.get_blob(Model.INDEX_FILE)
index = json.loads(blob.download_as_string().decode("utf-8"))
index["models"].setdefault(meta["model"], {})[meta["uuid"]] = \
extract_index_meta(meta, model_url)
if args.update_default:
index["models"][meta["model"]][Model.DEFAULT_NAME] = meta["uuid"]
blob.upload_from_string(json.dumps(index, indent=4, sort_keys=True))
blob.make_public()
finally:
sentinel.delete()
def list_models(args):
"""
Outputs the list of known models in the registry.
:param args: :class:`argparse.Namespace` with "gcs".
:return: None
"""
r = requests.get(Model.compose_index_url(args.gcs), stream=True)
content = r.content.decode("utf-8")
try:
index = json.loads(content)
except json.decoder.JSONDecodeError:
print(content)
return 1
for key, val in index["models"].items():
print(key)
default = None
for mid, meta in val.items():
if mid == "default":
default = meta
break
for mid, meta in sorted(
[m for m in val.items() if m[1] != default],
key=lambda m: parse_datetime(m[1]["created_at"]),
reverse=True):
print(" %s %s" % ("*" if default == mid else " ", mid),
meta["created_at"])