In [None]:
import nltk
from inflection import singularize
from loguru import logger
from textdistance import damerau_levenshtein

from app.core.config import SPELLCHECK_MODEL_THRESH
from app.db.events import mongo_db
from app.db.repositories.mongo import MongoBaseRepository
from app.resources.spellcheck import data
from app.schema.spellcheck import SpellCheckSuccessResponse
from app.services.normalize import normalize
from app.services.spellcheck import clean, get_prediction


async def get_spell_correction(query: str) -> SpellCheckSuccessResponse:
    mongo_repo = MongoBaseRepository(mongo_db.client)
    logger.info(f"Search Query: {query}")
    # Cleaning
    text = clean(query)
    logger.info(f"Cleaned query: {text}")
    if len(text) == 0:
        return SpellCheckSuccessResponse(
            query=query, corrected_query="", confident=True
        )
    # First layer of overrides
    mres = await mongo_repo._log_and_query_overrides(text)
    if len(mres) > 0:
        text = mres[0]
        logger.info(f"Overrided keyword: {text}")
    # Merging keywords
    space_correct_txt = (
        data.sym_spell_merge.word_segmentation(text)
        .corrected_string.replace("&", " &")
        .split()
    )
    if len(space_correct_txt) < len(text.split()):
        tokens = space_correct_txt
        logger.info(f"Merged space mistake: {text}->{' '.join(tokens)}")
    else:
        tokens = text.split()

    # Applying overrrdes
    text = " ".join(tokens)
    for idx, token in enumerate(tokens):
        mres = await mongo_repo._log_and_query_overrides(token)
        if len(mres) > 0:
            text = text.replace(token, mres[0])
            logger.info(f"Word overrided: {token}->{mres[0]}")
    for phrase in nltk.everygrams(text.split(), min_len=2):
        mres = await mongo_repo._log_and_query_overrides(" ".join(phrase))
        if len(mres) > 0:
            text = text.replace(" ".join(phrase), mres[0])
            logger.info(f"Phrase overrided: {' '.join(phrase)}->{mres[0]}")
    tokens = text.split()

    # Splitting alpha words and alphanumerics
    clean_tokens = []
    removed_idx = []
    for idx, token in enumerate(tokens):
        if token.isalpha():
            clean_tokens.append(token)
        # elif len(token) < 3:
        #     removed_idx.append(idx)
        else:
            removed_idx.append(idx)
    if len(clean_tokens) == 0:
        logger.info("No clean words present in the query")
        return SpellCheckSuccessResponse(
            query=query, corrected_query=" ".join(tokens), confident=True
        )
    # Single word case
    clean_text = " ".join(clean_tokens)
    if len(clean_tokens) == 1:
        in_vocab = False
        if clean_tokens[0] in data.vocab:
            logger.info("Clean unigram already present in vocab")
            in_vocab = True
        # Single word present in vocabulary case
        if in_vocab:
            if len(removed_idx) > 0:
                for idx in removed_idx:
                    clean_tokens.insert(idx, tokens[idx])
            normalized_txt = await normalize(" ".join(clean_tokens))
            mres = await mongo_repo._log_and_query_high_recall(normalized_txt)
            high_recall = False
            if len(mres) > 0:
                high_recall = True
            return SpellCheckSuccessResponse(
                query=query,
                corrected_query=" ".join(clean_tokens),
                normalized_query=normalized_txt,
                confident=True,
                high_recall=high_recall,
            )

    # Checking for Protected and overrides for exact match
    mres = await mongo_repo._log_and_query_protected(clean_text)
    if len(mres) > 0:
        clean_tokens = clean_text.split()
        logger.info(f"Protected keyword: {clean_text}")
        if len(removed_idx) > 0:
            for idx in removed_idx:
                clean_tokens.insert(idx, tokens[idx])
        normalized_txt = await normalize(" ".join(clean_tokens))
        mres = await mongo_repo._log_and_query_high_recall(" ".join(clean_tokens))
        high_recall = False
        if len(mres) > 0:
            high_recall = True
        return SpellCheckSuccessResponse(
            query=query,
            corrected_query=" ".join(clean_tokens),
            normalized_query=normalized_txt,
            confident=True,
            high_recall=high_recall,
        )
    mres = await mongo_repo._log_and_query_overrides(clean_text)
    if len(mres) > 0:
        clean_text = mres[0]
        clean_tokens = clean_text.split()
        logger.info(f"Overrided keyword: {clean_text}")
        if len(removed_idx) > 0:
            for idx in removed_idx:
                clean_tokens.insert(idx, tokens[idx])
        normalized_txt = await normalize(" ".join(corrected_tokens))
        mres = await mongo_repo._log_and_query_high_recall(" ".join(clean_tokens))
        high_recall = False
        if len(mres) > 0:
            high_recall = True
        return SpellCheckSuccessResponse(
            query=query,
            corrected_query=" ".join(clean_tokens),
            normalized_query=normalized_txt,
            confident=True,
            high_recall=high_recall,
        )
    # Checking protected for partial match
    correct_idx = []
    for idx, token in enumerate(clean_tokens):
        mres = await mongo_repo._log_and_query_protected(token)
        if len(mres) > 0:
            correct_idx.append(idx)
            continue
        if len(token) < 3:
            correct_idx.append(idx)
    if len(correct_idx) == len(clean_tokens):
        if len(removed_idx) > 0:
            for idx in removed_idx:
                clean_tokens.insert(idx, tokens[idx])
        normalized_txt = await normalize(" ".join(clean_tokens))
        mres = await mongo_repo._log_and_query_high_recall(normalized_txt)
        high_recall = False
        if len(mres) > 0:
            high_recall = True
        return SpellCheckSuccessResponse(
            query=query,
            corrected_query=" ".join(clean_tokens),
            normalized_query=normalized_txt,
            confident=True,
            high_recall=high_recall,
        )

    # Spltting merged words
    space_corrected_tokens = clean_tokens.copy()
    offset = 0
    for i in range(len(clean_tokens)):
        if i not in correct_idx and clean_tokens[i] not in data.vocab:
            res = data.sym_spell_split.word_segmentation(
                clean_tokens[i],
                max_edit_distance=0,
            ).corrected_string
            if (
                clean_tokens[i] != res
                and len(await mongo_repo._log_and_query_ngrams(res)) > 0
                and all([len(w) > 1 for w in res.split()])
            ):
                space_corrected_tokens[i + offset] = res.split()[0]
                for idx, w in enumerate(res.split()[1:]):
                    space_corrected_tokens.insert(i + offset + 1 + idx, w)
                idx = tokens.index(clean_tokens[i])
                tokens[idx] = res.split()[0]
                for idx_, w in enumerate(res.split()[1:]):
                    tokens.insert(idx + 1 + idx_, w)
                for j in range(len(removed_idx)):
                    if removed_idx[j] > idx:
                        removed_idx[j] += len(res.split()) - 1
                offset += len(res.split()) - 1
                continue
            elif len(clean_tokens) == 1:
                res = data.sym_spell_split.lookup(
                    clean_tokens[i], max_edit_distance=2, verbosity=1
                )
                if len(res) > 0:
                    space_corrected_tokens[i] = res[0].term
                    continue
            res = data.sym_spell_split.word_segmentation(
                clean_tokens[i],
                max_edit_distance=1,
            ).corrected_string
            if clean_tokens[i] != res and (
                len(await mongo_repo._log_and_query_ngrams(res)) > 0
                and all([len(w) > 1 for w in res.split()])
                or (len(clean_tokens) == 1 and res in data.vocab)
            ):
                space_corrected_tokens[i + offset] = res.split()[0]
                for idx, w in enumerate(res.split()[1:]):
                    space_corrected_tokens.insert(i + offset + 1 + idx, w)
                idx = tokens.index(clean_tokens[i])
                tokens[idx] = res.split()[0]
                for idx_, w in enumerate(res.split()[1:]):
                    tokens.insert(idx + 1 + idx_, w)
                for j in range(len(removed_idx)):
                    if removed_idx[j] > idx:
                        removed_idx[j] += len(res.split()) - 1
                offset += len(res.split()) - 1

    if space_corrected_tokens != clean_tokens:
        clean_tokens = space_corrected_tokens
        confidence = [1.0]
        logger.info(f"Symspell correction: {' '.join(space_corrected_tokens)}")
    # Spell Correction by Model
    res = await get_prediction(" ".join(clean_tokens))
    corrected_tokens, score_, confidence = (
        [i["correction"] for i in res[0]],
        [i["score"] for i in res[0]],
        [i["confidence"] for i in res[0]],
    )
    logger.info(f"Model correction: {' '.join(corrected_tokens)}")
    for i in range(len(clean_tokens)):
        if singularize(corrected_tokens[i]) == clean_tokens[i] or singularize(
            corrected_tokens[i]
        ) == singularize(clean_tokens[i]):
            corrected_tokens[i] = clean_tokens[i]

    if len(correct_idx) > 0:
        for idx in correct_idx:
            corrected_tokens[idx] = clean_tokens[idx]
            # confidence[idx] = 1.0
    if len(removed_idx) > 0:
        for idx in removed_idx:
            corrected_tokens.insert(idx, tokens[idx])
            # confidence.insert(idx, 1.0)
    # Calculating confidence
    is_confident = (
        True
        if damerau_levenshtein(" ".join(corrected_tokens), " ".join(clean_tokens)) < 3
        and all([i > SPELLCHECK_MODEL_THRESH for i in confidence])
        else False
    )

    normalized_txt = await normalize(" ".join(corrected_tokens))

    mres = await mongo_repo._log_and_query_high_recall(normalized_txt)
    high_recall = False
    if len(mres) > 0:
        high_recall = True
    return SpellCheckSuccessResponse(
        query=query,
        corrected_query=" ".join(corrected_tokens),
        normalized_query=normalized_txt,
        confident=is_confident,
        high_recall=high_recall,
    )
