In [29]:
import google.generativeai as genai
import json
from pymongo import MongoClient, ASCENDING
import redis
import json
from datetime import datetime
import threading
# Ensure the Gemini API key is configured
GEMINI_API_KEY = "AIzaSyA5bnFCaT3L3oPjwVPUQ1f5u6Z65ilGorQ"  # Replace with your actual API key
genai.configure(api_key=GEMINI_API_KEY)

MONGO_URI = 'mongodb://localhost:27017/'
DATABASE_NAME = 'chat_database'
COLLECTION_NAME = 'chat_history'
client = MongoClient(MONGO_URI)
db = client[DATABASE_NAME]
collection = db[COLLECTION_NAME]


REDIS_HOST = "localhost"
REDIS_PORT = 6379
redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True)

In [30]:
def call_gemini(prompt_text, temperature=0.7):
    model_name = "gemini-pro"
    
    # Start a chat session with the Gemini model
    model = genai.GenerativeModel(model_name)
    chat_session = model.start_chat()
    
    # Send the message to the model
    response = chat_session.send_message(prompt_text)
    
    # Parse the response
    result_text = response.text
    
    return {'role': 'assistant', 'content': result_text}

def create_mongodb_collection():
    try:
        # Create collection (will automatically create it if it doesn't exist)
        collection = db[COLLECTION_NAME]

        # Create indexes for UserId and Timestamp
        collection.create_index([('UserId', ASCENDING), ('Timestamp', ASCENDING)], name='user_timestamp_index')
        print(f"Collection '{COLLECTION_NAME}' is set up with indexes.")
    except Exception as e:
        print(f"Error creating MongoDB collection: {e}")

def test_redis_connection():
    try:
        # Add a test message to Redis
        user_id = "test_user_redis"
        message_text = "Hello from Redis!"
        timestamp = datetime.now().isoformat()
        message = {"content": message_text, "timestamp": timestamp}
        stack_key = f"{user_id}:stack"
        redis_client.rpush(stack_key, json.dumps(message))

        # Retrieve and print the messages
        retrieved_messages = redis_client.lrange(stack_key, 0, -1)
        print(f"Retrieved messages from Redis for {user_id}:")
        for msg in retrieved_messages:
            print(json.loads(msg))
    except Exception as e:
        print(f"Error in Redis connection test: {e}")

def clear_all_redis_data(redis_host, redis_port, redis_password=None):
    try:
        # Connect to Redis
        redis_client = redis.Redis(host=redis_host, port=redis_port, decode_responses=True)
        
        # Flush all keys in the current database
        redis_client.flushdb()
        print("Successfully cleared all data in the current Redis database.")
    except Exception as e:
        print(f"Error clearing Redis data: {e}")

def test_mongodb_connection():
    try:
        # Add a test message to MongoDB
        user_id = "test_user_mongodb"
        message_text = "Hello from MongoDB!"
        timestamp = datetime.now().isoformat()
        
        # Create the message document
        message_doc = {
            "UserId": user_id,
            "Timestamp": timestamp,
            "Content": message_text,
            "BatchId": 1  # You can change this as needed
        }
        
        # Insert the message into MongoDB
        collection.insert_one(message_doc)

        # Query and print the messages
        retrieved_messages = list(collection.find({"UserId": user_id}))
        print(f"Retrieved messages from MongoDB for {user_id}:")
        for item in retrieved_messages:
            print({
                "UserId": item["UserId"],
                "Timestamp": item["Timestamp"],
                "Content": item["Content"],
                "BatchId": item["BatchId"]
            })
    except Exception as e:
        print(f"Error in MongoDB connection test: {e}")

def clear_all_mongodb_data(database_name, collection_name):
    # Initialize the MongoDB client
    client = MongoClient(MONGO_URI)
    db = client[database_name]
    collection = db[collection_name]
    
    try:
        # Delete all documents in the collection
        result = collection.delete_many({})
        print(f"Successfully cleared {result.deleted_count} documents from MongoDB collection: {collection_name}")
    except Exception as e:
        print(f"Error clearing MongoDB data: {e}")


def summarize_chat_history(messages):
    summarize_prompt = """
    Given a history of chat messages, summarize the conversation between user and AI in to one paragraph of not more than 250 words.
    Summarize all user messages in to one paragraph and all AI messages in to another paragraph.
    User: 
    AI: 

    Message history:
    {conversation}

    Instructions:


    Summary:
    """
    # Combine all user messages (20 messages are passed to this function at a time)
    summary = " ".join([str(msg) for msg in messages])
    # Generate the prompt for summarization
    prompt = summarize_prompt.format(conversation=summary)
    # Call the Bedrock model to generate the summary
    response = call_gemini(prompt, 0.7)
    return response

In [31]:
clear_all_redis_data(REDIS_HOST, REDIS_PORT)
clear_all_mongodb_data(DATABASE_NAME, COLLECTION_NAME)

