From d64af556d46fdbb585c3f7599023f1001816f15d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gregor=20M=C3=BCllegger?= Date: Fri, 5 Mar 2010 01:02:31 +0100 Subject: [PATCH] Finer grained follow_fk/m2m and generate_fk/m2m parameters for AutoFixture. --- example/testapp/models.py | 5 + src/django_autofixture/autofixture.py | 127 ++++++++++++------ src/django_autofixture/generators.py | 20 ++- .../management/commands/loadtestdata.py | 41 ++++-- src/django_autofixture/tests.py | 93 ++++++++++++- 5 files changed, 235 insertions(+), 51 deletions(-) diff --git a/example/testapp/models.py b/example/testapp/models.py index 3b174c4..3688a19 100644 --- a/example/testapp/models.py +++ b/example/testapp/models.py @@ -9,10 +9,15 @@ def __unicode__(self): return self.name +class Category(models.Model): + name = models.CharField(max_length=50) + + class Post(models.Model): name = models.CharField(max_length=50) text = models.TextField() author = models.ForeignKey(Author) + categories = models.ManyToManyField(Category, null=True, blank=True) def __unicode__(self): return '%s: %s' % (self.name, self.text) diff --git a/src/django_autofixture/autofixture.py b/src/django_autofixture/autofixture.py index a7445f4..4ae3baf 100644 --- a/src/django_autofixture/autofixture.py +++ b/src/django_autofixture/autofixture.py @@ -9,6 +9,70 @@ class CreateInstanceError(Exception): pass +class Data(object): + ''' + Emulates behaviour of ``_value`` but can hold additional data. + + I'm happy to use an already existing implementation that is out there. But + for now this serves very well. + ''' + def __init__(self, *args, **kwargs): + self._value__ = args.pop(0) + self._list = args + self._dict = kwargs + + def __getattr__(self, attr): + return getattr(self._value__, attr) + + +class Link(object): + ''' + Handles logic of following or generating foreignkeys and m2m relations. + ''' + def __init__(self, fields=None, default=None): + self.fields = {} + self.subfields = {} + self.default = default + + fields = fields or {} + if fields is True: + fields = {'ALL': None} + if not isinstance(fields, dict): + fields = dict([(v, None) for v in fields]) + for field, value in fields.items(): + try: + fieldname, subfield = field.split('__', 1) + self.subfields.setdefault(fieldname, {})[subfield] = value + except ValueError: + self.fields[field] = value + + def __getitem__(self, key): + return self.fields.get(key, + self.fields.get('ALL', self.default)) + + def __iter__(self): + for field in self.fields: + yield field + for key, value in self.subfields.items(): + yield '%s__%s' % (key, value) + + def __contains__(self, value): + if 'ALL' in self.fields: + return True + if value in self.fields: + return True + return False + + def get_deep_links(self, field): + if 'ALL' in self.fields: + fields = {'ALL': self.fields['ALL']} + else: + fields = self.subfields.get(field, {}) + if 'ALL' in fields: + fields = {'ALL': fields['ALL']} + return Link(fields, default=self.default) + + class AutoFixture(object): ''' We don't support the following fields yet: @@ -23,10 +87,10 @@ class IGNORE_FIELD(object): pass overwrite_defaults = False - follow_fk = True - generate_fk = False - follow_m2m = (1, 5) - generate_m2m = False + follow_fk = Link(True) + generate_fk = Link(False) + follow_m2m = Link('ALL', default=(1,5)) + generate_m2m = Link(False, default=(0,0)) none_chance = 0.2 tries = 1000 @@ -55,7 +119,7 @@ def __init__(self, model, follow_m2m=None, generate_m2m=None): ''' Parameters: - ``model``: + ``model``: ``field_values``: A dictionary with field names of ``model`` as keys. Values may be static values that are assigned to the @@ -106,46 +170,28 @@ def __init__(self, model, self.overwrite_defaults = overwrite_defaults if follow_fk is not None: + if not isinstance(follow_fk, Link): + follow_fk = Link(follow_fk) self.follow_fk = follow_fk - if not hasattr(self.follow_fk, '__iter__'): - if self.follow_fk: - self.follow_fk = [field.name - for field in model._meta.fields - if isinstance(field, related.ForeignKey)] - else: - self.follow_fk = () if generate_fk is not None: + if not isinstance(generate_fk, Link): + generate_fk = Link(generate_fk) self.generate_fk = generate_fk - if not hasattr(self.generate_fk, '__iter__'): - if self.generate_fk: - self.generate_fk = [field.name - for field in model._meta.fields - if isinstance(field, related.ForeignKey)] - else: - self.generate_fk = () if follow_m2m is not None: + if not isinstance(follow_m2m, dict): + follow_m2m = Link({'ALL': follow_m2m}) + elif not isinstance(follow_m2m, Link): + follow_m2m = Link(follow_m2m) self.follow_m2m = follow_m2m - if not isinstance(self.follow_m2m, dict): - if self.follow_m2m: - min_count, max_count = self.follow_m2m - self.follow_m2m = {} - for field in model._meta.many_to_many: - self.follow_m2m[field.name] = min_count, max_count - else: - self.follow_m2m = {} if generate_m2m is not None: + if not isinstance(generate_m2m, dict): + generate_m2m = Link({'ALL': generate_m2m}) + elif not isinstance(generate_m2m, Link): + generate_m2m = Link(generate_m2m) self.generate_m2m = generate_m2m - if not isinstance(self.generate_m2m, dict): - if self.generate_m2m: - min_count, max_count = self.generate_m2m - self.generate_m2m = {} - for field in model._meta.many_to_many: - self.generate_m2m[field.name] = min_count, max_count - else: - self.generate_m2m = {} for constraint in self.default_constraints: self.add_constraint(constraint) @@ -188,7 +234,10 @@ def get_generator(self, field): # if generate_fk is set, follow_fk is ignored. if field.name in self.generate_fk: return generators.InstanceGenerator( - AutoFixture(field.rel.to), + AutoFixture( + field.rel.to, + follow_fk=self.follow_fk.get_deep_links(field.name), + generate_fk=self.generate_fk.get_deep_links(field.name)), limit_choices_to=field.rel.limit_choices_to) if field.name in self.follow_fk: return generators.InstanceSelector( @@ -242,9 +291,11 @@ def get_generator(self, field): if isinstance(field, fields.SlugField): generator = generators.SlugGenerator elif isinstance(field, fields.EmailField): - generator = generators.EmailGenerator + return generators.EmailGenerator( + max_length=min(field.max_length, 30)) elif isinstance(field, fields.URLField): - generator = generators.URLGenerator + return generators.URLGenerator( + max_length=min(field.max_length, 25)) elif field.max_length > 15: return generators.LoremSentenceGenerator( common=False, diff --git a/src/django_autofixture/generators.py b/src/django_autofixture/generators.py index d04550c..bcaeb12 100644 --- a/src/django_autofixture/generators.py +++ b/src/django_autofixture/generators.py @@ -180,9 +180,14 @@ def __init__(self, none=True, *args, **kwargs): class DateTimeGenerator(Generator): + min_date = datetime.datetime.now() - datetime.timedelta(365 * 5) + max_date = datetime.datetime.now() + datetime.timedelta(365 * 1) + def __init__(self, min_date=None, max_date=None, *args, **kwargs): - self.min_date = min_date or datetime.datetime.min - self.max_date = max_date or datetime.datetime.max + if min_date is not None: + self.min_date = min_date + if max_date is not None: + self.max_date = max_date assert self.min_date < self.max_date super(DateTimeGenerator, self).__init__(*args, **kwargs) @@ -193,9 +198,14 @@ def generate(self): class DateGenerator(Generator): + min_date = datetime.date.today() - datetime.timedelta(365 * 5) + max_date = datetime.date.today() + datetime.timedelta(365 * 1) + def __init__(self, min_date=None, max_date=None, *args, **kwargs): - self.min_date = min_date or datetime.date.min - self.max_date = max_date or datetime.date.max + if min_date is not None: + self.min_date = min_date + if max_date is not None: + self.max_date = max_date assert self.min_date < self.max_date super(DateGenerator, self).__init__(*args, **kwargs) @@ -206,6 +216,7 @@ def generate(self): return date return datetime.date(date.year, date.month, date.day) + class DecimalGenerator(Generator): coerce_type = Decimal @@ -226,6 +237,7 @@ def generate(self): 10 ** self.decimal_places) return value + class EmailGenerator(StringGenerator): chars = string.ascii_lowercase diff --git a/src/django_autofixture/management/commands/loadtestdata.py b/src/django_autofixture/management/commands/loadtestdata.py index 2a9a878..a4197a7 100644 --- a/src/django_autofixture/management/commands/loadtestdata.py +++ b/src/django_autofixture/management/commands/loadtestdata.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from django.db import models from django.db.transaction import commit_on_success from django.core.management.base import BaseCommand, CommandError from django_autofixture import signals, AutoFixture @@ -37,17 +38,42 @@ class Command(BaseCommand): u'0,0 which means that no related models are created.'), ) + def format_output(self, obj): + output = unicode(obj) + if len(output) > 50: + output = u'%s ...' % output[:50] + return output + def print_instance(self, sender, model, instance, **kwargs): - reprstr = unicode(instance) - if len(reprstr) > 50: - reprstr = u'%s ...' % reprstr[:50] + if self.verbosity < 1: + return print '%s(pk=%s): %s' % ( '%s.%s' % ( model._meta.app_label, model._meta.object_name), unicode(instance.pk), - reprstr, + self.format_output(instance), ) + if self.verbosity < 2: + return + for field in instance._meta.fields: + if isinstance(field, models.ForeignKey): + obj = getattr(instance, field.name) + if isinstance(obj, models.Model): + print '| %s (pk=%s): %s' % ( + field.name, + obj.pk, + self.format_output(obj)) + for field in instance._meta.many_to_many: + qs = getattr(instance, field.name).all() + if qs.count(): + print '| %s (count=%d):' % ( + field.name, + qs.count()) + for obj in qs: + print '| | (pk=%s): %s' % ( + obj.pk, + self.format_output(obj)) @commit_on_success def handle(self, *attrs, **options): @@ -76,7 +102,7 @@ def handle(self, *attrs, **options): overwrite_defaults = options['overwrite_defaults'] - verbosity = int(options['verbosity']) + self.verbosity = int(options['verbosity']) models = [] for attr in attrs: @@ -98,9 +124,8 @@ def handle(self, *attrs, **options): u'Unknown model: %s.%s' % (app_label, model_label)) models.append((model, count)) - if verbosity >= 1: - signals.instance_created.connect( - self.print_instance) + signals.instance_created.connect( + self.print_instance) for model, count in models: fill = AutoFixture( diff --git a/src/django_autofixture/tests.py b/src/django_autofixture/tests.py index c117e2d..199af2f 100644 --- a/src/django_autofixture/tests.py +++ b/src/django_autofixture/tests.py @@ -5,7 +5,8 @@ from django.db import models from django.test import TestCase from django_autofixture import generators -from django_autofixture.autofixture import AutoFixture, CreateInstanceError +from django_autofixture.autofixture import AutoFixture, CreateInstanceError, \ + Link def y2k(): @@ -23,6 +24,16 @@ class OtherSimpleModel(models.Model): name = models.CharField(max_length=50) +class DeepLinkModel1(models.Model): + related = models.ForeignKey('SimpleModel') + related2 = models.ForeignKey('SimpleModel', + related_name='deeplinkmodel1_rel2', + null=True, blank=True) + +class DeepLinkModel2(models.Model): + related = models.ForeignKey('DeepLinkModel1') + + class BasicModel(models.Model): chars = models.CharField(max_length=50) blankchars = models.CharField(max_length=100, blank=True) @@ -160,6 +171,25 @@ def test_generate_foreignkeys(self): self.assertEqual(obj.related.__class__, BasicModel) self.assertEqual(obj.limitedfk.name, 'foo') + def test_deep_generate_foreignkeys(self): + filler = AutoFixture( + DeepLinkModel2, + generate_fk=True) + for obj in filler.create(10): + self.assertEqual(obj.related.__class__, DeepLinkModel1) + self.assertEqual(obj.related.related.__class__, SimpleModel) + self.assertEqual(obj.related.related2.__class__, SimpleModel) + + def test_deep_generate_foreignkeys2(self): + filler = AutoFixture( + DeepLinkModel2, + follow_fk=False, + generate_fk=('related', 'related__related')) + for obj in filler.create(10): + self.assertEqual(obj.related.__class__, DeepLinkModel1) + self.assertEqual(obj.related.related.__class__, SimpleModel) + self.assertEqual(obj.related.related2, None) + def test_generate_only_some_foreignkeys(self): filler = AutoFixture( RelatedModel, @@ -333,3 +363,64 @@ def test_instance_selector(self): # works also with queryset as argument result = generators.InstanceSelector(SimpleModel.objects.all()).generate() self.assertEqual(result.__class__, SimpleModel) + + +class TestLinkClass(TestCase): + def test_flat_link(self): + link = Link(('foo', 'bar')) + self.assertTrue('foo' in link) + self.assertTrue('bar' in link) + self.assertFalse('spam' in link) + + self.assertEqual(link['foo'], None) + self.assertEqual(link['spam'], None) + + def test_nested_links(self): + link = Link(('foo', 'foo__bar', 'spam__ALL')) + self.assertTrue('foo' in link) + self.assertFalse('spam' in link) + self.assertFalse('egg' in link) + + foolink = link.get_deep_links('foo') + self.assertTrue('bar' in foolink) + self.assertFalse('egg' in foolink) + + spamlink = link.get_deep_links('spam') + self.assertTrue('bar' in spamlink) + self.assertTrue('egg' in spamlink) + + def test_links_with_value(self): + link = Link({'foo': 1, 'spam__egg': 2}, default=0) + self.assertTrue('foo' in link) + self.assertEqual(link['foo'], 1) + self.assertFalse('spam' in link) + self.assertEqual(link['spam'], 0) + + spamlink = link.get_deep_links('spam') + self.assertTrue('egg' in spamlink) + self.assertEqual(spamlink['bar'], 0) + self.assertEqual(spamlink['egg'], 2) + + def test_always_true_link(self): + link = Link(True) + self.assertTrue('field' in link) + self.assertTrue('any' in link) + + link = link.get_deep_links('field') + self.assertTrue('field' in link) + self.assertTrue('any' in link) + + link = Link(True) + self.assertTrue('field' in link) + self.assertTrue('any' in link) + + link = link.get_deep_links('field') + self.assertTrue('field' in link) + self.assertTrue('any' in link) + + def test_inherit_always_true_value(self): + link = Link({'ALL': 1}) + self.assertEqual(link['foo'], 1) + + sublink = link.get_deep_links('foo') + self.assertEqual(sublink['bar'], 1)