Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions scope3ai/tracers/huggingface/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,25 @@
huggingface_chat_wrapper,
huggingface_async_chat_wrapper,
)
from scope3ai.tracers.huggingface.text_to_image import huggingface_text_to_image_wrapper
from scope3ai.tracers.huggingface.text_to_speech import (
huggingface_text_to_speech_wrapper,
)
from scope3ai.tracers.huggingface.speech_to_text import (
huggingface_automatic_recognition_output_wrapper,
)
from scope3ai.tracers.huggingface.text_to_image import (
huggingface_text_to_image_wrapper,
huggingface_text_to_image_wrapper_async,
)
from scope3ai.tracers.huggingface.text_to_speech import (
huggingface_text_to_speech_wrapper,
)
from scope3ai.tracers.huggingface.translation import (
huggingface_translation_wrapper_non_stream,
)
from .utils import hf_raise_for_status_enabled, hf_raise_for_status_wrapper
from .utils import (
hf_raise_for_status_enabled,
hf_raise_for_status_wrapper,
hf_async_raise_for_status_enabled,
get_client_session_async_wrapper,
)


class HuggingfaceInstrumentor:
Expand Down Expand Up @@ -56,6 +64,17 @@ def __init__(self) -> None:
"wrapper": hf_raise_for_status_wrapper,
"enabled": hf_raise_for_status_enabled,
},
{
"module": "huggingface_hub.inference._generated._async_client",
"name": "AsyncInferenceClient._get_client_session",
"wrapper": get_client_session_async_wrapper,
"enable": hf_async_raise_for_status_enabled,
},
{
"module": "huggingface_hub.inference._generated._async_client",
"name": "AsyncInferenceClient.text_to_image",
"wrapper": huggingface_text_to_image_wrapper_async,
},
]

def instrument(self) -> None:
Expand Down
45 changes: 43 additions & 2 deletions scope3ai/tracers/huggingface/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
from dataclasses import dataclass
from typing import Any, Callable, Optional

from huggingface_hub import InferenceClient # type: ignore[import-untyped]
from huggingface_hub import InferenceClient, AsyncInferenceClient # type: ignore[import-untyped]
from huggingface_hub import TextToImageOutput as _TextToImageOutput

from scope3ai.api.types import Scope3AIContext, Model, ImpactRow
from scope3ai.api.typesgen import Task
from scope3ai.lib import Scope3AI
from scope3ai.tracers.huggingface.utils import hf_raise_for_status_capture
from scope3ai.tracers.huggingface.utils import (
hf_raise_for_status_capture,
hf_async_raise_for_status_capture,
)

PROVIDER = "huggingface_hub"

Expand Down Expand Up @@ -48,7 +51,45 @@ def huggingface_text_to_image_wrapper_non_stream(
return result


async def huggingface_text_to_image_wrapper_async_non_stream(
wrapped: Callable, instance: AsyncInferenceClient, args: Any, kwargs: Any
) -> TextToImageOutput:
with hf_async_raise_for_status_capture() as capture_response:
response = await wrapped(*args, **kwargs)
http_response = capture_response.get()
model = kwargs.get("model") or instance.get_recommended_model("text-to-speech")
encoder = tiktoken.get_encoding("cl100k_base")
if len(args) > 0:
prompt = args[0]
else:
prompt = kwargs["prompt"]
compute_time = http_response.headers.get("x-compute-time")
input_tokens = len(encoder.encode(prompt))
width, height = response.size
scope3_row = ImpactRow(
model=Model(id=model),
input_tokens=input_tokens,
task=Task.text_to_image,
output_images=["{width}x{height}".format(width=width, height=height)],
request_duration_ms=float(compute_time) * 1000,
managed_service_id=PROVIDER,
)

scope3_ctx = Scope3AI.get_instance().submit_impact(scope3_row)
result = TextToImageOutput(response)
result.scope3ai = scope3_ctx
return result


def huggingface_text_to_image_wrapper(
wrapped: Callable, instance: InferenceClient, args: Any, kwargs: Any
) -> TextToImageOutput:
return huggingface_text_to_image_wrapper_non_stream(wrapped, instance, args, kwargs)


async def huggingface_text_to_image_wrapper_async(
wrapped: Callable, instance: AsyncInferenceClient, args: Any, kwargs: Any
) -> TextToImageOutput:
return await huggingface_text_to_image_wrapper_async_non_stream(
wrapped, instance, args, kwargs
)
39 changes: 39 additions & 0 deletions scope3ai/tracers/huggingface/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,21 @@
HFRS_VALUE = contextvars.ContextVar(f"{HFRS_BASEKEY}__value", default=None)


HFRS_ASYNC_BASEKEY = "scope3ai__huggingface__hf_async_raise_for_status"
HFRS_ASYNC_ENABLED = contextvars.ContextVar(
f"{HFRS_ASYNC_BASEKEY}__enabled", default=None
)
HFRS_ASYNC_VALUE = contextvars.ContextVar(f"{HFRS_ASYNC_BASEKEY}__value", default=None)


def hf_raise_for_status_enabled():
return HFRS_ENABLED.get() is True


def hf_async_raise_for_status_enabled():
return HFRS_ASYNC_ENABLED.get() is True


def hf_raise_for_status_wrapper(wrapped, instance, args, kwargs):
try:
result = wrapped(*args, **kwargs)
Expand All @@ -28,3 +39,31 @@ def hf_raise_for_status_capture():
yield HFRS_VALUE
finally:
HFRS_ENABLED.set(False)


@contextlib.contextmanager
def hf_async_raise_for_status_capture():
try:
HFRS_ASYNC_VALUE.set(None)
HFRS_ASYNC_ENABLED.set(True)
yield HFRS_ASYNC_VALUE
finally:
HFRS_ASYNC_ENABLED.set(False)


def async_post_wrapper(session_post):
async def wrapped_post(*args, **kwargs):
result = await session_post(*args, **kwargs)
HFRS_ASYNC_VALUE.set(result)
return result

return wrapped_post


def get_client_session_async_wrapper(wrapped, instance, args, kwargs):
try:
result = wrapped(*args, **kwargs)
result.post = async_post_wrapper(result.post)
return result
except Exception as e:
raise e
Loading
Loading