In [2]:
import threading
import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Callable, Dict, Optional, Tuple
from urllib.parse import urlparse

import urllib3
from urllib3.connectionpool import HTTPConnectionPool, HTTPSConnectionPool


class RateLimitExceeded(Exception):
    """Raised when rate limit is exceeded."""

    pass


@dataclass
class RateLimitConfig:
    rate: float  # tokens per second
    capacity: int  # maximum tokens
    retry_after: Optional[float] = None  # seconds to wait before retry
    key_func: Callable = lambda req: "default"  # function to extract rate limit key


class TokenBucket:
    """Thread-safe token bucket implementation with backoff support."""

    def __init__(self, rate: float, capacity: int):
        self.rate = float(rate)
        self.capacity = float(capacity)
        self.tokens = float(capacity)
        self.timestamp = time.monotonic()
        self.lock = threading.Lock()

    def consume(self, tokens: float = 1.0) -> Tuple[bool, float]:
        """
        Attempt to consume tokens from the bucket.
        Returns (success, wait_time) tuple.
        """
        with self.lock:
            now = time.monotonic()
            elapsed = now - self.timestamp
            self.timestamp = now

            # Add new tokens based on elapsed time
            self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)

            if self.tokens >= tokens:
                self.tokens -= tokens
                return True, 0.0

            # Calculate wait time if rate limited
            needed = tokens - self.tokens
            wait_time = needed / self.rate
            return False, wait_time


class RateLimiter:
    """Rate limiter with configurable strategies per endpoint."""

    def __init__(self):
        self.configs: Dict[str, RateLimitConfig] = {}
        self.buckets: Dict[Tuple[str, str], TokenBucket] = {}
        self.lock = threading.Lock()

    def configure(self, pattern: str, config: RateLimitConfig):
        """Configure rate limiting for a specific URL pattern."""
        self.configs[pattern] = config

    def get_bucket(self, pattern: str, key: str) -> TokenBucket:
        """Get or create a token bucket for the given pattern and key."""
        bucket_key = (pattern, key)
        with self.lock:
            if bucket_key not in self.buckets:
                config = self.configs[pattern]
                self.buckets[bucket_key] = TokenBucket(config.rate, config.capacity)
            return self.buckets[bucket_key]

    def _get_matching_config(self, url: str) -> Optional[Tuple[str, RateLimitConfig]]:
        """Find the matching configuration for a URL."""
        parsed = urlparse(url)
        path = parsed.path

        for pattern, config in self.configs.items():
            # Simple pattern matching - could be enhanced with regex
            if pattern in path:
                return pattern, config
        return None

    def check_rate_limit(self, method: str, url: str, **kwargs) -> None:
        """Check if a request would exceed the rate limit."""
        config_match = self._get_matching_config(url)
        if not config_match:
            return

        pattern, config = config_match
        key = config.key_func(kwargs)
        bucket = self.get_bucket(pattern, key)

        success, wait_time = bucket.consume()
        if not success:
            if config.retry_after is not None and wait_time <= config.retry_after:
                time.sleep(wait_time)
            else:
                raise RateLimitExceeded(
                    f"Rate limit exceeded for {method} {url}. "
                    f"Try again in {wait_time:.2f} seconds."
                )


class RateLimitedPoolManager(urllib3.PoolManager):
    """PoolManager subclass with rate limiting support."""

    def __init__(self, *args, rate_limiter: Optional[RateLimiter] = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.rate_limiter = rate_limiter or RateLimiter()

    def urlopen(self, method: str, url: str, **kwargs):
        print("FOO", method, url)
        if self.rate_limiter:
            self.rate_limiter.check_rate_limit(method, url, **kwargs)
        return super().urlopen(method, url, **kwargs)


@contextmanager
def rate_limited_urllib3(configs: Dict[str, RateLimitConfig]):
    """
    Context manager for applying rate limiting to urllib3.

    Example:
        config = RateLimitConfig(rate=1, capacity=5)
        with rate_limited_urllib3({'/api/': config}):
            # Make your requests here
            http = urllib3.PoolManager()
            response = http.request('GET', 'https://api.example.com/api/data')
    """
    # Store original classes
    original_pool_manager = urllib3.PoolManager

    # Create rate limiter with configs
    rate_limiter = RateLimiter()
    for pattern, config in configs.items():
        rate_limiter.configure(pattern, config)

    # Create new pool manager class with rate limiting
    class PatchedPoolManager(RateLimitedPoolManager):
        def __init__(self, *args, **kwargs):
            kwargs["rate_limiter"] = rate_limiter
            super().__init__(*args, **kwargs)

    # Apply the patch
    urllib3.PoolManager = PatchedPoolManager

    try:
        yield rate_limiter
    finally:
        # Restore original classes
        urllib3.PoolManager = original_pool_manager

In [6]:
import openai

client = openai.Client()

configs = {
    "/api/": RateLimitConfig(
        rate=1.0,  # 1 request per second
        capacity=5,  # burst of 5 requests allowed
        retry_after=5.0,  # wait up to 5 seconds before failing
        key_func=lambda kwargs: kwargs.get("headers", {}).get(
            "Authorization", "default"
        ),
    )
}

# Use the rate limiter in a context
with rate_limited_urllib3(configs):
    http = urllib3.PoolManager()

    # Make some requests
    try:
        for _ in range(7):
            client.chat.completions.create(
                messages=[{"role": "user", "content": "hi"}], model="gpt-4o-mini"
            )

    except RateLimitExceeded as e:
        print(f"Rate limit exceeded: {e}")