Skip to content

Commit

Permalink
Merge pull request #6 from ygormutti/pysnap/fix-collection-issues
Browse files Browse the repository at this point in the history
Fix some bugs in GenericFormatter and collection formatters
  • Loading branch information
yourbuddyconner committed Jul 1, 2019
2 parents 004ada4 + a5015f3 commit 5f31565
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 17 deletions.
26 changes: 26 additions & 0 deletions examples/pytest/snapshots/snap_test_demo.py
Expand Up @@ -21,3 +21,29 @@
snapshots['test_multiple_files 1'] = FileSnapshot('snap_test_demo/test_multiple_files 1.txt')

snapshots['test_multiple_files 2'] = FileSnapshot('snap_test_demo/test_multiple_files 2.txt')

snapshots['test_nested_objects dict'] = {
'key': GenericRepr('#')
}

snapshots['test_nested_objects defaultdict'] = {
'key': [
GenericRepr('#')
]
}

snapshots['test_nested_objects list'] = [
GenericRepr('#')
]

snapshots['test_nested_objects tuple'] = (
GenericRepr('#')
,)

snapshots['test_nested_objects set'] = set([
GenericRepr('#')
])

snapshots['test_nested_objects frozenset'] = frozenset([
GenericRepr('#')
])
25 changes: 25 additions & 0 deletions examples/pytest/test_demo.py
@@ -1,4 +1,6 @@
# -*- coding: utf-8 -*-
from collections import defaultdict

from pysnap.file import FileSnapshot


Expand Down Expand Up @@ -58,3 +60,26 @@ def test_multiple_files(snapshot, tmpdir):
temp_file1 = tmpdir.join('example2.txt')
temp_file1.write('Hello, world 2!')
snapshot.assert_match(FileSnapshot(str(temp_file1)))


class ObjectWithBadRepr(object):
def __repr__(self):
return "#"


def test_nested_objects(snapshot):
obj = ObjectWithBadRepr()

dict_ = {'key': obj}
defaultdict_ = defaultdict(list, [('key', [obj])])
list_ = [obj]
tuple_ = (obj,)
set_ = set((obj,))
frozenset_ = frozenset((obj,))

snapshot.assert_match(dict_, 'dict')
snapshot.assert_match(defaultdict_, 'defaultdict')
snapshot.assert_match(list_, 'list')
snapshot.assert_match(tuple_, 'tuple')
snapshot.assert_match(set_, 'set')
snapshot.assert_match(frozenset_, 'frozenset')
2 changes: 1 addition & 1 deletion pysnap/file.py
Expand Up @@ -47,7 +47,7 @@ def get_imports(self):
def format(self, value, indent, formatter):
return repr(value)

def assert_value_matches_snapshot(self, test, test_value, snapshot_value):
def assert_value_matches_snapshot(self, test, test_value, snapshot_value, formatter):
snapshot_path = os.path.join(test.module.snapshot_dir, snapshot_value.path)
files_identical = filecmp.cmp(test_value.path, snapshot_path, shallow=False)
assert files_identical, "Stored file differs from test file"
Expand Down
4 changes: 4 additions & 0 deletions pysnap/formatter.py
Expand Up @@ -19,6 +19,10 @@ def format(self, value, indent):
self.imports[module].add(import_name)
return formatter.format(value, indent, self)

def normalize(self, value):
formatter = self.get_formatter(value)
return formatter.normalize(value, self)

@staticmethod
def get_formatter(value):
for formatter in Formatter.formatters:
Expand Down
62 changes: 47 additions & 15 deletions pysnap/formatters.py
@@ -1,4 +1,5 @@
import six
from collections import defaultdict

from .sorted_dict import SortedDict
from .generic_repr import GenericRepr
Expand All @@ -14,12 +15,15 @@ def format(self, value, indent, formatter):
def get_imports(self):
return ()

def assert_value_matches_snapshot(self, test, test_value, snapshot_value):
test.assert_equals(test_value, snapshot_value)
def assert_value_matches_snapshot(self, test, test_value, snapshot_value, formatter):
test.assert_equals(formatter.normalize(test_value), snapshot_value)

def store(self, test, value):
return value

def normalize(self, value, formatter):
return value


class TypeFormatter(BaseFormatter):
def __init__(self, types, format_func):
Expand All @@ -33,6 +37,19 @@ def format(self, value, indent, formatter):
return self.format_func(value, indent, formatter)


