Skip to content

Commit

Permalink
feat: implement UserAdapter & ChatAdapter, close #83
Browse files Browse the repository at this point in the history
  • Loading branch information
luwqz1 committed Jun 19, 2024
1 parent 08d007f commit 4b00ebd
Show file tree
Hide file tree
Showing 26 changed files with 359 additions and 175 deletions.
10 changes: 5 additions & 5 deletions examples/upload.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import pathlib

from telegrinder import API, Button, Keyboard, Message, Telegrinder, Token
from telegrinder import API, Message, Telegrinder, Token
from telegrinder.bot.rules.is_from import IsPrivate
from telegrinder.modules import logger
from telegrinder.rules import Text
from telegrinder.rules import IsUser, Text
from telegrinder.types import InputFile

api = API(token=Token.from_env())
bot = Telegrinder(api)
kb = (Keyboard().add(Button("Button 1")).add(Button("Button 2"))).get_markup()
cool_bytes = pathlib.Path("assets/satie.jpeg").read_bytes()
cool_bytes = pathlib.Path("examples/assets/satie.jpeg").read_bytes()

logger.set_level("INFO")


@bot.on.message(Text("/photo"))
@bot.on.message(Text("/photo"), IsPrivate() & IsUser())
async def start(message: Message):
await message.answer_photo(
InputFile("satie.jpeg", cool_bytes),
Expand Down
6 changes: 6 additions & 0 deletions telegrinder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ async def start(message: Message):
BaseView,
CallbackQueryCute,
CallbackQueryReturnManager,
CallbackQueryRule,
CallbackQueryView,
ChatJoinRequestCute,
ChatJoinRequestRule,
ChatJoinRequestView,
ChatMemberUpdatedCute,
ChatMemberView,
Expand All @@ -63,6 +65,7 @@ async def start(message: Message):
FuncHandler,
InlineQueryCute,
InlineQueryReturnManager,
InlineQueryRule,
MessageCute,
MessageReplyHandler,
MessageReturnManager,
Expand Down Expand Up @@ -153,6 +156,9 @@ async def start(message: Message):
"CallbackQueryView",
"ChatJoinRequest",
"ChatJoinRequestCute",
"CallbackQueryRule",
"ChatJoinRequestRule",
"InlineQueryRule",
"ChatJoinRequestView",
"ChatMemberUpdated",
"ChatMemberUpdatedCute",
Expand Down
4 changes: 3 additions & 1 deletion telegrinder/bot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
register_manager,
)
from .polling import ABCPolling, Polling
from .rules import ABCRule, CallbackQueryRule, MessageRule
from .rules import ABCRule, CallbackQueryRule, ChatJoinRequestRule, InlineQueryRule, MessageRule
from .scenario import ABCScenario, Checkbox, Choice

__all__ = (
Expand All @@ -59,6 +59,8 @@
"CallbackQueryReturnManager",
"CallbackQueryRule",
"CallbackQueryView",
"ChatJoinRequestRule",
"InlineQueryRule",
"ChatJoinRequestCute",
"ChatJoinRequestView",
"ChatMemberUpdatedCute",
Expand Down
12 changes: 4 additions & 8 deletions telegrinder/bot/cute_types/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,12 @@ class UpdateCute(BaseCute[Update], Update, kw_only=True):
api: ABCAPI

@property
def incoming_update(self) -> Option[Model]:
return getattr(
self,
self.update_type.expect("Update object has no incoming update.").value,
)
def incoming_update(self) -> Model:
return getattr(self, self.update_type.value).unwrap()

def get_event(self, event_model: type[ModelT]) -> Option[ModelT]:
match self.incoming_update:
case Some(event) if isinstance(event, event_model):
return Some(event)
if isinstance(self.incoming_update, event_model):
return Some(self.incoming_update)
return Nothing()


Expand Down
2 changes: 1 addition & 1 deletion telegrinder/bot/dispatch/handler/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __repr__(self) -> str:
)

async def check(self, api: ABCAPI, event: Update, ctx: Context | None = None) -> bool:
if self.update_type is not None and self.update_type != event.update_type.unwrap_or_none():
if self.update_type is not None and self.update_type != event.update_type:
return False
ctx = ctx or Context()
temp_ctx = ctx.copy()
Expand Down
3 changes: 3 additions & 0 deletions telegrinder/bot/dispatch/handler/message_reply.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ def __init__(
is_blocking: bool = True,
as_reply: bool = False,
preset_context: Context | None = None,
**default_params: typing.Any,
) -> None:
self.text = text
self.rules = list(rules)
self.as_reply = as_reply
self.is_blocking = is_blocking
self.default_params = default_params
self.preset_context = preset_context or Context()

def __repr__(self) -> str:
Expand Down Expand Up @@ -51,6 +53,7 @@ async def run(self, event: MessageCute, _: Context) -> typing.Any:
await event.answer(
text=self.text,
reply_parameters=ReplyParameters(event.message_id) if self.as_reply else None,
**self.default_params,
)


