diff --git a/pyproject.toml b/pyproject.toml index 86ce42b..e6d6063 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api" [tool.poetry] name = "together" -version = "1.5.29" +version = "1.5.30" authors = ["Together AI "] description = "Python client for Together's Cloud Platform!" readme = "README.md" diff --git a/src/together/utils/files.py b/src/together/utils/files.py index ef169ae..8f18bc6 100644 --- a/src/together/utils/files.py +++ b/src/together/utils/files.py @@ -102,81 +102,163 @@ def check_file( return report_dict -def validate_messages(messages: List[Dict[str, str | bool]], idx: int) -> None: - """Validate the messages column.""" +def _check_conversation_type(messages: List[Dict[str, str | bool]], idx: int) -> None: + """Check that the conversation has correct type. + + Args: + messages: The messages in the conversation. + Can be any type, this function ensures that the messages are a list of dictionaries. + idx: Line number in the file. + + Raises: + InvalidFileFormatError: If the conversation type is invalid. + """ if not isinstance(messages, list): raise InvalidFileFormatError( message=f"Invalid format on line {idx + 1} of the input file. " - f"Expected a list of messages. Found {type(messages)}", + f"The `messages` column must be a list. Found {type(messages)}", line_number=idx + 1, error_source="key_value", ) - if not messages: + if len(messages) == 0: raise InvalidFileFormatError( message=f"Invalid format on line {idx + 1} of the input file. " - f"Expected a non-empty list of messages. Found empty list", + f"The `messages` column must not be empty.", line_number=idx + 1, error_source="key_value", ) - has_weights = any("weight" in message for message in messages) - - previous_role = None for message in messages: if not isinstance(message, dict): raise InvalidFileFormatError( message=f"Invalid format on line {idx + 1} of the input file. " - f"Expected a dictionary in the messages list. Found {type(message)}", + f"The `messages` column must be a list of dicts. Found {type(message)}", line_number=idx + 1, error_source="key_value", ) + for column in REQUIRED_COLUMNS_MESSAGE: if column not in message: raise InvalidFileFormatError( - message=f"Field `{column}` is missing for a turn `{message}` on line {idx + 1} " - "of the the input file.", + message=f"Missing required column `{column}` in message on line {idx + 1}.", line_number=idx + 1, error_source="key_value", ) - else: - if not isinstance(message[column], str): - raise InvalidFileFormatError( - message=f"Invalid format on line {idx + 1} in the column {column} for turn `{message}` " - f"of the input file. Expected string. Found {type(message[column])}", - line_number=idx + 1, - error_source="text_field", - ) - - if has_weights and "weight" in message: - weight = message["weight"] - if not isinstance(weight, int): - raise InvalidFileFormatError( - message="Weight must be an integer", - line_number=idx + 1, - error_source="key_value", - ) - if weight not in {0, 1}: + if not isinstance(message[column], str): raise InvalidFileFormatError( - message="Weight must be either 0 or 1", + message=f"Column `{column}` is not a string on line {idx + 1}. Found {type(message[column])}", line_number=idx + 1, - error_source="key_value", + error_source="text_field", ) - if message["role"] not in POSSIBLE_ROLES_CONVERSATION: + + +def _check_conversation_roles( + require_assistant_role: bool, assistant_role_exists: bool, idx: int +) -> None: + """Check that the conversation has correct roles. + + Args: + require_assistant_role: Whether to require at least one assistant role. + assistant_role_exists: Whether an assistant role exists in the conversation. + idx: Line number in the file. + + Raises: + InvalidFileFormatError: If the conversation roles are invalid. + """ + if require_assistant_role and not assistant_role_exists: + raise InvalidFileFormatError( + message=f"Invalid format on line {idx + 1} of the input file. " + "At least one message with the assistant role must be present in the example.", + line_number=idx + 1, + error_source="key_value", + ) + + +def _check_message_weight(message: Dict[str, str | bool], idx: int) -> None: + """Check that the message has a weight with the correct type and value. + + Args: + message: The message to check. + idx: Line number in the file. + + Raises: + InvalidFileFormatError: If the message weight is invalid. + """ + if "weight" in message: + weight = message["weight"] + if not isinstance(weight, int): raise InvalidFileFormatError( - message=f"Found invalid role `{message['role']}` in the messages on the line {idx + 1}. " - f"Possible roles in the conversation are: {POSSIBLE_ROLES_CONVERSATION}", + message=f"Weight must be an integer on line {idx + 1}.", line_number=idx + 1, error_source="key_value", ) - - if previous_role == message["role"]: + if weight not in {0, 1}: raise InvalidFileFormatError( - message=f"Invalid role turns on line {idx + 1} of the input file. " - "`user` and `assistant` roles must alternate user/assistant/user/assistant/...", + message=f"Weight must be either 0 or 1 on line {idx + 1}.", line_number=idx + 1, error_source="key_value", ) - previous_role = message["role"] + + +def _check_message_role( + message: Dict[str, str | bool], previous_role: str | None, idx: int +) -> str | bool: + """Check that the message has correct roles. + + Args: + message: The message to check. + previous_role: The role of the previous message. + idx: Line number in the file. + + Returns: + str: The role of the current message. + + Raises: + InvalidFileFormatError: If the message role is invalid. + """ + if message["role"] not in POSSIBLE_ROLES_CONVERSATION: + raise InvalidFileFormatError( + message=f"Invalid role `{message['role']}` in conversation on line {idx + 1}. " + f"Possible roles: {', '.join(POSSIBLE_ROLES_CONVERSATION)}", + line_number=idx + 1, + error_source="key_value", + ) + if previous_role is not None and message["role"] == previous_role: + raise InvalidFileFormatError( + message=f"Invalid role turns on line {idx + 1} of the input file. " + "After the optional system message, conversation roles must alternate between user/assistant/user/assistant.", + line_number=idx + 1, + error_source="key_value", + ) + return message["role"] + + +def validate_messages( + messages: List[Dict[str, str | bool]], idx: int, require_assistant_role: bool = True +) -> None: + """Validate the messages column. + + Args: + messages: List of message dictionaries to validate. + idx: Line number in the file. + require_assistant_role: Whether to require at least one assistant role. + + Raises: + InvalidFileFormatError: If the messages are invalid. + """ + _check_conversation_type(messages, idx) + + has_weights = any("weight" in message for message in messages) + previous_role = None + assistant_role_exists = False + + for message in messages: + if has_weights: + _check_message_weight(message, idx) + previous_role = _check_message_role(message, previous_role, idx) + assistant_role_exists |= previous_role == "assistant" + + _check_conversation_roles(require_assistant_role, assistant_role_exists, idx) def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> None: @@ -203,37 +285,73 @@ def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> None: error_source="key_value", ) - validate_messages(example["input"]["messages"], idx) + validate_messages(example["input"]["messages"], idx, require_assistant_role=False) + + if example["input"]["messages"][-1]["role"] == "assistant": + raise InvalidFileFormatError( + message=f"The last message in the input conversation must not be from the assistant on line {idx + 1}.", + line_number=idx + 1, + error_source="key_value", + ) + + keys = ["preferred_output", "non_preferred_output"] + + for key in keys: + if key not in example: + raise InvalidFileFormatError( + message=f"The dataset is malformed, the `{key}` field must be present in the input dictionary on line {idx + 1}.", + line_number=idx + 1, + error_source="key_value", + ) + + if not isinstance(example[key], list): + raise InvalidFileFormatError( + message=f"The dataset is malformed, the `{key}` field must be a list on line {idx + 1}.", + line_number=idx + 1, + error_source="key_value", + ) + + if len(example[key]) != 1: + raise InvalidFileFormatError( + message=f"The dataset is malformed, the `{key}` list must contain exactly one message on line {idx + 1}.", + line_number=idx + 1, + error_source="key_value", + ) - for output_field in ["preferred_output", "non_preferred_output"]: - if not isinstance(example[output_field], list): + if not isinstance(example[key][0], dict): raise InvalidFileFormatError( - message=f"The dataset is malformed, the `{output_field}` field must be a list.", + message=f"The dataset is malformed, the first element of `{key}` must be a dictionary on line {idx + 1}.", line_number=idx + 1, error_source="key_value", ) - if len(example[output_field]) != 1: + if "role" not in example[key][0]: raise InvalidFileFormatError( - message=f"The dataset is malformed, the `{output_field}` list must contain exactly one message.", + message=f"The dataset is malformed, the first element of `{key}` must have a 'role' field on line {idx + 1}.", line_number=idx + 1, error_source="key_value", ) - if "role" not in example[output_field][0]: + + if example[key][0]["role"] != "assistant": raise InvalidFileFormatError( - message=f"The dataset is malformed, the `{output_field}` message is missing the `role` field.", + message=f"The dataset is malformed, the first element of `{key}` must have the 'assistant' role on line {idx + 1}.", line_number=idx + 1, error_source="key_value", ) - elif example[output_field][0]["role"] != "assistant": + + if "content" not in example[key][0]: raise InvalidFileFormatError( - message=f"The dataset is malformed, the `{output_field}` must contain an assistant message.", + message=f"The dataset is malformed, the first element of `{key}` must have a 'content' field on line {idx + 1}.", line_number=idx + 1, error_source="key_value", ) - validate_messages(example["preferred_output"], idx) - validate_messages(example["non_preferred_output"], idx) + if not isinstance(example[key][0]["content"], str): + raise InvalidFileFormatError( + message=f"The dataset is malformed, the 'content' field in `{key}` must be a string on line {idx + 1}.", + line_number=idx + 1, + error_source="key_value", + ) def _check_utf8(file: Path) -> Dict[str, Any]: @@ -410,7 +528,12 @@ def _check_jsonl(file: Path, purpose: FilePurpose | str) -> Dict[str, Any]: message_column = JSONL_REQUIRED_COLUMNS_MAP[ DatasetFormat.CONVERSATION ][0] - validate_messages(json_line[message_column], idx) + require_assistant = purpose != FilePurpose.Eval + validate_messages( + json_line[message_column], + idx, + require_assistant_role=require_assistant, + ) else: for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]: if not isinstance(json_line[column], str): diff --git a/tests/unit/test_files_checks.py b/tests/unit/test_files_checks.py index 4888718..728452c 100644 --- a/tests/unit/test_files_checks.py +++ b/tests/unit/test_files_checks.py @@ -182,7 +182,12 @@ def test_check_jsonl_inconsistent_dataset_format(tmp_path: Path): # Create a JSONL file with inconsistent dataset formats file = tmp_path / "inconsistent_format.jsonl" content = [ - {"messages": [{"role": "user", "content": "Hi"}]}, + { + "messages": [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hi! How can I help you?"}, + ] + }, {"text": "How are you?"}, # Missing 'messages' ] with file.open("w") as f: @@ -207,7 +212,7 @@ def test_check_jsonl_invalid_role(tmp_path: Path): report = check_file(file) assert not report["is_check_passed"] - assert "Found invalid role `invalid_role`" in report["message"] + assert "Invalid role `invalid_role` in conversation" in report["message"] def test_check_jsonl_non_alternating_roles(tmp_path: Path): @@ -230,6 +235,22 @@ def test_check_jsonl_non_alternating_roles(tmp_path: Path): assert "Invalid role turns" in report["message"] +def test_check_jsonl_assistant_role_exists(tmp_path: Path): + # Create a JSONL file with no assistant role + file = tmp_path / "assistant_role_exists.jsonl" + content = [{"messages": [{"role": "user", "content": "Hi"}]}] + with file.open("w") as f: + f.write("\n".join(json.dumps(item) for item in content)) + + report = check_file(file) + + assert not report["is_check_passed"] + assert ( + "At least one message with the assistant role must be present" + in report["message"] + ) + + def test_check_jsonl_invalid_value_type(tmp_path: Path): # Create a JSONL file with an invalid value type file = tmp_path / "invalid_value_type.jsonl" @@ -257,7 +278,7 @@ def test_check_jsonl_missing_field_in_conversation(tmp_path: Path): report = check_file(file) assert not report["is_check_passed"] - assert "Field `content` is missing for a turn" in report["message"] + assert "Missing required column `content`" in report["message"] def test_check_jsonl_wrong_turn_type(tmp_path: Path): @@ -277,7 +298,7 @@ def test_check_jsonl_wrong_turn_type(tmp_path: Path): report = check_file(file) assert not report["is_check_passed"] assert ( - "Invalid format on line 1 of the input file. Expected a dictionary" + "Invalid format on line 1 of the input file. The `messages` column must be a list of dicts." in report["message"] ) @@ -301,9 +322,7 @@ def test_check_jsonl_empty_messages(tmp_path: Path): report = check_file(file) assert not report["is_check_passed"] - assert ( - "Expected a non-empty list of messages. Found empty list" in report["message"] - ) + assert "The `messages` column must not be empty" in report["message"] def test_check_jsonl_valid_weights_all_messages(tmp_path: Path):