From 8dde110523f7ae0700f39fc4606df298e57ad67d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Daniel?= Date: Thu, 28 Mar 2024 00:04:47 +0100 Subject: [PATCH] wip --- germanium/tools/django.py | 19 +++++++++++++++---- germanium/tools/http.py | 4 ++-- germanium/tools/rest.py | 4 ++-- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/germanium/tools/django.py b/germanium/tools/django.py index 7d75352..548b09d 100644 --- a/germanium/tools/django.py +++ b/germanium/tools/django.py @@ -6,6 +6,10 @@ from django.db import DEFAULT_DB_ALIAS, connections +RUN_PRE_COMMIT = "run_pre_commit" +RUN_ON_COMMIT = "run_on_commit" + + class CatchCallbacks: def __init__(self, connection, callback_name): @@ -28,11 +32,18 @@ def end(self): def get_callbacks(self, start_count=None): start_count = self._start_count if start_count is None else start_count - if self._end_count == None: + if self._end_count is None: watching_callbacks = self._watching_callbacks[start_count:] else: watching_callbacks = self._watching_callbacks[start_count:self._end_count] - return [callbacks[-1] for callbacks in watching_callbacks] + + callbacks = [] + for watching_callback in watching_callbacks: + if self._callback_name == RUN_PRE_COMMIT: + callbacks.append(watching_callback[-1]) + elif self._callback_name == RUN_ON_COMMIT: + callbacks.append(watching_callback[-2]) + return callbacks def execute(self): if self._executed: @@ -60,8 +71,8 @@ def execute(self): class CommitCallbacks: def __init__(self, connection): - self.pre_commit = CatchCallbacks(connection, 'run_pre_commit') - self.on_commit = CatchCallbacks(connection, 'run_on_commit') + self.pre_commit = CatchCallbacks(connection, RUN_PRE_COMMIT) + self.on_commit = CatchCallbacks(connection, RUN_ON_COMMIT) def end(self): self.pre_commit.end() diff --git a/germanium/tools/http.py b/germanium/tools/http.py index 9f792e2..ce2cedd 100644 --- a/germanium/tools/http.py +++ b/germanium/tools/http.py @@ -1,12 +1,12 @@ from urllib.parse import urlencode -from django.utils.encoding import force_text +from django.utils.encoding import force_str from .trivials import assert_equal, assert_in def get_full_path(*paths): - string_paths = (force_text(path) for path in paths) + string_paths = (force_str(path) for path in paths) full_path = '/'.join((path[:-1] if path.endswith('/') else path for path in string_paths)) return full_path if full_path.endswith('/') else full_path + '/' diff --git a/germanium/tools/rest.py b/germanium/tools/rest.py index cc80d6d..1391793 100644 --- a/germanium/tools/rest.py +++ b/germanium/tools/rest.py @@ -1,6 +1,6 @@ import json -from django.utils.encoding import force_text +from django.utils.encoding import force_str from .trivials import assert_true, assert_equal, assert_in, fail from .http import assert_http_ok, assert_http_created @@ -12,7 +12,7 @@ def assert_valid_JSON(data, msg='Json is not valid'): can be loaded properly. """ try: - json.loads(force_text(data)) + json.loads(force_str(data)) except: fail(msg)