Skip to content

Commit

Permalink
Simpler, smarter Analysis. Closes #197.
Browse files Browse the repository at this point in the history
  • Loading branch information
onyxfish committed Aug 29, 2015
1 parent 7eb5653 commit 3eee866
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 95 deletions.
75 changes: 31 additions & 44 deletions agate/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
except ImportError: # pragma: no cover
import pickle

from agate.utils import memoize

class Analysis(object):
"""
An Analysis is a function whose code configuration and output can be
Expand All @@ -23,15 +25,16 @@ class Analysis(object):
:param func: A callable that implements the analysis. Must accept a `data`
argument that is the state inherited from its ancestors analysis.
:param parent: The parent analysis of this one, if any.
:param cache_path: Where to stored the cache files for this analysis.
:param cache_dir: Where to stored the cache files for this analysis.
"""
def __init__(self, func, parent=None, cache_path='.agate'):
def __init__(self, func, parent=None, cache_dir='.agate'):
self._name = func.__name__
self._func = func
self._parent = parent
self._cache_path = cache_path
self._cache_dir = cache_dir
self._next_analyses = []

@memoize
def _trace(self):
"""
Returns the sequence of Analysis instances that lead to this one.
Expand All @@ -41,66 +44,50 @@ def _trace(self):

return [self._name]

@memoize
def _fingerprint(self):
"""
Generate a fingerprint for this analysis function.
"""
hasher = hashlib.md5()

trace = self._trace()
hasher.update('\n'.join(trace))
hasher.update('\n'.join(trace).encode('utf-8'))

source = inspect.getsource(self._func)
hasher.update(source.encode('utf-8'))

return hasher.hexdigest()

def _save_fingerprint(self):
@memoize
def _cache_path(self):
"""
Save the fingerprint of this analysis function to its cache.
Get the full cache path for the current fingerprint.
"""
path = os.path.join(self._cache_path, '%s.fingerprint' % self._name)

if not os.path.exists(self._cache_path):
os.makedirs(self._cache_path)
return os.path.join(self._cache_dir, '%s.cache' % self._fingerprint())

with open(path, 'w') as f:
f.write(self._fingerprint())

def _load_fingerprint(self):
def _check_cache(self):
"""
Load the fingerprint of this analysis function from its cache.
Check if there exists a cache file for the current fingerprint.
"""
path = os.path.join(self._cache_path, '%s.fingerprint' % self._name)

if not os.path.exists(path):
return None
return os.path.exists(self._cache_path())

with open(path) as f:
fingerprint = f.read()

return fingerprint

def _save_data(self, data):
def _save_cache(self, data):
"""
Save the output data for this analysis from its cache.
"""
path = os.path.join(self._cache_path, '%s.data' % self._name)
if not os.path.exists(self._cache_dir):
os.makedirs(self._cache_dir)

f = bz2.BZ2File(path, 'w')
f = bz2.BZ2File(self._cache_path(), 'w')
f.write(pickle.dumps(data))
f.close()

def _load_data(self):
def _load_cache(self):
"""
Load the output data for this analysis from its cache.
"""
path = os.path.join(self._cache_path, '%s.data' % self._name)

if not os.path.exists(path):
raise IOError('Data cache missing at %s' % path)

f = bz2.BZ2File(path)
f = bz2.BZ2File(self._cache_path())
data = pickle.loads(f.read())
f.close()

Expand All @@ -115,7 +102,7 @@ def then(self, next_func):
`data` argument that is the state inherited from its ancestors
analysis.
"""
analysis = Analysis(next_func, parent=self, cache_path=self._cache_path)
analysis = Analysis(next_func, parent=self, cache_dir=self._cache_dir)

self._next_analyses.append(analysis)

Expand Down Expand Up @@ -144,23 +131,23 @@ def run(self, data={}, refresh=False):
local_data = deepcopy(data)

self._func(local_data)
self._save_fingerprint()
self._save_data(local_data)
self._save_cache(local_data)
else:
if self._fingerprint() != self._load_fingerprint():
fingerprint = self._fingerprint()

if self._check_cache():
print('Loaded from cache: %s' % self._name)

local_data = self._load_cache()
else:
print('Running: %s' % self._name)

local_data = deepcopy(data)

self._func(local_data)
self._save_fingerprint()
self._save_data(local_data)
self._save_cache(local_data)

refresh = True
else:
print('Loaded from cache: %s' % self._name)

local_data = self._load_data()

for analysis in self._next_analyses:
analysis.run(local_data, refresh)
62 changes: 11 additions & 51 deletions tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def stage_noop(self, data):
pass

