In [0]:
%pip install pyspark-data-sources[all] 

In [0]:
"""
OpenSky Network Data Source for Apache Spark - Global Version

Fetches aircraft data globally without region restrictions.
Supports both batch and streaming operations.
"""

import requests
import time
from datetime import datetime, timezone
from typing import Dict, List, Tuple, Any, Optional, Iterator
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry

from pyspark.sql.datasource import DataSource, DataSourceReader, SimpleDataSourceStreamReader
from pyspark.sql.types import *

DS_NAME = "opensky"


class OpenSkyAPIError(Exception):
    """Base exception for OpenSky API errors"""
    pass


class RateLimitError(OpenSkyAPIError):
    """Raised when API rate limit is exceeded"""
    pass


class OpenSkyStreamReader(SimpleDataSourceStreamReader):
    MIN_REQUEST_INTERVAL = 5.0  # seconds between requests
    ANONYMOUS_RATE_LIMIT = 100  # calls per day
    AUTHENTICATED_RATE_LIMIT = 4000  # calls per day
    MAX_RETRIES = 3
    RETRY_BACKOFF = 2
    RETRY_STATUS_CODES = [429, 500, 502, 503, 504]

    def __init__(self, schema: StructType, options: Dict[str, str]):
        super().__init__()
        self.schema = schema
        self.options = options
        self.session = self._create_session()
        self.last_request_time = 0

        self.client_id = options.get("client_id")
        self.client_secret = options.get("client_secret")
        self.access_token = None
        self.token_expires_at = 0

        if self.client_id and self.client_secret:
            self._get_access_token()  # OAuth2 authentication
            self.rate_limit = self.AUTHENTICATED_RATE_LIMIT
        else:
            self.rate_limit = self.ANONYMOUS_RATE_LIMIT

    def _get_access_token(self):
        """Get OAuth2 access token using client credentials flow"""
        current_time = time.time()
        if self.access_token and current_time < self.token_expires_at:
            return  # Token still valid

        token_url = "https://auth.opensky-network.org/auth/realms/opensky-network/protocol/openid-connect/token"
        data = {
            "grant_type": "client_credentials",
            "client_id": self.client_id,
            "client_secret": self.client_secret,
        }

        try:
            response = requests.post(token_url, data=data, timeout=10)
            response.raise_for_status()
            token_data = response.json()

            self.access_token = token_data["access_token"]
            expires_in = token_data.get("expires_in", 1800)
            self.token_expires_at = current_time + expires_in - 300

        except requests.exceptions.RequestException as e:
            raise OpenSkyAPIError(f"Failed to get access token: {str(e)}")

    def _create_session(self) -> requests.Session:
        """Create and configure requests session with retry logic"""
        session = requests.Session()
        retry_strategy = Retry(
            total=self.MAX_RETRIES,
            backoff_factor=self.RETRY_BACKOFF,
            status_forcelist=self.RETRY_STATUS_CODES,
        )
        adapter = HTTPAdapter(max_retries=retry_strategy)
        session.mount("https://", adapter)
        session.mount("http://", adapter)
        return session

    def initialOffset(self) -> Dict[str, int]:
        return {"last_fetch": 0}

    def _handle_rate_limit(self):
        """Ensure at least MIN_REQUEST_INTERVAL seconds between requests"""
        current_time = time.time()
        time_since_last_request = current_time - self.last_request_time

        if time_since_last_request < self.MIN_REQUEST_INTERVAL:
            sleep_time = self.MIN_REQUEST_INTERVAL - time_since_last_request
            time.sleep(sleep_time)

        self.last_request_time = time.time()

    def _fetch_states(self) -> requests.Response:
        """Fetch all states from OpenSky API globally with error handling"""
        self._handle_rate_limit()

        if self.client_id and self.client_secret:
            self._get_access_token()

        headers = {}
        if self.access_token:
            headers["Authorization"] = f"Bearer {self.access_token}"

        try:
            # No params means global coverage
            response = self.session.get(
                "https://opensky-network.org/api/states/all",
                headers=headers,
                timeout=10,
            )

            if response.status_code == 429:
                raise RateLimitError("API rate limit exceeded")
            response.raise_for_status()

            return response

        except requests.exceptions.RequestException as e:
            error_msg = f"API request failed: {str(e)}"
            if isinstance(e, requests.exceptions.Timeout):
                error_msg = "API request timed out"
            elif isinstance(e, requests.exceptions.ConnectionError):
                error_msg = "Connection error occurred"
            raise OpenSkyAPIError(error_msg) from e

    def valid_state(self, state: List) -> bool:
        """Validate state data"""
        if not state or len(state) < 17:
            return False

        return (
            state[0] is not None  # icao24
            and state[5] is not None  # longitude
            and state[6] is not None
        )  # latitude

    def parse_state(self, state: List, timestamp: int) -> Tuple:
        """Parse state data with safe type conversion"""

        def safe_float(value: Any) -> Optional[float]:
            try:
                return float(value) if value is not None else None
            except (ValueError, TypeError):
                return None

        def safe_int(value: Any) -> Optional[int]:
            try:
                return int(value) if value is not None else None
            except (ValueError, TypeError):
                return None

        def safe_bool(value: Any) -> Optional[bool]:
            return bool(value) if value is not None else None

        return (
            datetime.fromtimestamp(timestamp, tz=timezone.utc),
            state[0],  # icao24
            state[1],  # callsign
            state[2],  # origin_country
            datetime.fromtimestamp(state[3], tz=timezone.utc),
            datetime.fromtimestamp(state[4], tz=timezone.utc),
            safe_float(state[5]),  # longitude
            safe_float(state[6]),  # latitude
            safe_float(state[7]),  # geo_altitude
            safe_bool(state[8]),  # on_ground
            safe_float(state[9]),  # velocity
            safe_float(state[10]),  # true_track
            safe_float(state[11]),  # vertical_rate
            state[12],  # sensors
            safe_float(state[13]),  # baro_altitude
            state[14],  # squawk
            safe_bool(state[15]),  # spi
            safe_int(state[16]),  # category
        )

    def readBetweenOffsets(self, start: Dict[str, int], end: Dict[str, int]) -> Iterator[Tuple]:
        data, _ = self.read(start)
        return iter(data)

    def read(self, start: Dict[str, int]) -> Tuple[List[Tuple], Dict[str, int]]:
        """Read states with error handling and backoff"""
        try:
            response = self._fetch_states()
            data = response.json()

            valid_states = [
                self.parse_state(s, data["time"])
                for s in data.get("states", [])
                if self.valid_state(s)
            ]

            return (valid_states, {"last_fetch": data.get("time", int(time.time()))})

        except OpenSkyAPIError as e:
            print(f"OpenSky API Error: {str(e)}")
            return ([], start)
        except Exception as e:
            print(f"Unexpected error: {str(e)}")
            return ([], start)


