Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ uv add scope3ai
| OpenAI | ✅ | | | | |
| Huggingface | ✅ | ✅ | ✅ | ✅ | ✅ |
| LiteLLM | ✅ | | | | |
| MistralAi | ✅ | | | | |

Roadmap:
- Google
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ description = "Track the environmental impact of your use of AI"
readme = "README.md"
requires-python = ">=3.9"
dependencies = [
"httpx>=0.28.1",
"httpx>=0.27.2",
"litellm>=1.53.3",
"pillow>=11.0.0",
"pydantic>=2.10.3",
Expand All @@ -29,9 +29,6 @@ litellm = [
"litellm>=1.53.3",
"rapidfuzz>=3.10.1",
]
mistralai = [
"mistralai>=0.4.2",
]
anthropic = [
"anthropic>=0.40.0",
]
Expand All @@ -44,6 +41,9 @@ huggingface-hub = [
"minijinja>=2.5.0",
"tiktoken>=0.8.0",
]
mistralai = [
"mistralai>=1.2.5"
]
google-generativeai = [
"google-generativeai>=0.8.3",
]
Expand Down
9 changes: 9 additions & 0 deletions scope3ai/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,21 @@ def init_litellm_instrumentor() -> None:
instrumentor.instrument()


def init_mistral_v1_instrumentor() -> None:
if importlib.util.find_spec("mistralai") is not None:
from scope3ai.tracers.mistrarlai_v1.instrument import MistralAIInstrumentor

instrumentor = MistralAIInstrumentor()
instrumentor.instrument()


_INSTRUMENTS = {
"anthropic": init_anthropic_instrumentor,
"cohere": init_cohere_instrumentor,
"openai": init_openai_instrumentor,
"huggingface_hub": init_huggingface_hub_instrumentor,
"litellm": init_litellm_instrumentor,
"mistralai": init_mistral_v1_instrumentor,
}


Expand Down
131 changes: 131 additions & 0 deletions scope3ai/tracers/mistrarlai_v1/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import time
from collections.abc import AsyncGenerator, Iterable
from typing import Any, Callable, Optional

from mistralai import Mistral
from mistralai.models import ChatCompletionResponse as _ChatCompletionResponse
from mistralai.models import CompletionChunk as _CompletionChunk
from mistralai.models import CompletionEvent

from scope3ai import Scope3AI
from scope3ai.api.types import Scope3AIContext
from scope3ai.api.typesgen import ImpactRow, Model

PROVIDER = "mistralai"


class ChatCompletionResponse(_ChatCompletionResponse):
scope3ai: Optional[Scope3AIContext] = None


class CompletionChunk(_CompletionChunk):
scope3ai: Optional[Scope3AIContext] = None


def mistralai_v1_chat_wrapper(
wrapped: Callable,
instance: Mistral,
args: Any,
kwargs: Any,
) -> ChatCompletionResponse:
timer_start = time.perf_counter()
response = wrapped(*args, **kwargs)
request_latency = time.perf_counter() - timer_start
scope3_row = ImpactRow(
model=Model(id=response.model),
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
request_duration_ms=request_latency * 1000,
managed_service_id=PROVIDER,
)
scope3ai_ctx = Scope3AI.get_instance().submit_impact(scope3_row)
chat = ChatCompletionResponse(**response.model_dump())
chat.scope3ai = scope3ai_ctx
return chat


def mistralai_v1_chat_wrapper_stream(
wrapped: Callable,
instance: Mistral,
args: Any,
kwargs: Any,
) -> Iterable[CompletionEvent]:
timer_start = time.perf_counter()
stream = wrapped(*args, **kwargs)
token_count = 0
for i, chunk in enumerate(stream):
if i > 0 and chunk.data.choices[0].finish_reason is None:
token_count += 1
model_name = chunk.data.model
if chunk.data:
request_latency = time.perf_counter() - timer_start
scope3_row = ImpactRow(
model=Model(id=model_name),
input_tokens=token_count,
output_tokens=chunk.data.usage.completion_tokens
if chunk.data.usage
else None,
request_duration_ms=request_latency * 1000,
managed_service_id=PROVIDER,
)
scope3ai_ctx = Scope3AI.get_instance().submit_impact(scope3_row)
chunk.data = CompletionChunk(
**chunk.data.model_dump(), scope3ai=scope3ai_ctx
)
yield chunk


async def mistralai_v1_async_chat_wrapper(
wrapped: Callable,
instance: Mistral,
args: Any,
kwargs: Any,
) -> ChatCompletionResponse:
timer_start = time.perf_counter()
response = await wrapped(*args, **kwargs)
request_latency = time.perf_counter() - timer_start
scope3_row = ImpactRow(
model=Model(id=response.model),
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
request_duration_ms=request_latency * 1000,
managed_service_id=PROVIDER,
)
scope3ai_ctx = Scope3AI.get_instance().submit_impact(scope3_row)
chat = ChatCompletionResponse(**response.model_dump())
chat.scope3ai = scope3ai_ctx
return chat


async def _generator(
stream: AsyncGenerator[CompletionEvent, None], timer_start: float
) -> AsyncGenerator[CompletionEvent, None]:
token_count = 0
async for chunk in stream:
if chunk.data.usage is not None:
token_count = chunk.data.usage.completion_tokens
request_latency = time.perf_counter() - timer_start
model_name = chunk.data.model
scope3_row = ImpactRow(
model=Model(id=model_name),
input_tokens=token_count,
output_tokens=chunk.data.usage.completion_tokens
if chunk.data.usage
else None,
request_duration_ms=request_latency * 1000,
managed_service_id=PROVIDER,
)
scope3ai_ctx = Scope3AI.get_instance().submit_impact(scope3_row)
chunk.data = CompletionChunk(**chunk.data.model_dump(), scope3ai=scope3ai_ctx)
yield chunk


async def mistralai_v1_async_chat_wrapper_stream(
wrapped: Callable,
instance: Mistral,
args: Any,
kwargs: Any,
) -> AsyncGenerator[CompletionEvent, None]:
timer_start = time.perf_counter()
stream = await wrapped(*args, **kwargs)
return _generator(stream, timer_start)
40 changes: 40 additions & 0 deletions scope3ai/tracers/mistrarlai_v1/instrument.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from wrapt import wrap_function_wrapper # type: ignore[import-untyped]

from scope3ai.tracers.mistrarlai_v1.chat import (
mistralai_v1_chat_wrapper,
mistralai_v1_async_chat_wrapper,
mistralai_v1_chat_wrapper_stream,
mistralai_v1_async_chat_wrapper_stream,
)


class MistralAIInstrumentor:
def __init__(self) -> None:
self.wrapped_methods = [
{
"module": "mistralai.chat",
"name": "Chat.complete",
"wrapper": mistralai_v1_chat_wrapper,
},
{
"module": "mistralai.chat",
"name": "Chat.complete_async",
"wrapper": mistralai_v1_async_chat_wrapper,
},
{
"module": "mistralai.chat",
"name": "Chat.stream",
"wrapper": mistralai_v1_chat_wrapper_stream,
},
{
"module": "mistralai.chat",
"name": "Chat.stream_async",
"wrapper": mistralai_v1_async_chat_wrapper_stream,
},
]

def instrument(self) -> None:
for wrapper in self.wrapped_methods:
wrap_function_wrapper(
wrapper["module"], wrapper["name"], wrapper["wrapper"]
)
79 changes: 79 additions & 0 deletions tests/cassettes/test_mistralai_async_chat.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
interactions:
- request:
body: '{"messages":[{"content":"Hello World!","role":"user"}],"model":"mistral-tiny","safe_prompt":false,"stream":false,"top_p":1.0}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
authorization:
- DUMMY
connection:
- keep-alive
content-length:
- '125'
content-type:
- application/json
host:
- api.mistral.ai
user-agent:
- mistral-client-python/1.2.5
method: POST
uri: https://api.mistral.ai/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAAzzQQW7CMBCF4au4b9ONiwIEknjTLZyhqtDgTBpTxxPFg1qEcvcqqmD762n0ae4I
LRwa6qqq2NTtvi7L3fbc1Pumq6nzBXclNTUs5Hxhr3DwPenKyzBG1iAJFn5iUm7h1tV2V1fromgs
Bmk5wmEIWSeKbxrSbdn2EjxnuI87Qmr5F66wGDhn+mK4OyaJDAfKOWSlpLBQkXjyFGOGS9cYLbwk
5bRoDhyjGO154hdz1NdsUvBsVMzArOYm15U5yI/xlMzR/F9dqlFp6faO2aILKeT+NDFlSXDIKiPm
T4vrwzROMox6UvnmlOH2C0kpPsOmXEiPjzzzup7nPwAAAP//AwBG0Z1HYQEAAA==
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8fbf86f0889ada8f-MIA
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Fri, 03 Jan 2025 02:23:29 GMT
Server:
- cloudflare
Transfer-Encoding:
- chunked
access-control-allow-origin:
- '*'
alt-svc:
- h3=":443"; ma=86400
ratelimitbysize-limit:
- '500000'
ratelimitbysize-query-cost:
- '32003'
ratelimitbysize-remaining:
- '467997'
ratelimitbysize-reset:
- '31'
x-envoy-upstream-service-time:
- '141'
x-kong-proxy-latency:
- '1'
x-kong-request-id:
- 564b303d23cf7f55aa9c761e5fa8216e
x-kong-upstream-latency:
- '142'
x-ratelimitbysize-limit-minute:
- '500000'
x-ratelimitbysize-limit-month:
- '1000000000'
x-ratelimitbysize-remaining-minute:
- '467997'
x-ratelimitbysize-remaining-month:
- '999903991'
status:
code: 200
message: OK
version: 1
Loading
Loading