Skip to content

Commit

Permalink
working prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
hhsecond committed Dec 2, 2019
1 parent ac4a4a4 commit 391e143
Show file tree
Hide file tree
Showing 17 changed files with 428 additions and 2 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,8 @@ venv.bak/

# mypy
.mypy_cache/

# pycharm
.idea

.hangar
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# almacen
A model and data version toolkit
# StockRoom
A platform to version models, data, parameters, metrics etc alongside git versioned source code.
Althouh it is built as a high level API kit for [hangar](https://github.com/tensorwerk/hangar-py) and comes as part of hangar itself, user doesn't need to know any founding philosophy of hangar work with stockroom unless you need fine grained control
26 changes: 26 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import setuptools

with open("README.md", "r") as fh:
long_description = fh.read()

setuptools.setup(
name="stockroom",
version="0.0.1",
author="hhsecond",
author_email="sherin@tensorwerk.com",
description="A high level data and model versioning toolkit sits on top of hangar",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/hhsecond/stockroom",
packages=setuptools.find_packages(),
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
python_requires='>=3.6',
entry_points={
'console_scripts': ['stock = stockroom.cli:main']
},
install_requires=['hangar']
)
10 changes: 10 additions & 0 deletions stockroom/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from stockroom.storages import DataStore
from stockroom.storages import ModelStore
from stockroom.storages import ParamStore
from stockroom.storages import MetricStore
from .repository import init, commit


# TODO: Simplify APIs by not making users initiate a storage class each time

__all__ = ['DataStore', 'ModelStore', 'ParamStore', 'MetricStore', 'init', 'commit']
71 changes: 71 additions & 0 deletions stockroom/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from pathlib import Path
import click
from hangar import Repository
from . import repository


# TODO: move repetative code in hangar and here to a common function
pass_repo = click.make_pass_decorator(Repository, ensure=True)


@click.group(no_args_is_help=True, add_help_option=True, invoke_without_command=True)
@click.pass_context
def main(ctx):
cwd = Path.cwd()
ctx.obj = Repository(path=cwd, exists=False)


@main.command()
@click.option('--message', '-m', multiple=True,
help=('The commit message. If provided multiple times '
'each argument gets converted into a new line.'))
@pass_repo
def commit(repo: Repository, message):
"""Commits outstanding changes.
Commit changes to the given files into the repository. You will need to
'push' to push up your changes to other repositories.
"""
from hangar.records.summarize import status
if not message:
with repo.checkout(write=True) as co:
diff = co.diff.staged()
status_txt = status(co.branch_name, diff.diff)
status_txt.seek(0)
marker = '# Changes To Be committed: \n'
hint = ['\n', '\n', marker, '# \n']
for line in status_txt.readlines():
hint.append(f'# {line}')
# open default system editor
message = click.edit(''.join(hint))
if message is None:
click.echo('Aborted!')
return
msg = message.split(marker)[0].rstrip()
if not msg:
click.echo('Aborted! Empty commit message')
return
# TODO: should be done in the __exit__ of hangar checkout
co.close()
else:
msg = '\n'.join(message)
click.echo('Commit message:\n' + msg)
try:
digest = repository.commit(message)
except (FileNotFoundError, RuntimeError) as e:
raise click.ClickException(e)
click.echo(f'Commit Successful. Digest: {digest}')


@main.command()
@click.option('--name', prompt='User Name', help='first and last name of user')
@click.option('--email', prompt='User Email', help='email address of the user')
@click.option('--overwrite', is_flag=True, default=False,
help='overwrite a repository if it exists at the current path')
def init(name, email, overwrite):
try:
repository.init(name, email, overwrite)
except RuntimeError as e:
raise click.ClickException(e)


35 changes: 35 additions & 0 deletions stockroom/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# TODO: best practices like utf8
# TODO: is this separater enough
SEP = '--_'
PREFIX = '_STOCK'


def metakey(model, name):
return f"{PREFIX}_metakey_{model}_{name}"


def model_asetkey_from_details(*args):
# TODO: make more reliable hash rather than time.time()
asetkey = f"{PREFIX}{SEP}"
return asetkey + SEP.join(args)


def shape_asetkey_from_model_asetkey(model_asetkey):
return model_asetkey + '_shape'


# TODO: move this somewhere more sensib
def layers_to_string(layers):
return ','.join(layers)


def string_to_layers(string):
return string.split(',')


def dtypes_to_string(dtypes):
return ','.join(dtypes)


def string_to_dtypes(string):
return string.split(',')
46 changes: 46 additions & 0 deletions stockroom/repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from pathlib import Path
from hangar import Repository
from .utils import get_stock_root


def init(name, email, overwrite):
""" init hangar repo, create stock file and add details to .gitignore """
if not Path.cwd().joinpath('.git').exists():
raise RuntimeError("stock init should execute only in a"
" git repository. Try running stock "
"init after git init")
repo = Repository(Path.cwd(), exists=False)
if repo.initialized and (not overwrite):
commit_hash = repo.log(return_contents=True)['head']
print(f'Repo already exists at: {repo.path}')
else:
commit_hash = ''
repo.init(user_name=name, user_email=email, remove_old=overwrite)

stock_file = Path.cwd().joinpath('head.stock')
if not stock_file.exists():
with open(stock_file, 'w+') as f:
f.write(commit_hash)
print("Stock file created")

gitignore = Path.cwd().joinpath('.gitignore')
# TODO make sure this creates the file when file doesn't exist
with open(gitignore, 'a+') as f:
f.seek(0)
if '.hangar' not in f.read():
f.write('\n.hangar\n')


def commit(message):
repo = Repository(Path.cwd())
with repo.checkout(write=True) as co:
root = get_stock_root()
if not root:
raise FileNotFoundError("Could not find stock file. Aborting..")
digest = co.commit(message)
with open(root.joinpath('head.stock'), 'w') as f:
f.write(digest)
# TODO: print message about file write as well
# TODO: should be done in the __exit__ of hangar checkout
co.close()
return digest
4 changes: 4 additions & 0 deletions stockroom/storages/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .datastore import DataStore
from .modelstore import ModelStore
from .metricstore import MetricStore
from .paramstore import ParamStore
32 changes: 32 additions & 0 deletions stockroom/storages/datastore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from .storagebase import StorageBase
from ..utils import get_current_head, get_stock_root


class DataStore(StorageBase):
def __init__(self):
super().__init__()

def __getitem__(self, item):
root = get_stock_root()
dset = self.repo.checkout(commit=get_current_head(root))
# TODO: rigorous check like in hangar
if isinstance(item, tuple):
aset = item[0]
index = item[1]
return dset[aset, index]
return dset[item]

def __setitem__(self, item, value):
# TODO: optimized set item like context manager
dset = self.repo.checkout(write=True)
# TODO: rigorous check like in hangar
if isinstance(item, tuple):
aset = item[0]
index = item[1]
dset[aset, index] = value
dset[item] = value # this will raise error downstream
dset.close()




2 changes: 2 additions & 0 deletions stockroom/storages/metricstore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class MetricStore:
pass
97 changes: 97 additions & 0 deletions stockroom/storages/modelstore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import numpy as np

from .storagebase import StorageBase
from ..utils import get_current_head, get_stock_root
from .. import parser


def get_aset(co, name, dtype=None, longest=None, variable=False):
try:
aset = co.arraysets[name]
return aset
except KeyError:
pass
aset = co.arraysets.init_arrayset(
name, dtype=np.dtype(dtype), shape=(longest,), variable_shape=variable)
return aset

# TODO: figure out what' the importance of max shape if var_shape is True


class ModelStore(StorageBase):
def __init__(self):
super().__init__()

def save(self, name, model):
# TODO: optimize
co = self.repo.checkout(write=True)
if hasattr(model, 'state_dict'):
library = 'torch'
state = model.state_dict()
layers = list(state.keys())
# TODO: forloop for all needs or list comprehension few times
weights = [x.numpy() for x in state.values()]
str_layer = parser.layers_to_string(layers)
co.metadata[parser.metakey(name, 'layers')] = str_layer
elif hasattr(model, 'get_weights'):
library = 'tf'
# tf model
weights = model.get_weights()
else:
raise TypeError("Unknown model type. StockRoom can work with only "
"``Keras.Model`` or ``torch.nn.Module`` modules")
longest = max([len(x.reshape(-1)) for x in weights])
co.metadata[parser.metakey(name, 'library')] = library
co.metadata[parser.metakey(name, 'longest')] = str(longest)
co.metadata[parser.metakey(name, 'num_layers')] = str(len(weights))
dtypes = [w.dtype.name for w in weights]
str_dtypes = parser.dtypes_to_string(dtypes)
co.metadata[parser.metakey(name, 'dtypes')] = str_dtypes
aset_prefix = parser.model_asetkey_from_details(name, str(longest))
co.metadata[parser.metakey(name, 'aset_prefix')] = aset_prefix
shape_asetn = parser.shape_asetkey_from_model_asetkey(name)
shape_aset = co.arraysets.init_arrayset(
shape_asetn, shape=(10,), dtype=np.int64, variable_shape=True)
for i, w in enumerate(weights):
asetn = parser.model_asetkey_from_details(aset_prefix, dtypes[i])
aset = get_aset(co, asetn, dtypes[i], longest, variable=True)
aset[i] = w.reshape(-1)
if w.shape:
shape_aset[i] = np.array(w.shape)
else:
shape_aset[i] = np.array(()).astype('int64')
co.close()

def load(self, name, model):
import torch
root = get_stock_root()
head_commit = get_current_head(root)
co = self.repo.checkout(commit=head_commit)
aset_prefix = co.metadata[parser.metakey(name, 'aset_prefix')]
dtypes = parser.string_to_dtypes(co.metadata[parser.metakey(name, 'dtypes')])
library = co.metadata[parser.metakey(name, 'library')]
num_layers = int(co.metadata[parser.metakey(name, 'num_layers')])
weights = []
for i in range(num_layers):
asetn = parser.model_asetkey_from_details(aset_prefix, dtypes[i])
aset = get_aset(co, asetn)
shape_asetn = parser.shape_asetkey_from_model_asetkey(name)
shape_aset = co.arraysets[shape_asetn]
w = aset[i].reshape(np.array(shape_aset[i]))
weights.append(w)
if len(weights) != num_layers:
raise RuntimeError("Critical: length doesn't match. Raise an issue")
if library == 'torch':
str_layers = co.metadata[parser.metakey(name, 'layers')]
layers = parser.string_to_layers(str_layers)
if len(layers) != num_layers:
raise RuntimeError("Critical: length doesn't match. Raise an issue")
state = {layers[i]: torch.from_numpy(weights[i]) for i in range(num_layers)}
model.load_state_dict(state)
else:
model.set_weights(weights)





2 changes: 2 additions & 0 deletions stockroom/storages/paramstore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class ParamStore:
pass
14 changes: 14 additions & 0 deletions stockroom/storages/storagebase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from hangar import Repository
from ..utils import get_stock_root


class StorageBase(object):

def __init__(self):
if not hasattr(StorageBase, 'repo'):
root = get_stock_root()
if root is None:
raise RuntimeError("Could not find the stock root. "
"Did you forget to `stock init`?")
StorageBase.root = root
StorageBase.repo = Repository(root)
31 changes: 31 additions & 0 deletions stockroom/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pathlib import Path


def get_stock_root():
# TODO: would CWD work always
path = Path.cwd()
while True:
stock_exist = path.joinpath('head.stock').exists()
if stock_exist:
hangar_exist = path.joinpath('.hangar').exists()
git_exist = path.joinpath('.git').exists()
if not hangar_exist and not git_exist:
raise RuntimeError("Stock root should be the root of git and"
"hangar repository")
return path
if path == path.parent: # system root check
return None
path = path.parent


def get_current_head(root: Path):
head = root.joinpath('head.stock')
with open(head, 'r') as f:
commit = f.read()
return commit if commit else ''


def set_current_head(root: Path, commit: str):
head = root.joinpath('head.stock')
with open(head, 'w+') as f:
f.write(commit)

0 comments on commit 391e143

Please sign in to comment.