class CollectionFormatter(TypeFormatter):
def normalize(self, value, formatter):
iterator = iter(value.items()) if isinstance(value, dict) else iter(value)
return value.__class__(formatter.normalize(item) for item in iterator)


class DefaultDictFormatter(TypeFormatter):
def normalize(self, value, formatter):
return defaultdict(
value.default_factory, (formatter.normalize(item) for item in value.items())
)


def trepr(s):
text = '\n'.join([repr(line).lstrip('u')[1:-1] for line in s.split('\n')])
quotes, dquotes = "'''", '"""'
Expand Down Expand Up @@ -73,36 +90,48 @@ def format_dict(value, indent, formatter):


def format_list(value, indent, formatter):
return '[%s]' % format_sequence(value, indent, formatter)


def format_sequence(value, indent, formatter):
items = [
formatter.lfchar + formatter.htchar * (indent + 1) + formatter.format(item, indent + 1)
for item in value
]
return '[%s]' % (','.join(items) + formatter.lfchar + formatter.htchar * indent)
return ','.join(items) + formatter.lfchar + formatter.htchar * indent


def format_tuple(value, indent, formatter):
items = [
formatter.lfchar + formatter.htchar * (indent + 1) + formatter.format(item, indent + 1)
for item in value
]
return '(%s,)' % (','.join(items) + formatter.lfchar + formatter.htchar * indent)
return '(%s%s' % (format_sequence(value, indent, formatter), ',)' if len(value) == 1 else ")")


def format_set(value, indent, formatter):
return 'set([%s])' % format_sequence(value, indent, formatter)


def format_frozenset(value, indent, formatter):
return 'frozenset([%s])' % format_sequence(value, indent, formatter)


class GenericFormatter(BaseFormatter):
def can_format(self, value):
return True

def store(self, formatter, value):
def store(self, test, value):
return GenericRepr.from_value(value)

def normalize(self, value, formatter):
return GenericRepr.from_value(value)

def format(self, value, indent, formatter):
# `value` will always be a GenericRepr object because that's what `store` returns.
if not isinstance(value, GenericRepr):
value = GenericRepr.from_value(value)
return repr(value)

def get_imports(self):
return [('pysnap', 'GenericRepr')]

def assert_value_matches_snapshot(self, test, test_value, snapshot_value):
def assert_value_matches_snapshot(self, test, test_value, snapshot_value, formatter):
test_value = GenericRepr.from_value(test_value)
# Assert equality between the representations to provide a nice textual diff.
test.assert_equals(test_value.representation, snapshot_value.representation)
Expand All @@ -111,10 +140,13 @@ def assert_value_matches_snapshot(self, test, test_value, snapshot_value):
def default_formatters():
return [
TypeFormatter(type(None), format_none),
TypeFormatter(dict, format_dict),
TypeFormatter(tuple, format_tuple),
TypeFormatter(list, format_list),
DefaultDictFormatter(defaultdict, format_dict),
CollectionFormatter(dict, format_dict),
CollectionFormatter(tuple, format_tuple),
CollectionFormatter(list, format_list),
CollectionFormatter(set, format_set),
CollectionFormatter(frozenset, format_frozenset),
TypeFormatter(six.string_types, format_str),
TypeFormatter((int, float, complex, bool, bytes, set, frozenset), format_std_type),
TypeFormatter((int, float, complex, bool, bytes), format_std_type),
GenericFormatter()
]
3 changes: 3 additions & 0 deletions pysnap/generic_repr.py
Expand Up @@ -8,6 +8,9 @@ def __repr__(self):
def __eq__(self, other):
return isinstance(other, GenericRepr) and self.representation == other.representation

def __hash__(self):
return hash(self.representation)

@staticmethod
def from_value(value):
representation = repr(value)
Expand Down
2 changes: 1 addition & 1 deletion pysnap/module.py
Expand Up @@ -226,7 +226,7 @@ def store(self, data):

def assert_value_matches_snapshot(self, test_value, snapshot_value):
formatter = Formatter.get_formatter(test_value)
formatter.assert_value_matches_snapshot(self, test_value, snapshot_value)
formatter.assert_value_matches_snapshot(self, test_value, snapshot_value, Formatter())

def assert_equals(self, value, snapshot):
assert value == snapshot
Expand Down

0 comments on commit 5f31565

Please sign in to comment.