In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, year, month, to_timestamp
import datetime
from pyspark.sql import functions as F
import os
import pandas as pd
from sqlalchemy import create_engine, inspect
import psycopg2
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
from sqlalchemy import inspect
import pandas as pd
import datetime
from pyspark.sql.functions import to_timestamp, col
from pyspark.sql.window import Window

In [0]:
POSTGRES_USER = os.environ["POSTGRES_USER"]
POSTGRES_PASSWORD = os.environ["POSTGRES_PASSWORD"]
POSTGRES_HOST = os.environ["POSTGRES_HOST"]
POSTGRES_PORT = os.environ["POSTGRES_PORT"]
OPENAQ_BUCKET = "openaq-data-archive"
OPENAQ_BUCKET_URI = f"s3://{OPENAQ_BUCKET}"

In [0]:
def create_feature_store_if_not_exists():
    """Create the 'feature-store' database if it doesn't exist"""
    # Get PostgreSQL connection details from environment variables
    pg_host = os.environ.get('POSTGRES_HOST')
    pg_user = os.environ.get('POSTGRES_USER')
    pg_password = os.environ.get('POSTGRES_PASSWORD')
    
    # Connect to the default 'postgres' database first
    conn = psycopg2.connect(
        host=pg_host,
        user=pg_user,
        password=pg_password,
        dbname='postgres',
        port=os.environ.get("POSTGRES_PORT")
    )
    
    # We need to set autocommit to create a database
    conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
    
    cursor = conn.cursor()
    
    # Check if the database exists
    cursor.execute("SELECT 1 FROM pg_database WHERE datname='feature-store'")
    exists = cursor.fetchone()
    
    if not exists:
        cursor.execute("CREATE DATABASE \"feature-store\"")
        print("Database 'feature-store' created successfully")
    else:
        print("Database 'feature-store' already exists")
    
    cursor.close()
    conn.close()

create_feature_store_if_not_exists()

In [0]:
connection_string = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/openaq"

# Create the connection engine
engine = create_engine(connection_string)

# Fetch tables using pandas
countries = pd.read_sql_table("countries", engine)
locations = pd.read_sql_table("locations", engine)


# Close the connection
engine.dispose()

In [0]:
countries["name"] = countries["name"].replace({"Russian Federation": "Russia"})
countries["name"].sort_values().to_list()

In [0]:
connection_string = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/energy"

# Create the connection engine
engine = create_engine(connection_string)

# Create an inspector object
inspector = inspect(engine)

# Get all table names
db_table_names = inspector.get_table_names()


# Close the connection
engine.dispose()

db_table_names

In [0]:
table_names = pd.Series(db_table_names).map(lambda x: x.replace("energy_", ""))
matched_filter = table_names.map(lambda x: x.lower().replace("_", " ")).isin(
    countries["name"].map(lambda x: x.lower().replace("_", " "))
)
matched_filter.value_counts()

In [0]:
table_names = table_names[matched_filter]

In [0]:
def check_and_recreate_table(engine, table_name, df=None):
    """
    Check if a table exists and recreate it with new data.
    
    Args:
        engine: SQLAlchemy engine
        table_name (str): Name of the table
        df (DataFrame): DataFrame containing data to upload
    
    Returns:
        bool: True if successful
    """
    inspector = inspect(engine)
    table_exists = table_name in inspector.get_table_names()
    
    if table_exists:
        # Use parameterized query to prevent SQL injection
        with engine.begin() as conn:
            conn.execute(f"DROP TABLE IF EXISTS {table_name}")
        print(f"Table '{table_name}' dropped.")
    
    # Create the table from DataFrame
    if df is not None:
        print(f"Creating table '{table_name}' from DataFrame.")
        df.to_sql(table_name, engine, index=False, if_exists='replace')
        print(f"Table '{table_name}' created and data uploaded.")
    else:
        raise ValueError("DataFrame must be provided to create the table")
    
    return True

def read_openaq_with_partition_pruning(location_ids, start_year=2024):
    """
    Efficiently read OpenAQ data using Spark's partition pruning.

    Args:
        location_ids (list): List of location IDs to include
        start_year (int): Minimum year to include (default: 2024)

    Returns:
        DataFrame: Filtered OpenAQ data
    """
    # Convert location IDs to strings once
    location_ids_str = [str(id) for id in location_ids]

    # Early return for empty location list
    if not location_ids_str:
        return create_empty_openaq_dataframe()

    # Build location pattern
    location_pattern = "{" + ",".join(location_ids_str) + "}"

    # Build year pattern - get all years from start_year to current
    current_year = datetime.datetime.now().year
    year_pattern = "{" + ",".join(str(y) for y in range(start_year, current_year + 1)) + "}"

    # Construct path with glob patterns
    glob_path = f"{OPENAQ_BUCKET_URI}/records/csv.gz/locationid={location_pattern}/year={year_pattern}/month=*/*.csv.gz"

    print(f"Reading data with glob pattern: {glob_path}")

    try:
        df = (
            spark.read.format("csv")
            .option("header", "true")
            .option("inferSchema", "true")
            .option("compression", "gzip")
            .load(glob_path)
        )

        # Convert datetime string to timestamp
        return df.withColumn("datetime", to_timestamp(col("datetime")))
    except Exception as e:
        print(f"Error reading data: {e}")
        return create_empty_openaq_dataframe()

