In [None]:
from typing import List

from .commons import spacy_tokenizer
from .corrector import Corrector
from .util import is_module_available

if is_module_available("allennlp"):
    from .elmosclstm import load_model, load_pretrained, model_predictions

from app.core.config import SPELLCHECK_MODEL_DEVICE, SPELLCHECK_MODEL_TOKENIZE


class ElmosclstmChecker(Corrector):
    def __init__(self, **kwargs):

        if not is_module_available("allennlp"):
            raise ImportError(
                "install `allennlp` by running `pip install -r extras-requirements.txt`. "
                "See `README.md` for more info."
            )

        super().__init__(**kwargs)

        self.pretrained_name_or_path = "elmo-sclstm"

    async def load_model(self, ckpt_path, class_weights=None, sample_weights=False):
        print(f"initializing model")
        initialized_model = load_model(
            self.vocab, class_weights=class_weights, sample_weights=sample_weights
        )
        if ckpt_path is not None:
            self.model = load_pretrained(
                initialized_model, self.ckpt_path, device=self.device
            )
        else:
            self.model = initialized_model

    async def correct_strings(
        self,
        mystrings: List[str],
        return_all=False,
        batch_size=16,
        topk=1,
        best=True,
        beam_search=False,
    ) -> List[str]:
        self.is_model_ready()
        if self.tokenize:
            mystrings = [spacy_tokenizer(my_str) for my_str in mystrings]
        data = [(line, line) for line in mystrings]
        #         batch_size = 4 if self.device == "cpu" else 16
        return_strings = await model_predictions(
            self.model,
            data,
            self.vocab,
            device=self.device,
            batch_size=batch_size,
            topk=topk,
            best=best,
            beam_search=beam_search,
        )
        if return_all:
            return mystrings, return_strings
        else:
            return return_strings


model = ElmosclstmChecker(
    device=SPELLCHECK_MODEL_DEVICE, tokenize=SPELLCHECK_MODEL_TOKENIZE
)
