diff --git a/palladium/persistence.py b/palladium/persistence.py index d723a3f..890fa48 100644 --- a/palladium/persistence.py +++ b/palladium/persistence.py @@ -196,7 +196,17 @@ def read(self, version=None): with self.io.open(fname, 'rb') as fh: with gzip.open(fh, 'rb') as f: - return pickle.load(f) + model = pickle.load(f) + + attachments = annotate(model).get('__attachments__', []) + for key in attachments: + fname_attach = self.attach_filename(version=version, key=key) + if self.io.exists(fname_attach): + with open(fname_attach, 'rb') as f: + data_attach = base64.b64encode(f.read()) + annotate(model, {key: data_attach}) + + return model def write(self, model): last_version = 0 @@ -207,11 +217,29 @@ def write(self, model): version = last_version + 1 li.append(annotate(model, {'version': version})) + annotations = annotate(model) + attachments = { + key: data + for key, data in annotations.items() + if key.startswith('attachments/') + } + if attachments: + for key in attachments: + del annotations[key] + annotations['__attachments__'] = tuple(attachments.keys()) + annotate(model, annotations) + fname = self.path.format(version=version) + '.pkl.gz' with self.io.open(fname, 'wb') as fh: with gzip.open(fh, 'wb') as f: pickle.dump(model, f) + if attachments: + for key, data in attachments.items(): + fname_attach = self.attach_filename(version=version, key=key) + with self.io.open(fname_attach, 'wb') as f: + f.write(base64.b64decode(data)) + self._update_md({'models': li}) return version @@ -230,19 +258,33 @@ def activate(self, version): self._update_md({'properties': md['properties']}) def delete(self, version): - md = self._read_md() - versions = [m['version'] for m in md['models']] version = int(version) - if version not in versions: + md = self._read_md() + try: + model_md = [m for m in md['models'] if m['version'] == version][0] + except IndexError: raise LookupError("No such version: {}".format(version)) + self._update_md({ 'models': [m for m in md['models'] if m['version'] != version]}) self.io.remove(self.path.format(version=version) + '.pkl.gz') + attachments = model_md.get('__attachments__', []) + for key in attachments: + fname_attach = self.attach_filename(version=version, key=key) + if self.io.exists(fname_attach): + self.io.remove(fname_attach) + @property def _md_filename(self): return self.path.format(version='metadata') + '.json' + def attach_filename(self, version, key): + return ( + self.path.format(version=version) + + '-{}'.format(key[len('attachments/'):]) + ) + def _read_md(self): if self.io.exists(self._md_filename): with self.io.open(self._md_filename, 'r') as f: diff --git a/palladium/tests/test_persistence.py b/palladium/tests/test_persistence.py index e50f9b3..eec3c0f 100644 --- a/palladium/tests/test_persistence.py +++ b/palladium/tests/test_persistence.py @@ -13,6 +13,8 @@ import requests_mock import pytest +from palladium.interfaces import annotate + class Dummy: def __init__(self, **kwargs): @@ -74,12 +76,14 @@ def test_read(self, File): patch('palladium.persistence.File.list_properties') as lp,\ patch('palladium.persistence.os.path.exists') as exists,\ patch('palladium.persistence.open') as open,\ + patch('palladium.persistence.annotate') as annotate,\ patch('palladium.persistence.gzip.open') as gzopen,\ patch('palladium.persistence.pickle.load') as load: lm.return_value = [{'version': 99}] lp.return_value = {'active-model': '99'} exists.side_effect = lambda fn: fn == '/models/model-99.pkl.gz' open.return_value = MagicMock() + annotate.return_value = {} result = File('/models/model-{version}').read() open.assert_called_with('/models/model-99.pkl.gz', 'rb') assert result == load.return_value @@ -89,11 +93,13 @@ def test_read_with_version(self, File): with patch('palladium.persistence.File.list_models') as lm,\ patch('palladium.persistence.os.path.exists') as exists,\ patch('palladium.persistence.open') as open,\ + patch('palladium.persistence.annotate') as annotate,\ patch('palladium.persistence.gzip.open') as gzopen,\ patch('palladium.persistence.pickle.load') as load: lm.return_value = [{'version': 99}] exists.side_effect = lambda fn: fn == '/models/model-432.pkl.gz' open.return_value = MagicMock() + annotate.return_value = {} result = File('/models/model-{version}').read(432) open.assert_called_with('/models/model-432.pkl.gz', 'rb') assert result == load.return_value @@ -385,6 +391,69 @@ def test_upgrade_1_0_no_metadata(self, File): dump.assert_called_with(new_md, open_rv, indent=4) +class TestFileAttachments: + @pytest.fixture + def persister(self, tmpdir): + from palladium.persistence import File + model1 = Dummy() + annotate(model1, {'attachments/myatt.txt': 'aGV5', + 'attachments/my2ndatt.txt': 'aG8='}) + model2 = Dummy() + annotate(model2, {'attachments/myatt.txt': 'aG8='}) + persister = File(str(tmpdir) + '/model-{version}') + persister.write(model1) + persister.write(model2) + return persister + + def test_filenames(self, persister, tmpdir): + # Attachment files are namespaced by the model: + assert sorted(os.listdir(tmpdir)) == [ + 'model-1-my2ndatt.txt', 'model-1-myatt.txt', 'model-1.pkl.gz', + 'model-2-myatt.txt', 'model-2.pkl.gz', + 'model-metadata.json', + ] + + def test_attachment_file_contents(self, persister, tmpdir): + # Attachment data is written to files system decoded: + with open(tmpdir + '/model-1-myatt.txt', 'rb') as f: + assert f.read() == b'hey' + with open(tmpdir + '/model-1-my2ndatt.txt', 'rb') as f: + assert f.read() == b'ho' + with open(tmpdir + '/model-2-myatt.txt', 'rb') as f: + assert f.read() == b'ho' + + def test_attachment_not_in_metadata_file(self, persister, tmpdir): + # Attachment data is not written to the metadata file: + with open(tmpdir + '/model-metadata.json') as f: + md = json.loads(f.read()) + assert len(md['models']) == 2 + for model_md in md['models']: + assert 'attachments/myatt.txt' not in model_md + + def test_attachment_not_in_pickle(self, persister, tmpdir): + # Attachment data is not pickled as part of the model: + with open(tmpdir + '/model-1.pkl.gz', 'rb') as fh: + with gzip.open(fh, 'rb') as f: + model1 = pickle.load(f) + assert 'attachments/myatt.txt' not in annotate(model1) + + def test_loaded_back_on_read(self, persister, tmpdir): + # Attachment is read back from the file into metadata + # dictionary on read: + model1 = persister.read(version=1) + assert annotate(model1)['attachments/myatt.txt'] == b'aGV5' + assert annotate(model1)['attachments/my2ndatt.txt'] == b'aG8=' + + def test_deleted_on_delete(self, persister, tmpdir): + # Attachment files are removed from the file system when a + # model is deleted: + persister.delete(1) + assert sorted(os.listdir(tmpdir)) == [ + 'model-2-myatt.txt', 'model-2.pkl.gz', + 'model-metadata.json', + ] + + class TestDatabase: @pytest.fixture def Database(self): @@ -665,7 +734,7 @@ def handle_put_md(request, context): def test_download(self, mocked_requests, persister): """ test download and activation of a model """ - expected = Dummy(name='mymodel') + expected = Dummy(name='mymodel', __metadata__={}) zipped_model = gzip.compress(pickle.dumps(expected)) get_md_url = "%s/mymodel-metadata.json" % (self.base_url,)