class OpenSkyBatchReader(DataSourceReader):
    """Batch reader for one-time snapshot of aircraft data"""
    
    def __init__(self, schema: StructType, options: Dict[str, str]):
        self.schema = schema
        self.options = options
        # Reuse the stream reader's fetch logic
        self.stream_reader = OpenSkyStreamReader(schema, options)
    
    def read(self, partition) -> Iterator[Tuple]:
        """Read a single batch of data"""
        data, _ = self.stream_reader.read({"last_fetch": 0})
        return iter(data)


class OpenSkyDataSource(DataSource):
    """
    Apache Spark DataSource for streaming and batch reading of real-time aircraft 
    tracking data from OpenSky Network globally.

    Supports both streaming (readStream) and batch (read) operations.
    Fetches aircraft data from all regions worldwide.

    Examples
    --------
    Streaming usage:
    >>> df = spark.readStream.format("opensky").load()
    >>> query = df.writeStream.format("console").start()

    Batch usage (single snapshot):
    >>> df = spark.read.format("opensky").load()
    >>> df.show()

    With authentication for higher rate limits:
    >>> df = spark.read.format("opensky") \\
    ...     .option("client_id", "your_client_id") \\
    ...     .option("client_secret", "your_secret") \\
    ...     .load()
    """

    def __init__(self, options: Dict[str, str] = None):
        super().__init__(options or {})
        self.options = options or {}

        if "client_id" in self.options and not self.options.get("client_secret"):
            raise ValueError("client_secret must be provided when client_id is set")

    @classmethod
    def name(cls) -> str:
        return DS_NAME

    def schema(self) -> StructType:
        return StructType(
            [
                StructField("time_ingest", TimestampType()),
                StructField("icao24", StringType()),
                StructField("callsign", StringType()),
                StructField("origin_country", StringType()),
                StructField("time_position", TimestampType()),
                StructField("last_contact", TimestampType()),
                StructField("longitude", DoubleType()),
                StructField("latitude", DoubleType()),
                StructField("geo_altitude", DoubleType()),
                StructField("on_ground", BooleanType()),
                StructField("velocity", DoubleType()),
                StructField("true_track", DoubleType()),
                StructField("vertical_rate", DoubleType()),
                StructField("sensors", ArrayType(IntegerType())),
                StructField("baro_altitude", DoubleType()),
                StructField("squawk", StringType()),
                StructField("spi", BooleanType()),
                StructField("category", IntegerType()),
            ]
        )

    def reader(self, schema: StructType) -> "DataSourceReader":
        """Returns a batch reader for one-time data snapshots"""
        return OpenSkyBatchReader(schema, self.options)

    def simpleStreamReader(self, schema: StructType) -> "SimpleDataSourceStreamReader":
        """Returns a streaming reader for continuous data ingestion"""
        return OpenSkyStreamReader(schema, self.options)

In [0]:
from pyspark import pipelines as dp
spark.dataSource.register(OpenSkyDataSource)


@dp.table(name="geospatial.opensky.flight_states_all")
def df_stream() -> Any:
    return spark.readStream.format("opensky").load()