Skip to content
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
"""OpenTelemetry Vertex AI instrumentation"""

import logging
import os
import types
from typing import Collection
from opentelemetry.instrumentation.vertexai.config import Config
from opentelemetry.instrumentation.vertexai.utils import dont_throw
from wrapt import wrap_function_wrapper

from opentelemetry import context as context_api
from opentelemetry.trace import get_tracer, SpanKind
from opentelemetry.trace.status import Status, StatusCode

from opentelemetry._events import get_event_logger
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY, unwrap

from opentelemetry.instrumentation.vertexai.config import Config
from opentelemetry.instrumentation.vertexai.event_emitter import (
emit_prompt_events,
emit_response_events,
)
from opentelemetry.instrumentation.vertexai.span_utils import (
set_input_attributes,
set_model_input_attributes,
set_model_response_attributes,
set_response_attributes,
)
from opentelemetry.instrumentation.vertexai.utils import dont_throw, should_emit_events
from opentelemetry.instrumentation.vertexai.version import __version__
from opentelemetry.semconv_ai import (
SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY,
SpanAttributes,
LLMRequestTypeValues,
SpanAttributes,
)
from opentelemetry.instrumentation.vertexai.version import __version__
from opentelemetry.trace import SpanKind, get_tracer
from opentelemetry.trace.status import Status, StatusCode
from wrapt import wrap_function_wrapper

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -114,12 +122,6 @@
]


def should_send_prompts():
return (
os.getenv("TRACELOOP_TRACE_CONTENT") or "true"
).lower() == "true" or context_api.get_value("override_enable_content_tracing")


def is_streaming_response(response):
return isinstance(response, types.GeneratorType)

Expand All @@ -128,81 +130,18 @@ def is_async_streaming_response(response):
return isinstance(response, types.AsyncGeneratorType)


def _set_span_attribute(span, name, value):
if value is not None:
if value != "":
span.set_attribute(name, value)
return


def _set_input_attributes(span, args, kwargs, llm_model):
if should_send_prompts() and args is not None and len(args) > 0:
prompt = ""
for arg in args:
if isinstance(arg, str):
prompt = f"{prompt}{arg}\n"
elif isinstance(arg, list):
for subarg in arg:
prompt = f"{prompt}{subarg}\n"

_set_span_attribute(
span,
f"{SpanAttributes.LLM_PROMPTS}.0.user",
prompt,
)

_set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, llm_model)
_set_span_attribute(
span, f"{SpanAttributes.LLM_PROMPTS}.0.user", kwargs.get("prompt")
)
_set_span_attribute(
span, SpanAttributes.LLM_REQUEST_TEMPERATURE, kwargs.get("temperature")
)
_set_span_attribute(
span, SpanAttributes.LLM_REQUEST_MAX_TOKENS, kwargs.get("max_output_tokens")
)
_set_span_attribute(span, SpanAttributes.LLM_REQUEST_TOP_P, kwargs.get("top_p"))
_set_span_attribute(span, SpanAttributes.LLM_TOP_K, kwargs.get("top_k"))
_set_span_attribute(
span, SpanAttributes.LLM_PRESENCE_PENALTY, kwargs.get("presence_penalty")
)
_set_span_attribute(
span, SpanAttributes.LLM_FREQUENCY_PENALTY, kwargs.get("frequency_penalty")
)

return


@dont_throw
def _set_response_attributes(span, llm_model, generation_text, token_usage):
_set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, llm_model)

if token_usage:
_set_span_attribute(
span,
SpanAttributes.LLM_USAGE_TOTAL_TOKENS,
token_usage.total_token_count,
)
_set_span_attribute(
span,
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
token_usage.candidates_token_count,
)
_set_span_attribute(
span,
SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
token_usage.prompt_token_count,
)

_set_span_attribute(span, f"{SpanAttributes.LLM_COMPLETIONS}.0.role", "assistant")
_set_span_attribute(
span,
f"{SpanAttributes.LLM_COMPLETIONS}.0.content",
generation_text,
)
def handle_streaming_response(span, event_logger, llm_model, response, token_usage):
set_model_response_attributes(span, llm_model, token_usage)
if should_emit_events():
emit_response_events(response, event_logger)
else:
set_response_attributes(span, llm_model, response)
if span.is_recording():
span.set_status(Status(StatusCode.OK))


def _build_from_streaming_response(span, response, llm_model):
def _build_from_streaming_response(span, event_logger, response, llm_model):
complete_response = ""
token_usage = None
for item in response:
Expand All @@ -213,13 +152,15 @@ def _build_from_streaming_response(span, response, llm_model):

yield item_to_yield

_set_response_attributes(span, llm_model, complete_response, token_usage)
handle_streaming_response(
span, event_logger, llm_model, complete_response, token_usage
)

span.set_status(Status(StatusCode.OK))
span.end()


async def _abuild_from_streaming_response(span, response, llm_model):
async def _abuild_from_streaming_response(span, event_logger, response, llm_model):
complete_response = ""
token_usage = None
async for item in response:
Expand All @@ -230,42 +171,47 @@ async def _abuild_from_streaming_response(span, response, llm_model):

yield item_to_yield

