#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Test script to demonstrate wakeable poll() interruptibility with real Kafka broker.
#
# This script tests the wakeable poll pattern by:
# 1. Connecting to a real Kafka broker
# 2. Subscribing to a topic
# 3. Calling poll() with a long timeout (or infinite)
# 4. Allowing Ctrl+C to interrupt the poll operation
#
# Usage:
#   python test_wakeable_poll_interrupt.py [bootstrap_servers] [topic]
#
# Example:
#   python test_wakeable_poll_interrupt.py localhost:9092 test-topic
#
# Press Ctrl+C to test interruptibility
#
# NOTE: This script automatically uses the local development version of confluent_kafka
#       by setting PYTHONPATH to include the src directory.

import os
import sys
import time

# Ensure we use the local development version
script_dir = os.path.dirname(os.path.abspath(__file__))
src_dir = os.path.join(script_dir, 'src')
if os.path.exists(src_dir) and src_dir not in sys.path:
    sys.path.insert(0, src_dir)

from confluent_kafka import Consumer, KafkaException
from confluent_kafka.admin import AdminClient, NewTopic

# Verify we're using the local development version
import confluent_kafka
if 'src/confluent_kafka' in confluent_kafka.__file__ or script_dir in confluent_kafka.__file__:
    print(f"✓ Using local development version: {confluent_kafka.__file__}")
else:
    print(f"⚠ WARNING: Using installed version: {confluent_kafka.__file__}")
    print("  The wakeable poll changes may not be active!")
    print("  Make sure to build the local version with: python setup.py build_ext --inplace")
    print()

# Default configuration
DEFAULT_BOOTSTRAP_SERVERS = os.environ.get('BOOTSTRAP_SERVERS', 'localhost:9092')
DEFAULT_TOPIC = os.environ.get('TEST_TOPIC', 'test-wakeable-poll-topic')
DEFAULT_GROUP_ID = 'test-wakeable-poll-group'


def main():
    # Parse command line arguments
    bootstrap_servers = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_BOOTSTRAP_SERVERS
    topic = sys.argv[2] if len(sys.argv) > 2 else DEFAULT_TOPIC
    
    print("=" * 70)
    print("Wakeable Poll Interruptibility Test")
    print("=" * 70)
    print(f"Bootstrap servers: {bootstrap_servers}")
    print(f"Topic: {topic}")
    print(f"Group ID: {DEFAULT_GROUP_ID}")
    print()
    print("This test demonstrates the wakeable poll pattern:")
    print("  - poll() will be called with a long timeout (30 seconds)")
    print("  - The operation can be interrupted with Ctrl+C")
    print("  - With the wakeable pattern, Ctrl+C should interrupt within ~200ms")
    print()
    print("Press Ctrl+C to test interruptibility...")
    print("=" * 70)
    print()
    
    # Create topic if it doesn't exist
    admin_client = None
    try:
        print(f"Ensuring topic '{topic}' exists...")
        admin_conf = {'bootstrap.servers': bootstrap_servers}
        admin_client = AdminClient(admin_conf)
        
        # Try to create the topic
        new_topic = NewTopic(topic, num_partitions=1, replication_factor=1)
        fs = admin_client.create_topics([new_topic], request_timeout=10.0)
        
        # Wait for topic creation
        for topic_name, f in fs.items():
            try:
                f.result(timeout=10.0)
                print(f"✓ Topic '{topic_name}' created successfully")
            except Exception as e:
                if "already exists" in str(e).lower() or "TopicExistsException" in str(type(e).__name__):
                    print(f"✓ Topic '{topic_name}' already exists")
                else:
                    print(f"⚠ Could not create topic '{topic_name}': {e}")
                    print("  Continuing anyway - topic may already exist...")
    except Exception as e:
        print(f"⚠ Could not create topic: {e}")
        print("  Continuing anyway - topic may already exist...")
    finally:
        if admin_client:
            admin_client = None  # Clean up admin client
    
    # Consumer configuration
    conf = {
        'bootstrap.servers': bootstrap_servers,
        'group.id': DEFAULT_GROUP_ID,
        'session.timeout.ms': 6000,
        'auto.offset.reset': 'latest',  # Start from latest to avoid consuming old messages
        'socket.timeout.ms': 100,
    }
    
    consumer = None
    try:
        # Create consumer
        consumer = Consumer(conf)
        
        # Subscribe to topic
        print(f"Subscribing to topic: {topic}")
        consumer.subscribe([topic])
        
        # Wait a bit for subscription to complete and topic to be available
        print("Waiting for subscription to complete...")
        time.sleep(2.0)
        
        # Poll a few times to clear any initial errors (like topic not available)
        print("Polling to clear initial subscription messages...")
        for i in range(5):
            msg = consumer.poll(timeout=1.0)
            if msg is None:
                break
            if msg.error():
                error_str = str(msg.error())
                # Check if it's a recoverable error
                if "UNKNOWN_TOPIC_OR_PART" in error_str:
                    print(f"  Waiting for topic to be available... (attempt {i+1}/5)")
                    time.sleep(1.0)
                    continue
            break
        
        print("Ready! Starting poll() with 30 second timeout...")
        print()
        
        # Test poll() with long timeout - this should be interruptible
        start_time = time.time()
        try:
            print(f"[{time.strftime('%H:%M:%S')}] Calling poll(timeout=30.0)...")
            print("    (This will block until a message arrives, timeout expires, or Ctrl+C)")
            print()
            
            msg = consumer.poll(timeout=30.0)
            
            elapsed = time.time() - start_time
            
            if msg is None:
                print(f"[{time.strftime('%H:%M:%S')}] poll() returned None (timeout after {elapsed:.2f}s)")
            elif msg.error():
                error_str = str(msg.error())
                print(f"[{time.strftime('%H:%M:%S')}] poll() returned error: {error_str}")
                if "UNKNOWN_TOPIC_OR_PART" in error_str:
                    print("  ⚠ Topic may not be available yet. Make sure the topic exists in the cluster.")
            else:
                print(f"[{time.strftime('%H:%M:%S')}] poll() returned message:")
                print(f"    Topic: {msg.topic()}")
                print(f"    Partition: {msg.partition()}")
                print(f"    Offset: {msg.offset()}")
                print(f"    Value: {msg.value()}")
                print(f"    Elapsed time: {elapsed:.2f}s")
        
        except KeyboardInterrupt:
            elapsed = time.time() - start_time
            print()
            print("=" * 70)
            print("✓ KeyboardInterrupt caught!")
            print(f"  Interrupted after {elapsed:.2f} seconds")
            print(f"  With wakeable pattern, interruption should occur within ~200ms of Ctrl+C")
            if elapsed < 0.5:
                print("  ✓ Fast interruption confirmed!")
            else:
                print(f"  ⚠ Interruption took {elapsed:.2f}s (may indicate wakeable pattern issue)")
            print("=" * 70)
            raise  # Re-raise to exit
    
    except KafkaException as e:
        print(f"Kafka error: {e}")
        sys.exit(1)
    
    except Exception as e:
        print(f"Unexpected error: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)
    
    finally:
        if consumer:
            print()
            print("Closing consumer...")
            consumer.close()
            print("Consumer closed.")


if __name__ == '__main__':
    main()

