Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion django/contrib/gis/db/models/lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class GISLookup(Lookup):
band_lhs = None

def __init__(self, lhs, rhs):
rhs, *self.rhs_params = rhs if isinstance(rhs, (list, tuple)) else [rhs]
rhs, *self.rhs_params = rhs if isinstance(rhs, (list, tuple)) else (rhs,)
super().__init__(lhs, rhs)
self.template_params = {}
self.process_rhs_params()
Expand Down
6 changes: 3 additions & 3 deletions django/db/models/fields/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,8 @@ class JSONExact(lookups.Exact):
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
# Treat None lookup values as null.
if rhs == "%s" and rhs_params == [None]:
rhs_params = ["null"]
if rhs == "%s" and (*rhs_params,) == (None,):
rhs_params = ("null",)
if connection.vendor == "mysql":
func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
rhs %= tuple(func)
Expand Down Expand Up @@ -552,7 +552,7 @@ def process_rhs(self, compiler, connection):

def as_oracle(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
if rhs_params == ["null"]:
if rhs_params and (*rhs_params,) == ("null",):
# Field has key and it's NULL.
has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name)
has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
Expand Down
6 changes: 3 additions & 3 deletions django/db/models/lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,12 @@ def process_lhs(self, compiler, connection, lhs=None):
lhs_sql = (
connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
)
return lhs_sql, list(params)
return lhs_sql, tuple(params)

def as_sql(self, compiler, connection):
lhs_sql, params = self.process_lhs(compiler, connection)
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params)
params = (*params, *rhs_params)
rhs_sql = self.get_rhs_op(connection, rhs_sql)
return "%s %s" % (lhs_sql, rhs_sql), params

Expand Down Expand Up @@ -725,7 +725,7 @@ def as_sql(self, compiler, connection):
rhs_sql, _ = self.process_rhs(compiler, connection)
rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
start, finish = self.year_lookup_bounds(connection, self.rhs)
params.extend(self.get_bound_params(start, finish))
params = (*params, *self.get_bound_params(start, finish))
return "%s %s" % (lhs_sql, rhs_sql), params
return super().as_sql(compiler, connection)

Expand Down
4 changes: 2 additions & 2 deletions django/db/models/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1703,8 +1703,8 @@ def field_as_sql(self, field, get_placeholder, val):
sql, params = "%s", [val]

# The following hook is only used by Oracle Spatial, which sometimes
# needs to yield 'NULL' and [] as its placeholder and params instead
# of '%s' and [None]. The 'NULL' placeholder is produced earlier by
# needs to yield 'NULL' and () as its placeholder and params instead
# of '%s' and (None,). The 'NULL' placeholder is produced earlier by
# OracleOperations.get_geom_placeholder(). The following line removes
# the corresponding None parameter. See ticket #10888.
params = self.connection.ops.modify_insert_params(sql, params)
Expand Down
19 changes: 17 additions & 2 deletions django/test/runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import ctypes
import faulthandler
import functools
import hashlib
import io
import itertools
Expand Down Expand Up @@ -485,6 +486,16 @@ def _init_worker(
)


def _safe_init_worker(init_worker, counter, *args, **kwargs):
try:
init_worker(counter, *args, **kwargs)
except Exception:
with counter.get_lock():
# Set a value that will not increment above zero any time soon.
counter.value = -1000
raise


def _run_subsuite(args):
"""
Run a suite of tests with a RemoteTestRunner and return a RemoteTestResult.
Expand Down Expand Up @@ -558,7 +569,7 @@ def run(self, result):
counter = multiprocessing.Value(ctypes.c_int, 0)
pool = multiprocessing.Pool(
processes=self.processes,
initializer=self.init_worker.__func__,
initializer=functools.partial(_safe_init_worker, self.init_worker.__func__),
initargs=[
counter,
self.initial_settings,
Expand All @@ -585,7 +596,11 @@ def run(self, result):

try:
subsuite_index, events = test_results.next(timeout=0.1)
except multiprocessing.TimeoutError:
except multiprocessing.TimeoutError as err:
if counter.value < 0:
err.add_note("ERROR: _init_worker failed, see prior traceback")
pool.close()
raise
continue
except StopIteration:
pool.close()
Expand Down
2 changes: 1 addition & 1 deletion tests/custom_lookups/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def as_sql(self, compiler, connection):
real_lhs = self.lhs.lhs
lhs_sql, params = self.process_lhs(compiler, connection, real_lhs)
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params)
params = (*params, *rhs_params)
# Build SQL where the integer year is concatenated with last month
# and day, then convert that to date. (We try to have SQL like:
# WHERE somecol <= '2013-12-31')
Expand Down
7 changes: 5 additions & 2 deletions tests/test_runner/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import collections.abc
import functools
import multiprocessing
import os
import sys
Expand Down Expand Up @@ -738,8 +739,10 @@ def run_suite(self, suite, **kwargs):
"test_runner_apps.simple.tests",
]
)
# Initializer must be a function.
self.assertIs(mocked_pool.call_args.kwargs["initializer"], _init_worker)
# Initializer must be a partial function binding _init_worker.
initializer = mocked_pool.call_args.kwargs["initializer"]
self.assertIsInstance(initializer, functools.partial)
self.assertIs(initializer.args[0], _init_worker)
initargs = mocked_pool.call_args.kwargs["initargs"]
self.assertEqual(len(initargs), 7)
self.assertEqual(initargs[5], True) # debug_mode
Expand Down