|
12 | 12 | import threading |
13 | 13 | from dataclasses import dataclass, fields |
14 | 14 | from datetime import datetime, timezone |
15 | | -from typing import Any, Dict, List, Optional |
| 15 | +from typing import Any, Dict, Iterator, List, Optional |
16 | 16 |
|
17 | 17 | from .config import ( |
18 | 18 | DEFAULT_BANK_ID, |
@@ -678,6 +678,149 @@ async def aretain( |
678 | 678 | ) |
679 | 679 |
|
680 | 680 |
|
| 681 | +class _StreamWrapper: |
| 682 | + """Wrapper for OpenAI stream that collects content and stores conversation when done.""" |
| 683 | + |
| 684 | + def __init__( |
| 685 | + self, |
| 686 | + stream: Any, |
| 687 | + user_query: str, |
| 688 | + model: str, |
| 689 | + wrapper: "HindsightOpenAI", |
| 690 | + settings: HindsightCallSettings, |
| 691 | + ): |
| 692 | + self._stream = stream |
| 693 | + self._user_query = user_query |
| 694 | + self._model = model |
| 695 | + self._wrapper = wrapper |
| 696 | + self._settings = settings |
| 697 | + self._collected_content: List[str] = [] |
| 698 | + self._finished = False |
| 699 | + |
| 700 | + def __iter__(self) -> Iterator[Any]: |
| 701 | + return self |
| 702 | + |
| 703 | + def __next__(self) -> Any: |
| 704 | + try: |
| 705 | + chunk = next(self._stream) |
| 706 | + # Collect content from the chunk |
| 707 | + if hasattr(chunk, "choices") and chunk.choices: |
| 708 | + delta = chunk.choices[0].delta |
| 709 | + if hasattr(delta, "content") and delta.content: |
| 710 | + self._collected_content.append(delta.content) |
| 711 | + return chunk |
| 712 | + except StopIteration: |
| 713 | + # Stream exhausted - store conversation if we collected content |
| 714 | + self._store_if_needed() |
| 715 | + raise |
| 716 | + |
| 717 | + def __enter__(self): |
| 718 | + return self |
| 719 | + |
| 720 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 721 | + self._store_if_needed() |
| 722 | + if hasattr(self._stream, "__exit__"): |
| 723 | + return self._stream.__exit__(exc_type, exc_val, exc_tb) |
| 724 | + |
| 725 | + def _store_if_needed(self): |
| 726 | + """Store the collected conversation if not already done.""" |
| 727 | + if self._finished or not self._settings.store_conversations: |
| 728 | + return |
| 729 | + |
| 730 | + self._finished = True |
| 731 | + if self._collected_content: |
| 732 | + assistant_output = "".join(self._collected_content) |
| 733 | + if assistant_output: |
| 734 | + try: |
| 735 | + self._wrapper._store_conversation( |
| 736 | + self._user_query, assistant_output, self._model, self._settings |
| 737 | + ) |
| 738 | + except Exception as e: |
| 739 | + if self._settings.verbose: |
| 740 | + logger.warning(f"Failed to store streamed conversation: {e}") |
| 741 | + |
| 742 | + def close(self): |
| 743 | + """Close the underlying stream if it has a close method.""" |
| 744 | + self._store_if_needed() |
| 745 | + if hasattr(self._stream, "close"): |
| 746 | + self._stream.close() |
| 747 | + |
| 748 | + def __getattr__(self, name: str) -> Any: |
| 749 | + """Proxy other attributes to the underlying stream.""" |
| 750 | + return getattr(self._stream, name) |
| 751 | + |
| 752 | + |
| 753 | +class _AnthropicStreamWrapper: |
| 754 | + """Wrapper for Anthropic stream that collects content and stores conversation when done.""" |
| 755 | + |
| 756 | + def __init__( |
| 757 | + self, |
| 758 | + stream: Any, |
| 759 | + user_query: str, |
| 760 | + model: str, |
| 761 | + wrapper: "HindsightAnthropic", |
| 762 | + settings: HindsightCallSettings, |
| 763 | + ): |
| 764 | + self._stream = stream |
| 765 | + self._user_query = user_query |
| 766 | + self._model = model |
| 767 | + self._wrapper = wrapper |
| 768 | + self._settings = settings |
| 769 | + self._collected_content: List[str] = [] |
| 770 | + self._finished = False |
| 771 | + |
| 772 | + def __iter__(self) -> Iterator[Any]: |
| 773 | + return self |
| 774 | + |
| 775 | + def __next__(self) -> Any: |
| 776 | + try: |
| 777 | + chunk = next(self._stream) |
| 778 | + # Collect content from the chunk |
| 779 | + if hasattr(chunk, "type") and chunk.type == "content_block_delta": |
| 780 | + if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"): |
| 781 | + self._collected_content.append(chunk.delta.text) |
| 782 | + return chunk |
| 783 | + except StopIteration: |
| 784 | + # Stream exhausted - store conversation if we collected content |
| 785 | + self._store_if_needed() |
| 786 | + raise |
| 787 | + |
| 788 | + def __enter__(self): |
| 789 | + return self |
| 790 | + |
| 791 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 792 | + self._store_if_needed() |
| 793 | + if hasattr(self._stream, "__exit__"): |
| 794 | + return self._stream.__exit__(exc_type, exc_val, exc_tb) |
| 795 | + |
| 796 | + def _store_if_needed(self): |
| 797 | + """Store the collected conversation if not already done.""" |
| 798 | + if self._finished or not self._settings.store_conversations: |
| 799 | + return |
| 800 | + |
| 801 | + self._finished = True |
| 802 | + if self._collected_content: |
| 803 | + assistant_output = "".join(self._collected_content) |
| 804 | + if assistant_output: |
| 805 | + try: |
| 806 | + self._wrapper._store_conversation( |
| 807 | + self._user_query, assistant_output, self._model, self._settings |
| 808 | + ) |
| 809 | + except Exception as e: |
| 810 | + if self._settings.verbose: |
| 811 | + logger.warning(f"Failed to store streamed conversation: {e}") |
| 812 | + |
| 813 | + def close(self): |
| 814 | + """Close the underlying stream if it has a close method.""" |
| 815 | + self._store_if_needed() |
| 816 | + if hasattr(self._stream, "close"): |
| 817 | + self._stream.close() |
| 818 | + |
| 819 | + def __getattr__(self, name: str) -> Any: |
| 820 | + """Proxy other attributes to the underlying stream.""" |
| 821 | + return getattr(self._stream, name) |
| 822 | + |
| 823 | + |
681 | 824 | class HindsightOpenAI: |
682 | 825 | """Wrapper for OpenAI client with Hindsight memory integration. |
683 | 826 |
|
@@ -1080,16 +1223,30 @@ def create(self, **kwargs) -> Any: |
1080 | 1223 | # Make the actual API call |
1081 | 1224 | response = self._wrapper._client.chat.completions.create(**openai_kwargs) |
1082 | 1225 |
|
1083 | | - # Store conversation |
1084 | | - if user_query and settings.store_conversations: |
1085 | | - if response.choices and response.choices[0].message: |
1086 | | - assistant_output = response.choices[0].message.content or "" |
1087 | | - if assistant_output: |
1088 | | - self._wrapper._store_conversation( |
1089 | | - user_query, assistant_output, model, settings |
1090 | | - ) |
1091 | | - |
1092 | | - return response |
| 1226 | + # Handle streaming vs non-streaming responses |
| 1227 | + is_streaming = openai_kwargs.get("stream", False) |
| 1228 | + if is_streaming: |
| 1229 | + # Wrap the stream to collect content and store conversation when done |
| 1230 | + if user_query and settings.store_conversations: |
| 1231 | + return _StreamWrapper( |
| 1232 | + stream=response, |
| 1233 | + user_query=user_query, |
| 1234 | + model=model, |
| 1235 | + wrapper=self._wrapper, |
| 1236 | + settings=settings, |
| 1237 | + ) |
| 1238 | + else: |
| 1239 | + return response |
| 1240 | + else: |
| 1241 | + # Non-streaming: store conversation immediately |
| 1242 | + if user_query and settings.store_conversations: |
| 1243 | + if response.choices and response.choices[0].message: |
| 1244 | + assistant_output = response.choices[0].message.content or "" |
| 1245 | + if assistant_output: |
| 1246 | + self._wrapper._store_conversation( |
| 1247 | + user_query, assistant_output, model, settings |
| 1248 | + ) |
| 1249 | + return response |
1093 | 1250 |
|
1094 | 1251 |
|
1095 | 1252 | class HindsightAnthropic: |
@@ -1477,19 +1634,33 @@ def create(self, **kwargs) -> Any: |
1477 | 1634 | # Make the actual API call |
1478 | 1635 | response = self._wrapper._client.messages.create(**anthropic_kwargs) |
1479 | 1636 |
|
1480 | | - # Store conversation |
1481 | | - if user_query and settings.store_conversations: |
1482 | | - if response.content: |
1483 | | - assistant_output = "" |
1484 | | - for block in response.content: |
1485 | | - if hasattr(block, "text"): |
1486 | | - assistant_output += block.text |
1487 | | - if assistant_output: |
1488 | | - self._wrapper._store_conversation( |
1489 | | - user_query, assistant_output, model, settings |
1490 | | - ) |
1491 | | - |
1492 | | - return response |
| 1637 | + # Handle streaming vs non-streaming responses |
| 1638 | + is_streaming = anthropic_kwargs.get("stream", False) |
| 1639 | + if is_streaming: |
| 1640 | + # Wrap the stream to collect content and store conversation when done |
| 1641 | + if user_query and settings.store_conversations: |
| 1642 | + return _AnthropicStreamWrapper( |
| 1643 | + stream=response, |
| 1644 | + user_query=user_query, |
| 1645 | + model=model, |
| 1646 | + wrapper=self._wrapper, |
| 1647 | + settings=settings, |
| 1648 | + ) |
| 1649 | + else: |
| 1650 | + return response |
| 1651 | + else: |
| 1652 | + # Non-streaming: store conversation immediately |
| 1653 | + if user_query and settings.store_conversations: |
| 1654 | + if response.content: |
| 1655 | + assistant_output = "" |
| 1656 | + for block in response.content: |
| 1657 | + if hasattr(block, "text"): |
| 1658 | + assistant_output += block.text |
| 1659 | + if assistant_output: |
| 1660 | + self._wrapper._store_conversation( |
| 1661 | + user_query, assistant_output, model, settings |
| 1662 | + ) |
| 1663 | + return response |
1493 | 1664 |
|
1494 | 1665 |
|
1495 | 1666 | def wrap_openai( |
|
0 commit comments