-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: move annotations to sep file
- Loading branch information
1 parent
6706fee
commit 918688d
Showing
2 changed files
with
103 additions
and
101 deletions.
There are no files selected for viewing
106 changes: 5 additions & 101 deletions
106
word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,119 +1,23 @@ | ||
from typing import Optional | ||
|
||
from sparv import api as sparv_api # type: ignore [import-untyped] | ||
from sparv.api import ( # type: ignore [import-untyped] | ||
Annotation, | ||
Config, | ||
Output, | ||
SparvErrorMessage, | ||
annotator, | ||
) | ||
|
||
from sbx_word_prediction_kb_bert.constants import PROJECT_NAME | ||
from sbx_word_prediction_kb_bert.predictor import TopKPredictor | ||
from sbx_word_prediction_kb_bert.annotations import predict_words__kb_bert | ||
|
||
__all__ = ["predict_words__kb_bert"] | ||
|
||
__description__ = "Calculating word predictions by mask a word in a BERT model." | ||
|
||
|
||
__config__ = [ | ||
Config( | ||
sparv_api.Config( | ||
"sbx_word_prediction_kb_bert.num_predictions", | ||
description="The number of predictions to list", | ||
default=5, | ||
), | ||
Config( | ||
sparv_api.Config( | ||
"sbx_word_prediction_kb_bert.num_decimals", | ||
description="The number of decimals to round the score to", | ||
default=3, | ||
), | ||
] | ||
|
||
__version__ = "0.5.4" | ||
|
||
logger = sparv_api.get_logger(__name__) | ||
|
||
TOK_SEP = " " | ||
|
||
|
||
def load_predictor(num_decimals_str: str) -> TopKPredictor: | ||
try: | ||
num_decimals = int(num_decimals_str) | ||
except ValueError as exc: | ||
raise sparv_api.SparvErrorMessage( | ||
f"'{PROJECT_NAME}.num_decimals' must contain an 'int' got: '{num_decimals_str}'" # noqa: E501 | ||
) from exc | ||
|
||
return TopKPredictor(num_decimals=num_decimals) | ||
|
||
|
||
@annotator( | ||
"Word prediction tagging with a masked Bert model", | ||
language=["swe"], | ||
preloader=load_predictor, | ||
preloader_params=["num_decimals_str"], | ||
preloader_target="predictor_preloaded", | ||
) | ||
def predict_words__kb_bert( | ||
out_prediction: Output = Output( | ||
f"<token>:{PROJECT_NAME}.word-prediction--kb-bert", | ||
description="Word predictions from masked BERT (format: '|<word>:<score>|...|)", | ||
), | ||
word: Annotation = Annotation("<token:word>"), | ||
sentence: Annotation = Annotation("<sentence>"), | ||
num_predictions_str: str = Config(f"{PROJECT_NAME}.num_predictions"), | ||
num_decimals_str: str = Config(f"{PROJECT_NAME}.num_decimals"), | ||
predictor_preloaded: Optional[TopKPredictor] = None, | ||
) -> None: | ||
logger.info("predict_words") | ||
try: | ||
num_predictions = int(num_predictions_str) | ||
except ValueError as exc: | ||
raise SparvErrorMessage( | ||
f"'{PROJECT_NAME}.num_predictions' must contain an 'int' got: '{num_predictions_str}'" # noqa: E501 | ||
) from exc | ||
|
||
predictor = predictor_preloaded or load_predictor(num_decimals_str) | ||
|
||
sentences, _orphans = sentence.get_children(word) | ||
token_word = list(word.read()) | ||
out_prediction_annotation = word.create_empty_attribute() | ||
|
||
run_word_prediction( | ||
predictor=predictor, | ||
num_predictions=num_predictions, | ||
sentences=sentences, | ||
token_word=token_word, | ||
out_prediction_annotations=out_prediction_annotation, | ||
) | ||
|
||
logger.info("writing annotations") | ||
out_prediction.write(out_prediction_annotation) | ||
|
||
|
||
def run_word_prediction( | ||
predictor: TopKPredictor, | ||
num_predictions: int, | ||
sentences, | ||
token_word: list, | ||
out_prediction_annotations, | ||
) -> None: | ||
logger.info("run_word_prediction") | ||
|
||
logger.progress(total=len(sentences)) # type: ignore | ||
for sent in sentences: | ||
logger.progress() # type: ignore | ||
token_indices = list(sent) | ||
for token_index_to_mask in token_indices: | ||
sent_to_tag = TOK_SEP.join( | ||
( | ||
"[MASK]" | ||
if token_index == token_index_to_mask | ||
else token_word[token_index] | ||
) | ||
for token_index in sent | ||
) | ||
|
||
predictions_scores = predictor.get_top_k_predictions( | ||
sent_to_tag, k=num_predictions | ||
) | ||
out_prediction_annotations[token_index_to_mask] = predictions_scores |
98 changes: 98 additions & 0 deletions
98
word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/annotations.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from typing import Optional | ||
|
||
from sparv import api as sparv_api # type: ignore [import-untyped] | ||
from sparv.api import ( # type: ignore [import-untyped] | ||
Annotation, | ||
Config, | ||
Output, | ||
) | ||
|
||
from sbx_word_prediction_kb_bert.constants import PROJECT_NAME | ||
from sbx_word_prediction_kb_bert.predictor import TopKPredictor | ||
|
||
logger = sparv_api.get_logger(__name__) | ||
TOK_SEP = " " | ||
|
||
|
||
def load_predictor(num_decimals_str: str) -> TopKPredictor: | ||
try: | ||
num_decimals = int(num_decimals_str) | ||
except ValueError as exc: | ||
raise sparv_api.SparvErrorMessage( | ||
f"'{PROJECT_NAME}.num_decimals' must contain an 'int' got: '{num_decimals_str}'" # noqa: E501 | ||
) from exc | ||
|
||
return TopKPredictor(num_decimals=num_decimals) | ||
|
||
|
||
@sparv_api.annotator( | ||
"Word prediction tagging with a masked Bert model", | ||
language=["swe"], | ||
preloader=load_predictor, | ||
preloader_params=["num_decimals_str"], | ||
preloader_target="predictor_preloaded", | ||
) | ||
def predict_words__kb_bert( | ||
out_prediction: Output = Output( | ||
f"<token>:{PROJECT_NAME}.word-prediction--kb-bert", | ||
description="Word predictions from masked BERT (format: '|<word>:<score>|...|)", | ||
), | ||
word: Annotation = Annotation("<token:word>"), | ||
sentence: Annotation = Annotation("<sentence>"), | ||
num_predictions_str: str = Config(f"{PROJECT_NAME}.num_predictions"), | ||
num_decimals_str: str = Config(f"{PROJECT_NAME}.num_decimals"), | ||
predictor_preloaded: Optional[TopKPredictor] = None, | ||
) -> None: | ||
logger.info("predict_words") | ||
try: | ||
num_predictions = int(num_predictions_str) | ||
except ValueError as exc: | ||
raise sparv_api.SparvErrorMessage( | ||
f"'{PROJECT_NAME}.num_predictions' must contain an 'int' got: '{num_predictions_str}'" # noqa: E501 | ||
) from exc | ||
|
||
predictor = predictor_preloaded or load_predictor(num_decimals_str) | ||
|
||
sentences, _orphans = sentence.get_children(word) | ||
token_word = list(word.read()) | ||
out_prediction_annotation = word.create_empty_attribute() | ||
|
||
run_word_prediction( | ||
predictor=predictor, | ||
num_predictions=num_predictions, | ||
sentences=sentences, | ||
token_word=token_word, | ||
out_prediction_annotations=out_prediction_annotation, | ||
) | ||
|
||
logger.info("writing annotations") | ||
out_prediction.write(out_prediction_annotation) | ||
|
||
|
||
def run_word_prediction( | ||
predictor: TopKPredictor, | ||
num_predictions: int, | ||
sentences, | ||
token_word: list, | ||
out_prediction_annotations, | ||
) -> None: | ||
logger.info("run_word_prediction") | ||
|
||
logger.progress(total=len(sentences)) # type: ignore | ||
for sent in sentences: | ||
logger.progress() # type: ignore | ||
token_indices = list(sent) | ||
for token_index_to_mask in token_indices: | ||
sent_to_tag = TOK_SEP.join( | ||
( | ||
"[MASK]" | ||
if token_index == token_index_to_mask | ||
else token_word[token_index] | ||
) | ||
for token_index in sent | ||
) | ||
|
||
predictions_scores = predictor.get_top_k_predictions( | ||
sent_to_tag, k=num_predictions | ||
) | ||
out_prediction_annotations[token_index_to_mask] = predictions_scores |