Skip to content

Commit

Permalink
fix: misalignment between token offsets returned from the API and sam…
Browse files Browse the repository at this point in the history
…ples in the UI (#821)
  • Loading branch information
Jonathan Gomes Selman committed Dec 15, 2023
1 parent 657e1cb commit 908af1d
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 16 deletions.
2 changes: 1 addition & 1 deletion dataquality/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""


__version__ = "1.4.0"
__version__ = "1.4.1"

import sys
from typing import Any, List, Optional
Expand Down
29 changes: 22 additions & 7 deletions dataquality/loggers/data_logger/seq2seq/formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def format_text(
tokenizer: PreTrainedTokenizerFast,
max_tokens: Optional[int],
split_key: str,
) -> Tuple[AlignedTokenData, List[List[str]]]:
) -> Tuple[AlignedTokenData, List[List[str]], List[str]]:
"""Tokenize and align the `text` samples
`format_text` tokenizes and computes token alignments for
Expand All @@ -58,7 +58,12 @@ def format_text(
different between the two model architectures. See their respective
implementations for further details.
Additionally, we assign the necessary `self.logger_config`.
Additional information computed / variable assignements:
- Assign the necessary `self.logger_config` fields
- Compute token_label_str: the per token str representation
of each sample (List[str]), saved and used for high DEP tokens.
- In Decoder-Only: Decode the response tokens to get the str
representation of the response (i.e. the target show in the UI).
Parameters:
-----------
Expand All @@ -78,6 +83,10 @@ def format_text(
Aligned token data for *just* target tokens, based on `text`
token_label_str: List[List[str]]
The target tokens (as strings) - see `Seq2SeqDataLogger.token_label_str`
targets: List[str]
The decoded response tokens - i.e. the string representation of the
Targets for each sample. Note that this is only computed for
Decoder-Only models. Returns [] for Encoder-Decoder
"""
pass

Expand Down Expand Up @@ -212,7 +221,7 @@ def format_text(
tokenizer: PreTrainedTokenizerFast,
max_tokens: Optional[int],
split_key: str,
) -> Tuple[AlignedTokenData, List[List[str]]]:
) -> Tuple[AlignedTokenData, List[List[str]], List[str]]:
"""Further validation for Encoder-Decoder
For Encoder-Decoder we need to:
Expand Down Expand Up @@ -257,7 +266,7 @@ def format_text(
id_to_tokens = dict(zip(ids, token_label_ids))
self.logger_config.id_to_tokens[split_key].update(id_to_tokens)

return batch_aligned_data, token_label_str
return batch_aligned_data, token_label_str, []

@torch.no_grad()
def generate_sample(
Expand Down Expand Up @@ -382,7 +391,7 @@ def format_text(
tokenizer: PreTrainedTokenizerFast,
max_tokens: Optional[int],
split_key: str,
) -> Tuple[AlignedTokenData, List[List[str]]]:
) -> Tuple[AlignedTokenData, List[List[str]], List[str]]:
"""Further formatting for Decoder-Only
Text is the formatted prompt of combined input/target
Expand Down Expand Up @@ -421,6 +430,8 @@ def format_text(
# Empty initialization
batch_aligned_data = AlignedTokenData([], [])
token_label_str = []
targets = []

# Decode then re-tokenize just the response labels to get correct offsets
for token_label_ids in tqdm(
tokenized_labels,
Expand All @@ -436,12 +447,16 @@ def format_text(
)
)

response_aligned_data = align_response_tokens_to_character_spans(
(
response_aligned_data,
response_str,
) = align_response_tokens_to_character_spans(
tokenizer,
token_label_ids,
max_input_tokens,
)
batch_aligned_data.append(response_aligned_data)
targets.append(response_str)

# Save the tokenized response labels for each samples
id_to_tokens = dict(zip(ids, tokenized_labels))
Expand All @@ -456,7 +471,7 @@ def format_text(
id_to_formatted_prompt_length
)

return batch_aligned_data, token_label_str
return batch_aligned_data, token_label_str, targets

@torch.no_grad()
def generate_sample(
Expand Down
22 changes: 17 additions & 5 deletions dataquality/loggers/data_logger/seq2seq/seq2seq_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,16 @@ def validate_and_format(self) -> None:
label_len = len(self.labels)
text_len = len(self.texts)
id_len = len(self.ids)
assert id_len == text_len == label_len, (
"IDs, texts, and labels must be the same length, got "
f"({id_len} ids, {text_len} texts, {label_len} labels)"
)
if label_len > 0: # Encoder-Decoder case
assert id_len == text_len == label_len, (
"IDs, texts, and labels must be the same length, got "
f"({id_len} ids, {text_len} texts, {label_len} labels)"
)
else: # Decoder-Only case
assert id_len == text_len, (
"IDs and texts must be the same length, got "
f"({id_len} ids, {text_len} texts)"
)
assert self.logger_config.tokenizer, (
"You must set your tokenizer before logging. "
"Use `dq.integrations.seq2seq.core.set_tokenizer`"
Expand All @@ -141,7 +147,11 @@ def validate_and_format(self) -> None:
max_tokens = self.logger_config.max_target_tokens
assert max_tokens

batch_aligned_token_data, token_label_str = self.formatter.format_text(
(
batch_aligned_token_data,
token_label_str,
targets,
) = self.formatter.format_text(
text=texts,
ids=self.ids,
tokenizer=self.logger_config.tokenizer,
Expand All @@ -151,6 +161,8 @@ def validate_and_format(self) -> None:
self.token_label_offsets = batch_aligned_token_data.token_label_offsets
self.token_label_positions = batch_aligned_token_data.token_label_positions
self.token_label_str = token_label_str
if len(targets) > 0: # For Decoder-Only we update the 'targets' here
self.labels = targets

def _get_input_df(self) -> DataFrame:
df_dict = {
Expand Down
18 changes: 15 additions & 3 deletions dataquality/utils/seq2seq/offsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def align_response_tokens_to_character_spans(
tokenizer: PreTrainedTokenizerFast,
tokenized_response: List[int],
max_input_tokens: Optional[int],
) -> AlignedTokenData:
) -> Tuple[AlignedTokenData, str]:
"""Decodes then re-tokenizes the isolated response to get the character alignments
TODO This can prob be done with just tokenizing the "target" in isolation!!
Expand All @@ -283,6 +283,15 @@ def align_response_tokens_to_character_spans(
in the offset map and slice the offset map accordingly.
This may also avoid strange space issues with tokenizers hanlding words
at the start of a document.
Return:
-------
aligned_token_data: AlignedTokenData
Aligned token data for a single Response - batch dim = 1.
decoded_response: str
The string representation of the Response, used as the
Target string in the console. Note: we do not remove
special characters, so these will appear in the console!
"""
decoded_response = tokenizer.decode(tokenized_response)
re_tokenized_response = tokenizer(
Expand All @@ -293,6 +302,9 @@ def align_response_tokens_to_character_spans(
# I believe that this should be handled! We can prob set to None
max_length=max_input_tokens,
)
return align_tokens_to_character_spans(
re_tokenized_response["offset_mapping"], disable_tqdm=True
return (
align_tokens_to_character_spans(
re_tokenized_response["offset_mapping"], disable_tqdm=True
),
decoded_response,
)

0 comments on commit 908af1d

Please sign in to comment.