-
Couldn't load subscription status.
- Fork 2
feat(Hugging face): Vision methods - image classification / image segmentation / object detection #61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
feat(Hugging face): Vision methods - image classification / image segmentation / object detection #61
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
5da6829
feat: added vision draft methods for hugging face
kevdevg 6bbb847
fix: remove env var
kevdevg 5913d7e
feat: hugging face vision
kevdevg 872b52b
fix: use file path as parameter in examples
kevdevg 3a8d3f1
fix: vision pillow read bytes
kevdevg File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
106 changes: 106 additions & 0 deletions
106
scope3ai/tracers/huggingface/vision/image_classification.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| import io | ||
| import time | ||
| from dataclasses import dataclass | ||
| from typing import Any, Callable, Optional, Union, List | ||
| from PIL import Image | ||
| from aiohttp import ClientResponse | ||
| from huggingface_hub import ( | ||
| InferenceClient, | ||
| AsyncInferenceClient, | ||
| ImageClassificationOutputElement, | ||
| ) # type: ignore[import-untyped] | ||
| from requests import Response | ||
|
|
||
| from scope3ai.api.types import Scope3AIContext, Model, ImpactRow | ||
| from scope3ai.api.typesgen import Task | ||
| from scope3ai.constants import PROVIDERS | ||
| from scope3ai.lib import Scope3AI | ||
| from scope3ai.response_interceptor.aiohttp_interceptor import aiohttp_response_capture | ||
| from scope3ai.response_interceptor.requests_interceptor import requests_response_capture | ||
|
|
||
| PROVIDER = PROVIDERS.HUGGINGFACE_HUB.value | ||
| HUGGING_FACE_IMAGE_CLASSIFICATION_TASK = "image-classification" | ||
|
|
||
|
|
||
| @dataclass | ||
| class ImageClassificationOutput: | ||
| elements: List[ImageClassificationOutputElement] = None | ||
| scope3ai: Optional[Scope3AIContext] = None | ||
|
|
||
|
|
||
| def _hugging_face_image_classification_wrapper( | ||
| timer_start: Any, | ||
| model: Any, | ||
| response: Any, | ||
| http_response: Union[ClientResponse, Response], | ||
| args: Any, | ||
| kwargs: Any, | ||
| ) -> ImageClassificationOutput: | ||
| input_tokens = 0 | ||
| if http_response: | ||
| compute_time = http_response.headers.get("x-compute-time") | ||
| else: | ||
| compute_time = time.perf_counter() - timer_start | ||
| try: | ||
| image_param = args[0] if len(args) > 0 else kwargs["image"] | ||
| if type(image_param) is str: | ||
| input_image = Image.open(args[0] if len(args) > 0 else kwargs["image"]) | ||
| else: | ||
| input_image = Image.open(io.BytesIO(image_param)) | ||
| input_width, input_height = input_image.size | ||
| input_images = [ | ||
| ("{width}x{height}".format(width=input_width, height=input_height)) | ||
| ] | ||
| except Exception: | ||
| pass | ||
| scope3_row = ImpactRow( | ||
| model=Model(id=model), | ||
| input_tokens=input_tokens, | ||
| task=Task.image_classification, | ||
| output_images=[], # No images to output in classification | ||
| request_duration_ms=float(compute_time) * 1000, | ||
| managed_service_id=PROVIDER, | ||
| input_images=input_images, | ||
| ) | ||
|
|
||
| scope3_ctx = Scope3AI.get_instance().submit_impact(scope3_row) | ||
| result = ImageClassificationOutput() | ||
| result.elements = response | ||
| result.scope3ai = scope3_ctx | ||
| return result | ||
|
|
||
|
|
||
| def huggingface_image_classification_wrapper( | ||
| wrapped: Callable, instance: InferenceClient, args: Any, kwargs: Any | ||
| ) -> ImageClassificationOutput: | ||
| timer_start = time.perf_counter() | ||
| http_response: Response | None = None | ||
| with requests_response_capture() as responses: | ||
| response = wrapped(*args, **kwargs) | ||
| http_responses = responses.get() | ||
| if len(http_responses) > 0: | ||
| http_response = http_responses[-1] | ||
| model = kwargs.get("model") or instance.get_recommended_model( | ||
| HUGGING_FACE_IMAGE_CLASSIFICATION_TASK | ||
| ) | ||
| return _hugging_face_image_classification_wrapper( | ||
| timer_start, model, response, http_response, args, kwargs | ||
| ) | ||
|
|
||
|
|
||
| async def huggingface_image_classification_wrapper_async( | ||
| wrapped: Callable, instance: AsyncInferenceClient, args: Any, kwargs: Any | ||
| ) -> ImageClassificationOutput: | ||
| timer_start = time.perf_counter() | ||
| http_response: ClientResponse | None = None | ||
| with aiohttp_response_capture() as responses: | ||
| response = await wrapped(*args, **kwargs) | ||
| http_responses = responses.get() | ||
| if len(http_responses) > 0: | ||
| http_response = http_responses[-1] | ||
| model = kwargs.get("model") or instance.get_recommended_model( | ||
| HUGGING_FACE_IMAGE_CLASSIFICATION_TASK | ||
| ) | ||
| return _hugging_face_image_classification_wrapper( | ||
| timer_start, model, response, http_response, args, kwargs | ||
| ) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.