_set_response_attributes(span, llm_model, complete_response, token_usage)
handle_streaming_response(span, event_logger, llm_model, response, token_usage)

span.set_status(Status(StatusCode.OK))
span.end()


@dont_throw
def _handle_request(span, args, kwargs, llm_model):
if span.is_recording():
_set_input_attributes(span, args, kwargs, llm_model)


@dont_throw
def _handle_response(span, response, llm_model):
if span.is_recording():
_set_response_attributes(
span, llm_model, response.candidates[0].text, response.usage_metadata
def _handle_request(span, event_logger, args, kwargs, llm_model):
set_model_input_attributes(span, kwargs, llm_model)
if should_emit_events():
emit_prompt_events(args, event_logger)
else:
set_input_attributes(span, args)


def _handle_response(span, event_logger, response, llm_model):
set_model_response_attributes(span, llm_model, response.usage_metadata)
if should_emit_events():
emit_response_events(response, event_logger)
else:
set_response_attributes(
span, llm_model, response.candidates[0].text if response.candidates else ""
)

if span.is_recording():
span.set_status(Status(StatusCode.OK))


def _with_tracer_wrapper(func):
"""Helper for providing tracer for wrapper functions."""

def _with_tracer(tracer, to_wrap):
def _with_tracer(tracer, event_logger, to_wrap):
def wrapper(wrapped, instance, args, kwargs):
return func(tracer, to_wrap, wrapped, instance, args, kwargs)
return func(tracer, event_logger, to_wrap, wrapped, instance, args, kwargs)

return wrapper

return _with_tracer


@_with_tracer_wrapper
async def _awrap(tracer, to_wrap, wrapped, instance, args, kwargs):
async def _awrap(tracer, event_logger, to_wrap, wrapped, instance, args, kwargs):
"""Instruments and calls every function defined in TO_WRAP."""
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY) or context_api.get_value(
SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY
Expand All @@ -288,24 +234,28 @@ async def _awrap(tracer, to_wrap, wrapped, instance, args, kwargs):
},
)

_handle_request(span, args, kwargs, llm_model)
_handle_request(span, event_logger, args, kwargs, llm_model)

response = await wrapped(*args, **kwargs)

if response:
if is_streaming_response(response):
return _build_from_streaming_response(span, response, llm_model)
return _build_from_streaming_response(
span, event_logger, response, llm_model
)
elif is_async_streaming_response(response):
return _abuild_from_streaming_response(span, response, llm_model)
return _abuild_from_streaming_response(
span, event_logger, response, llm_model
)
else:
_handle_response(span, response, llm_model)
_handle_response(span, event_logger, response, llm_model)

span.end()
return response


@_with_tracer_wrapper
def _wrap(tracer, to_wrap, wrapped, instance, args, kwargs):
def _wrap(tracer, event_logger, to_wrap, wrapped, instance, args, kwargs):
"""Instruments and calls every function defined in TO_WRAP."""
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY) or context_api.get_value(
SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY
Expand All @@ -328,17 +278,21 @@ def _wrap(tracer, to_wrap, wrapped, instance, args, kwargs):
},
)

_handle_request(span, args, kwargs, llm_model)
_handle_request(span, event_logger, args, kwargs, llm_model)

response = wrapped(*args, **kwargs)

if response:
if is_streaming_response(response):
return _build_from_streaming_response(span, response, llm_model)
return _build_from_streaming_response(
span, event_logger, response, llm_model
)
elif is_async_streaming_response(response):
return _abuild_from_streaming_response(span, response, llm_model)
return _abuild_from_streaming_response(
span, event_logger, response, llm_model
)
else:
_handle_response(span, response, llm_model)
_handle_response(span, event_logger, response, llm_model)

span.end()
return response
Expand All @@ -347,16 +301,28 @@ def _wrap(tracer, to_wrap, wrapped, instance, args, kwargs):
class VertexAIInstrumentor(BaseInstrumentor):
"""An instrumentor for VertextAI's client library."""

def __init__(self, exception_logger=None):
def __init__(self, exception_logger=None, use_legacy_attributes=True):
super().__init__()
Config.exception_logger = exception_logger
Config.use_legacy_attributes = use_legacy_attributes

def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs):
tracer_provider = kwargs.get("tracer_provider")
tracer = get_tracer(__name__, __version__, tracer_provider)

event_logger = None

if should_emit_events():
event_logger_provider = kwargs.get("event_logger_provider")
event_logger = get_event_logger(
__name__,
__version__,
event_logger_provider=event_logger_provider,
)

for wrapped_method in WRAPPED_METHODS:
wrap_package = wrapped_method.get("package")
wrap_object = wrapped_method.get("object")
Expand All @@ -366,9 +332,9 @@ def _instrument(self, **kwargs):
wrap_package,
f"{wrap_object}.{wrap_method}",
(
_awrap(tracer, wrapped_method)
_awrap(tracer, event_logger, wrapped_method)
if wrapped_method.get("is_async")
else _wrap(tracer, wrapped_method)
else _wrap(tracer, event_logger, wrapped_method)
),
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
class Config:
exception_logger = None
use_legacy_attributes = True
Loading