Skip to content

Commit 665877b

Browse files
authored
feat(hindsight-litellm): support streaming on wrappers (#296)
1 parent a43d208 commit 665877b

File tree

3 files changed

+375
-25
lines changed

3 files changed

+375
-25
lines changed

hindsight-integrations/litellm/hindsight_litellm/wrappers.py

Lines changed: 195 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import threading
1313
from dataclasses import dataclass, fields
1414
from datetime import datetime, timezone
15-
from typing import Any, Dict, List, Optional
15+
from typing import Any, Dict, Iterator, List, Optional
1616

1717
from .config import (
1818
DEFAULT_BANK_ID,
@@ -678,6 +678,149 @@ async def aretain(
678678
)
679679

680680

681+
class _StreamWrapper:
682+
"""Wrapper for OpenAI stream that collects content and stores conversation when done."""
683+
684+
def __init__(
685+
self,
686+
stream: Any,
687+
user_query: str,
688+
model: str,
689+
wrapper: "HindsightOpenAI",
690+
settings: HindsightCallSettings,
691+
):
692+
self._stream = stream
693+
self._user_query = user_query
694+
self._model = model
695+
self._wrapper = wrapper
696+
self._settings = settings
697+
self._collected_content: List[str] = []
698+
self._finished = False
699+
700+
def __iter__(self) -> Iterator[Any]:
701+
return self
702+
703+
def __next__(self) -> Any:
704+
try:
705+
chunk = next(self._stream)
706+
# Collect content from the chunk
707+
if hasattr(chunk, "choices") and chunk.choices:
708+
delta = chunk.choices[0].delta
709+
if hasattr(delta, "content") and delta.content:
710+
self._collected_content.append(delta.content)
711+
return chunk
712+
except StopIteration:
713+
# Stream exhausted - store conversation if we collected content
714+
self._store_if_needed()
715+
raise
716+
717+
def __enter__(self):
718+
return self
719+
720+
def __exit__(self, exc_type, exc_val, exc_tb):
721+
self._store_if_needed()
722+
if hasattr(self._stream, "__exit__"):
723+
return self._stream.__exit__(exc_type, exc_val, exc_tb)
724+
725+
def _store_if_needed(self):
726+
"""Store the collected conversation if not already done."""
727+
if self._finished or not self._settings.store_conversations:
728+
return
729+
730+
self._finished = True
731+
if self._collected_content:
732+
assistant_output = "".join(self._collected_content)
733+
if assistant_output:
734+
try:
735+
self._wrapper._store_conversation(
736+
self._user_query, assistant_output, self._model, self._settings
737+
)
738+
except Exception as e:
739+
if self._settings.verbose:
740+
logger.warning(f"Failed to store streamed conversation: {e}")
741+
742+
def close(self):
743+
"""Close the underlying stream if it has a close method."""
744+
self._store_if_needed()
745+
if hasattr(self._stream, "close"):
746+
self._stream.close()
747+
748+
def __getattr__(self, name: str) -> Any:
749+
"""Proxy other attributes to the underlying stream."""
750+
return getattr(self._stream, name)
751+
752+
753+
class _AnthropicStreamWrapper:
754+
"""Wrapper for Anthropic stream that collects content and stores conversation when done."""
755+
756+
def __init__(
757+
self,
758+
stream: Any,
759+
user_query: str,
760+
model: str,
761+
wrapper: "HindsightAnthropic",
762+
settings: HindsightCallSettings,
763+
):
764+
self._stream = stream
765+
self._user_query = user_query
766+
self._model = model
767+
self._wrapper = wrapper
768+
self._settings = settings
769+
self._collected_content: List[str] = []
770+
self._finished = False
771+
772+
def __iter__(self) -> Iterator[Any]:
773+
return self
774+
775+
def __next__(self) -> Any:
776+
try:
777+
chunk = next(self._stream)
778+
# Collect content from the chunk
779+
if hasattr(chunk, "type") and chunk.type == "content_block_delta":
780+
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
781+
self._collected_content.append(chunk.delta.text)
782+
return chunk
783+
except StopIteration:
784+
# Stream exhausted - store conversation if we collected content
785+
self._store_if_needed()
786+
raise
787+
788+
def __enter__(self):
789+
return self
790+
791+
def __exit__(self, exc_type, exc_val, exc_tb):
792+
self._store_if_needed()
793+
if hasattr(self._stream, "__exit__"):
794+
return self._stream.__exit__(exc_type, exc_val, exc_tb)
795+
796+
def _store_if_needed(self):
797+
"""Store the collected conversation if not already done."""
798+
if self._finished or not self._settings.store_conversations:
799+
return
800+
801+
self._finished = True
802+
if self._collected_content:
803+
assistant_output = "".join(self._collected_content)
804+
if assistant_output:
805+
try:
806+
self._wrapper._store_conversation(
807+
self._user_query, assistant_output, self._model, self._settings
808+
)
809+
except Exception as e:
810+
if self._settings.verbose:
811+
logger.warning(f"Failed to store streamed conversation: {e}")
812+
813+
def close(self):
814+
"""Close the underlying stream if it has a close method."""
815+
self._store_if_needed()
816+
if hasattr(self._stream, "close"):
817+
self._stream.close()
818+
819+
def __getattr__(self, name: str) -> Any:
820+
"""Proxy other attributes to the underlying stream."""
821+
return getattr(self._stream, name)
822+
823+
681824
class HindsightOpenAI:
682825
"""Wrapper for OpenAI client with Hindsight memory integration.
683826
@@ -1080,16 +1223,30 @@ def create(self, **kwargs) -> Any:
10801223
# Make the actual API call
10811224
response = self._wrapper._client.chat.completions.create(**openai_kwargs)
10821225

1083-
# Store conversation
1084-
if user_query and settings.store_conversations:
1085-
if response.choices and response.choices[0].message:
1086-
assistant_output = response.choices[0].message.content or ""
1087-
if assistant_output:
1088-
self._wrapper._store_conversation(
1089-
user_query, assistant_output, model, settings
1090-
)
1091-
1092-
return response
1226+
# Handle streaming vs non-streaming responses
1227+
is_streaming = openai_kwargs.get("stream", False)
1228+
if is_streaming:
1229+
# Wrap the stream to collect content and store conversation when done
1230+
if user_query and settings.store_conversations:
1231+
return _StreamWrapper(
1232+
stream=response,
1233+
user_query=user_query,
1234+
model=model,
1235+
wrapper=self._wrapper,
1236+
settings=settings,
1237+
)
1238+
else:
1239+
return response
1240+
else:
1241+
# Non-streaming: store conversation immediately
1242+
if user_query and settings.store_conversations:
1243+
if response.choices and response.choices[0].message:
1244+
assistant_output = response.choices[0].message.content or ""
1245+
if assistant_output:
1246+
self._wrapper._store_conversation(
1247+
user_query, assistant_output, model, settings
1248+
)
1249+
return response
10931250

10941251

10951252
class HindsightAnthropic:
@@ -1477,19 +1634,33 @@ def create(self, **kwargs) -> Any:
14771634
# Make the actual API call
14781635
response = self._wrapper._client.messages.create(**anthropic_kwargs)
14791636

1480-
# Store conversation
1481-
if user_query and settings.store_conversations:
1482-
if response.content:
1483-
assistant_output = ""
1484-
for block in response.content:
1485-
if hasattr(block, "text"):
1486-
assistant_output += block.text
1487-
if assistant_output:
1488-
self._wrapper._store_conversation(
1489-
user_query, assistant_output, model, settings
1490-
)
1491-
1492-
return response
1637+
# Handle streaming vs non-streaming responses
1638+
is_streaming = anthropic_kwargs.get("stream", False)
1639+
if is_streaming:
1640+
# Wrap the stream to collect content and store conversation when done
1641+
if user_query and settings.store_conversations:
1642+
return _AnthropicStreamWrapper(
1643+
stream=response,
1644+
user_query=user_query,
1645+
model=model,
1646+
wrapper=self._wrapper,
1647+
settings=settings,
1648+
)
1649+
else:
1650+
return response
1651+
else:
1652+
# Non-streaming: store conversation immediately
1653+
if user_query and settings.store_conversations:
1654+
if response.content:
1655+
assistant_output = ""
1656+
for block in response.content:
1657+
if hasattr(block, "text"):
1658+
assistant_output += block.text
1659+
if assistant_output:
1660+
self._wrapper._store_conversation(
1661+
user_query, assistant_output, model, settings
1662+
)
1663+
return response
14931664

14941665

14951666
def wrap_openai(

0 commit comments

Comments
 (0)