def test_data_flow(self):
analysis = Analysis(self.stage1, cache_path=TEST_CACHE)
analysis = Analysis(self.stage1, cache_dir=TEST_CACHE)
analysis.then(self.stage2)

data = {}
Expand All @@ -63,7 +63,7 @@ def test_data_flow(self):
self.assertEqual(self.data_after_stage2, { 'stage1': 5, 'stage2': 25 })

def test_caching(self):
analysis = Analysis(self.stage1, cache_path=TEST_CACHE)
analysis = Analysis(self.stage1, cache_dir=TEST_CACHE)
analysis.then(self.stage2)

analysis.run()
Expand All @@ -77,87 +77,47 @@ def test_caching(self):
self.assertEqual(self.executed_stage2, 1)

def test_descendent_fingerprint_deleted(self):
analysis = Analysis(self.stage1, cache_path=TEST_CACHE)
analysis.then(self.stage2)
analysis = Analysis(self.stage1, cache_dir=TEST_CACHE)
stage2_analysis = analysis.then(self.stage2)

analysis.run()

self.assertEqual(self.executed_stage1, 1)
self.assertEqual(self.executed_stage2, 1)

path = os.path.join(TEST_CACHE, 'stage2.fingerprint')
os.remove(path)
os.remove(stage2_analysis._cache_path())

analysis.run()

self.assertEqual(self.executed_stage1, 1)
self.assertEqual(self.executed_stage2, 2)

def test_ancestor_fingerprint_deleted(self):
analysis = Analysis(self.stage1, cache_path=TEST_CACHE)
analysis = Analysis(self.stage1, cache_dir=TEST_CACHE)
analysis.then(self.stage2)

analysis.run()

self.assertEqual(self.executed_stage1, 1)
self.assertEqual(self.executed_stage2, 1)

path = os.path.join(TEST_CACHE, 'stage1.fingerprint')
os.remove(path)

analysis.run()

self.assertEqual(self.executed_stage1, 2)
self.assertEqual(self.executed_stage2, 2)

def test_descendent_fingerprint_mismatch(self):
analysis = Analysis(self.stage1, cache_path=TEST_CACHE)
analysis.then(self.stage2)

analysis.run()

self.assertEqual(self.executed_stage1, 1)
self.assertEqual(self.executed_stage2, 1)

path = os.path.join(TEST_CACHE, 'stage2.fingerprint')

with open(path, 'w') as f:
f.write('foo')

analysis.run()

self.assertEqual(self.executed_stage1, 1)
self.assertEqual(self.executed_stage2, 2)

def test_ancestor_fingerprint_mismatch(self):
analysis = Analysis(self.stage1, cache_path=TEST_CACHE)
analysis.then(self.stage2)

analysis.run()

self.assertEqual(self.executed_stage1, 1)
self.assertEqual(self.executed_stage2, 1)

path = os.path.join(TEST_CACHE, 'stage1.fingerprint')

with open(path, 'w') as f:
f.write('foo')
os.remove(analysis._cache_path())

analysis.run()

self.assertEqual(self.executed_stage1, 2)
self.assertEqual(self.executed_stage2, 2)

def test_cache_reused(self):
analysis = Analysis(self.stage1, cache_path=TEST_CACHE)
analysis = Analysis(self.stage1, cache_dir=TEST_CACHE)
analysis.then(self.stage2)

analysis.run()

self.assertEqual(self.executed_stage1, 1)
self.assertEqual(self.executed_stage2, 1)

analysis2 = Analysis(self.stage1, cache_path=TEST_CACHE)
analysis2 = Analysis(self.stage1, cache_dir=TEST_CACHE)
analysis2.then(self.stage2)

analysis2.run()
Expand All @@ -166,7 +126,7 @@ def test_cache_reused(self):
self.assertEqual(self.executed_stage2, 1)

def test_ancestor_changed(self):
analysis = Analysis(self.stage1, cache_path=TEST_CACHE)
analysis = Analysis(self.stage1, cache_dir=TEST_CACHE)
noop = analysis.then(self.stage_noop)
noop.then(self.stage2)

Expand All @@ -175,7 +135,7 @@ def test_ancestor_changed(self):
self.assertEqual(self.executed_stage1, 1)
self.assertEqual(self.executed_stage2, 1)

analysis2 = Analysis(self.stage1, cache_path=TEST_CACHE)
analysis2 = Analysis(self.stage1, cache_dir=TEST_CACHE)
analysis2.then(self.stage2)

analysis2.run()
Expand Down

0 comments on commit 3eee866

Please sign in to comment.