diff --git a/ChangeLog.rst b/ChangeLog.rst index 1b2f7ce..428622f 100644 --- a/ChangeLog.rst +++ b/ChangeLog.rst @@ -11,6 +11,10 @@ master - Provide programmatically accessible package version number. - Migrate package metadata from setup.py to setup.cfg and specify the PEP-517 build-backend to use with the project. +*Bugfix:* + - Fixed a :code:`KeyError` that happened when saving a Model with :code:`update_fields` specified after updating a + field value with an :code:`F` object (#118). + .. _v1.6.0: 1.6.0 (07/04/2021) diff --git a/src/dirtyfields/dirtyfields.py b/src/dirtyfields/dirtyfields.py index 9d971c7..6011564 100644 --- a/src/dirtyfields/dirtyfields.py +++ b/src/dirtyfields/dirtyfields.py @@ -48,6 +48,12 @@ def _connect_m2m_relations(self): name=self.__class__.__name__)) def _as_dict(self, check_relationship, include_primary_key=True): + """ + Capture the model fields' state as a dictionary. + + Only capture values we are confident are in the database, or would be + saved to the database if self.save() is called. + """ all_field = {} deferred_fields = self.get_deferred_fields() @@ -165,7 +171,18 @@ def reset_state(sender, instance, **kwargs): if field.get_attname() in instance.get_deferred_fields(): continue - instance._original_state[field.name] = new_state[field.name] + if field.name in new_state: + instance._original_state[field.name] = ( + new_state[field.name] + ) + else: + # If we are here it means the field was updated in the DB, + # and we don't know the new value in the database. + # e.g it was updated with an F() expression. + # Because we now don't know the value in the DB, + # we remove it from _original_state, because we can't tell + # if its dirty or not. + del instance._original_state[field.name] else: instance._original_state = new_state diff --git a/tests/test_non_regression.py b/tests/test_non_regression.py index 517f6f1..00c3494 100644 --- a/tests/test_non_regression.py +++ b/tests/test_non_regression.py @@ -1,5 +1,6 @@ import pytest from django.db import IntegrityError +from django.db.models import F from django.test.utils import override_settings from .models import (ModelTest, ModelWithForeignKeyTest, ModelWithNonEditableFieldsTest, @@ -173,3 +174,21 @@ def test_access_deferred_field_doesnt_reset_state(): assert tm_deferred.get_deferred_fields() == set() # previously accessing the deferred field would reset the dirty state. assert tm_deferred.get_dirty_fields() == {"boolean": True} + + +@pytest.mark.django_db +def test_f_objects_and_save_update_fields_works(): + # Non regression test case for bug: + # https://github.com/romgar/django-dirtyfields/issues/118 + tm = ExpressionModelTest.objects.create(counter=0) + assert tm.counter == 0 + + tm.counter = F("counter") + 1 + tm.save() + tm.refresh_from_db() + assert tm.counter == 1 + + tm.counter = F("counter") + 1 + tm.save(update_fields=["counter"]) + tm.refresh_from_db() + assert tm.counter == 2 diff --git a/tests/test_save_fields.py b/tests/test_save_fields.py index 2a08c1e..0e564f5 100644 --- a/tests/test_save_fields.py +++ b/tests/test_save_fields.py @@ -1,6 +1,13 @@ import pytest -from .models import ModelTest, MixedFieldsModelTest, ModelWithForeignKeyTest +from django.db.models import F + +from .models import ( + ExpressionModelTest, + MixedFieldsModelTest, + ModelTest, + ModelWithForeignKeyTest, +) from .utils import assert_number_of_queries_on_regex @@ -139,3 +146,68 @@ def test_save_deferred_field_with_update_fields_behaviour(): tm.save(update_fields=['boolean']) tm.boolean = False assert tm.get_dirty_fields() == {'boolean': True} + + +@pytest.mark.django_db +def test_get_dirty_fields_when_saving_with_f_objects(): + """ + This documents how get_dirty_fields() behaves when updating model fields + with F objects. + """ + + tm = ExpressionModelTest.objects.create(counter=0) + assert tm.counter == 0 + assert tm.get_dirty_fields() == {} + + tm.counter = F("counter") + 1 + # tm.counter field is not considered dirty because it doesn't have a simple + # value in memory we can compare to the original value. + # i.e. we don't know what value it will be in the database after the F + # object is translated into SQL. + assert tm.get_dirty_fields() == {} + + tm.save() + # tm.counter is still an F object after save() - we don't know the new + # value in the database. + assert tm.get_dirty_fields() == {} + + tm.counter = 10 + # even though we have now assigned a literal value to tm.counter, we don't + # know the value in the database, so it is not considered dirty. + assert tm.get_dirty_fields() == {} + + tm.save() + assert tm.get_dirty_fields() == {} + + tm.refresh_from_db() + # if we call refresh_from_db(), we load the database value, + # so we can assign a value and make the field dirty again. + tm.counter = 20 + assert tm.get_dirty_fields() == {"counter": 10} + + +@pytest.mark.django_db +def test_get_dirty_fields_when_saving_with_f_objects_update_fields_specified(): + """ + Same as above but with update_fields specified when saving/refreshing + """ + + tm = ExpressionModelTest.objects.create(counter=0) + assert tm.counter == 0 + assert tm.get_dirty_fields() == {} + + tm.counter = F("counter") + 1 + assert tm.get_dirty_fields() == {} + + tm.save(update_fields={"counter"}) + assert tm.get_dirty_fields() == {} + + tm.counter = 10 + assert tm.get_dirty_fields() == {} + + tm.save(update_fields={"counter"}) + assert tm.get_dirty_fields() == {} + + tm.refresh_from_db(fields={"counter"}) + tm.counter = 20 + assert tm.get_dirty_fields() == {"counter": 10}