Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Snippet search #25

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
151 changes: 151 additions & 0 deletions snippet_search/snippet_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import re
from collections import Counter, defaultdict
from typing import Optional

import discord
from discord.ext import commands

from bot import ModmailBot
from core import checks
from core.models import PermissionLevel
from core.paginator import EmbedPaginatorSession
from core.utils import escape_code_block, truncate


WORD_PATTERN = re.compile(r"[a-zA-Z]+")
THRESHOLD = 1.0


def score(query: Optional[str], name: str, content: str) -> float:
"""
Return a numerical sorting score for a snippet based on a query.

More relevant snippets have higher scores. If the query is None,
return a score that always meets the search inclusion threshold.
"""
if query is None:
return THRESHOLD
return (
(common_word_count(query, name) + common_word_count(query, content))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The common_word_count function is being called multiple times redundantly. Perhaps modify to take a list of names: typing.Iterable[str], eg

return (
    max(
        common_word_count(query, name)
        for name in names
    )
+ common_word_count(query, content)
) / len(words(query))

If we're being even more pedantic, would try to minimise calls to words but i guess we don't necessarily care about the efficiency that much. I just kept this as the heuristic function should probably be called as a whole on all the data available for it as opposed to name-by-name.

/ len(words(query))
)


def words(s: str) -> list[str]:
"""
Extract a list of 'words' from the given string.

A 'word' is defined by the WORD_PATTERN regex. This is purely for
use by the scoring function so isn't perfect.
"""
return WORD_PATTERN.findall(s)


def common_word_count(s1: str, s2: str) -> int:
"""Return the number of words in common between the two strings."""
return sum(
(
Counter(map(str.casefold, words(s1)))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason you casefold after using words instead of s1.casefold()? Would also result in the word pattern only needing lowercase but thats an aside.

& Counter(map(str.casefold, words(s2)))
).values()
)


def group_snippets_by_content(snippets: dict[str, str]) -> list[tuple[set[str], str]]:
"""
Take a dictionary of snippets (in the form {name: content}) and group together snippets with the same content.

Snippet contents are stipped of leading and trailing whitespace
before comparison.

The result is of the form [(set_of_snippet_names, content)].
"""
names_by_content = defaultdict(set)
for name, content in snippets.items():
names_by_content[content.strip()].add(name)
grouped_snippets = []

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the strip necessary here? Aliases shouldn't differ by leading/trailing whitespace unless someone intentionally modifies them, and you seem to subscribe to that idea given the content = snippets[name] which grabs an unstripped version (ie if the differences mattered we would be stripping that one too).

Working with that added strip yields:

    names_by_content = defaultdict(set)
    for name, content in snippets.items():
        names_by_content[content.strip()].add(name)
    grouped_snippets = []
    for content, group in names_by_content.items():
        grouped_snippets.append((group, content))
    return grouped_snippets

or even

    return [
        (v,k)
        for k,v in names_by_content.items()
        ]

for group in names_by_content.values():
name, *_ = group
content = snippets[name]
grouped_snippets.append((group, content))
return grouped_snippets


class SnippetSearch(commands.Cog):
"""A plugin that provides a command for searching snippets."""

def __init__(self, bot: ModmailBot):
self.bot = bot

@checks.has_permissions(PermissionLevel.SUPPORTER)
@commands.command(name="snippetsearch")
async def snippet_search(
self, ctx: commands.Context, *, query: Optional[str] = None
) -> None:
"""Search for a snippet."""
grouped_snippets = group_snippets_by_content(self.bot.snippets)

scored_groups = []
for i, (names, content) in enumerate(grouped_snippets):
group_score = max(score(query, name, content) for name in names)
scored_groups.append((group_score, i, names, content))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the enumerate is to ensure tuple sorting is done properly in the case of equal scores, idk how much it matters particularly tho

Could replace 89-99 with:

        for i, (names, content) in enumerate(grouped_snippets):
            group_score = max(score(query, name, content) for name in names)
            #or score(query, names, content) if modifying score

            if group_score >= THRESHOLD: #saves sorting time?
                scored_groups.append((group_score, i, names, content))

        scored_groups.sort(reverse=True)

Line 116 for _, _, names, content in matching_snippet_groups would avoid the need for lines 95-99.

scored_groups.sort(reverse=True)

matching_snippet_groups = [
(names, content)
for group_score, _, names, content in scored_groups
if group_score >= THRESHOLD
]

if not matching_snippet_groups:
embed = discord.Embed(
description="No matching snippets found.",
color=self.bot.error_color,
)
await ctx.send(embed=embed)
return

num_results = len(matching_snippet_groups)

result_summary_embed = discord.Embed(
color=self.bot.main_color,
title=f"Found {num_results} Matching Snippet{'s' if num_results > 1 else ''}:",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

title=f"Found {num_results} Matching Snippet{'s'*(num_results > 1)}:",
#134 \/
name=f"Name{'s'*(len(names) > 1)}",

unless too illegible (which is likely) 🤷

description=", ".join(
"/".join(f"`{name}`" for name in sorted(names))
for names, content in matching_snippet_groups
),
)

await ctx.send(embed=result_summary_embed)

embeds = []

for names, content in matching_snippet_groups:
formatted_content = (
f"```\n{truncate(escape_code_block(content), 1000)}\n```"
)
embed = (
discord.Embed(
color=self.bot.main_color,
)
.add_field(
name=f"Name{'s' if len(names) > 1 else ''}",
value=",".join(f"`{name}`" for name in sorted(names)),
inline=False,
)
.add_field(
name="Raw Content",
value=formatted_content,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formatted_content seems the same complexity as the other stuff in the fields, so could move it into the add_field?

inline=False,
)
)
embeds.append(embed)

session = EmbedPaginatorSession(ctx, *embeds)
await session.run()


def setup(bot: ModmailBot) -> None:
"""Add the SnippetSearch cog to the bot."""
bot.add_cog(SnippetSearch(bot))