Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve handling of custom objects in BP.insert/replace_bot #2151

Merged
merged 7 commits into from Nov 14, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 30 additions & 7 deletions telegram/ext/basepersistence.py
Expand Up @@ -17,7 +17,7 @@
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the BasePersistence class."""

import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from copy import copy
Expand Down Expand Up @@ -128,12 +128,12 @@ def set_bot(self, bot: Bot) -> None:
self.bot = bot

@classmethod
def replace_bot(cls, obj: object) -> object:
def replace_bot(cls, obj: object) -> object: # pylint: disable=R0911
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def replace_bot(cls, obj: object) -> object: # pylint: disable=R0911
def replace_bot(cls, obj: object) -> object:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pylint complains about the methods having too many return statements, so moving that line to the except Exception will make the test fail …

"""
Replaces all instances of :class:`telegram.Bot` that occur within the passed object with
:attr:`REPLACED_BOT`. Currently, this handles objects of type ``list``, ``tuple``, ``set``,
``frozenset``, ``dict``, ``defaultdict`` and objects that have a ``__dict__`` or
``__slot__`` attribute.
``__slot__`` attribute, excluding objects that can't be copied with `copy.copy`.

Args:
obj (:obj:`object`): The object
Expand All @@ -146,7 +146,18 @@ def replace_bot(cls, obj: object) -> object:
if isinstance(obj, (list, tuple, set, frozenset)):
return obj.__class__(cls.replace_bot(item) for item in obj)

new_obj = copy(obj)
try:
new_obj = copy(obj)
except Exception:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
except Exception:
except Exception: # pylint: disable=R0911

warnings.warn(
'BasePersistence.replace_bot caught an error while trying to copy an object. '
'Objects that can not be copied will be assumed to not contain a telegram.Bot '
'instance and will not be handled further. See the docs of '
'BasePersistence.replace_bot for more information.',
RuntimeWarning,
Bibo-Joshi marked this conversation as resolved.
Show resolved Hide resolved
)
return obj

if isinstance(obj, (dict, defaultdict)):
new_obj = cast(dict, new_obj)
new_obj.clear()
Expand All @@ -173,7 +184,7 @@ def insert_bot(self, obj: object) -> object: # pylint: disable=R0911
Replaces all instances of :attr:`REPLACED_BOT` that occur within the passed object with
:attr:`bot`. Currently, this handles objects of type ``list``, ``tuple``, ``set``,
``frozenset``, ``dict``, ``defaultdict`` and objects that have a ``__dict__`` or
``__slot__`` attribute.
``__slot__`` attribute, excluding objects that can't be copied with `copy.copy`.

