-
-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Snippet search #25
base: main
Are you sure you want to change the base?
Snippet search #25
Changes from all commits
aac8e9c
3fbfa76
fb2c120
e11a08b
65d0f5b
9b2b19f
1f5f51d
e611e1f
d398f80
29ebaba
53abbf9
831c1c7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import re | ||
from collections import Counter, defaultdict | ||
from typing import Optional | ||
|
||
import discord | ||
from discord.ext import commands | ||
|
||
from bot import ModmailBot | ||
from core import checks | ||
from core.models import PermissionLevel | ||
from core.paginator import EmbedPaginatorSession | ||
from core.utils import escape_code_block, truncate | ||
|
||
|
||
WORD_PATTERN = re.compile(r"[a-zA-Z]+") | ||
THRESHOLD = 1.0 | ||
|
||
|
||
def score(query: Optional[str], name: str, content: str) -> float: | ||
""" | ||
Return a numerical sorting score for a snippet based on a query. | ||
|
||
More relevant snippets have higher scores. If the query is None, | ||
return a score that always meets the search inclusion threshold. | ||
""" | ||
if query is None: | ||
return THRESHOLD | ||
return ( | ||
(common_word_count(query, name) + common_word_count(query, content)) | ||
/ len(words(query)) | ||
) | ||
|
||
|
||
def words(s: str) -> list[str]: | ||
""" | ||
Extract a list of 'words' from the given string. | ||
|
||
A 'word' is defined by the WORD_PATTERN regex. This is purely for | ||
use by the scoring function so isn't perfect. | ||
""" | ||
return WORD_PATTERN.findall(s) | ||
|
||
|
||
def common_word_count(s1: str, s2: str) -> int: | ||
"""Return the number of words in common between the two strings.""" | ||
return sum( | ||
( | ||
Counter(map(str.casefold, words(s1))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason you casefold after using |
||
& Counter(map(str.casefold, words(s2))) | ||
).values() | ||
) | ||
|
||
|
||
def group_snippets_by_content(snippets: dict[str, str]) -> list[tuple[set[str], str]]: | ||
""" | ||
Take a dictionary of snippets (in the form {name: content}) and group together snippets with the same content. | ||
|
||
Snippet contents are stipped of leading and trailing whitespace | ||
before comparison. | ||
|
||
The result is of the form [(set_of_snippet_names, content)]. | ||
""" | ||
names_by_content = defaultdict(set) | ||
for name, content in snippets.items(): | ||
names_by_content[content.strip()].add(name) | ||
grouped_snippets = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the Working with that added strip yields: names_by_content = defaultdict(set)
for name, content in snippets.items():
names_by_content[content.strip()].add(name)
grouped_snippets = []
for content, group in names_by_content.items():
grouped_snippets.append((group, content))
return grouped_snippets or even return [
(v,k)
for k,v in names_by_content.items()
] |
||
for group in names_by_content.values(): | ||
name, *_ = group | ||
content = snippets[name] | ||
grouped_snippets.append((group, content)) | ||
return grouped_snippets | ||
|
||
|
||
class SnippetSearch(commands.Cog): | ||
"""A plugin that provides a command for searching snippets.""" | ||
|
||
def __init__(self, bot: ModmailBot): | ||
self.bot = bot | ||
|
||
@checks.has_permissions(PermissionLevel.SUPPORTER) | ||
@commands.command(name="snippetsearch") | ||
async def snippet_search( | ||
self, ctx: commands.Context, *, query: Optional[str] = None | ||
) -> None: | ||
"""Search for a snippet.""" | ||
grouped_snippets = group_snippets_by_content(self.bot.snippets) | ||
|
||
scored_groups = [] | ||
for i, (names, content) in enumerate(grouped_snippets): | ||
group_score = max(score(query, name, content) for name in names) | ||
scored_groups.append((group_score, i, names, content)) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like the enumerate is to ensure tuple sorting is done properly in the case of equal scores, idk how much it matters particularly tho Could replace 89-99 with: for i, (names, content) in enumerate(grouped_snippets):
group_score = max(score(query, name, content) for name in names)
#or score(query, names, content) if modifying score
if group_score >= THRESHOLD: #saves sorting time?
scored_groups.append((group_score, i, names, content))
scored_groups.sort(reverse=True) Line 116 |
||
scored_groups.sort(reverse=True) | ||
|
||
matching_snippet_groups = [ | ||
(names, content) | ||
for group_score, _, names, content in scored_groups | ||
if group_score >= THRESHOLD | ||
] | ||
|
||
if not matching_snippet_groups: | ||
embed = discord.Embed( | ||
description="No matching snippets found.", | ||
color=self.bot.error_color, | ||
) | ||
await ctx.send(embed=embed) | ||
return | ||
|
||
num_results = len(matching_snippet_groups) | ||
|
||
result_summary_embed = discord.Embed( | ||
color=self.bot.main_color, | ||
title=f"Found {num_results} Matching Snippet{'s' if num_results > 1 else ''}:", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. title=f"Found {num_results} Matching Snippet{'s'*(num_results > 1)}:",
#134 \/
name=f"Name{'s'*(len(names) > 1)}", unless too illegible (which is likely) 🤷 |
||
description=", ".join( | ||
"/".join(f"`{name}`" for name in sorted(names)) | ||
for names, content in matching_snippet_groups | ||
), | ||
) | ||
|
||
await ctx.send(embed=result_summary_embed) | ||
|
||
embeds = [] | ||
|
||
for names, content in matching_snippet_groups: | ||
formatted_content = ( | ||
f"```\n{truncate(escape_code_block(content), 1000)}\n```" | ||
) | ||
embed = ( | ||
discord.Embed( | ||
color=self.bot.main_color, | ||
) | ||
.add_field( | ||
name=f"Name{'s' if len(names) > 1 else ''}", | ||
value=",".join(f"`{name}`" for name in sorted(names)), | ||
inline=False, | ||
) | ||
.add_field( | ||
name="Raw Content", | ||
value=formatted_content, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
inline=False, | ||
) | ||
) | ||
embeds.append(embed) | ||
|
||
session = EmbedPaginatorSession(ctx, *embeds) | ||
await session.run() | ||
|
||
|
||
def setup(bot: ModmailBot) -> None: | ||
"""Add the SnippetSearch cog to the bot.""" | ||
bot.add_cog(SnippetSearch(bot)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
common_word_count
function is being called multiple times redundantly. Perhaps modify to take a list ofnames: typing.Iterable[str]
, egIf we're being even more pedantic, would try to minimise calls to
words
but i guess we don't necessarily care about the efficiency that much. I just kept this as the heuristic function should probably be called as a whole on all the data available for it as opposed to name-by-name.