Successfully cleared all data in the current Redis database.
Successfully cleared 0 documents from MongoDB collection: chat_history


In [32]:
# clear_all_mongodb_data(DATABASE_NAME, COLLECTION_NAME)
retrieved_messages = list(collection.find({"UserId": "user1"}))

In [33]:
retrieved_messages

[]

In [34]:
def pull_all_from_redis(redis_client):
    try:
        # Retrieve all keys
        keys = redis_client.keys('user1')  # Use '*' to match all keys
        all_data = {}

        # Retrieve and print values for each key
        for key in keys:
            value = redis_client.get(key)  # Get value associated with the key
            all_data[key] = json.loads(value) if value else None  # Load JSON if it's in JSON format

        return all_data  # Return all key-value pairs as a dictionary
    except Exception as e:
        print(f"Error retrieving data from Redis: {e}")
        return None

# Call the function to pull all data from Redis
all_redis_data = pull_all_from_redis(redis_client)
print(all_redis_data)

{}


In [15]:
class ChatManager:
    def __init__(self, redis_host, redis_port, mongo_uri, database_name, collection_name):
        # Connect to Redis
        self.redis_client = redis.Redis(host=redis_host, port=redis_port, decode_responses=True)
        self.lock = threading.Lock()
        
        # Set up MongoDB client, database, and collection
        self.mongo_client = MongoClient(mongo_uri)
        self.db = self.mongo_client[database_name]
        self.collection = self.db[collection_name]

    def count_messages_with_batch(self, stack_key, batch_id):
        """
        Count the number of messages in the specified stack that have the given batch_id.
        """
        count = 0
        # Retrieve all messages from the Redis stack
        messages = self.redis_client.lrange(stack_key, 0, -1)

        # Count messages with the specified batch_id
        for msg in messages:
            message = json.loads(msg)
            if message["batch_id"] == batch_id:
                count += 1
        return count

    def handle_new_message(self, user_id, message_text):
        with self.lock:
            # Define keys for Redis
            stack_key = f"{user_id}:stack"  # Stores individual messages for a particular user
            batch_id_key = f"{user_id}:batch_id"  # Stores the current batch ID for a user
            
            # Retrieve the current batch ID from Redis, or initialize it
            batch_id = self.redis_client.get(batch_id_key)
            if batch_id is None:
                batch_id = 1
            else:
                batch_id = int(batch_id)

            # Create a new message dictionary with a timestamp
            timestamp = datetime.now().isoformat()
            message = {"batch_id": batch_id, "content": message_text, "timestamp": timestamp}
            
            # Add the new message to the cache stack for the user. When the stack reaches 20 messages, a summary is created.
            self.redis_client.rpush(stack_key, json.dumps(message))

            # Persist the message to MongoDB
            message_doc = {
                "UserId": user_id,
                "Timestamp": timestamp,
                "Content": message_text,
                "BatchId": batch_id
            }
            self.collection.insert_one(message_doc)

            # Check if the stack has 20 messages for the current batch ID
            count = self.count_messages_with_batch(stack_key, batch_id)
            if count == 20:  # Check if the stack has 20 messages
                threading.Thread(target=self._create_summary, args=(user_id, batch_id)).start()
                self.redis_client.set(batch_id_key, batch_id + 1)

    def _create_summary(self, user_id, batch_id):
        with self.lock:
            # Retrieve all messages for the user with the specified batch ID
            stack_key = f"{user_id}:stack"
            all_messages = [json.loads(msg) for msg in self.redis_client.lrange(stack_key, 0, -1)]
            messages = [msg for msg in all_messages if msg["batch_id"] == batch_id]

            # Create a summary of those messages
            summary_content = summarize_chat_history(messages)
            summary = {"batch_id": batch_id, "content": summary_content, "count": len(messages)}

            # Store the summary in Redis
            summary_key = f"{user_id}:summary"
            self.redis_client.set(summary_key, json.dumps(summary))

            # Gather all messages from the cache which have not been summarized yet
            remaining_messages = [msg for msg in all_messages if msg["batch_id"] != batch_id]

            # Clear the stack and repopulate with the remaining messages
            self.redis_client.delete(stack_key)
            for msg in remaining_messages:
                self.redis_client.rpush(stack_key, json.dumps(msg))

    # Get the chat history for a user. This fetches the most recent summary and messages which are not summarized yet.
    def get_chat_history(self, user_id):
        # Retrieve the summary for the user and add it to the history
        summary_key = f"{user_id}:summary"
        summary = self.redis_client.get(summary_key)
        history = []
        
        if summary is not None:
            history.append(json.loads(summary))

        # Retrieve all messages from the Redis stack which have not been summarized. This batch will be summarized when it reaches 20 messages.
        # Combine the remaining messages with the summary
        stack_key = f"{user_id}:stack"
        remaining_messages = [json.loads(msg) for msg in self.redis_client.lrange(stack_key, 0, -1)]
        history.extend(remaining_messages)

        return history

    # Load the last 20 messages for a user from MongoDB and populate the Redis cache.
    def load_messages_from_mongodb(self, user_id):
        # Query the collection with a limit of 20 items
        messages = list(self.collection.find({"UserId": user_id}).sort("Timestamp", -1).limit(20))

        # Reset Redis state and populate messages
        stack_key = f"{user_id}:stack"
        summary_key = f"{user_id}:summary"
        self.redis_client.delete(stack_key)
        self.redis_client.delete(summary_key)
        
        # Rehydrate redis cache
        for item in reversed(messages):
            print("Reloading:" + str(item))
            message = {"batch_id": int(item['BatchId']), "content": item['Content'], "timestamp": item['Timestamp']}
            self.redis_client.rpush(stack_key, json.dumps(message))

