From 9b78360515ab5e00a32caea0db69f57cc2b4b402 Mon Sep 17 00:00:00 2001 From: Markus Unterwaditzer Date: Wed, 9 Apr 2014 18:36:46 +0200 Subject: [PATCH] Use safer write to avoid corruption of data. --- tests/storage/test_filesystem.py | 9 +++++++++ vdirsyncer/storage/filesystem.py | 27 +++++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/tests/storage/test_filesystem.py b/tests/storage/test_filesystem.py index 62c6d46d0..bdd8be370 100644 --- a/tests/storage/test_filesystem.py +++ b/tests/storage/test_filesystem.py @@ -41,3 +41,12 @@ def test_is_not_directory(self, tmpdir): def test_create_is_true(self, tmpdir): self.storage_class(str(tmpdir), '.txt', collection='asd') assert tmpdir.listdir() == [tmpdir.join('asd')] + + def test_broken_data(self, tmpdir): + s = self.storage_class(str(tmpdir), '.txt') + class BrokenItem(object): + raw = b'Ц, Ш, Л, ж, Д, З, Ю' + uid = 'jeezus' + with pytest.raises(UnicodeError): + s.upload(BrokenItem) + assert not tmpdir.listdir() diff --git a/vdirsyncer/storage/filesystem.py b/vdirsyncer/storage/filesystem.py index ec9c4f27d..2b56ce1ed 100644 --- a/vdirsyncer/storage/filesystem.py +++ b/vdirsyncer/storage/filesystem.py @@ -19,6 +19,29 @@ def _get_etag(fpath): return '{:.9f}'.format(os.path.getmtime(fpath)) +class safe_write(object): + f = None + tmppath = None + fpath = None + mode = None + + def __init__(self, fpath, mode): + self.tmppath = fpath + '.tmp' + self.fpath = fpath + self.mode = mode + + def __enter__(self): + self.f = f = open(self.tmppath, self.mode) + self.write = f.write + return self + + def __exit__(self, type, value, tb): + if type is None: + os.rename(self.tmppath, self.fpath) + else: + os.remove(self.tmppath) + + class FilesystemStorage(Storage): '''Saves data in vdir collection @@ -93,7 +116,7 @@ def upload(self, item): fpath = self._get_filepath(href) if os.path.exists(fpath): raise exceptions.AlreadyExistingError(item.uid) - with open(fpath, 'wb+') as f: + with safe_write(fpath, 'wb+') as f: f.write(item.raw.encode(self.encoding)) return href, _get_etag(fpath) @@ -108,7 +131,7 @@ def update(self, href, item, etag): if etag != actual_etag: raise exceptions.WrongEtagError(etag, actual_etag) - with open(fpath, 'wb') as f: + with safe_write(fpath, 'wb') as f: f.write(item.raw.encode(self.encoding)) return _get_etag(fpath)