diff --git a/src/together/lib/cli/api/fine_tuning.py b/src/together/lib/cli/api/fine_tuning.py index c116d352..d1215359 100644 --- a/src/together/lib/cli/api/fine_tuning.py +++ b/src/together/lib/cli/api/fine_tuning.py @@ -13,13 +13,13 @@ from click.core import ParameterSource # type: ignore[attr-defined] from together import Together -from together.types import FullTrainingType, LoRaTrainingType from together._types import NOT_GIVEN, NotGiven from together.lib.utils import log_warn from together.lib.utils.tools import format_timestamp, finetune_price_to_dollars from together.lib.cli.api.utils import INT_WITH_MAX, BOOL_WITH_AUTO from together.lib.resources.files import DownloadManager from together.lib.utils.serializer import datetime_serializer +from together.types.finetune_response import TrainingTypeFullTrainingType, TrainingTypeLoRaTrainingType from together.lib.resources.fine_tuning import get_model_limits _CONFIRMATION_MESSAGE = ( @@ -513,11 +513,11 @@ def download( ft_job = client.fine_tuning.retrieve(fine_tune_id) loosely_typed_checkpoint_type: str | NotGiven = checkpoint_type - if isinstance(ft_job.training_type, FullTrainingType): + if isinstance(ft_job.training_type, TrainingTypeFullTrainingType): if checkpoint_type != "default": raise ValueError("Only DEFAULT checkpoint type is allowed for FullTrainingType") loosely_typed_checkpoint_type = "model_output_path" - elif isinstance(ft_job.training_type, LoRaTrainingType): + elif isinstance(ft_job.training_type, TrainingTypeLoRaTrainingType): if checkpoint_type == "default": loosely_typed_checkpoint_type = "merged" diff --git a/src/together/lib/resources/files.py b/src/together/lib/resources/files.py index 7120abfd..b82e84d8 100644 --- a/src/together/lib/resources/files.py +++ b/src/together/lib/resources/files.py @@ -18,7 +18,7 @@ from filelock import FileLock from tqdm.utils import CallbackIOWrapper -from ...types import FileType, FilePurpose, FileRetrieveResponse +from ...types import FileType, FilePurpose, FileResponse from ..._types import RequestOptions from ..constants import ( DISABLE_TQDM, @@ -273,9 +273,9 @@ def get_upload_url( return redirect_url, file_id - def callback(self, url: str) -> FileRetrieveResponse: + def callback(self, url: str) -> FileResponse: response = self._client.post( - cast_to=FileRetrieveResponse, + cast_to=FileResponse, path=url, ) @@ -286,7 +286,7 @@ def upload( url: str, file: Path, purpose: FilePurpose, - ) -> FileRetrieveResponse: + ) -> FileResponse: file_size = os.stat(file.as_posix()).st_size file_size_gb = file_size / NUM_BYTES_IN_GB @@ -306,7 +306,7 @@ def _upload_single_file( url: str, file: Path, purpose: FilePurpose, - ) -> FileRetrieveResponse: + ) -> FileResponse: file_id = None redirect_url = None @@ -357,7 +357,7 @@ def _upload_single_file( response = self.callback(f"{url}/{file_id}/preprocess") - assert isinstance(response, FileRetrieveResponse) # type: ignore + assert isinstance(response, FileResponse) # type: ignore return response @@ -374,7 +374,7 @@ def upload( url: str, file: Path, purpose: FilePurpose, - ) -> FileRetrieveResponse: + ) -> FileResponse: """Upload large file using multipart upload""" file_size = os.stat(file.as_posix()).st_size @@ -551,7 +551,7 @@ def _complete_upload( upload_id: str, file_id: str, completed_parts: List[Dict[str, Any]], - ) -> FileRetrieveResponse: + ) -> FileResponse: """Complete the multipart upload""" payload = { @@ -576,7 +576,7 @@ def _complete_upload( if response.status_code == 200: response_data = response.json() file_data = response_data.get("file", response_data) - return FileRetrieveResponse(**file_data) + return FileResponse(**file_data) else: raise APIStatusError( f"Failed to complete multipart upload: {response.text}", @@ -654,9 +654,9 @@ async def get_upload_url( return redirect_url, file_id - async def callback(self, url: str) -> FileRetrieveResponse: + async def callback(self, url: str) -> FileResponse: response = self._client.post( - cast_to=FileRetrieveResponse, + cast_to=FileResponse, path=url, ) @@ -667,7 +667,7 @@ async def upload( url: str, file: Path, purpose: FilePurpose, - ) -> FileRetrieveResponse: + ) -> FileResponse: file_size = os.stat(file.as_posix()).st_size file_size_gb = file_size / NUM_BYTES_IN_GB @@ -687,7 +687,7 @@ async def _upload_single_file( url: str, file: Path, purpose: FilePurpose, - ) -> FileRetrieveResponse: + ) -> FileResponse: file_id = None redirect_url = None @@ -738,7 +738,7 @@ async def _upload_single_file( response = await self.callback(f"{url}/{file_id}/preprocess") - assert isinstance(response, FileRetrieveResponse) # type: ignore + assert isinstance(response, FileResponse) # type: ignore return response @@ -755,7 +755,7 @@ async def upload( url: str, file: Path, purpose: FilePurpose, - ) -> FileRetrieveResponse: + ) -> FileResponse: """Upload large file using multipart upload via ThreadPoolExecutor""" file_size = os.stat(file.as_posix()).st_size @@ -932,7 +932,7 @@ async def _complete_upload( upload_id: str, file_id: str, completed_parts: List[Dict[str, Any]], - ) -> FileRetrieveResponse: + ) -> FileResponse: """Complete the multipart upload""" payload = { @@ -957,7 +957,7 @@ async def _complete_upload( if response.status_code == 200: response_data = response.json() file_data = response_data.get("file", response_data) - return FileRetrieveResponse(**file_data) + return FileResponse(**file_data) else: raise APIStatusError( f"Failed to complete multipart upload: {response.text}", diff --git a/src/together/resources/files.py b/src/together/resources/files.py index a3087c99..e8e4cd34 100644 --- a/src/together/resources/files.py +++ b/src/together/resources/files.py @@ -148,7 +148,7 @@ def upload( *, purpose: FilePurpose | str = "fine-tune", check: bool = True, - ) -> FileRetrieveResponse: + ) -> FileResponse: if check: report_dict = check_file(file) if not report_dict["is_check_passed"]: @@ -165,7 +165,7 @@ def upload( upload_manager = UploadManager(self._client) result = upload_manager.upload("/files", file, purpose) - return FileRetrieveResponse( + return FileResponse( id=result.id, bytes=result.bytes, created_at=result.created_at, @@ -323,7 +323,7 @@ async def upload( *, purpose: FilePurpose | str = "fine-tune", check: bool = True, - ) -> FileRetrieveResponse: + ) -> FileResponse: if check: report_dict = check_file(file) if not report_dict["is_check_passed"]: @@ -340,7 +340,7 @@ async def upload( upload_manager = AsyncUploadManager(self._client) result = await upload_manager.upload("/files", file, purpose) - return FileRetrieveResponse( + return FileResponse( id=result.id, bytes=result.bytes, created_at=result.created_at, diff --git a/src/together/resources/fine_tuning.py b/src/together/resources/fine_tuning.py index f0f1346d..47362d31 100644 --- a/src/together/resources/fine_tuning.py +++ b/src/together/resources/fine_tuning.py @@ -27,7 +27,7 @@ async_to_custom_streamed_response_wrapper, ) from .._base_client import make_request_options -from ..lib.types.fine_tuning import FinetuneTrainingLimits +from ..lib.types.fine_tuning import FinetuneResponse as FinetuneResponseLib, FinetuneTrainingLimits from ..types.finetune_response import FinetuneResponse from ..lib.resources.fine_tuning import get_model_limits, async_get_model_limits, create_finetune_request from ..types.fine_tuning_list_response import FineTuningListResponse @@ -99,7 +99,7 @@ def create( hf_model_revision: str | None = None, hf_api_token: str | None = None, hf_output_repo_name: str | None = None, - ) -> FinetuneResponse: + ) -> FinetuneResponseLib: """ Method to initiate a fine-tuning job @@ -228,7 +228,7 @@ def create( return self._client.post( "/fine-tunes", body=parameter_payload, - cast_to=FinetuneResponse, + cast_to=FinetuneResponseLib, ) def retrieve( diff --git a/tests/integration/resources/test_files.py b/tests/integration/resources/test_files.py index 4303ee24..2e713040 100644 --- a/tests/integration/resources/test_files.py +++ b/tests/integration/resources/test_files.py @@ -6,7 +6,7 @@ from together import Together from together.types import ( - FileRetrieveResponse, + FileResponse, ) @@ -39,7 +39,7 @@ def test_file_upload( ) # Verify the response - assert isinstance(response, FileRetrieveResponse) + assert isinstance(response, FileResponse) assert response.filename == "valid.jsonl" assert response.file_type == "jsonl" assert response.line_count == 0 diff --git a/tests/unit/test_files_resource.py b/tests/unit/test_files_resource.py index 759fb074..c7d63eec 100644 --- a/tests/unit/test_files_resource.py +++ b/tests/unit/test_files_resource.py @@ -8,7 +8,7 @@ from together import Together from together.types import ( - FileRetrieveResponse, + FileResponse, ) @@ -75,7 +75,7 @@ def test_file_upload(mocker: MockerFixture, tmp_path: Path): ) # Verify the response - assert isinstance(response, FileRetrieveResponse) + assert isinstance(response, FileResponse) assert response.filename == "valid.jsonl" assert response.bytes == len(content_bytes) assert response.created_at == 1234567890