Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/together/lib/cli/api/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from click.core import ParameterSource # type: ignore[attr-defined]

from together import Together
from together.types import FullTrainingType, LoRaTrainingType
from together._types import NOT_GIVEN, NotGiven
from together.lib.utils import log_warn
from together.lib.utils.tools import format_timestamp, finetune_price_to_dollars
from together.lib.cli.api.utils import INT_WITH_MAX, BOOL_WITH_AUTO
from together.lib.resources.files import DownloadManager
from together.lib.utils.serializer import datetime_serializer
from together.types.finetune_response import TrainingTypeFullTrainingType, TrainingTypeLoRaTrainingType
from together.lib.resources.fine_tuning import get_model_limits

_CONFIRMATION_MESSAGE = (
Expand Down Expand Up @@ -513,11 +513,11 @@ def download(
ft_job = client.fine_tuning.retrieve(fine_tune_id)

loosely_typed_checkpoint_type: str | NotGiven = checkpoint_type
if isinstance(ft_job.training_type, FullTrainingType):
if isinstance(ft_job.training_type, TrainingTypeFullTrainingType):
if checkpoint_type != "default":
raise ValueError("Only DEFAULT checkpoint type is allowed for FullTrainingType")
loosely_typed_checkpoint_type = "model_output_path"
elif isinstance(ft_job.training_type, LoRaTrainingType):
elif isinstance(ft_job.training_type, TrainingTypeLoRaTrainingType):
if checkpoint_type == "default":
loosely_typed_checkpoint_type = "merged"

Expand Down
34 changes: 17 additions & 17 deletions src/together/lib/resources/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from filelock import FileLock
from tqdm.utils import CallbackIOWrapper

from ...types import FileType, FilePurpose, FileRetrieveResponse
from ...types import FileType, FilePurpose, FileResponse
from ..._types import RequestOptions
from ..constants import (
DISABLE_TQDM,
Expand Down Expand Up @@ -273,9 +273,9 @@ def get_upload_url(

return redirect_url, file_id

def callback(self, url: str) -> FileRetrieveResponse:
def callback(self, url: str) -> FileResponse:
response = self._client.post(
cast_to=FileRetrieveResponse,
cast_to=FileResponse,
path=url,
)

Expand All @@ -286,7 +286,7 @@ def upload(
url: str,
file: Path,
purpose: FilePurpose,
) -> FileRetrieveResponse:
) -> FileResponse:
file_size = os.stat(file.as_posix()).st_size
file_size_gb = file_size / NUM_BYTES_IN_GB

Expand All @@ -306,7 +306,7 @@ def _upload_single_file(
url: str,
file: Path,
purpose: FilePurpose,
) -> FileRetrieveResponse:
) -> FileResponse:
file_id = None

redirect_url = None
Expand Down Expand Up @@ -357,7 +357,7 @@ def _upload_single_file(

response = self.callback(f"{url}/{file_id}/preprocess")

assert isinstance(response, FileRetrieveResponse) # type: ignore
assert isinstance(response, FileResponse) # type: ignore

return response

Expand All @@ -374,7 +374,7 @@ def upload(
url: str,
file: Path,
purpose: FilePurpose,
) -> FileRetrieveResponse:
) -> FileResponse:
"""Upload large file using multipart upload"""

file_size = os.stat(file.as_posix()).st_size
Expand Down Expand Up @@ -551,7 +551,7 @@ def _complete_upload(
upload_id: str,
file_id: str,
completed_parts: List[Dict[str, Any]],
) -> FileRetrieveResponse:
) -> FileResponse:
"""Complete the multipart upload"""

payload = {
Expand All @@ -576,7 +576,7 @@ def _complete_upload(
if response.status_code == 200:
response_data = response.json()
file_data = response_data.get("file", response_data)
return FileRetrieveResponse(**file_data)
return FileResponse(**file_data)
else:
raise APIStatusError(
f"Failed to complete multipart upload: {response.text}",
Expand Down Expand Up @@ -654,9 +654,9 @@ async def get_upload_url(

return redirect_url, file_id

async def callback(self, url: str) -> FileRetrieveResponse:
async def callback(self, url: str) -> FileResponse:
response = self._client.post(
cast_to=FileRetrieveResponse,
cast_to=FileResponse,
path=url,
)

Expand All @@ -667,7 +667,7 @@ async def upload(
url: str,
file: Path,
purpose: FilePurpose,
) -> FileRetrieveResponse:
) -> FileResponse:
file_size = os.stat(file.as_posix()).st_size
file_size_gb = file_size / NUM_BYTES_IN_GB

Expand All @@ -687,7 +687,7 @@ async def _upload_single_file(
url: str,
file: Path,
purpose: FilePurpose,
) -> FileRetrieveResponse:
) -> FileResponse:
file_id = None

redirect_url = None
Expand Down Expand Up @@ -738,7 +738,7 @@ async def _upload_single_file(

response = await self.callback(f"{url}/{file_id}/preprocess")

assert isinstance(response, FileRetrieveResponse) # type: ignore
assert isinstance(response, FileResponse) # type: ignore

return response

Expand All @@ -755,7 +755,7 @@ async def upload(
url: str,
file: Path,
purpose: FilePurpose,
) -> FileRetrieveResponse:
) -> FileResponse:
"""Upload large file using multipart upload via ThreadPoolExecutor"""

file_size = os.stat(file.as_posix()).st_size
Expand Down Expand Up @@ -932,7 +932,7 @@ async def _complete_upload(
upload_id: str,
file_id: str,
completed_parts: List[Dict[str, Any]],
) -> FileRetrieveResponse:
) -> FileResponse:
"""Complete the multipart upload"""

payload = {
Expand All @@ -957,7 +957,7 @@ async def _complete_upload(
if response.status_code == 200:
response_data = response.json()
file_data = response_data.get("file", response_data)
return FileRetrieveResponse(**file_data)
return FileResponse(**file_data)
else:
raise APIStatusError(
f"Failed to complete multipart upload: {response.text}",
Expand Down
8 changes: 4 additions & 4 deletions src/together/resources/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def upload(
*,
purpose: FilePurpose | str = "fine-tune",
check: bool = True,
) -> FileRetrieveResponse:
) -> FileResponse:
if check:
report_dict = check_file(file)
if not report_dict["is_check_passed"]:
Expand All @@ -165,7 +165,7 @@ def upload(
upload_manager = UploadManager(self._client)
result = upload_manager.upload("/files", file, purpose)

return FileRetrieveResponse(
return FileResponse(
id=result.id,
bytes=result.bytes,
created_at=result.created_at,
Expand Down Expand Up @@ -323,7 +323,7 @@ async def upload(
*,
purpose: FilePurpose | str = "fine-tune",
check: bool = True,
) -> FileRetrieveResponse:
) -> FileResponse:
if check:
report_dict = check_file(file)
if not report_dict["is_check_passed"]:
Expand All @@ -340,7 +340,7 @@ async def upload(
upload_manager = AsyncUploadManager(self._client)
result = await upload_manager.upload("/files", file, purpose)

return FileRetrieveResponse(
return FileResponse(
id=result.id,
bytes=result.bytes,
created_at=result.created_at,
Expand Down
6 changes: 3 additions & 3 deletions src/together/resources/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
async_to_custom_streamed_response_wrapper,
)
from .._base_client import make_request_options
from ..lib.types.fine_tuning import FinetuneTrainingLimits
from ..lib.types.fine_tuning import FinetuneResponse as FinetuneResponseLib, FinetuneTrainingLimits
from ..types.finetune_response import FinetuneResponse
from ..lib.resources.fine_tuning import get_model_limits, async_get_model_limits, create_finetune_request
from ..types.fine_tuning_list_response import FineTuningListResponse
Expand Down Expand Up @@ -99,7 +99,7 @@ def create(
hf_model_revision: str | None = None,
hf_api_token: str | None = None,
hf_output_repo_name: str | None = None,
) -> FinetuneResponse:
) -> FinetuneResponseLib:
"""
Method to initiate a fine-tuning job

Expand Down Expand Up @@ -228,7 +228,7 @@ def create(
return self._client.post(
"/fine-tunes",
body=parameter_payload,
cast_to=FinetuneResponse,
cast_to=FinetuneResponseLib,
)

def retrieve(
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/resources/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from together import Together
from together.types import (
FileRetrieveResponse,
FileResponse,
)


Expand Down Expand Up @@ -39,7 +39,7 @@ def test_file_upload(
)

# Verify the response
assert isinstance(response, FileRetrieveResponse)
assert isinstance(response, FileResponse)
assert response.filename == "valid.jsonl"
assert response.file_type == "jsonl"
assert response.line_count == 0
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_files_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from together import Together
from together.types import (
FileRetrieveResponse,
FileResponse,
)


Expand Down Expand Up @@ -75,7 +75,7 @@ def test_file_upload(mocker: MockerFixture, tmp_path: Path):
)

# Verify the response
assert isinstance(response, FileRetrieveResponse)
assert isinstance(response, FileResponse)
assert response.filename == "valid.jsonl"
assert response.bytes == len(content_bytes)
assert response.created_at == 1234567890
Expand Down