def create_empty_openaq_dataframe():
    """Create an empty DataFrame with the expected OpenAQ schema."""
    return spark.createDataFrame(
        [],
        [
            "location_id",
            "sensors_id",
            "location",
            "datetime",
            "lat",
            "lon",
            "parameter",
            "units",
            "value",
        ],
    )

def normalize_country_name(country):
    """Normalize country name format."""
    country_words = [word.capitalize() for word in country.split("_")]
    return " ".join(country_words) if len(country_words) > 1 else country_words[0]

def get_share_energy_columns(energy_data):
    """Get relevant energy share columns."""
    return [
        col for col in energy_data.columns
        if "share" in col
        and "demand" not in col
        and "energy" in col
        and "renewables" not in col
        and "fossil" not in col
    ]

def create_feature_store_country_data_by_sensor(country, openaq_countries, openaq_locations, energy_engine, feature_store_engine):
    """
    Create feature store data by sensor for a specific country.
    
    Args:
        country (str): Country name (underscore-separated)
        openaq_countries (DataFrame): OpenAQ countries data
        openaq_locations (DataFrame): OpenAQ locations data
        energy_engine: SQLAlchemy engine for energy data
        feature_store_engine: SQLAlchemy engine for feature store
    """
    # Normalize database and display country names
    db_country_name = "energy_" + country
    normalized_country = normalize_country_name(country)
    
    # Get country and location data
    openaq_country = openaq_countries[openaq_countries["name"].eq(normalized_country)]
    
    # Early return if country not found
    if openaq_country.empty:
        print(f"Country '{normalized_country}' not found in OpenAQ data")
        return
    
    country_locations = openaq_locations.loc[
        openaq_locations["country_id"].isin(openaq_country["id"])
    ]
    
    # Early return if no locations found
    if country_locations.empty:
        print(f"No locations found for country '{normalized_country}'")
        return

    # Read energy data and determine minimum year
    try:
        energy_data = pd.read_sql_table(db_country_name, energy_engine)
        if energy_data.empty:
            print(f"No energy data found for {normalized_country}")
            return
            
        min_year = energy_data["year"].min()
    except Exception as e:
        print(f"Error reading energy data: {e}")
        return

    # Read OpenAQ data with partition pruning
    openaq_data = read_openaq_with_partition_pruning(
        location_ids=country_locations["id"].to_list(), 
        start_year=min_year
    )
    
    # Check if OpenAQ data is empty
    if openaq_data.count() == 0:
        print(f"No OpenAQ data found for {normalized_country}")
        return

    # Select energy columns and convert to Spark DataFrame
    energy_columns = get_share_energy_columns(energy_data) + ["country", "population", "gdp", "per_capita_electricity", "year"]
    energy_data = energy_data[energy_columns]
    energy_data = spark.createDataFrame(energy_data)

    # Cache frequently used DataFrames
    openaq_data.cache()
    energy_data.cache()

    # Extract year from datetime
    openaq_data_with_year = openaq_data.withColumn("year", F.year("datetime"))

    # Define window for getting the latest reading per sensor per year
    window_spec = Window.partitionBy("sensors_id", "year").orderBy(F.desc("datetime"))
    
    # Add row number and filter to keep only the most recent reading
    yearly_last_values = (
        openaq_data_with_year
        .withColumn("row_number", F.row_number().over(window_spec))
        .filter(F.col("row_number") == 1)
        .drop("row_number")
    )

    # Process each sensor
    for sensor_id in [row.sensors_id for row in openaq_data.select("sensors_id").distinct().collect()]:
        # Filter data for this sensor
        sensor_df = yearly_last_values.filter(F.col("sensors_id") == sensor_id)
        
        # Skip if no data for this sensor
        if sensor_df.count() == 0:
            continue
            
        # Get parameter for this sensor
        parameter_row = sensor_df.select("parameter").head(1)
        if not parameter_row:
            continue
            
        parameter = parameter_row[0]["parameter"]
        
        # Join with energy data
        sensor_df = sensor_df.join(energy_data, on="year", how="inner")
        
        # Skip if join resulted in empty DataFrame
        if sensor_df.count() == 0:
            continue
            
        # Create feature table name
        feature_table_name = normalized_country.replace(" ", "_") + "_" + parameter + "_sensor_" + str(sensor_id)
        
        # Convert to Pandas and save to database
        check_and_recreate_table(feature_store_engine, feature_table_name, sensor_df.toPandas())
    
    # Unpersist cached DataFrames
    openaq_data.unpersist()
    energy_data.unpersist()

In [0]:
connection_string = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/energy"
energy_engine = create_engine(connection_string)

connection_string = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/feature-store"
feature_store_engine = create_engine(connection_string)

for country in table_names:
    create_feature_store_country_data_by_sensor(country, countries, locations, energy_engine, feature_store_engine)