Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/linters.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
uses: actions/setup-python@v6
with:
python-version: '3.13'
- run: python -m pip install "isort<6"
- run: python -m pip install isort
- name: isort
# Pinned to v3.0.0.
uses: liskin/gh-problem-matcher-wrap@e7b7beaaafa52524748b31a381160759d68d61fb
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
files: 'docs/.*\.txt$'
args: ["--rst-literal-block"]
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
rev: 7.0.0
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
Expand Down
30 changes: 26 additions & 4 deletions django/contrib/contenttypes/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ReverseManyToOneDescriptor,
lazy_related_operation,
)
from django.db.models.query import prefetch_related_objects
from django.db.models.query_utils import PathInfo
from django.db.models.sql import AND
from django.db.models.sql.where import WhereNode
Expand Down Expand Up @@ -200,11 +201,13 @@ def get_prefetch_querysets(self, instances, querysets=None):
for ct_id, fkeys in fk_dict.items():
if ct_id in custom_queryset_dict:
# Return values from the custom queryset, if provided.
ret_val.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys))
queryset = custom_queryset_dict[ct_id].filter(pk__in=fkeys)
else:
instance = instance_dict[ct_id]
ct = self.field.get_content_type(id=ct_id, using=instance._state.db)
ret_val.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))
queryset = ct.get_all_objects_for_this_type(pk__in=fkeys)

ret_val.extend(queryset.fetch_mode(instances[0]._state.fetch_mode))

# For doing the join in Python, we have to match both the FK val and
# the content type, so we use a callable that returns a (fk, class)
Expand Down Expand Up @@ -253,6 +256,15 @@ def __get__(self, instance, cls=None):
return rel_obj
else:
rel_obj = None

instance._state.fetch_mode.fetch(self, instance)
return self.field.get_cached_value(instance)

def fetch_one(self, instance):
f = self.field.model._meta.get_field(self.field.ct_field)
ct_id = getattr(instance, f.attname, None)
pk_val = getattr(instance, self.field.fk_field)
rel_obj = None
if ct_id is not None:
ct = self.field.get_content_type(id=ct_id, using=instance._state.db)
try:
Expand All @@ -261,8 +273,14 @@ def __get__(self, instance, cls=None):
)
except ObjectDoesNotExist:
pass
else:
rel_obj._state.fetch_mode = instance._state.fetch_mode
self.field.set_cached_value(instance, rel_obj)
return rel_obj

def fetch_many(self, instances):
is_cached = self.field.is_cached
missing_instances = [i for i in instances if not is_cached(i)]
return prefetch_related_objects(missing_instances, self.field.name)

