From 908af1db8200cc846ea5ffdaa6a2c77d0d1d1dc7 Mon Sep 17 00:00:00 2001 From: Jonathan Gomes Selman Date: Fri, 15 Dec 2023 18:30:33 -0500 Subject: [PATCH] fix: misalignment between token offsets returned from the API and samples in the UI (#821) --- dataquality/__init__.py | 2 +- .../loggers/data_logger/seq2seq/formatters.py | 29 ++++++++++++++----- .../data_logger/seq2seq/seq2seq_base.py | 22 ++++++++++---- dataquality/utils/seq2seq/offsets.py | 18 ++++++++++-- 4 files changed, 55 insertions(+), 16 deletions(-) diff --git a/dataquality/__init__.py b/dataquality/__init__.py index 9d036af70..0c262777d 100644 --- a/dataquality/__init__.py +++ b/dataquality/__init__.py @@ -31,7 +31,7 @@ """ -__version__ = "1.4.0" +__version__ = "1.4.1" import sys from typing import Any, List, Optional diff --git a/dataquality/loggers/data_logger/seq2seq/formatters.py b/dataquality/loggers/data_logger/seq2seq/formatters.py index b9080b071..18cfbf186 100644 --- a/dataquality/loggers/data_logger/seq2seq/formatters.py +++ b/dataquality/loggers/data_logger/seq2seq/formatters.py @@ -43,7 +43,7 @@ def format_text( tokenizer: PreTrainedTokenizerFast, max_tokens: Optional[int], split_key: str, - ) -> Tuple[AlignedTokenData, List[List[str]]]: + ) -> Tuple[AlignedTokenData, List[List[str]], List[str]]: """Tokenize and align the `text` samples `format_text` tokenizes and computes token alignments for @@ -58,7 +58,12 @@ def format_text( different between the two model architectures. See their respective implementations for further details. - Additionally, we assign the necessary `self.logger_config`. + Additional information computed / variable assignements: + - Assign the necessary `self.logger_config` fields + - Compute token_label_str: the per token str representation + of each sample (List[str]), saved and used for high DEP tokens. + - In Decoder-Only: Decode the response tokens to get the str + representation of the response (i.e. the target show in the UI). Parameters: ----------- @@ -78,6 +83,10 @@ def format_text( Aligned token data for *just* target tokens, based on `text` token_label_str: List[List[str]] The target tokens (as strings) - see `Seq2SeqDataLogger.token_label_str` + targets: List[str] + The decoded response tokens - i.e. the string representation of the + Targets for each sample. Note that this is only computed for + Decoder-Only models. Returns [] for Encoder-Decoder """ pass @@ -212,7 +221,7 @@ def format_text( tokenizer: PreTrainedTokenizerFast, max_tokens: Optional[int], split_key: str, - ) -> Tuple[AlignedTokenData, List[List[str]]]: + ) -> Tuple[AlignedTokenData, List[List[str]], List[str]]: """Further validation for Encoder-Decoder For Encoder-Decoder we need to: @@ -257,7 +266,7 @@ def format_text( id_to_tokens = dict(zip(ids, token_label_ids)) self.logger_config.id_to_tokens[split_key].update(id_to_tokens) - return batch_aligned_data, token_label_str + return batch_aligned_data, token_label_str, [] @torch.no_grad() def generate_sample( @@ -382,7 +391,7 @@ def format_text( tokenizer: PreTrainedTokenizerFast, max_tokens: Optional[int], split_key: str, - ) -> Tuple[AlignedTokenData, List[List[str]]]: + ) -> Tuple[AlignedTokenData, List[List[str]], List[str]]: """Further formatting for Decoder-Only Text is the formatted prompt of combined input/target @@ -421,6 +430,8 @@ def format_text( # Empty initialization batch_aligned_data = AlignedTokenData([], []) token_label_str = [] + targets = [] + # Decode then re-tokenize just the response labels to get correct offsets for token_label_ids in tqdm( tokenized_labels, @@ -436,12 +447,16 @@ def format_text( ) ) - response_aligned_data = align_response_tokens_to_character_spans( + ( + response_aligned_data, + response_str, + ) = align_response_tokens_to_character_spans( tokenizer, token_label_ids, max_input_tokens, ) batch_aligned_data.append(response_aligned_data) + targets.append(response_str) # Save the tokenized response labels for each samples id_to_tokens = dict(zip(ids, tokenized_labels)) @@ -456,7 +471,7 @@ def format_text( id_to_formatted_prompt_length ) - return batch_aligned_data, token_label_str + return batch_aligned_data, token_label_str, targets @torch.no_grad() def generate_sample( diff --git a/dataquality/loggers/data_logger/seq2seq/seq2seq_base.py b/dataquality/loggers/data_logger/seq2seq/seq2seq_base.py index 83ea5eb8b..02bd8a5d8 100644 --- a/dataquality/loggers/data_logger/seq2seq/seq2seq_base.py +++ b/dataquality/loggers/data_logger/seq2seq/seq2seq_base.py @@ -118,10 +118,16 @@ def validate_and_format(self) -> None: label_len = len(self.labels) text_len = len(self.texts) id_len = len(self.ids) - assert id_len == text_len == label_len, ( - "IDs, texts, and labels must be the same length, got " - f"({id_len} ids, {text_len} texts, {label_len} labels)" - ) + if label_len > 0: # Encoder-Decoder case + assert id_len == text_len == label_len, ( + "IDs, texts, and labels must be the same length, got " + f"({id_len} ids, {text_len} texts, {label_len} labels)" + ) + else: # Decoder-Only case + assert id_len == text_len, ( + "IDs and texts must be the same length, got " + f"({id_len} ids, {text_len} texts)" + ) assert self.logger_config.tokenizer, ( "You must set your tokenizer before logging. " "Use `dq.integrations.seq2seq.core.set_tokenizer`" @@ -141,7 +147,11 @@ def validate_and_format(self) -> None: max_tokens = self.logger_config.max_target_tokens assert max_tokens - batch_aligned_token_data, token_label_str = self.formatter.format_text( + ( + batch_aligned_token_data, + token_label_str, + targets, + ) = self.formatter.format_text( text=texts, ids=self.ids, tokenizer=self.logger_config.tokenizer, @@ -151,6 +161,8 @@ def validate_and_format(self) -> None: self.token_label_offsets = batch_aligned_token_data.token_label_offsets self.token_label_positions = batch_aligned_token_data.token_label_positions self.token_label_str = token_label_str + if len(targets) > 0: # For Decoder-Only we update the 'targets' here + self.labels = targets def _get_input_df(self) -> DataFrame: df_dict = { diff --git a/dataquality/utils/seq2seq/offsets.py b/dataquality/utils/seq2seq/offsets.py index 25f32c7c4..e70ed9f3f 100644 --- a/dataquality/utils/seq2seq/offsets.py +++ b/dataquality/utils/seq2seq/offsets.py @@ -274,7 +274,7 @@ def align_response_tokens_to_character_spans( tokenizer: PreTrainedTokenizerFast, tokenized_response: List[int], max_input_tokens: Optional[int], -) -> AlignedTokenData: +) -> Tuple[AlignedTokenData, str]: """Decodes then re-tokenizes the isolated response to get the character alignments TODO This can prob be done with just tokenizing the "target" in isolation!! @@ -283,6 +283,15 @@ def align_response_tokens_to_character_spans( in the offset map and slice the offset map accordingly. This may also avoid strange space issues with tokenizers hanlding words at the start of a document. + + Return: + ------- + aligned_token_data: AlignedTokenData + Aligned token data for a single Response - batch dim = 1. + decoded_response: str + The string representation of the Response, used as the + Target string in the console. Note: we do not remove + special characters, so these will appear in the console! """ decoded_response = tokenizer.decode(tokenized_response) re_tokenized_response = tokenizer( @@ -293,6 +302,9 @@ def align_response_tokens_to_character_spans( # I believe that this should be handled! We can prob set to None max_length=max_input_tokens, ) - return align_tokens_to_character_spans( - re_tokenized_response["offset_mapping"], disable_tqdm=True + return ( + align_tokens_to_character_spans( + re_tokenized_response["offset_mapping"], disable_tqdm=True + ), + decoded_response, )