Skip to content

Commit

Permalink
Slight DI refactor for better disabling and user function support
Browse files Browse the repository at this point in the history
  • Loading branch information
tandemdude committed Feb 17, 2024
1 parent 9266754 commit b6930e6
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 30 deletions.
3 changes: 3 additions & 0 deletions lightbulb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from lightbulb.client import *
from lightbulb.commands import *
from lightbulb.context import *
from lightbulb.internal import *

__all__ = [
"exceptions",
Expand Down Expand Up @@ -54,6 +55,8 @@
"mentionable",
"attachment",
"Context",
"ensure_di_context",
"with_di",
]

__version__ = "3.0.0"
22 changes: 10 additions & 12 deletions lightbulb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,18 +314,16 @@ async def handle_application_command_interaction(self, interaction: hikari.Comma

LOGGER.debug("%r - invoking command", " ".join(command_path))

token = di._di_container.set(self._di_container)
try:
await execution.ExecutionPipeline(context, self._execution_step_order)._run()
except Exception as e:
# TODO - dispatch to error handler
LOGGER.error(
"Error encountered during invocation of command %r",
" ".join(command_path),
exc_info=(type(e), e, e.__traceback__),
)
finally:
di._di_container.reset(token)
with di.ensure_di_context(self):
try:
await execution.ExecutionPipeline(context, self._execution_step_order)._run()
except Exception as e:
# TODO - dispatch to error handler
LOGGER.error(
"Error encountered during invocation of command %r",
" ".join(command_path),
exc_info=(type(e), e, e.__traceback__),
)


class GatewayEnabledClient(Client):
Expand Down
8 changes: 2 additions & 6 deletions lightbulb/commands/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,7 @@ def only_on_mondays(pl: lightbulb.ExecutionPipeline, _: lightbulb.Context) -> No
"""

def inner(func: ExecutionHookFuncT) -> ExecutionHook:
if not isinstance(func, di.LazyInjecting):
func = di.LazyInjecting(func) # type: ignore[reportArgumentType]

return ExecutionHook(step, func)
return ExecutionHook(step, di.with_di(func)) # type: ignore[reportArgumentType]

return inner

Expand Down Expand Up @@ -292,8 +289,7 @@ class ExampleCommand(
async def invoke(self, ctx: lightbulb.Context) -> None:
await ctx.respond("example")
"""
if not isinstance(func, di.LazyInjecting):
func = di.LazyInjecting(func)
func = di.with_di(func)

setattr(func, "__lb_cmd_invoke_method__", "_")
return func
4 changes: 4 additions & 0 deletions lightbulb/internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@
#
# You should have received a copy of the GNU Lesser General Public License
# along with Lightbulb. If not, see <https://www.gnu.org/licenses/>.
from lightbulb.internal.di import ensure_di_context
from lightbulb.internal.di import with_di

__all__ = ["ensure_di_context", "with_di"]
76 changes: 64 additions & 12 deletions lightbulb/internal/di.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# along with Lightbulb. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations

import contextlib
import contextvars
import inspect
import os
Expand All @@ -25,8 +26,39 @@
if t.TYPE_CHECKING:
import svcs

AnyCallableT = t.TypeVar("AnyCallableT", bound=t.Callable[..., t.Any])
_di_container: contextvars.ContextVar[svcs.Container] = contextvars.ContextVar("_di_container")
from lightbulb import client as client_

AnyAsyncCallableT = t.TypeVar("AnyAsyncCallableT", bound=t.Callable[..., t.Awaitable[t.Any]])


DI_ENABLED: t.Final[bool] = os.environ.get("LIGHTBULB_DI_DISABLED", "false").lower() == "true"
DI_CONTAINER: contextvars.ContextVar[svcs.Container] = contextvars.ContextVar("_di_container")


@contextlib.contextmanager
def ensure_di_context(client: client_.Client) -> t.Generator[None, t.Any, t.Any]:
"""
Context manager that ensures a dependency injection context is available for the nested operations.
Args:
client (:obj:`~lightbulb.client.Client`): The client that "hosts" the dependency injection context.
I.e. knows about the dependencies that will be needed.
Example:
.. code-block:: python
with lightbulb.ensure_di_context(client):
await some_function_that_needs_dependencies()
"""
if DI_ENABLED:
token = DI_CONTAINER.set(client._di_container)
try:
yield
finally:
DI_CONTAINER.reset(token)
else:
yield


def find_injectable_kwargs(
Expand Down Expand Up @@ -80,9 +112,14 @@ class LazyInjecting:
You should generally never have to instantiate this yourself - you should instead use one of the
decorators that applies this to the target automatically.
See Also:
:obj:`~with_di`
:obj:`~lightbulb.commands.execcution.hook`
:obj:`~lightbulb.commands.execution.invoke`
"""

__slots__ = ("_func", "_processed", "_self", "__lb_cmd_invoke_method__")
__slots__ = ("_func", "_processed", "_self")

def __init__(
self,
Expand All @@ -97,11 +134,17 @@ def __get__(self, instance: t.Any, owner: t.Type[t.Any]) -> LazyInjecting:
return LazyInjecting(self._func, instance)
return self

def __getattr__(self, item: str) -> t.Any:
return getattr(self._func, item)

def __setattr__(self, key: str, value: t.Any) -> None:
setattr(self._func, key, value)

async def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
new_kwargs: t.Dict[str, t.Any] = {}
new_kwargs.update(kwargs)

di_container: t.Optional[svcs.Container] = _di_container.get(None)
di_container: t.Optional[svcs.Container] = DI_CONTAINER.get(None)
if di_container is None:
raise RuntimeError("cannot prepare dependency injection as client not yet populated")

Expand All @@ -115,14 +158,23 @@ async def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
return await self._func(*args, **new_kwargs)


if os.environ.get("LIGHTBULB_DI_DISABLED", "false").lower() == "true":
def with_di(func: AnyAsyncCallableT) -> AnyAsyncCallableT:
"""
Enables dependency injection on the decorated asynchronous function. If dependency injection
has been disabled globally then this function does nothing and simply returns the object that was passed in.
class FakeLazyInjecting:
__slots__ = ()
Args:
func: The asynchronous function to enable dependency injection for
# To disable DI we just replace the LazyInjecting class with one that does nothing
# TODO - maybe look into doing this a different way in the future
def __new__(cls, func: AnyCallableT, *args: t.Any, **kwargs: t.Any) -> AnyCallableT:
return func
Returns:
The function with dependency injection enabled, or the same function if DI has been disabled globally.
LazyInjecting = FakeLazyInjecting # type: ignore[reportAssignmentType]
Warning:
Dependency injection relies on a context (note: not a lightbulb :obj:`~lightbulb.context.Context`) being
available when the function is called. If the function is called during a lightbulb-controlled flow
(such as command invocation or error handling), then one will be available automatically. Otherwise,
you will have to set up the context yourself using the helper context manager :obj:`~setup_di_context`.
"""
if DI_ENABLED:
return LazyInjecting(func) # type: ignore[reportReturnType]
return func

0 comments on commit b6930e6

Please sign in to comment.