diff --git a/bot.py b/bot.py index 725afefc37..5201434ec0 100644 --- a/bot.py +++ b/bot.py @@ -583,7 +583,7 @@ async def on_ready(self): ) logger.line() - await self.threads.populate_cache() + await self.threads.quick_populate_cache() # closures closures = self.config["closures"] @@ -621,21 +621,10 @@ async def on_ready(self): for log in await self.api.get_open_logs(): if self.get_channel(int(log["channel_id"])) is None: logger.debug("Unable to resolve thread with channel %s.", log["channel_id"]) - log_data = await self.api.post_log( - log["channel_id"], - { - "open": False, - "title": None, - "closed_at": str(discord.utils.utcnow()), - "close_message": "Channel has been deleted, no closer found.", - "closer": { - "id": str(self.user.id), - "name": self.user.name, - "discriminator": self.user.discriminator, - "avatar_url": self.user.display_avatar.url, - "mod": True, - }, - }, + log_data = await self.api.close_log( + channel_id=log["channel_id"], + close_message="Channel has been deleted, no closer found.", + closer=self.user, ) if log_data: logger.debug("Successfully closed thread with channel %s.", log["channel_id"]) diff --git a/core/clients.py b/core/clients.py index e80225d9e9..7fc794782c 100644 --- a/core/clients.py +++ b/core/clients.py @@ -1,9 +1,10 @@ import secrets import sys from json import JSONDecodeError -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union import discord +import pymongo.results from aiohttp import ClientResponse, ClientResponseError from discord import DMChannel, Member, Message, TextChannel from discord.ext import commands @@ -429,6 +430,12 @@ async def update_repository(self) -> dict: async def get_user_info(self) -> Optional[dict]: return NotImplemented + async def add_recipients(self, channel_id: int, recipient: List[discord.User]): + return NotImplemented + + async def close_log(self, channel_id: int, title: str, close_message: str, closer: discord.User) -> dict: + return NotImplemented + async def update_title(self, title: str, channel_id: Union[str, int]): return NotImplemented @@ -566,6 +573,9 @@ async def get_log(self, channel_id: Union[str, int]) -> dict: logger.debug("Retrieving channel %s logs.", channel_id) return await self.logs.find_one({"channel_id": str(channel_id)}) + async def get_logs(self, channel_id: List[Union[str, int]]) -> dict: + return await self.logs.find({"channel_id": {"$in": [str(i) for i in channel_id]}}).to_list(None) + async def get_log_link(self, channel_id: Union[str, int]) -> str: doc = await self.get_log(channel_id) logger.debug("Retrieving log link for channel %s.", channel_id) @@ -593,6 +603,7 @@ async def create_log_entry(self, recipient: Member, channel: TextChannel, creato "recipient": { "id": str(recipient.id), "name": recipient.name, + "global_name": recipient.global_name, "discriminator": recipient.discriminator, "avatar_url": recipient.display_avatar.url, "mod": False, @@ -600,6 +611,7 @@ async def create_log_entry(self, recipient: Member, channel: TextChannel, creato "creator": { "id": str(creator.id), "name": creator.name, + "global_name": creator.global_name, "discriminator": creator.discriminator, "avatar_url": creator.display_avatar.url, "mod": isinstance(creator, Member), @@ -662,6 +674,7 @@ async def append_log( "author": { "id": str(message.author.id), "name": message.author.name, + "global_name": message.author.global_name, "discriminator": message.author.discriminator, "avatar_url": message.author.display_avatar.url, "mod": not isinstance(message.channel, DMChannel), @@ -714,6 +727,7 @@ async def create_note(self, recipient: Member, message: Message, message_id: Uni "author": { "id": str(message.author.id), "name": message.author.name, + "global_name": message.author.global_name, "discriminator": message.author.discriminator, "avatar_url": message.author.display_avatar.url, }, @@ -735,6 +749,52 @@ async def delete_note(self, message_id: Union[int, str]): async def edit_note(self, message_id: Union[int, str], message: str): await self.db.notes.update_one({"message_id": str(message_id)}, {"$set": {"message": message}}) + async def add_recipients(self, channel_id: int, recipient: List[discord.User]): + results: pymongo.results.UpdateResult = await self.bot.db.logs.update_one( + {"channel_id": str(channel_id)}, + { + "$addToSet": { + "other_recipients": { + "$each": [ + { + "id": r.id, + "name": r.name, + "global_name": r.global_name, + "discriminator": r.discriminator, + "avatar_url": r.display_avatar.url, + } + for r in recipient + ] + } + } + }, + ) + if results.matched_count == 0: + raise ValueError(f"Channel id {channel_id} not found in mongodb") + return + + async def close_log(self, channel_id: int, title: str, close_message: str, closer: discord.User) -> dict: + # TODO doesn't set title yet + return await self.bot.db.logs.find_one_and_update( + {"channel_id": str(channel_id)}, + { + "$set": { + "open": False, + "closed_at": str(discord.utils.utcnow()), + "title": title, + "close_message": close_message, + "closer": { + "id": str(closer.id), + "name": closer.name, + "global_name": closer.name, + "discriminator": closer.discriminator, + "avatar_url": closer.display_avatar.url, + "mod": True, + }, + }, + }, + ) + def get_plugin_partition(self, cog): cls_name = cog.__class__.__name__ return self.db.plugins[cls_name] diff --git a/core/thread.py b/core/thread.py index 556c1a30dc..aeed08bdd2 100644 --- a/core/thread.py +++ b/core/thread.py @@ -5,6 +5,7 @@ import typing import warnings from datetime import timedelta +from time import perf_counter import discord import isodate @@ -441,22 +442,11 @@ async def _close(self, closer, silent=False, delete_channel=True, message=None, # Logging if self.channel: - log_data = await self.bot.api.post_log( - self.channel.id, - { - "open": False, - "title": match_title(self.channel.topic), - "closed_at": str(discord.utils.utcnow()), - "nsfw": self.channel.nsfw, - "close_message": message, - "closer": { - "id": str(closer.id), - "name": closer.name, - "discriminator": closer.discriminator, - "avatar_url": closer.display_avatar.url, - "mod": True, - }, - }, + log_data = await self.bot.api.close_log( + channel_id=self.channel.id, + title=match_title(self.channel.topic), + closer=closer, + close_message=message, ) else: log_data = None @@ -1217,6 +1207,9 @@ async def add_users(self, users: typing.List[typing.Union[discord.Member, discor topic += f"\nOther Recipients: {ids}" + # Add recipients to database + await self.bot.api.add_recipients(self._channel.id, users) + await self.channel.edit(topic=topic) await self._update_users_genesis() @@ -1244,11 +1237,14 @@ class ThreadManager: def __init__(self, bot): self.bot = bot - self.cache = {} + self.cache: typing.Dict[int, Thread] = {} async def populate_cache(self) -> None: + # time method runtime + start = perf_counter() for channel in self.bot.modmail_guild.text_channels: await self.find(channel=channel) + logger.info("Cache populated in %fs.", time.perf_counter() - start) def __len__(self): return len(self.cache) @@ -1259,6 +1255,27 @@ def __iter__(self): def __getitem__(self, item: str) -> Thread: return self.cache[item] + async def quick_populate_cache(self) -> None: + start = perf_counter() + + # create a list containing the id of every text channel in the modmail guild + channel_ids = [channel.id for channel in self.bot.modmail_guild.text_channels] + logs = await self.bot.api.get_logs(channel_ids) + + for log in logs: + recipients = log["other_recipients"] + + tasks = [self.bot.get_or_fetch_user(user_data["id"]) for user_data in recipients] + recipient_users: list[discord.Member] = await asyncio.gather(*tasks) + + self.cache[log["recipient"]["id"]] = Thread( + self, + recipient=log["creator"]["id"], + channel=log["channel_id"], + other_recipients=recipient_users, + ) + logger.debug("Cache populated in %fs.", perf_counter() - start) + async def find( self, *, @@ -1322,44 +1339,36 @@ def check(topic): return thread - async def _find_from_channel(self, channel): + async def _find_from_channel(self, channel) -> typing.Optional[Thread]: """ - Tries to find a thread from a channel channel topic, - if channel topic doesnt exist for some reason, falls back to + Tries to find a thread from a channel topic, + if channel topic doesn't exist for some reason, falls back to searching channel history for genesis embed and extracts user_id from that. """ - if not channel.topic: - return None + logger.debug("_find_from_channel") + logger.debug(f"channel: {channel}") - _, user_id, other_ids = parse_channel_topic(channel.topic) + # TODO cache thread for channel ID - if user_id == -1: - return None + log = await self.bot.api.get_log(channel.id) - if user_id in self.cache: - return self.cache[user_id] + if log is None: + return None - try: - recipient = await self.bot.get_or_fetch_user(user_id) - except discord.NotFound: - recipient = None + logger.debug("This is a thread channel") - other_recipients = [] - for uid in other_ids: - try: - other_recipient = await self.bot.get_or_fetch_user(uid) - except discord.NotFound: - continue - other_recipients.append(other_recipient) + recipients = log["other_recipients"] + # Create a list of tasks to fetch the users + tasks = [self.bot.get_or_fetch_user(user_data["id"]) for user_data in recipients] + # Fetch the users + recipient_users: list[discord.Member] = await asyncio.gather(*tasks) - if recipient is None: - thread = Thread(self, user_id, channel, other_recipients) - else: - self.cache[user_id] = thread = Thread(self, recipient, channel, other_recipients) + thread = Thread( + self, recipient=log["creator"]["id"], channel=channel, other_recipients=recipient_users + ) thread.ready = True - return thread async def create(