Skip to content

Commit

Permalink
feat: add preloader
Browse files Browse the repository at this point in the history
fixes #28
  • Loading branch information
kod-kristoff committed Jun 12, 2024
1 parent 82ac1b8 commit 6706fee
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 21 deletions.
46 changes: 29 additions & 17 deletions word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import Optional

from sparv import api as sparv_api # type: ignore [import-untyped]
from sparv.api import ( # type: ignore [import-untyped]
Annotation,
Config,
Output,
SparvErrorMessage,
annotator,
get_logger,
)

from sbx_word_prediction_kb_bert.constants import PROJECT_NAME
from sbx_word_prediction_kb_bert.predictor import TopKPredictor

__description__ = "Calculating word predictions by mask a word in a BERT model."
Expand All @@ -27,40 +30,49 @@

__version__ = "0.5.4"

logger = get_logger(__name__)
logger = sparv_api.get_logger(__name__)

TOK_SEP = " "


@annotator("Word prediction tagging with a masked Bert model", language=["swe"])
def load_predictor(num_decimals_str: str) -> TopKPredictor:
try:
num_decimals = int(num_decimals_str)
except ValueError as exc:
raise sparv_api.SparvErrorMessage(
f"'{PROJECT_NAME}.num_decimals' must contain an 'int' got: '{num_decimals_str}'" # noqa: E501
) from exc

return TopKPredictor(num_decimals=num_decimals)


@annotator(
"Word prediction tagging with a masked Bert model",
language=["swe"],
preloader=load_predictor,
preloader_params=["num_decimals_str"],
preloader_target="predictor_preloaded",
)
def predict_words__kb_bert(
out_prediction: Output = Output(
"<token>:sbx_word_prediction_kb_bert.word-prediction--kb-bert",
cls="word_prediction",
f"<token>:{PROJECT_NAME}.word-prediction--kb-bert",
description="Word predictions from masked BERT (format: '|<word>:<score>|...|)",
),
word: Annotation = Annotation("<token:word>"),
sentence: Annotation = Annotation("<sentence>"),
num_predictions_str: str = Config("sbx_word_prediction_kb_bert.num_predictions"),
num_decimals_str: str = Config("sbx_word_prediction_kb_bert.num_decimals"),
num_predictions_str: str = Config(f"{PROJECT_NAME}.num_predictions"),
num_decimals_str: str = Config(f"{PROJECT_NAME}.num_decimals"),
predictor_preloaded: Optional[TopKPredictor] = None,
) -> None:
logger.info("predict_words")
try:
num_predictions = int(num_predictions_str)
except ValueError as exc:
raise SparvErrorMessage(
f"'sbx_word_prediction_kb_bert.num_predictions' must contain an 'int' got: '{num_predictions_str}'" # noqa: E501
) from exc
try:
num_decimals = int(num_decimals_str)
except ValueError as exc:
raise SparvErrorMessage(
f"'sbx_word_prediction_kb_bert.num_decimals' must contain an 'int' got: '{num_decimals_str}'" # noqa: E501
f"'{PROJECT_NAME}.num_predictions' must contain an 'int' got: '{num_predictions_str}'" # noqa: E501
) from exc

predictor = TopKPredictor(
num_decimals=num_decimals,
)
predictor = predictor_preloaded or load_predictor(num_decimals_str)

sentences, _orphans = sentence.get_children(word)
token_word = list(word.read())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
PROJECT_NAME: str = "sbx_word_prediction_kb_bert"
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def _default_model(cls) -> BertForMaskedLM:
),
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
return model
model = model.cuda() # type: ignore
return model # type: ignore

@classmethod
def _default_tokenizer(cls) -> BertTokenizer:
Expand All @@ -72,12 +72,12 @@ def _default_tokenizer(cls) -> BertTokenizer:

def get_top_k_predictions(self, text: str, k: int = 5) -> str:
tokenized_inputs = self.tokenizer(text)
if len(tokenized_inputs["input_ids"]) <= 512:
if len(tokenized_inputs["input_ids"]) <= 512: # type: ignore
return self._run_pipeline(text, k)
if text.count("[MASK]") == 1:
return self._run_pipeline_on_mask_context(text, k)
raise RuntimeError(
f"can't handle large input and multiple [MASK]: {len(tokenized_inputs['input_ids'])} tokens > 512 tokens" # noqa: E501
f"can't handle large input and multiple [MASK]: {len(tokenized_inputs['input_ids'])} tokens > 512 tokens" # noqa: E501 # type: ignore
)

def _run_pipeline_on_mask_context(self, text, k):
Expand Down

0 comments on commit 6706fee

Please sign in to comment.