In [1]:
import os
from typing import Any, List, Optional, Mapping, Dict
from dotenv import load_dotenv

# LangChain Core
from langchain_core.language_models.llms import BaseLLM
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.outputs import LLMResult, Generation
from langchain_core.messages import (
    HumanMessage, 
    AIMessage, 
    SystemMessage,
    RemoveMessage,
    BaseMessage
)
from langchain_core.messages.utils import trim_messages, count_tokens_approximately
from langchain_core.runnables import RunnableConfig

# LangGraph
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.checkpoint.memory import MemorySaver

# Snowflake
import snowflake.connector

load_dotenv()
print("‚úÖ All imports loaded")

‚úÖ All imports loaded


In [2]:
class SnowflakeCortexLLM(BaseLLM):
    """Snowflake Cortex LLM for LangChain"""
    
    model_name: str = "mistral-large"
    
    class Config:
        arbitrary_types_allowed = True
    
    def __init__(self, model_name: str = "mistral-large", **kwargs):
        super().__init__(model_name=model_name, **kwargs)
        object.__setattr__(self, '_conn', None)
    
    def _get_connection(self):
        if self._conn is None:
            object.__setattr__(self, '_conn', snowflake.connector.connect(
                account=os.getenv("SNOWFLAKE_ACCOUNT"),
                user=os.getenv("SNOWFLAKE_USER"),
                password=os.getenv("SNOWFLAKE_PASSWORD"),
                database=os.getenv("SNOWFLAKE_DATABASE"),
                schema=os.getenv("SNOWFLAKE_SCHEMA"),
                warehouse=os.getenv("SNOWFLAKE_WAREHOUSE"),
                role=os.getenv("SNOWFLAKE_ROLE")
            ))
        return self._conn
    
    @property
    def _llm_type(self) -> str:
        return "snowflake_cortex"
    
    def _generate(self, prompts: List[str], stop: Optional[List[str]] = None, 
                  run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> LLMResult:
        generations = []
        for prompt in prompts:
            text = self._call_cortex(prompt)
            generations.append([Generation(text=text)])
        return LLMResult(generations=generations)
    
    def _call_cortex(self, prompt: str) -> str:
        conn = self._get_connection()
        cursor = conn.cursor()
        try:
            escaped = prompt.replace("'", "''")
            sql = f"SELECT SNOWFLAKE.CORTEX.COMPLETE('{self.model_name}', '{escaped}')"
            cursor.execute(sql)
            result = cursor.fetchone()
            return result[0] if result else ""
        except Exception as e:
            return f"Error: {str(e)}"
        finally:
            cursor.close()
    
    def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str:
        return self._call_cortex(prompt)
    
    def __call__(self, prompt: str) -> str:
        return self._call_cortex(prompt)
    
    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        return {"model_name": self.model_name}
    
    def __del__(self):
        try:
            if hasattr(self, '_conn') and self._conn:
                self._conn.close()
        except:
            pass

# Initialize
llm = SnowflakeCortexLLM(model_name="mistral-large")
print("‚úÖ Snowflake Cortex LLM initialized")

# Test it
test_response = llm.invoke("Say one line about Memory Management in LLMs")
print(f"Test: {test_response}")


‚úÖ Snowflake Cortex LLM initialized
Test:  Memory management in Large Language Models (LLMs) involves efficient allocation and deallocation of memory resources to handle extensive data, enabling the model to learn and generate responses without causing system overload.


In [5]:
print("\n" + "="*70)
print("CHUNK 4: DEMONSTRATING CONTEXT OVERFLOW")
print("="*70)

# Create a long conversation
messages = []

for i in range(100000):
    messages.append(HumanMessage(content=f"Question {i+1}: Tell me about data engineering topic {i+1}. " + "Additional context. " * 10))
    messages.append(AIMessage(content=f"Answer {i+1}: Here's detailed information about topic {i+1}. " + "Detailed response. " * 10))

# Count tokens
total_tokens = count_tokens_approximately(messages)

print(f"\nüìä CONTEXT ANALYSIS:")
print(f"   Total messages: {len(messages)}")
print(f"   Total tokens (approx): {total_tokens:,}")
print(f"   GPT-4 limit: 128,000 tokens")
print(f"   Safe limit (80%): 102,400 tokens")

if total_tokens > 102400:
    print(f"\n   ‚ùå OVERFLOW! Over by {total_tokens - 102400:,} tokens")
    print(f"   ‚Üí Slower response, higher cost, errors possible")
else:
    print(f"\n   ‚úÖ Within limits")



CHUNK 4: DEMONSTRATING CONTEXT OVERFLOW

üìä CONTEXT ANALYSIS:
   Total messages: 200000
   Total tokens (approx): 13,697,986
   GPT-4 limit: 128,000 tokens
   Safe limit (80%): 102,400 tokens

   ‚ùå OVERFLOW! Over by 13,595,586 tokens
   ‚Üí Slower response, higher cost, errors possible


In [6]:
print("\n" + "="*70)
print("CHUNK 5: SOLUTION 1 - trim_messages()")
print("="*70)

# Use official trim_messages function
trimmed = trim_messages(
    messages,
    strategy="last",  # Keep last messages
    token_counter=count_tokens_approximately,
    max_tokens=2000,  # Keep up to 2000 tokens
    start_on="human",  # Start with human message
)

tokens_after_trim = count_tokens_approximately(trimmed)

print(f"\n‚úÇÔ∏è  TRIMMING RESULTS:")
print(f"   Messages: {len(messages)} ‚Üí {len(trimmed)}")
print(f"   Tokens: {total_tokens:,} ‚Üí {tokens_after_trim:,}")
print(f"   Saved: {total_tokens - tokens_after_trim:,} tokens ({((total_tokens - tokens_after_trim)/total_tokens*100):.1f}%)")
print(f"   ‚úÖ Within limits!")



CHUNK 5: SOLUTION 1 - trim_messages()

‚úÇÔ∏è  TRIMMING RESULTS:
   Messages: 200000 ‚Üí 28
   Tokens: 13,697,986 ‚Üí 1,920
   Saved: 13,696,066 tokens (100.0%)
   ‚úÖ Within limits!


In [7]:
print("\n" + "="*70)
print("CHUNK 6: AUTOMATIC TRIMMING WITH AGENT (Snowflake Cortex)")
print("="*70)

# Build agent using LangGraph (since create_agent may not work with custom LLM)
class State(MessagesState):
    """State with messages"""
    pass

def chatbot(state: State):
    """Chatbot node that calls Snowflake Cortex"""
    
    # Get messages
    messages_for_llm = state.get("llm_input_messages", state["messages"])
    
    # Convert to prompt
    prompt = "\n".join([
        f"{m.type}: {m.content}" for m in messages_for_llm
    ])
    prompt += "\n\nAssistant:"
    
    # Call Snowflake Cortex
    response = llm.invoke(prompt)
    
    return {"messages": [AIMessage(content=response)]}


def pre_model_hook(state):
    """Trim messages before LLM call"""
    
    trimmed = trim_messages(
        state["messages"],
        strategy="last",
        token_counter=count_tokens_approximately,
        max_tokens=2000,
        start_on="human",
    )
    
    print(f"   üîß Trimmed: {len(state['messages'])} ‚Üí {len(trimmed)} messages")
    
    return {"llm_input_messages": trimmed}


# Build graph with trimming
workflow = StateGraph(State)
workflow.add_node("trim", pre_model_hook)
workflow.add_node("chatbot", chatbot)
workflow.add_edge(START, "trim")
workflow.add_edge("trim", "chatbot")
workflow.add_edge("chatbot", END)

# Compile with checkpointer
checkpointer = MemorySaver()
app = workflow.compile(checkpointer=checkpointer)

print("‚úÖ Agent with automatic trimming created")


# Test the agent
config: RunnableConfig = {"configurable": {"thread_id": "1"}}

print("\nüìù Testing agent with automatic trimming...\n")

# Message 1
response1 = app.invoke({"messages": [HumanMessage(content="Hi, my name is Bob")]}, config)
print(f"1. {response1['messages'][-1].content[:80]}...")

# Message 2
response2 = app.invoke({"messages": [HumanMessage(content="What's the capital of France?")]}, config)
print(f"2. {response2['messages'][-1].content[:80]}...")

# Message 3
response3 = app.invoke({"messages": [HumanMessage(content="What's my name?")]}, config)
print(f"3. {response3['messages'][-1].content[:80]}...")

print(f"\n   ‚úÖ Agent remembers 'Bob' with automatic trimming!")
print(f"   Total messages in state: {len(response3['messages'])}")




CHUNK 6: AUTOMATIC TRIMMING WITH AGENT (Snowflake Cortex)
‚úÖ Agent with automatic trimming created

üìù Testing agent with automatic trimming...

   üîß Trimmed: 1 ‚Üí 1 messages
1.  Hello Bob! It's a pleasure to meet you. How can I assist you today?...
   üîß Trimmed: 3 ‚Üí 3 messages
2.  The capital of France is Paris.

Reference(s):
[INST] retrieval:1
id: wikipedia...
   üîß Trimmed: 5 ‚Üí 5 messages
3.  Your name is Bob, as you mentioned earlier in our conversation....

   ‚úÖ Agent remembers 'Bob' with automatic trimming!
   Total messages in state: 6


In [9]:
print("\n" + "="*70)
print("CHUNK 7: SOLUTION 2 - RemoveMessage (Delete Old Messages)")
print("="*70)

def post_model_hook(state):
    """Delete old messages after each response"""
    
    messages = state["messages"]
    
    if len(messages) > 10:
        # Keep only last 10 messages
        to_remove = messages[:len(messages)-10]
        print(f"   üóëÔ∏è  Removing {len(to_remove)} old messages")
        return {"messages": [RemoveMessage(id=m.id) for m in to_remove]}
    
    return {}


# Build graph with deletion
workflow_delete = StateGraph(State)
workflow_delete.add_node("chatbot", chatbot)
workflow_delete.add_node("delete", post_model_hook)
workflow_delete.add_edge(START, "chatbot")
workflow_delete.add_edge("chatbot", "delete")
workflow_delete.add_edge("delete", END)

app_delete = workflow_delete.compile(checkpointer=MemorySaver())

print("‚úÖ Agent with automatic deletion created")


# Test deletion
config2: RunnableConfig = {"configurable": {"thread_id": "2"}}

print("\nüìù Testing agent with RemoveMessage...\n")

# Add many messages
for i in range(2):
    app_delete.invoke({"messages": [HumanMessage(content=f"Message {i+1}")]}, config2)

final_state = app_delete.get_state(config2)
print(f"   After 15 messages, state has: {len(final_state.values['messages'])} messages")
print(f"   ‚úÖ Old messages automatically deleted!")




CHUNK 7: SOLUTION 2 - RemoveMessage (Delete Old Messages)
‚úÖ Agent with automatic deletion created

üìù Testing agent with RemoveMessage...

   After 15 messages, state has: 4 messages
   ‚úÖ Old messages automatically deleted!


In [10]:
print("\n" + "="*70)
print("CHUNK 8: SOLUTION 3 - SUMMARIZATION")
print("="*70)

def create_summary(messages: List[BaseMessage], llm) -> str:
    """Create summary of messages using Snowflake Cortex"""
    
    # Convert messages to text
    conversation = "\n".join([
        f"{m.type}: {m.content}" for m in messages
    ])
    
    # Summarization prompt
    prompt = f"""Summarize the following conversation concisely, preserving key facts:

{conversation}

Summary (2-3 sentences):"""
    
    summary = llm.invoke(prompt)
    return summary


def summarization_hook(state):
    """Summarize old messages, keep recent"""
    
    messages = state["messages"]
    
    # Trigger at 20 messages
    if len(messages) > 20:
        keep_recent = 10
        old_messages = messages[:-keep_recent]
        recent_messages = messages[-keep_recent:]
        
        print(f"   üìù Summarizing {len(old_messages)} old messages...")
        
        # Create summary
        summary_text = create_summary(old_messages, llm)
        summary_msg = SystemMessage(content=f"Previous conversation summary: {summary_text}")
        
        # Create remove commands for old messages
        remove_old = [RemoveMessage(id=m.id) for m in old_messages]
        
        print(f"   ‚úÖ Summary created, keeping {keep_recent} recent messages")
        
        return {"messages": remove_old + [summary_msg]}
    
    return {}


# Build graph with summarization
workflow_summary = StateGraph(State)
workflow_summary.add_node("chatbot", chatbot)
workflow_summary.add_node("summarize", summarization_hook)
workflow_summary.add_edge(START, "chatbot")
workflow_summary.add_edge("chatbot", "summarize")
workflow_summary.add_edge("summarize", END)

app_summary = workflow_summary.compile(checkpointer=MemorySaver())

print("‚úÖ Agent with automatic summarization created")


# Test summarization
config3: RunnableConfig = {"configurable": {"thread_id": "3"}}

print("\nüìù Testing agent with summarization...\n")

# Add many messages to trigger summarization
for i in range(25):
    response = app_summary.invoke(
        {"messages": [HumanMessage(content=f"Tell me fact {i+1} about AI")]}, 
        config3
    )

final_state = app_summary.get_state(config3)
print(f"   After 25 messages:")
print(f"   Messages in state: {len(final_state.values['messages'])}")
print(f"   ‚úÖ Summarization triggered and old messages compressed!")

# Check if there's a summary message
for msg in final_state.values['messages']:
    if isinstance(msg, SystemMessage) and "summary" in msg.content.lower():
        print(f"\n   üìù Summary found: {msg.content[:100]}...")
        break


CHUNK 8: SOLUTION 3 - SUMMARIZATION
‚úÖ Agent with automatic summarization created

üìù Testing agent with summarization...

   üìù Summarizing 12 old messages...
   ‚úÖ Summary created, keeping 10 recent messages
   üìù Summarizing 11 old messages...
   ‚úÖ Summary created, keeping 10 recent messages


KeyboardInterrupt: 