Skip to content

Commit

Permalink
Finer grained follow_fk/m2m and generate_fk/m2m parameters for AutoFi…
Browse files Browse the repository at this point in the history
…xture.
  • Loading branch information
gregmuellegger committed Mar 5, 2010
1 parent c1bdc57 commit d64af55
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 51 deletions.
5 changes: 5 additions & 0 deletions example/testapp/models.py
Expand Up @@ -9,10 +9,15 @@ def __unicode__(self):
return self.name


class Category(models.Model):
name = models.CharField(max_length=50)


class Post(models.Model):
name = models.CharField(max_length=50)
text = models.TextField()
author = models.ForeignKey(Author)
categories = models.ManyToManyField(Category, null=True, blank=True)

def __unicode__(self):
return '%s: %s' % (self.name, self.text)
127 changes: 89 additions & 38 deletions src/django_autofixture/autofixture.py
Expand Up @@ -9,6 +9,70 @@ class CreateInstanceError(Exception):
pass


class Data(object):
'''
Emulates behaviour of ``_value`` but can hold additional data.
I'm happy to use an already existing implementation that is out there. But
for now this serves very well.
'''
def __init__(self, *args, **kwargs):
self._value__ = args.pop(0)
self._list = args
self._dict = kwargs

def __getattr__(self, attr):
return getattr(self._value__, attr)


class Link(object):
'''
Handles logic of following or generating foreignkeys and m2m relations.
'''
def __init__(self, fields=None, default=None):
self.fields = {}
self.subfields = {}
self.default = default

fields = fields or {}
if fields is True:
fields = {'ALL': None}
if not isinstance(fields, dict):
fields = dict([(v, None) for v in fields])
for field, value in fields.items():
try:
fieldname, subfield = field.split('__', 1)
self.subfields.setdefault(fieldname, {})[subfield] = value
except ValueError:
self.fields[field] = value

def __getitem__(self, key):
return self.fields.get(key,
self.fields.get('ALL', self.default))

def __iter__(self):
for field in self.fields:
yield field
for key, value in self.subfields.items():
yield '%s__%s' % (key, value)

def __contains__(self, value):
if 'ALL' in self.fields:
return True
if value in self.fields:
return True
return False

def get_deep_links(self, field):
if 'ALL' in self.fields:
fields = {'ALL': self.fields['ALL']}
else:
fields = self.subfields.get(field, {})
if 'ALL' in fields:
fields = {'ALL': fields['ALL']}
return Link(fields, default=self.default)


