#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Test script to demonstrate wakeable producer flush() interruptibility with real Kafka broker.
#
# This script tests the wakeable flush pattern by:
# 1. Connecting to a real Kafka broker
# 2. Producing multiple messages to a topic
# 3. Calling flush() with a long timeout (or infinite)
# 4. Allowing Ctrl+C to interrupt the flush operation
#
# Usage:
#   python test_wakeable_producer_flush_interrupt.py [bootstrap_servers] [topic]
#
# Example:
#   python test_wakeable_producer_flush_interrupt.py localhost:9092 test-topic
#
# Press Ctrl+C to test interruptibility
#
# NOTE: This script automatically uses the local development version of confluent_kafka
#       by setting PYTHONPATH to include the src directory.

import os
import sys
import time
import threading

# Ensure we use the local development version
script_dir = os.path.dirname(os.path.abspath(__file__))
src_dir = os.path.join(script_dir, 'src')
if os.path.exists(src_dir) and src_dir not in sys.path:
    sys.path.insert(0, src_dir)

from confluent_kafka import Producer, KafkaException
from confluent_kafka.admin import AdminClient, NewTopic

# Verify we're using the local development version
import confluent_kafka
if 'src/confluent_kafka' in confluent_kafka.__file__ or script_dir in confluent_kafka.__file__:
    print(f"✓ Using local development version: {confluent_kafka.__file__}")
else:
    print(f"⚠ WARNING: Using installed version: {confluent_kafka.__file__}")
    print("  The wakeable flush changes may not be active!")
    print("  Make sure to build the local version with: python setup.py build_ext --inplace")
    print()

# Default configuration
DEFAULT_BOOTSTRAP_SERVERS = os.environ.get('BOOTSTRAP_SERVERS', 'localhost:9092')
DEFAULT_TOPIC = os.environ.get('TEST_TOPIC', 'test-wakeable-producer-flush-topic')