In [22]:
import pprint
from sample_conversation import conversation

chat_manager = ChatManager(REDIS_HOST, REDIS_PORT, MONGO_URI, DATABASE_NAME, COLLECTION_NAME)

i = 0
for c in conversation:
    chat_manager.handle_new_message("user1", str(c))
    i+=1
    if i==10:break


In [23]:
history = chat_manager.get_chat_history("user1")

In [24]:
history

[{'batch_id': 1,
  'content': "{'role': 'user', 'content': 'Can you tell me what AWS is all about?'}",
  'timestamp': '2024-10-08T21:42:59.796641'},
 {'batch_id': 1,
  'content': "{'role': 'assistant', 'content': 'AWS stands for Amazon Web Services, a cloud computing platform offering a range of infrastructure and application services.'}",
  'timestamp': '2024-10-08T21:42:59.800064'},
 {'batch_id': 1,
  'content': "{'role': 'user', 'content': 'What types of services does AWS offer?'}",
  'timestamp': '2024-10-08T21:42:59.804019'},
 {'batch_id': 1,
  'content': "{'role': 'assistant', 'content': 'AWS provides services like computing power, storage options, networking, machine learning, security, and more.'}",
  'timestamp': '2024-10-08T21:42:59.807980'},
 {'batch_id': 1,
  'content': "{'role': 'user', 'content': 'What are some of the popular services within AWS?'}",
  'timestamp': '2024-10-08T21:42:59.810180'},
 {'batch_id': 1,
  'content': "{'role': 'assistant', 'content': 'Popular serv

In [26]:
question_prompt = """
Given a question and chat history, answer the question in the context of the conversation.

Chat History:
{chat_history}

Question: {question}
"""

for i in range(6):
    question = input()
    message = {"role":"user", "content": question}
    history = str(chat_manager.get_chat_history("user1"))
    prompt = question_prompt.format(chat_history=history, question=question)
    response = call_gemini(prompt, 0.7)
    chat_manager.handle_new_message("user1", str(message))
    chat_manager.handle_new_message("user1", str(response))
    print(response)
    

{'role': 'assistant', 'content': 'This conversation is about the Amazon Web Services (AWS) cloud computing platform. It covers various aspects of AWS, including its services, benefits, and best practices.'}
{'role': 'assistant', 'content': 'The chat history does not contain any information about the top services.'}
{'role': 'assistant', 'content': "- AWS's managed database offerings\n- Advantages of DynamoDB\n- Benefits of RDS\n- Importance of AWS IAM\n- Best practices for IAM security\n- Implementation of multi-factor authentication in IAM\n- Strategies for optimizing AWS costs\n- Nature of spot instances\n- Meaning of rightsizing resources\n- Overall topic of the conversation"}
{'role': 'assistant', 'content': 'DynamoDB is a managed NoSQL database service that provides fast and flexible performance and automatic scaling. It is designed for applications that require high throughput and low latency. DynamoDB manages all aspects of setup, maintenance, and backups.'}
{'role': 'assistant'

In [27]:
history = chat_manager.get_chat_history("user1")
pprint.pprint(history)

[{'batch_id': 2,
  'content': {'content': '**User Summary:**\n'
                         "The user inquired about AWS's managed database "
                         'offerings, the advantages of DynamoDB, the benefits '
                         'of RDS, the importance of AWS IAM, best practices '
                         'for IAM security, the implementation of multi-factor '
                         'authentication in IAM, strategies for optimizing AWS '
                         'costs, the nature of spot instances, the meaning of '
                         'rightsizing resources, and the overall topic of the '
                         'conversation.\n'
                         '\n'
                         '**AI Summary:**\n'
                         'The AI provided thorough explanations of AWS '
                         'services and best practices. It clarified the '
                         'advantages of DynamoDB, emphasizing its optimized '
                         'performance,