-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
hhsecond
committed
Dec 2, 2019
1 parent
ac4a4a4
commit 391e143
Showing
17 changed files
with
428 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -102,3 +102,8 @@ venv.bak/ | |
|
||
# mypy | ||
.mypy_cache/ | ||
|
||
# pycharm | ||
.idea | ||
|
||
.hangar |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
# almacen | ||
A model and data version toolkit | ||
# StockRoom | ||
A platform to version models, data, parameters, metrics etc alongside git versioned source code. | ||
Althouh it is built as a high level API kit for [hangar](https://github.com/tensorwerk/hangar-py) and comes as part of hangar itself, user doesn't need to know any founding philosophy of hangar work with stockroom unless you need fine grained control |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import setuptools | ||
|
||
with open("README.md", "r") as fh: | ||
long_description = fh.read() | ||
|
||
setuptools.setup( | ||
name="stockroom", | ||
version="0.0.1", | ||
author="hhsecond", | ||
author_email="sherin@tensorwerk.com", | ||
description="A high level data and model versioning toolkit sits on top of hangar", | ||
long_description=long_description, | ||
long_description_content_type="text/markdown", | ||
url="https://github.com/hhsecond/stockroom", | ||
packages=setuptools.find_packages(), | ||
classifiers=[ | ||
"Programming Language :: Python :: 3", | ||
"License :: OSI Approved :: MIT License", | ||
"Operating System :: OS Independent", | ||
], | ||
python_requires='>=3.6', | ||
entry_points={ | ||
'console_scripts': ['stock = stockroom.cli:main'] | ||
}, | ||
install_requires=['hangar'] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from stockroom.storages import DataStore | ||
from stockroom.storages import ModelStore | ||
from stockroom.storages import ParamStore | ||
from stockroom.storages import MetricStore | ||
from .repository import init, commit | ||
|
||
|
||
# TODO: Simplify APIs by not making users initiate a storage class each time | ||
|
||
__all__ = ['DataStore', 'ModelStore', 'ParamStore', 'MetricStore', 'init', 'commit'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from pathlib import Path | ||
import click | ||
from hangar import Repository | ||
from . import repository | ||
|
||
|
||
# TODO: move repetative code in hangar and here to a common function | ||
pass_repo = click.make_pass_decorator(Repository, ensure=True) | ||
|
||
|
||
@click.group(no_args_is_help=True, add_help_option=True, invoke_without_command=True) | ||
@click.pass_context | ||
def main(ctx): | ||
cwd = Path.cwd() | ||
ctx.obj = Repository(path=cwd, exists=False) | ||
|
||
|
||
@main.command() | ||
@click.option('--message', '-m', multiple=True, | ||
help=('The commit message. If provided multiple times ' | ||
'each argument gets converted into a new line.')) | ||
@pass_repo | ||
def commit(repo: Repository, message): | ||
"""Commits outstanding changes. | ||
Commit changes to the given files into the repository. You will need to | ||
'push' to push up your changes to other repositories. | ||
""" | ||
from hangar.records.summarize import status | ||
if not message: | ||
with repo.checkout(write=True) as co: | ||
diff = co.diff.staged() | ||
status_txt = status(co.branch_name, diff.diff) | ||
status_txt.seek(0) | ||
marker = '# Changes To Be committed: \n' | ||
hint = ['\n', '\n', marker, '# \n'] | ||
for line in status_txt.readlines(): | ||
hint.append(f'# {line}') | ||
# open default system editor | ||
message = click.edit(''.join(hint)) | ||
if message is None: | ||
click.echo('Aborted!') | ||
return | ||
msg = message.split(marker)[0].rstrip() | ||
if not msg: | ||
click.echo('Aborted! Empty commit message') | ||
return | ||
# TODO: should be done in the __exit__ of hangar checkout | ||
co.close() | ||
else: | ||
msg = '\n'.join(message) | ||
click.echo('Commit message:\n' + msg) | ||
try: | ||
digest = repository.commit(message) | ||
except (FileNotFoundError, RuntimeError) as e: | ||
raise click.ClickException(e) | ||
click.echo(f'Commit Successful. Digest: {digest}') | ||
|
||
|
||
@main.command() | ||
@click.option('--name', prompt='User Name', help='first and last name of user') | ||
@click.option('--email', prompt='User Email', help='email address of the user') | ||
@click.option('--overwrite', is_flag=True, default=False, | ||
help='overwrite a repository if it exists at the current path') | ||
def init(name, email, overwrite): | ||
try: | ||
repository.init(name, email, overwrite) | ||
except RuntimeError as e: | ||
raise click.ClickException(e) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# TODO: best practices like utf8 | ||
# TODO: is this separater enough | ||
SEP = '--_' | ||
PREFIX = '_STOCK' | ||
|
||
|
||
def metakey(model, name): | ||
return f"{PREFIX}_metakey_{model}_{name}" | ||
|
||
|
||
def model_asetkey_from_details(*args): | ||
# TODO: make more reliable hash rather than time.time() | ||
asetkey = f"{PREFIX}{SEP}" | ||
return asetkey + SEP.join(args) | ||
|
||
|
||
def shape_asetkey_from_model_asetkey(model_asetkey): | ||
return model_asetkey + '_shape' | ||
|
||
|
||
# TODO: move this somewhere more sensib | ||
def layers_to_string(layers): | ||
return ','.join(layers) | ||
|
||
|
||
def string_to_layers(string): | ||
return string.split(',') | ||
|
||
|
||
def dtypes_to_string(dtypes): | ||
return ','.join(dtypes) | ||
|
||
|
||
def string_to_dtypes(string): | ||
return string.split(',') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from pathlib import Path | ||
from hangar import Repository | ||
from .utils import get_stock_root | ||
|
||
|
||
def init(name, email, overwrite): | ||
""" init hangar repo, create stock file and add details to .gitignore """ | ||
if not Path.cwd().joinpath('.git').exists(): | ||
raise RuntimeError("stock init should execute only in a" | ||
" git repository. Try running stock " | ||
"init after git init") | ||
repo = Repository(Path.cwd(), exists=False) | ||
if repo.initialized and (not overwrite): | ||
commit_hash = repo.log(return_contents=True)['head'] | ||
print(f'Repo already exists at: {repo.path}') | ||
else: | ||
commit_hash = '' | ||
repo.init(user_name=name, user_email=email, remove_old=overwrite) | ||
|
||
stock_file = Path.cwd().joinpath('head.stock') | ||
if not stock_file.exists(): | ||
with open(stock_file, 'w+') as f: | ||
f.write(commit_hash) | ||
print("Stock file created") | ||
|
||
gitignore = Path.cwd().joinpath('.gitignore') | ||
# TODO make sure this creates the file when file doesn't exist | ||
with open(gitignore, 'a+') as f: | ||
f.seek(0) | ||
if '.hangar' not in f.read(): | ||
f.write('\n.hangar\n') | ||
|
||
|
||
def commit(message): | ||
repo = Repository(Path.cwd()) | ||
with repo.checkout(write=True) as co: | ||
root = get_stock_root() | ||
if not root: | ||
raise FileNotFoundError("Could not find stock file. Aborting..") | ||
digest = co.commit(message) | ||
with open(root.joinpath('head.stock'), 'w') as f: | ||
f.write(digest) | ||
# TODO: print message about file write as well | ||
# TODO: should be done in the __exit__ of hangar checkout | ||
co.close() | ||
return digest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .datastore import DataStore | ||
from .modelstore import ModelStore | ||
from .metricstore import MetricStore | ||
from .paramstore import ParamStore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from .storagebase import StorageBase | ||
from ..utils import get_current_head, get_stock_root | ||
|
||
|
||
class DataStore(StorageBase): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def __getitem__(self, item): | ||
root = get_stock_root() | ||
dset = self.repo.checkout(commit=get_current_head(root)) | ||
# TODO: rigorous check like in hangar | ||
if isinstance(item, tuple): | ||
aset = item[0] | ||
index = item[1] | ||
return dset[aset, index] | ||
return dset[item] | ||
|
||
def __setitem__(self, item, value): | ||
# TODO: optimized set item like context manager | ||
dset = self.repo.checkout(write=True) | ||
# TODO: rigorous check like in hangar | ||
if isinstance(item, tuple): | ||
aset = item[0] | ||
index = item[1] | ||
dset[aset, index] = value | ||
dset[item] = value # this will raise error downstream | ||
dset.close() | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
class MetricStore: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import numpy as np | ||
|
||
from .storagebase import StorageBase | ||
from ..utils import get_current_head, get_stock_root | ||
from .. import parser | ||
|
||
|
||
def get_aset(co, name, dtype=None, longest=None, variable=False): | ||
try: | ||
aset = co.arraysets[name] | ||
return aset | ||
except KeyError: | ||
pass | ||
aset = co.arraysets.init_arrayset( | ||
name, dtype=np.dtype(dtype), shape=(longest,), variable_shape=variable) | ||
return aset | ||
|
||
# TODO: figure out what' the importance of max shape if var_shape is True | ||
|
||
|
||
class ModelStore(StorageBase): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def save(self, name, model): | ||
# TODO: optimize | ||
co = self.repo.checkout(write=True) | ||
if hasattr(model, 'state_dict'): | ||
library = 'torch' | ||
state = model.state_dict() | ||
layers = list(state.keys()) | ||
# TODO: forloop for all needs or list comprehension few times | ||
weights = [x.numpy() for x in state.values()] | ||
str_layer = parser.layers_to_string(layers) | ||
co.metadata[parser.metakey(name, 'layers')] = str_layer | ||
elif hasattr(model, 'get_weights'): | ||
library = 'tf' | ||
# tf model | ||
weights = model.get_weights() | ||
else: | ||
raise TypeError("Unknown model type. StockRoom can work with only " | ||
"``Keras.Model`` or ``torch.nn.Module`` modules") | ||
longest = max([len(x.reshape(-1)) for x in weights]) | ||
co.metadata[parser.metakey(name, 'library')] = library | ||
co.metadata[parser.metakey(name, 'longest')] = str(longest) | ||
co.metadata[parser.metakey(name, 'num_layers')] = str(len(weights)) | ||
dtypes = [w.dtype.name for w in weights] | ||
str_dtypes = parser.dtypes_to_string(dtypes) | ||
co.metadata[parser.metakey(name, 'dtypes')] = str_dtypes | ||
aset_prefix = parser.model_asetkey_from_details(name, str(longest)) | ||
co.metadata[parser.metakey(name, 'aset_prefix')] = aset_prefix | ||
shape_asetn = parser.shape_asetkey_from_model_asetkey(name) | ||
shape_aset = co.arraysets.init_arrayset( | ||
shape_asetn, shape=(10,), dtype=np.int64, variable_shape=True) | ||
for i, w in enumerate(weights): | ||
asetn = parser.model_asetkey_from_details(aset_prefix, dtypes[i]) | ||
aset = get_aset(co, asetn, dtypes[i], longest, variable=True) | ||
aset[i] = w.reshape(-1) | ||
if w.shape: | ||
shape_aset[i] = np.array(w.shape) | ||
else: | ||
shape_aset[i] = np.array(()).astype('int64') | ||
co.close() | ||
|
||
def load(self, name, model): | ||
import torch | ||
root = get_stock_root() | ||
head_commit = get_current_head(root) | ||
co = self.repo.checkout(commit=head_commit) | ||
aset_prefix = co.metadata[parser.metakey(name, 'aset_prefix')] | ||
dtypes = parser.string_to_dtypes(co.metadata[parser.metakey(name, 'dtypes')]) | ||
library = co.metadata[parser.metakey(name, 'library')] | ||
num_layers = int(co.metadata[parser.metakey(name, 'num_layers')]) | ||
weights = [] | ||
for i in range(num_layers): | ||
asetn = parser.model_asetkey_from_details(aset_prefix, dtypes[i]) | ||
aset = get_aset(co, asetn) | ||
shape_asetn = parser.shape_asetkey_from_model_asetkey(name) | ||
shape_aset = co.arraysets[shape_asetn] | ||
w = aset[i].reshape(np.array(shape_aset[i])) | ||
weights.append(w) | ||
if len(weights) != num_layers: | ||
raise RuntimeError("Critical: length doesn't match. Raise an issue") | ||
if library == 'torch': | ||
str_layers = co.metadata[parser.metakey(name, 'layers')] | ||
layers = parser.string_to_layers(str_layers) | ||
if len(layers) != num_layers: | ||
raise RuntimeError("Critical: length doesn't match. Raise an issue") | ||
state = {layers[i]: torch.from_numpy(weights[i]) for i in range(num_layers)} | ||
model.load_state_dict(state) | ||
else: | ||
model.set_weights(weights) | ||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
class ParamStore: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from hangar import Repository | ||
from ..utils import get_stock_root | ||
|
||
|
||
class StorageBase(object): | ||
|
||
def __init__(self): | ||
if not hasattr(StorageBase, 'repo'): | ||
root = get_stock_root() | ||
if root is None: | ||
raise RuntimeError("Could not find the stock root. " | ||
"Did you forget to `stock init`?") | ||
StorageBase.root = root | ||
StorageBase.repo = Repository(root) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from pathlib import Path | ||
|
||
|
||
def get_stock_root(): | ||
# TODO: would CWD work always | ||
path = Path.cwd() | ||
while True: | ||
stock_exist = path.joinpath('head.stock').exists() | ||
if stock_exist: | ||
hangar_exist = path.joinpath('.hangar').exists() | ||
git_exist = path.joinpath('.git').exists() | ||
if not hangar_exist and not git_exist: | ||
raise RuntimeError("Stock root should be the root of git and" | ||
"hangar repository") | ||
return path | ||
if path == path.parent: # system root check | ||
return None | ||
path = path.parent | ||
|
||
|
||
def get_current_head(root: Path): | ||
head = root.joinpath('head.stock') | ||
with open(head, 'r') as f: | ||
commit = f.read() | ||
return commit if commit else '' | ||
|
||
|
||
def set_current_head(root: Path, commit: str): | ||
head = root.joinpath('head.stock') | ||
with open(head, 'w+') as f: | ||
f.write(commit) |
Oops, something went wrong.