def main():
    # Parse command line arguments
    bootstrap_servers = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_BOOTSTRAP_SERVERS
    topic = sys.argv[2] if len(sys.argv) > 2 else DEFAULT_TOPIC
    
    print("=" * 70)
    print("Wakeable Producer Flush Interruptibility Test")
    print("=" * 70)
    print(f"Bootstrap servers: {bootstrap_servers}")
    print(f"Topic: {topic}")
    print()
    print("This test demonstrates the wakeable flush pattern:")
    print("  - flush() will be called with infinite timeout (blocks until Ctrl+C)")
    print("  - The operation can be interrupted with Ctrl+C")
    print("  - With the wakeable pattern, Ctrl+C should interrupt within ~200ms")
    print()
    print("Press Ctrl+C to test interruptibility...")
    print("=" * 70)
    print()
    
    # Create topic if it doesn't exist
    admin_client = None
    try:
        print(f"Ensuring topic '{topic}' exists...")
        admin_conf = {'bootstrap.servers': bootstrap_servers}
        admin_client = AdminClient(admin_conf)
        
        # Try to create the topic
        new_topic = NewTopic(topic, num_partitions=1, replication_factor=1)
        fs = admin_client.create_topics([new_topic], request_timeout=10.0)
        
        # Wait for topic creation
        for topic_name, f in fs.items():
            try:
                f.result(timeout=10.0)
                print(f"✓ Topic '{topic_name}' created successfully")
            except Exception as e:
                if "already exists" in str(e).lower() or "TopicExistsException" in str(type(e).__name__):
                    print(f"✓ Topic '{topic_name}' already exists")
                else:
                    print(f"⚠ Could not create topic '{topic_name}': {e}")
                    print("  Continuing anyway - topic may already exist...")
    except Exception as e:
        print(f"⚠ Could not create topic: {e}")
        print("  Continuing anyway - topic may already exist...")
    finally:
        if admin_client:
            admin_client = None  # Clean up admin client
    
    # Producer configuration to ensure flush() blocks
    # Key insight: flush() blocks when messages are in-flight waiting for acknowledgment
    # Strategy: Use acks=all and DON'T call poll() during flush, so delivery callbacks
    # aren't processed, keeping messages in-flight longer
    conf = {
        'bootstrap.servers': bootstrap_servers,
        'socket.timeout.ms': 60000,  # Long socket timeout
        'acks': 'all',  # Wait for all acknowledgments - CRITICAL for flush() to block
        # Settings to keep messages in queue longer before sending
        'batch.num.messages': 100,  # Wait for 100 messages before sending
        'linger.ms': 100,  # Wait up to 100ms before sending a batch
        'queue.buffering.max.messages': 100000,  # Large queue
        'queue.buffering.max.kbytes': 104857600,  # 100MB
        # IMPORTANT: With max.in.flight=1 and acks=all, each message must be
        # acknowledged before the next is sent, making flush() block longer
        'max.in.flight.requests.per.connection': 1,
        # Request timeout - how long to wait for broker response
        'request.timeout.ms': 30000,  # 30 seconds
        # Delivery timeout - how long to wait for delivery report
        'delivery.timeout.ms': 30000,  # 30 seconds
    }
    
    producer = None
    stop_producing = threading.Event()
    production_stats = {'count': 0, 'errors': 0}
    
    def continuous_producer():
        """Background thread that continuously produces messages as fast as possible"""
        message_num = 0
        while not stop_producing.is_set():
            try:
                # Produce messages as fast as possible - no delays
                producer.produce(topic, 
                                value=f'continuous-message-{message_num}'.encode(), 
                                key=f'continuous-key-{message_num}'.encode())
                production_stats['count'] += 1
                message_num += 1
            except Exception as e:
                production_stats['errors'] += 1
                if "QUEUE_FULL" in str(e):
                    # Queue is full - wait a tiny bit and try again
                    time.sleep(0.001)  # 1ms wait when queue is full
                else:
                    if production_stats['errors'] < 5:
                        print(f"  ⚠ Production error: {e}")
                    time.sleep(0.01)  # Wait on other errors
    
    try:
        # Create producer
        producer = Producer(conf)
        
        # Strategy: Produce many messages and ensure they're sent but not all acknowledged
        # flush() blocks when messages are in-flight waiting for acknowledgment
        # With acks=all, it waits for all replicas to acknowledge
        print(f"Producing messages to fill queue and send to broker...")
        initial_batch = 50000  # Produce many messages
        produced_initial = 0
        for i in range(initial_batch):
            try:
                producer.produce(topic, 
                                value=f'initial-message-{i}'.encode(), 
                                key=f'initial-key-{i}'.encode())
                produced_initial += 1
                if (i + 1) % 10000 == 0:
                    print(f"  Produced {i+1}/{initial_batch} messages...")
            except Exception as e:
                if "QUEUE_FULL" not in str(e):
                    print(f"⚠ Could not produce initial message {i+1}: {e}")
                if "QUEUE_FULL" in str(e):
                    time.sleep(0.01)
                    continue
                break
        
        print(f"✓ Produced {produced_initial} messages")
        print()
        
        # Poll briefly to send some messages, but not all will be acknowledged yet
        # This creates a mix of messages in queue and in-flight
        print("Polling briefly to send messages to broker...")
        print("  (Some messages will be sent and waiting for acknowledgment)")
        poll_start = time.time()
        while time.time() - poll_start < 1.0:  # Poll for 1 second
            producer.poll(timeout=0.1)  # Process some delivery callbacks
        queue_after_poll = len(producer)
        print(f"  Queue after brief poll: {queue_after_poll} messages")
        print("  (Some messages are in-flight, waiting for acknowledgment)")
        print()
        
        # Start background thread to continuously produce messages
        print("Starting background thread to continuously produce messages...")
        producer_thread = threading.Thread(target=continuous_producer, daemon=True)
        producer_thread.start()
        print("✓ Background producer thread started")
        print("  (Will produce messages continuously during flush)")
        print()
        
        # Give a moment for background thread to start
        time.sleep(0.5)
        
        print("Ready! Starting flush() with infinite timeout (will block until Ctrl+C)...")
        print("  KEY INSIGHT: flush() blocks when messages are in-flight waiting for acknowledgment")
        print("  With acks=all and max.in.flight=1, messages are sent one at a time")
        print("  and flush() waits for each acknowledgment before sending the next")
        print("  With continuous production, new messages keep being added,")
        print("  keeping messages in-flight and making flush() block")
        print()
        
        initial_queue = len(producer)
        print(f"  Queue before flush: {initial_queue} messages")
        print("  (Many messages are in queue and in-flight, waiting for acknowledgment)")
        print()
        
        # Test flush() with infinite timeout - this should be interruptible
        start_time = time.time()
        try:
            print(f"[{time.strftime('%H:%M:%S')}] Calling flush() with infinite timeout...")
            print("    (This will block until all messages are flushed or Ctrl+C)")
            print("    Background thread is continuously adding messages to the queue")
            print(f"    Current queue length: {len(producer)} messages")
            print("    Press Ctrl+C to test interruptibility...")
            print()
            
            remaining = producer.flush(timeout=-1.0)  # Infinite timeout - will block
            
            elapsed = time.time() - start_time
            print(f"[{time.strftime('%H:%M:%S')}] flush() returned")
            print(f"    Messages remaining in queue: {remaining}")
            print(f"    Elapsed time: {elapsed:.2f}s")
            
            if remaining == 0:
                print("    ✓ All messages flushed successfully!")
                if elapsed < 1.0:
                    print("    ⚠ Messages flushed very quickly - operation completed before interrupt test")
                    print("    💡 Tip: Try producing more messages or using a slower broker configuration")
            else:
                print(f"    ⚠ {remaining} messages still in queue (timeout or interruption)")
        
        except KeyboardInterrupt:
            elapsed = time.time() - start_time
            # Stop the background producer thread
            stop_producing.set()
            print()
            print("=" * 70)
            print("✓ KeyboardInterrupt caught!")
            print(f"  Interrupted after {elapsed:.2f} seconds")
            print(f"  Background thread produced {production_stats['count']} messages during this time")
            if production_stats['errors'] > 0:
                print(f"  (Had {production_stats['errors']} production errors)")
            print(f"  With wakeable pattern, interruption should occur within ~200ms of Ctrl+C")
            if elapsed < 0.5:
                print("  ✓ Fast interruption confirmed!")
            else:
                print(f"  ⚠ Interruption took {elapsed:.2f}s (may indicate wakeable pattern issue)")
            print("=" * 70)
            raise  # Re-raise to exit
    
    except KafkaException as e:
        print(f"Kafka error: {e}")
        sys.exit(1)
    
    except Exception as e:
        print(f"Unexpected error: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)
    
    finally:
        # Stop the background producer thread
        stop_producing.set()
        if producer:
            print()
            print("Stopping background producer thread...")
            # Give thread a moment to stop
            time.sleep(0.1)
            print(f"  Final stats: {production_stats['count']} messages produced")
            if production_stats['errors'] > 0:
                print(f"  ({production_stats['errors']} production errors)")
            print()
            print("Closing producer...")
            # Try to flush remaining messages with short timeout
            try:
                remaining = producer.flush(timeout=2.0)
                if remaining > 0:
                    print(f"  Note: {remaining} messages may still be in queue")
            except:
                pass
            producer.close()
            print("Producer closed.")


if __name__ == '__main__':
    main()

