Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/enabled all features #91

Merged
merged 13 commits into from
Mar 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.idea
./venv
**venv**
.DS_Store
Expand Down
2 changes: 1 addition & 1 deletion docs/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ datasetId = "<Dataset ID>" # Don't forget to replace this with the actual datas
dataset = client.dataset(datasetId)

# Retrieve the URL for downloading dataset contents
url = dataset.get_download_link("archive")
url = dataset.get_download_link()
print("Download URL:", url)
```

Expand Down
8 changes: 4 additions & 4 deletions docs/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,15 @@ model = client.model(modelId)
model.export(format="pyTorch") # Exports the model as a PyTorch file
```

## Retrieve a Direct Download URL
## Retrieve a Direct Weight URL

Occasionally, you might require direct access to your model's remotely-stored artifacts. This convenient function provides a URL to access and download specific files like your best-performing model weights.
Occasionally, you might require direct access to your model's remotely-stored artifacts. This convenient function provides a URL to access specific files like your best-performing model weights.

```python
modelId = "<Model ID>"
model = client.model(modelId)
download_url = model.get_download_link("best") # Retrieves the download link for the best model checkpoint
print("Model download link:", download_url) # Prints out the download link
weight_url = model.get_weights_url("best") # Retrieves the URL for the model's optimal checkpoint weights. By default, it returns the URL for the best weights. To obtain the most recent weights, specify 'last.
print("Weight URL link:", weight_url) # Prints out the weight url link
```

## Upload a Model Checkpoint
Expand Down
5 changes: 4 additions & 1 deletion hub_sdk/base/paginated_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class PaginatedList(APIClient):
def __init__(self, base_endpoint, name, page_size=None, headers=None):
def __init__(self, base_endpoint, name, page_size=None, public=None, headers=None):
"""
Initialize a PaginatedList instance.

Expand All @@ -20,6 +20,7 @@ def __init__(self, base_endpoint, name, page_size=None, headers=None):
super().__init__(f"{HUB_FUNCTIONS_ROOT}/v1/{base_endpoint}", headers)
self.name = name
self.page_size = page_size
self.public = public
self.pages = [None]
self.current_page = 0
self.total_pages = 1
Expand Down Expand Up @@ -99,6 +100,8 @@ def list(self, page_size: int = 10, last_record=None, query=None) -> Optional[Re
params["lastRecordId"] = last_record
if query:
params["query"] = query
if self.public is not None:
params["public"] = self.public
return self.get("", params=params)
except Exception as e:
self.logger.error(f"Failed to list {self.name}: %s", e)
1 change: 0 additions & 1 deletion hub_sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
"http://localhost:9099/identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key=AIzaSyDlTep"
"-ubgWoafviJJneFL35raoJjWFnOw",
)

HUB_FUNCTIONS_ROOT = f"{HUB_API_ROOT}"

HUB_EXCEPTIONS = os.getenv("ULTRALYTICS_HUB_EXCEPTIONS", "true").lower() == "true"
Expand Down
10 changes: 5 additions & 5 deletions hub_sdk/hub_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def model(self, model_id: Optional[str] = None) -> Models:
return Models(model_id, self.get_auth_header())

@require_authentication
def dataset(self, dataset_id: str = None) -> DatasetList:
def dataset(self, dataset_id: str = None) -> Datasets:
"""
Returns an instance of the Datasets class for interacting with datasets.

Expand All @@ -106,7 +106,7 @@ def dataset(self, dataset_id: str = None) -> DatasetList:
Returns:
(Datasets): An instance of the Datasets class.
"""
raise Exception("Coming Soon")
return Datasets(dataset_id, self.get_auth_header())

@require_authentication
def team(self, arg):
Expand All @@ -125,7 +125,7 @@ def project(self, project_id: Optional[str] = None) -> Projects:
Returns:
(Projects): An instance of the Projects class.
"""
raise Exception("Coming Soon")
return Projects(project_id, self.get_auth_header())