class AutoFixture(object):
'''
We don't support the following fields yet:
Expand All @@ -23,10 +87,10 @@ class IGNORE_FIELD(object):
pass

overwrite_defaults = False
follow_fk = True
generate_fk = False
follow_m2m = (1, 5)
generate_m2m = False
follow_fk = Link(True)
generate_fk = Link(False)
follow_m2m = Link('ALL', default=(1,5))
generate_m2m = Link(False, default=(0,0))

none_chance = 0.2
tries = 1000
Expand Down Expand Up @@ -55,7 +119,7 @@ def __init__(self, model,
follow_m2m=None, generate_m2m=None):
'''
Parameters:
``model``:
``model``:
``field_values``: A dictionary with field names of ``model`` as
keys. Values may be static values that are assigned to the
Expand Down Expand Up @@ -106,46 +170,28 @@ def __init__(self, model,
self.overwrite_defaults = overwrite_defaults

if follow_fk is not None:
if not isinstance(follow_fk, Link):
follow_fk = Link(follow_fk)
self.follow_fk = follow_fk
if not hasattr(self.follow_fk, '__iter__'):
if self.follow_fk:
self.follow_fk = [field.name
for field in model._meta.fields
if isinstance(field, related.ForeignKey)]
else:
self.follow_fk = ()

if generate_fk is not None:
if not isinstance(generate_fk, Link):
generate_fk = Link(generate_fk)
self.generate_fk = generate_fk
if not hasattr(self.generate_fk, '__iter__'):
if self.generate_fk:
self.generate_fk = [field.name
for field in model._meta.fields
if isinstance(field, related.ForeignKey)]
else:
self.generate_fk = ()

if follow_m2m is not None:
if not isinstance(follow_m2m, dict):
follow_m2m = Link({'ALL': follow_m2m})
elif not isinstance(follow_m2m, Link):
follow_m2m = Link(follow_m2m)
self.follow_m2m = follow_m2m
if not isinstance(self.follow_m2m, dict):
if self.follow_m2m:
min_count, max_count = self.follow_m2m
self.follow_m2m = {}
for field in model._meta.many_to_many:
self.follow_m2m[field.name] = min_count, max_count
else:
self.follow_m2m = {}

if generate_m2m is not None:
if not isinstance(generate_m2m, dict):
generate_m2m = Link({'ALL': generate_m2m})
elif not isinstance(generate_m2m, Link):
generate_m2m = Link(generate_m2m)
self.generate_m2m = generate_m2m
if not isinstance(self.generate_m2m, dict):
if self.generate_m2m:
min_count, max_count = self.generate_m2m
self.generate_m2m = {}
for field in model._meta.many_to_many:
self.generate_m2m[field.name] = min_count, max_count
else:
self.generate_m2m = {}

for constraint in self.default_constraints:
self.add_constraint(constraint)
Expand Down Expand Up @@ -188,7 +234,10 @@ def get_generator(self, field):
# if generate_fk is set, follow_fk is ignored.
if field.name in self.generate_fk:
return generators.InstanceGenerator(
AutoFixture(field.rel.to),
AutoFixture(
field.rel.to,
follow_fk=self.follow_fk.get_deep_links(field.name),
generate_fk=self.generate_fk.get_deep_links(field.name)),
limit_choices_to=field.rel.limit_choices_to)
if field.name in self.follow_fk:
return generators.InstanceSelector(
Expand Down Expand Up @@ -242,9 +291,11 @@ def get_generator(self, field):
if isinstance(field, fields.SlugField):
generator = generators.SlugGenerator
elif isinstance(field, fields.EmailField):
generator = generators.EmailGenerator
return generators.EmailGenerator(
max_length=min(field.max_length, 30))
elif isinstance(field, fields.URLField):
generator = generators.URLGenerator
return generators.URLGenerator(
max_length=min(field.max_length, 25))
elif field.max_length > 15:
return generators.LoremSentenceGenerator(
common=False,
Expand Down
20 changes: 16 additions & 4 deletions src/django_autofixture/generators.py
Expand Up @@ -180,9 +180,14 @@ def __init__(self, none=True, *args, **kwargs):


class DateTimeGenerator(Generator):
min_date = datetime.datetime.now() - datetime.timedelta(365 * 5)
max_date = datetime.datetime.now() + datetime.timedelta(365 * 1)

def __init__(self, min_date=None, max_date=None, *args, **kwargs):
self.min_date = min_date or datetime.datetime.min
self.max_date = max_date or datetime.datetime.max
if min_date is not None:
self.min_date = min_date
if max_date is not None:
self.max_date = max_date
assert self.min_date < self.max_date
super(DateTimeGenerator, self).__init__(*args, **kwargs)

Expand All @@ -193,9 +198,14 @@ def generate(self):


class DateGenerator(Generator):
min_date = datetime.date.today() - datetime.timedelta(365 * 5)
max_date = datetime.date.today() + datetime.timedelta(365 * 1)

def __init__(self, min_date=None, max_date=None, *args, **kwargs):
self.min_date = min_date or datetime.date.min
self.max_date = max_date or datetime.date.max
if min_date is not None:
self.min_date = min_date
if max_date is not None:
self.max_date = max_date
assert self.min_date < self.max_date
super(DateGenerator, self).__init__(*args, **kwargs)

Expand All @@ -206,6 +216,7 @@ def generate(self):
return date
return datetime.date(date.year, date.month, date.day)


class DecimalGenerator(Generator):
coerce_type = Decimal

Expand All @@ -226,6 +237,7 @@ def generate(self):
10 ** self.decimal_places)
return value


class EmailGenerator(StringGenerator):
chars = string.ascii_lowercase

Expand Down
41 changes: 33 additions & 8 deletions src/django_autofixture/management/commands/loadtestdata.py
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
from django.db import models
from django.db.transaction import commit_on_success
from django.core.management.base import BaseCommand, CommandError
from django_autofixture import signals, AutoFixture
Expand Down Expand Up @@ -37,17 +38,42 @@ class Command(BaseCommand):
u'0,0 which means that no related models are created.'),
)

def format_output(self, obj):
output = unicode(obj)
if len(output) > 50:
output = u'%s ...' % output[:50]
return output

def print_instance(self, sender, model, instance, **kwargs):
reprstr = unicode(instance)
if len(reprstr) > 50:
reprstr = u'%s ...' % reprstr[:50]
if self.verbosity < 1:
return
print '%s(pk=%s): %s' % (
'%s.%s' % (
model._meta.app_label,
model._meta.object_name),
unicode(instance.pk),
reprstr,
self.format_output(instance),
)
if self.verbosity < 2:
return
for field in instance._meta.fields:
if isinstance(field, models.ForeignKey):
obj = getattr(instance, field.name)
if isinstance(obj, models.Model):
print '| %s (pk=%s): %s' % (
field.name,
obj.pk,
self.format_output(obj))
for field in instance._meta.many_to_many:
qs = getattr(instance, field.name).all()
if qs.count():
print '| %s (count=%d):' % (
field.name,
qs.count())
for obj in qs:
print '| | (pk=%s): %s' % (
obj.pk,
self.format_output(obj))

@commit_on_success
def handle(self, *attrs, **options):
Expand Down Expand Up @@ -76,7 +102,7 @@ def handle(self, *attrs, **options):

overwrite_defaults = options['overwrite_defaults']

verbosity = int(options['verbosity'])
self.verbosity = int(options['verbosity'])

models = []
for attr in attrs:
Expand All @@ -98,9 +124,8 @@ def handle(self, *attrs, **options):
u'Unknown model: %s.%s' % (app_label, model_label))
models.append((model, count))

if verbosity >= 1:
signals.instance_created.connect(
self.print_instance)
signals.instance_created.connect(
self.print_instance)

for model, count in models:
fill = AutoFixture(
Expand Down

0 comments on commit d64af55

Please sign in to comment.