Expand Down
13 changes: 7 additions & 6 deletions telegrinder/bot/dispatch/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@
from telegrinder.bot.dispatch.handler.abc import ABCHandler
from telegrinder.bot.rules.abc import ABCRule

T = typing.TypeVar("T", bound=BaseCute)
AdaptTo = typing.TypeVar("AdaptTo")
Event = typing.TypeVar("Event", bound=BaseCute)
_: typing.TypeAlias = typing.Any


async def process_inner(
event: T,
event: Event,
raw_event: Update,
middlewares: list[ABCMiddleware[T]],
handlers: list["ABCHandler[T]"],
return_manager: ABCReturnManager[T] | None = None,
middlewares: list[ABCMiddleware[Event]],
handlers: list["ABCHandler[Event]"],
return_manager: ABCReturnManager[Event] | None = None,
) -> bool:
logger.debug("Processing {!r}...", event.__class__.__name__)
ctx = Context(raw_update=raw_event)
Expand Down Expand Up @@ -58,7 +59,7 @@ async def process_inner(

async def check_rule(
api: ABCAPI,
rule: "ABCRule[T]",
rule: "ABCRule[Event, AdaptTo]",
update: Update,
ctx: Context,
) -> bool:
Expand Down
6 changes: 1 addition & 5 deletions telegrinder/bot/dispatch/view/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,7 @@ def get_event_type(cls) -> Option[type[EventType]]:

@staticmethod
def get_raw_event(update: Update) -> Option[Model]:
match update.update_type:
case Some(update_type):
return getattr(update, update_type.value)
case _:
return Nothing()
return getattr(update, update.update_type.value)

@typing.overload
@classmethod
Expand Down
2 changes: 1 addition & 1 deletion telegrinder/bot/dispatch/view/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def check(self, event: Update) -> bool:
return (
True
if self.update_type is None
else self.update_type == event.update_type.unwrap_or_none()
else self.update_type == event.update_type
)


Expand Down
62 changes: 44 additions & 18 deletions telegrinder/bot/rules/abc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import inspect
import typing
from abc import ABC, abstractmethod

import typing_extensions as typing

from telegrinder.bot.cute_types import BaseCute, MessageCute, UpdateCute
from telegrinder.bot.dispatch.context import Context
from telegrinder.bot.dispatch.process import check_rule
Expand All @@ -10,13 +11,14 @@
from telegrinder.tools.magic import cache_translation, get_cached_translation
from telegrinder.types.objects import Update as UpdateObject

T = typing.TypeVar("T", bound=BaseCute)
AdaptTo = typing.TypeVar("AdaptTo", default=UpdateCute)
EventCute = typing.TypeVar("EventCute", bound=BaseCute, default=UpdateCute)

Message: typing.TypeAlias = MessageCute
Update: typing.TypeAlias = UpdateCute


def with_caching_translations(func):
def with_caching_translations(func: typing.Callable[..., typing.Any]):
"""Should be used as decorator for .translate method. Caches rule translations."""

async def wrapper(self: "ABCRule[typing.Any]", translator: ABCTranslator):
Expand All @@ -29,15 +31,15 @@ async def wrapper(self: "ABCRule[typing.Any]", translator: ABCTranslator):
return wrapper


class ABCRule(ABC, typing.Generic[T]):
adapter: ABCAdapter[UpdateObject, T] = RawUpdateAdapter() # type: ignore
requires: list["ABCRule[T]"] = []
class ABCRule(ABC, typing.Generic[EventCute, AdaptTo]):
adapter: ABCAdapter[UpdateObject, AdaptTo] = RawUpdateAdapter() # type: ignore
requires: list["ABCRule[EventCute]"] = []

@abstractmethod
async def check(self, event: T, ctx: Context) -> bool:
async def check(self, event: AdaptTo, ctx: Context) -> bool:
pass

def __init_subclass__(cls, requires: list["ABCRule[T]"] | None = None):
def __init_subclass__(cls, requires: list["ABCRule[EventCute, AdaptTo]"] | None = None) -> None:
"""Merges requirements from inherited classes and rule-specific requirements."""

requirements = []
Expand All @@ -48,17 +50,41 @@ def __init_subclass__(cls, requires: list["ABCRule[T]"] | None = None):
requirements.extend(requires or ())
cls.requires = list(dict.fromkeys(requirements))

def __and__(self, other: "ABCRule[T]"):
def __and__(self, other: "ABCRule[EventCute, AdaptTo]") -> "AndRule[EventCute, AdaptTo]":
"""And Rule.
```python
rule = HasText() & HasCaption()
rule #> AndRule(HasText(), HasCaption()) -> True if all rules in an AndRule are True, otherwise False.
```
"""

return AndRule(self, other)

def __or__(self, other: "ABCRule[T]"):
def __or__(self, other: "ABCRule[EventCute, AdaptTo]") -> "OrRule[EventCute, AdaptTo]":
"""Or Rule.
```python
rule = HasText() | HasCaption()
rule #> OrRule(HasText(), HasCaption()) -> True if any rule in an OrRule are True, otherwise False.
```
"""

return OrRule(self, other)

def __neg__(self) -> "ABCRule[T]":
def __invert__(self) -> "NotRule[EventCute, AdaptTo]":
"""Not Rule.
```python
rule = ~HasText()
rule # NotRule(HasText()) -> True if rule returned False, otherwise False.
```
"""

return NotRule(self)

def __repr__(self) -> str:
return "<rule: {!r}, adapter: {!r}>".format(
return "<{}: adapter={!r}>".format(
self.__class__.__name__,
self.adapter,
)
Expand All @@ -67,8 +93,8 @@ async def translate(self, translator: ABCTranslator) -> typing.Self:
return self


class AndRule(ABCRule[T]):
def __init__(self, *rules: ABCRule[T]):
class AndRule(ABCRule[EventCute, AdaptTo]):
def __init__(self, *rules: ABCRule[EventCute, AdaptTo]) -> None:
self.rules = rules

async def check(self, event: Update, ctx: Context) -> bool:
Expand All @@ -80,8 +106,8 @@ async def check(self, event: Update, ctx: Context) -> bool:
return True


class OrRule(ABCRule[T]):
def __init__(self, *rules: ABCRule[T]):
class OrRule(ABCRule[EventCute, AdaptTo]):
def __init__(self, *rules: ABCRule[EventCute, AdaptTo]) -> None:
self.rules = rules

async def check(self, event: Update, ctx: Context) -> bool:
Expand All @@ -93,8 +119,8 @@ async def check(self, event: Update, ctx: Context) -> bool:
return False


class NotRule(ABCRule[T]):
def __init__(self, rule: ABCRule[T]):
class NotRule(ABCRule[EventCute, AdaptTo]):
def __init__(self, rule: ABCRule[EventCute, AdaptTo]) -> None:
self.rule = rule

async def check(self, event: Update, ctx: Context) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions telegrinder/bot/rules/adapter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from .errors import AdapterError
from .event import EventAdapter
from .raw_update import RawUpdateAdapter
from .user import UserAdapter

__all__ = (
"ABCAdapter",
"AdapterError",
"EventAdapter",
"RawUpdateAdapter",
"UserAdapter",
)
8 changes: 4 additions & 4 deletions telegrinder/bot/rules/adapter/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from telegrinder.bot.rules.adapter.errors import AdapterError
from telegrinder.model import Model

UpdateT = typing.TypeVar("UpdateT", bound=Model)
CuteT = typing.TypeVar("CuteT", bound=BaseCute)
From = typing.TypeVar("From", bound=Model)
To = typing.TypeVar("To")


class ABCAdapter(abc.ABC, typing.Generic[UpdateT, CuteT]):
class ABCAdapter(abc.ABC, typing.Generic[From, To]):
@abc.abstractmethod
async def adapt(self, api: ABCAPI, update: UpdateT) -> Result[CuteT, AdapterError]:
async def adapt(self, api: ABCAPI, update: From) -> Result[To, AdapterError]:
pass


Expand Down
38 changes: 38 additions & 0 deletions telegrinder/bot/rules/adapter/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import typing

from fntypes.result import Error, Ok, Result

from telegrinder.api.abc import ABCAPI
from telegrinder.bot.cute_types.base import BaseCute
from telegrinder.bot.rules.adapter.abc import ABCAdapter
from telegrinder.bot.rules.adapter.errors import AdapterError
from telegrinder.bot.rules.adapter.raw_update import RawUpdateAdapter
from telegrinder.bot.rules.adapter.utils import Source, get_by_sources
from telegrinder.types.objects import Chat, Update

ToCute = typing.TypeVar("ToCute", bound=BaseCute)


@typing.runtime_checkable
class HasChat(Source, typing.Protocol):
chat: Chat


class ChatAdapter(ABCAdapter[Update, Chat]):
def __init__(self) -> None:
self.raw_adapter = RawUpdateAdapter()

def __repr__(self) -> str:
return f"<{self.__class__.__name__}: Update -> UpdateCute -> Chat>"

async def adapt(self, api: ABCAPI, update: Update) -> Result[Chat, AdapterError]:
match await self.raw_adapter.adapt(api, update):
case Ok(event):
if (source := get_by_sources(event.incoming_update, HasChat)):
return Ok(source)
return Error(AdapterError(f"{event.incoming_update.__class__.__name__!r} has no chat."))
case Error(_) as error:
return error


__all__ = ("ChatAdapter",)
3 changes: 1 addition & 2 deletions telegrinder/bot/rules/adapter/errors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
class AdapterError(RuntimeError):
pass
class AdapterError(RuntimeError): ...


__all__ = ("AdapterError",)
Loading

0 comments on commit 4b00ebd

Please sign in to comment.