@require_authentication
def user(self, user_id: Optional[str] = None) -> Users:
Expand Down Expand Up @@ -167,7 +167,7 @@ def project_list(self, page_size: Optional[int] = None, public: Optional[bool] =
Returns:
(ProjectList): An instance of the ProjectList class.
"""
raise Exception("Coming Soon")
return ProjectList(page_size, public, self.get_auth_header())

@require_authentication
def dataset_list(self, page_size: Optional[int] = None, public: Optional[bool] = None) -> DatasetList:
Expand All @@ -181,7 +181,7 @@ def dataset_list(self, page_size: Optional[int] = None, public: Optional[bool] =
Returns:
(DatasetList): An instance of the DatasetList class.
"""
raise Exception("Coming Soon")
return DatasetList(page_size, public, self.get_auth_header())

@require_authentication
def team_list(self, page_size=None, public=None):
Expand Down
19 changes: 3 additions & 16 deletions hub_sdk/modules/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,25 +110,14 @@ def upload_dataset(self, file: str = None) -> Optional[Response]:
"""
return self.hub_client.upload_dataset(self.id, file)

def get_download_link(self, type: str) -> Optional[str]:
def get_download_link(self) -> Optional[str]:
"""
Get dataset download link.

Args:
type (str):

Returns:
(Optional[str]): Return download link or None if the link is not available.
"""
try:
payload = {"collection": "datasets", "docId": self.id, "object": type}
endpoint = f"{HUB_FUNCTIONS_ROOT}/v1/storage"
response = self.post(endpoint, json=payload)
json = response.json()
return json.get("data", {}).get("url")
except Exception as e:
self.logger.error(f"Failed to download file file for {self.name}: %s", e)
raise e
return self.data.get("url")


class DatasetList(PaginatedList):
Expand All @@ -142,6 +131,4 @@ def __init__(self, page_size=None, public=None, headers=None):
headers (dict, optional): Headers to be included in API requests.
"""
base_endpoint = "datasets"
if public:
base_endpoint = f"public/{base_endpoint}"
super().__init__(base_endpoint, "dataset", page_size, headers)
super().__init__(base_endpoint, "dataset", page_size, public, headers)
26 changes: 2 additions & 24 deletions hub_sdk/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,27 +303,7 @@ def upload_metrics(self, metrics: dict) -> Optional[Response]:
"""
return self.hub_client.upload_metrics(self.id, metrics) # response

def get_download_link(self, type: str) -> Optional[str]:
"""
Get model download link.

Args:
type (Optional[str]):

Returns:
(Optional[str]): Return download link or None if the link is not available.
"""
try:
payload = {"collection": "models", "docId": self.id, "object": type}
endpoint = f"{HUB_FUNCTIONS_ROOT}/v1/storage"
response = self.post(endpoint, json=payload)
json = response.json()
return json.get("data", {}).get("url")
except Exception as e:
self.logger.error(f"Failed to download link for {self.name}: %s", e)
raise e

def start_heartbeat(self, interval: int = 60) -> None:
def start_heartbeat(self, interval: int = 60):
"""
Starts sending heartbeat signals to a remote hub server.

Expand Down Expand Up @@ -396,6 +376,4 @@ def __init__(self, page_size=None, public=None, headers=None):
headers (dict, optional): Headers to be included in API requests.
"""
base_endpoint = "models"
if public:
base_endpoint = f"public/{base_endpoint}"
super().__init__(base_endpoint, "model", page_size, headers)
super().__init__(base_endpoint, "model", page_size, public, headers)
4 changes: 1 addition & 3 deletions hub_sdk/modules/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,4 @@ def __init__(self, page_size: int = None, public: bool = None, headers: dict = N
headers (dict, optional): Headers to be included in API requests.
"""
base_endpoint = "projects"
if public:
base_endpoint = f"public/{base_endpoint}"
super().__init__(base_endpoint, "project", page_size, headers)
super().__init__(base_endpoint, "project", page_size, public, headers)
Loading