def __set__(self, instance, value):
ct = None
Expand Down Expand Up @@ -622,7 +640,11 @@ def _apply_rel_filters(self, queryset):
Filter the queryset for the instance this manager is bound to.
"""
db = self._db or router.db_for_read(self.model, instance=self.instance)
return queryset.using(db).filter(**self.core_filters)
return (
queryset.using(db)
.fetch_mode(self.instance._state.fetch_mode)
.filter(**self.core_filters)
)

def _remove_prefetched_objects(self):
try:
Expand Down
6 changes: 6 additions & 0 deletions django/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ class FieldError(Exception):
pass


class FieldFetchBlocked(FieldError):
"""On-demand fetching of a model field blocked."""

pass


NON_FIELD_ERRORS = "__all__"


Expand Down
6 changes: 3 additions & 3 deletions django/db/backends/postgresql/compiler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from django.db.models.sql.compiler import (
from django.db.models.sql.compiler import ( # isort:skip
SQLAggregateCompiler,
SQLCompiler,
SQLDeleteCompiler,
SQLInsertCompiler as BaseSQLInsertCompiler,
SQLUpdateCompiler,
)
from django.db.models.sql.compiler import SQLInsertCompiler as BaseSQLInsertCompiler
from django.db.models.sql.compiler import SQLUpdateCompiler

__all__ = [
"SQLAggregateCompiler",
Expand Down
4 changes: 4 additions & 0 deletions django/db/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
WindowFrame,
WindowFrameExclusion,
)
from django.db.models.fetch_modes import FETCH_ONE, FETCH_PEERS, RAISE
from django.db.models.fields import * # NOQA
from django.db.models.fields import __all__ as fields_all
from django.db.models.fields.composite import CompositePrimaryKey
Expand Down Expand Up @@ -105,6 +106,9 @@
"GeneratedField",
"JSONField",
"OrderWrt",
"FETCH_ONE",
"FETCH_PEERS",
"RAISE",
"Lookup",
"Transform",
"Manager",
Expand Down
25 changes: 22 additions & 3 deletions django/db/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from django.db.models.constants import LOOKUP_SEP
from django.db.models.deletion import CASCADE, Collector
from django.db.models.expressions import DatabaseDefault
from django.db.models.fetch_modes import FETCH_ONE
from django.db.models.fields.composite import CompositePrimaryKey
from django.db.models.fields.related import (
ForeignObjectRel,
Expand Down Expand Up @@ -466,6 +467,14 @@ def __get__(self, instance, cls=None):
return res


class ModelStateFetchModeDescriptor:
def __get__(self, instance, cls=None):
if instance is None:
return self
res = instance.fetch_mode = FETCH_ONE
return res


class ModelState:
"""Store model instance state."""

Expand All @@ -476,6 +485,14 @@ class ModelState:
# on the actual save.
adding = True
fields_cache = ModelStateFieldsCacheDescriptor()
fetch_mode = ModelStateFetchModeDescriptor()
peers = ()

def __getstate__(self):
state = self.__dict__.copy()
# Weak references can't be pickled.
state.pop("peers", None)
return state


class Model(AltersData, metaclass=ModelBase):
Expand Down Expand Up @@ -595,7 +612,7 @@ def __init__(self, *args, **kwargs):
post_init.send(sender=cls, instance=self)

@classmethod
def from_db(cls, db, field_names, values):
def from_db(cls, db, field_names, values, *, fetch_mode=None):
if len(values) != len(cls._meta.concrete_fields):
values_iter = iter(values)
values = [
Expand All @@ -605,6 +622,8 @@ def from_db(cls, db, field_names, values):
new = cls(*values)
new._state.adding = False
new._state.db = db
if fetch_mode is not None:
new._state.fetch_mode = fetch_mode
return new

def __repr__(self):
Expand Down Expand Up @@ -714,8 +733,8 @@ def refresh_from_db(self, using=None, fields=None, from_queryset=None):
should be an iterable of field attnames. If fields is None, then
all non-deferred fields are reloaded.

When accessing deferred fields of an instance, the deferred loading
of the field will call this method.
When fetching deferred fields for a single instance (the FETCH_ONE
fetch mode), the deferred loading uses this method.
"""
if fields is None:
self._prefetched_objects_cache = {}
Expand Down
61 changes: 61 additions & 0 deletions django/db/models/fetch_modes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from django.core.exceptions import FieldFetchBlocked


class FetchMode:
__slots__ = ()

track_peers = False

def fetch(self, fetcher, instance):
raise NotImplementedError("Subclasses must implement this method.")


class FetchOne(FetchMode):
__slots__ = ()

def fetch(self, fetcher, instance):
fetcher.fetch_one(instance)

def __reduce__(self):
return "FETCH_ONE"


FETCH_ONE = FetchOne()


class FetchPeers(FetchMode):
__slots__ = ()

track_peers = True

def fetch(self, fetcher, instance):
instances = [
peer
for peer_weakref in instance._state.peers
if (peer := peer_weakref()) is not None
]
if len(instances) > 1:
fetcher.fetch_many(instances)
else:
fetcher.fetch_one(instance)

def __reduce__(self):
return "FETCH_PEERS"


FETCH_PEERS = FetchPeers()


class Raise(FetchMode):
__slots__ = ()

def fetch(self, fetcher, instance):
klass = instance.__class__.__qualname__
field_name = fetcher.field.name
raise FieldFetchBlocked(f"Fetching of {klass}.{field_name} blocked.") from None

def __reduce__(self):
return "RAISE"


RAISE = Raise()
Loading