Args:
obj (:obj:`object`): The object
Expand All @@ -183,12 +194,23 @@ def insert_bot(self, obj: object) -> object: # pylint: disable=R0911
"""
if isinstance(obj, Bot):
return self.bot
if obj == self.REPLACED_BOT:
if isinstance(obj, str) and obj == self.REPLACED_BOT:
return self.bot
if isinstance(obj, (list, tuple, set, frozenset)):
return obj.__class__(self.insert_bot(item) for item in obj)

new_obj = copy(obj)
try:
new_obj = copy(obj)
except Exception:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
except Exception:
except Exception: # pylint: disable=R0911

warnings.warn(
'BasePersistence.insert_bot caught an error while trying to copy an object. '
'Objects that can not be copied will be assumed to not contain a telegram.Bot '
'instance and will not be handled further. See the docs of '
'BasePersistence.insert_bot for more information.',
RuntimeWarning,
)
Bibo-Joshi marked this conversation as resolved.
Show resolved Hide resolved
return obj

if isinstance(obj, (dict, defaultdict)):
new_obj = cast(dict, new_obj)
new_obj.clear()
Expand All @@ -207,6 +229,7 @@ def insert_bot(self, obj: object) -> object: # pylint: disable=R0911
self.insert_bot(self.insert_bot(getattr(new_obj, attr_name))),
)
return new_obj

return obj

@abstractmethod
Expand Down
145 changes: 110 additions & 35 deletions tests/test_persistence.py
Expand Up @@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
import signal
from threading import Lock

from telegram.utils.helpers import encode_conversations_to_json

Expand Down Expand Up @@ -88,6 +89,42 @@ def update_user_data(self, user_id, data):
return OwnPersistence(store_chat_data=True, store_user_data=True, store_bot_data=True)


@pytest.fixture(scope="function")
def bot_persistence():
class BotPersistence(BasePersistence):
def __init__(self):
super().__init__()
self.bot_data = None
self.chat_data = defaultdict(dict)
self.user_data = defaultdict(dict)

def get_bot_data(self):
return self.bot_data

def get_chat_data(self):
return self.chat_data

def get_user_data(self):
return self.user_data

def get_conversations(self, name):
raise NotImplementedError

def update_bot_data(self, data):
self.bot_data = data

def update_chat_data(self, chat_id, data):
self.chat_data[chat_id] = data

def update_user_data(self, user_id, data):
self.user_data[user_id] = data

def update_conversation(self, name, key, new_state):
raise NotImplementedError

return BotPersistence()


@pytest.fixture(scope="function")
def bot_data():
return {'test1': 'test2', 'test3': {'test4': 'test5'}}
Expand Down Expand Up @@ -437,38 +474,7 @@ class MyUpdate:
dp.process_update(MyUpdate())
assert 'An uncaught error was raised while processing the update' not in caplog.text

def test_bot_replace_insert_bot(self, bot):
class BotPersistence(BasePersistence):
def __init__(self):
super().__init__()
self.bot_data = None
self.chat_data = defaultdict(dict)
self.user_data = defaultdict(dict)

def get_bot_data(self):
return self.bot_data

def get_chat_data(self):
return self.chat_data

def get_user_data(self):
return self.user_data

def get_conversations(self, name):
raise NotImplementedError

def update_bot_data(self, data):
self.bot_data = data

def update_chat_data(self, chat_id, data):
self.chat_data[chat_id] = data

def update_user_data(self, user_id, data):
self.user_data[user_id] = data

def update_conversation(self, name, key, new_state):
raise NotImplementedError

def test_bot_replace_insert_bot(self, bot, bot_persistence):
class CustomSlottedClass:
__slots__ = ('bot',)

Expand Down Expand Up @@ -506,8 +512,6 @@ def replace_bot():

def __eq__(self, other):
if isinstance(other, CustomClass):
# print(self.__dict__)
# print(other.__dict__)
return (
self.bot is other.bot
and self.slotted_object == other.slotted_object
Expand All @@ -520,7 +524,7 @@ def __eq__(self, other):
)
return False

persistence = BotPersistence()
persistence = bot_persistence
persistence.set_bot(bot)
cc = CustomClass()

Expand All @@ -543,6 +547,77 @@ def __eq__(self, other):
assert persistence.get_user_data()[123][1] == cc
assert persistence.get_user_data()[123][1].bot is bot

def test_bot_replace_insert_bot_unpickable_objects(self, bot, bot_persistence, recwarn):
"""Here check that unpickable objects are just returned verbatim."""
persistence = bot_persistence
persistence.set_bot(bot)

class CustomClass:
def __copy__(self):
raise TypeError('UnhandledException')

lock = Lock()

persistence.update_bot_data({1: lock})
assert persistence.bot_data[1] is lock
persistence.update_chat_data(123, {1: lock})
assert persistence.chat_data[123][1] is lock
persistence.update_user_data(123, {1: lock})
assert persistence.user_data[123][1] is lock

assert persistence.get_bot_data()[1] is lock
assert persistence.get_chat_data()[123][1] is lock
assert persistence.get_user_data()[123][1] is lock

cc = CustomClass()

persistence.update_bot_data({1: cc})
assert persistence.bot_data[1] is cc
persistence.update_chat_data(123, {1: cc})
assert persistence.chat_data[123][1] is cc
persistence.update_user_data(123, {1: cc})
assert persistence.user_data[123][1] is cc

assert persistence.get_bot_data()[1] is cc
assert persistence.get_chat_data()[123][1] is cc
assert persistence.get_user_data()[123][1] is cc

assert len(recwarn) == 2
assert str(recwarn[0].message).startswith(
"BasePersistence.replace_bot caught an error while trying to copy an object."
)
assert str(recwarn[1].message).startswith(
"BasePersistence.insert_bot caught an error while trying to copy an object."
)

def test_bot_replace_insert_bot_objects_with_faulty_equality(self, bot, bot_persistence):
"""Here check that trying to compare obj == self.REPLACED_BOT doesn't lead to problems."""
persistence = bot_persistence
persistence.set_bot(bot)

class CustomClass:
def __init__(self, data):
self.data = data

def __eq__(self, other):
raise RuntimeError("Can't be compared")

cc = CustomClass({1: bot, 2: 'foo'})
expected = {1: BasePersistence.REPLACED_BOT, 2: 'foo'}

persistence.update_bot_data({1: cc})
assert persistence.bot_data[1].data == expected
persistence.update_chat_data(123, {1: cc})
assert persistence.chat_data[123][1].data == expected
persistence.update_user_data(123, {1: cc})
assert persistence.user_data[123][1].data == expected

expected = {1: bot, 2: 'foo'}

assert persistence.get_bot_data()[1].data == expected
assert persistence.get_chat_data()[123][1].data == expected
assert persistence.get_user_data()[123][1].data == expected


@pytest.fixture(scope='function')
def pickle_persistence():
Expand Down