Skip to content

Commit

Permalink
Merge branch 'feature/thread-find-refactor' of https://github.com/kha…
Browse files Browse the repository at this point in the history
…kers/OpenModmail into khakers-feature/thread-find-refactor
  • Loading branch information
raidensakura committed Mar 31, 2024
2 parents cfab16f + 4197f4e commit ab4b3a3
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 60 deletions.
21 changes: 5 additions & 16 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ async def on_ready(self):
)
logger.line()

await self.threads.populate_cache()
await self.threads.quick_populate_cache()

# closures
closures = self.config["closures"]
Expand Down Expand Up @@ -621,21 +621,10 @@ async def on_ready(self):
for log in await self.api.get_open_logs():
if self.get_channel(int(log["channel_id"])) is None:
logger.debug("Unable to resolve thread with channel %s.", log["channel_id"])
log_data = await self.api.post_log(
log["channel_id"],
{
"open": False,
"title": None,
"closed_at": str(discord.utils.utcnow()),
"close_message": "Channel has been deleted, no closer found.",
"closer": {
"id": str(self.user.id),
"name": self.user.name,
"discriminator": self.user.discriminator,
"avatar_url": self.user.display_avatar.url,
"mod": True,
},
},
log_data = await self.api.close_log(
channel_id=log["channel_id"],
close_message="Channel has been deleted, no closer found.",
closer=self.user,
)
if log_data:
logger.debug("Successfully closed thread with channel %s.", log["channel_id"])
Expand Down
62 changes: 61 additions & 1 deletion core/clients.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import secrets
import sys
from json import JSONDecodeError
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union

import discord
import pymongo.results
from aiohttp import ClientResponse, ClientResponseError
from discord import DMChannel, Member, Message, TextChannel
from discord.ext import commands
Expand Down Expand Up @@ -429,6 +430,12 @@ async def update_repository(self) -> dict:
async def get_user_info(self) -> Optional[dict]:
return NotImplemented

async def add_recipients(self, channel_id: int, recipient: List[discord.User]):
return NotImplemented

async def close_log(self, channel_id: int, title: str, close_message: str, closer: discord.User) -> dict:
return NotImplemented

async def update_title(self, title: str, channel_id: Union[str, int]):
return NotImplemented

Expand Down Expand Up @@ -566,6 +573,9 @@ async def get_log(self, channel_id: Union[str, int]) -> dict:
logger.debug("Retrieving channel %s logs.", channel_id)
return await self.logs.find_one({"channel_id": str(channel_id)})

async def get_logs(self, channel_id: List[Union[str, int]]) -> dict:
return await self.logs.find({"channel_id": {"$in": [str(i) for i in channel_id]}}).to_list(None)

async def get_log_link(self, channel_id: Union[str, int]) -> str:
doc = await self.get_log(channel_id)
logger.debug("Retrieving log link for channel %s.", channel_id)
Expand Down Expand Up @@ -593,13 +603,15 @@ async def create_log_entry(self, recipient: Member, channel: TextChannel, creato
"recipient": {
"id": str(recipient.id),
"name": recipient.name,
"global_name": recipient.global_name,
"discriminator": recipient.discriminator,
"avatar_url": recipient.display_avatar.url,
"mod": False,
},
"creator": {
"id": str(creator.id),
"name": creator.name,
"global_name": creator.global_name,
"discriminator": creator.discriminator,
"avatar_url": creator.display_avatar.url,
"mod": isinstance(creator, Member),
Expand Down Expand Up @@ -662,6 +674,7 @@ async def append_log(
"author": {
"id": str(message.author.id),
"name": message.author.name,
"global_name": message.author.global_name,
"discriminator": message.author.discriminator,
"avatar_url": message.author.display_avatar.url,
"mod": not isinstance(message.channel, DMChannel),
Expand Down Expand Up @@ -714,6 +727,7 @@ async def create_note(self, recipient: Member, message: Message, message_id: Uni
"author": {
"id": str(message.author.id),
"name": message.author.name,
"global_name": message.author.global_name,
"discriminator": message.author.discriminator,
"avatar_url": message.author.display_avatar.url,
},
Expand All @@ -735,6 +749,52 @@ async def delete_note(self, message_id: Union[int, str]):
async def edit_note(self, message_id: Union[int, str], message: str):
await self.db.notes.update_one({"message_id": str(message_id)}, {"$set": {"message": message}})

async def add_recipients(self, channel_id: int, recipient: List[discord.User]):
results: pymongo.results.UpdateResult = await self.bot.db.logs.update_one(
{"channel_id": str(channel_id)},
{
"$addToSet": {
"other_recipients": {
"$each": [
{
"id": r.id,
"name": r.name,
"global_name": r.global_name,
"discriminator": r.discriminator,
"avatar_url": r.display_avatar.url,
}
for r in recipient
]
}
}
},
)
if results.matched_count == 0:
raise ValueError(f"Channel id {channel_id} not found in mongodb")
return

async def close_log(self, channel_id: int, title: str, close_message: str, closer: discord.User) -> dict:
# TODO doesn't set title yet
return await self.bot.db.logs.find_one_and_update(
{"channel_id": str(channel_id)},
{
"$set": {
"open": False,
"closed_at": str(discord.utils.utcnow()),
"title": title,
"close_message": close_message,
"closer": {
"id": str(closer.id),
"name": closer.name,
"global_name": closer.name,
"discriminator": closer.discriminator,
"avatar_url": closer.display_avatar.url,
"mod": True,
},
},
},
)

def get_plugin_partition(self, cog):
cls_name = cog.__class__.__name__
return self.db.plugins[cls_name]
Expand Down
95 changes: 52 additions & 43 deletions core/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import typing
import warnings
from datetime import timedelta
from time import perf_counter

import discord
import isodate
Expand Down Expand Up @@ -441,22 +442,11 @@ async def _close(self, closer, silent=False, delete_channel=True, message=None,

# Logging
if self.channel:
log_data = await self.bot.api.post_log(
self.channel.id,
{
"open": False,
"title": match_title(self.channel.topic),
"closed_at": str(discord.utils.utcnow()),
"nsfw": self.channel.nsfw,
"close_message": message,
"closer": {
"id": str(closer.id),
"name": closer.name,
"discriminator": closer.discriminator,
"avatar_url": closer.display_avatar.url,
"mod": True,
},
},
log_data = await self.bot.api.close_log(
channel_id=self.channel.id,
title=match_title(self.channel.topic),
closer=closer,
close_message=message,
)
else:
log_data = None
Expand Down Expand Up @@ -1217,6 +1207,9 @@ async def add_users(self, users: typing.List[typing.Union[discord.Member, discor

topic += f"\nOther Recipients: {ids}"

# Add recipients to database
await self.bot.api.add_recipients(self._channel.id, users)

await self.channel.edit(topic=topic)
await self._update_users_genesis()

Expand Down Expand Up @@ -1244,11 +1237,14 @@ class ThreadManager:

def __init__(self, bot):
self.bot = bot
self.cache = {}
self.cache: typing.Dict[int, Thread] = {}

async def populate_cache(self) -> None:
# time method runtime
start = perf_counter()
for channel in self.bot.modmail_guild.text_channels:
await self.find(channel=channel)
logger.info("Cache populated in %fs.", time.perf_counter() - start)

def __len__(self):
return len(self.cache)
Expand All @@ -1259,6 +1255,27 @@ def __iter__(self):
def __getitem__(self, item: str) -> Thread:
return self.cache[item]

async def quick_populate_cache(self) -> None:
start = perf_counter()

# create a list containing the id of every text channel in the modmail guild
channel_ids = [channel.id for channel in self.bot.modmail_guild.text_channels]
logs = await self.bot.api.get_logs(channel_ids)

for log in logs:
recipients = log["other_recipients"]

tasks = [self.bot.get_or_fetch_user(user_data["id"]) for user_data in recipients]
recipient_users: list[discord.Member] = await asyncio.gather(*tasks)

self.cache[log["recipient"]["id"]] = Thread(
self,
recipient=log["creator"]["id"],
channel=log["channel_id"],
other_recipients=recipient_users,
)
logger.debug("Cache populated in %fs.", perf_counter() - start)

async def find(
self,
*,
Expand Down Expand Up @@ -1322,44 +1339,36 @@ def check(topic):

return thread

async def _find_from_channel(self, channel):
async def _find_from_channel(self, channel) -> typing.Optional[Thread]:
"""
Tries to find a thread from a channel channel topic,
if channel topic doesnt exist for some reason, falls back to
Tries to find a thread from a channel topic,
if channel topic doesn't exist for some reason, falls back to
searching channel history for genesis embed and
extracts user_id from that.
"""

if not channel.topic:
return None
logger.debug("_find_from_channel")
logger.debug(f"channel: {channel}")

_, user_id, other_ids = parse_channel_topic(channel.topic)
# TODO cache thread for channel ID

if user_id == -1:
return None
log = await self.bot.api.get_log(channel.id)

if user_id in self.cache:
return self.cache[user_id]
if log is None:
return None

try:
recipient = await self.bot.get_or_fetch_user(user_id)
except discord.NotFound:
recipient = None
logger.debug("This is a thread channel")

other_recipients = []
for uid in other_ids:
try:
other_recipient = await self.bot.get_or_fetch_user(uid)
except discord.NotFound:
continue
other_recipients.append(other_recipient)
recipients = log["other_recipients"]
# Create a list of tasks to fetch the users
tasks = [self.bot.get_or_fetch_user(user_data["id"]) for user_data in recipients]
# Fetch the users
recipient_users: list[discord.Member] = await asyncio.gather(*tasks)

if recipient is None:
thread = Thread(self, user_id, channel, other_recipients)
else:
self.cache[user_id] = thread = Thread(self, recipient, channel, other_recipients)
thread = Thread(
self, recipient=log["creator"]["id"], channel=channel, other_recipients=recipient_users
)
thread.ready = True

return thread

async def create(
Expand Down

0 comments on commit ab4b3a3

Please sign in to comment.