In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, year, month, to_timestamp
import datetime
from pyspark.sql import functions as F
import os
import pandas as pd
from sqlalchemy import create_engine, inspect
from pyspark.sql.window import Window
import psycopg2
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT

In [0]:
POSTGRES_USER = os.environ["POSTGRES_USER"]
POSTGRES_PASSWORD = os.environ["POSTGRES_PASSWORD"]
POSTGRES_HOST = os.environ["POSTGRES_HOST"]
POSTGRES_PORT = os.environ["POSTGRES_PORT"]
OPENAQ_BUCKET = "openaq-data-archive"
OPENAQ_BUCKET_URI = f"s3://{OPENAQ_BUCKET}"

In [0]:
def create_feature_store_if_not_exists():
    """Create the 'feature-store' database if it doesn't exist"""
    # Get PostgreSQL connection details from environment variables
    pg_host = os.environ.get('POSTGRES_HOST')
    pg_user = os.environ.get('POSTGRES_USER')
    pg_password = os.environ.get('POSTGRES_PASSWORD')
    
    # Connect to the default 'postgres' database first
    conn = psycopg2.connect(
        host=pg_host,
        user=pg_user,
        password=pg_password,
        dbname='postgres',
        port=os.environ.get("POSTGRES_PORT")
    )
    
    # We need to set autocommit to create a database
    conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
    
    cursor = conn.cursor()
    
    # Check if the database exists
    cursor.execute("SELECT 1 FROM pg_database WHERE datname='feature-store'")
    exists = cursor.fetchone()
    
    if not exists:
        cursor.execute("CREATE DATABASE \"feature-store\"")
        print("Database 'feature-store' created successfully")
    else:
        print("Database 'feature-store' already exists")
    
    cursor.close()
    conn.close()

create_feature_store_if_not_exists()

In [0]:
connection_string = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/openaq"

# Create the connection engine
engine = create_engine(connection_string)

# Fetch tables using pandas
countries = pd.read_sql_table("countries", engine)
locations = pd.read_sql_table("locations", engine)


# Close the connection
engine.dispose()

In [0]:
countries["name"] = countries["name"].replace({"Russian Federation": "Russia"})
countries["name"].sort_values().to_list()

In [0]:
connection_string = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/energy"

# Create the connection engine
engine = create_engine(connection_string)

# Create an inspector object
inspector = inspect(engine)

# Get all table names
db_table_names = inspector.get_table_names()


# Close the connection
engine.dispose()

db_table_names

In [0]:
table_names = pd.Series(db_table_names).map(lambda x: x.replace("energy_", ""))
matched_filter = table_names.map(lambda x: x.lower().replace("_", " ")).isin(
    countries["name"].map(lambda x: x.lower().replace("_", " "))
)
matched_filter.value_counts()

In [0]:
table_names = table_names[matched_filter]

In [0]:
def check_and_recreate_table(engine, table_name, df=None):
    inspector = inspect(engine)
    table_exists = table_name in inspector.get_table_names()
    
    if table_exists:
        print(f"Table '{table_name}' exists. Dropping it.")
        engine.execute(f"DROP TABLE {table_name}")
        print(f"Table '{table_name}' dropped.")
    
    # Create the table
    if df is not None:
        # Create from DataFrame
        print(f"Creating table '{table_name}' from DataFrame.")
        df.to_sql(table_name, engine, index=False, if_exists='replace')
        print(f"Table '{table_name}' created and data uploaded.")
    else:
        raise ValueError("Either df or table_definition must be provided to create the table")
    
    return True

# Improved function for more efficient reading using partition pruning
def read_openaq_with_partition_pruning(location_ids, start_year=2024):
    """
    More efficient version using Spark's built-in partition pruning

    Args:
        location_ids (list): List of location IDs to include
        start_year (int): Minimum year to include (default: 2024)

    Returns:
        DataFrame: Filtered OpenAQ data
    """
    # Convert location IDs to strings
    location_ids_str = [str(id) for id in location_ids]

    # Build location pattern
    location_pattern = "{" + ",".join(location_ids_str) + "}"

    # Build year pattern - get all years from start_year to current
    current_year = datetime.datetime.now().year
    years = range(start_year, current_year + 1)
    year_pattern = "{" + ",".join(str(y) for y in years) + "}"

    # Construct path with glob patterns
    glob_path = f"{OPENAQ_BUCKET_URI}/records/csv.gz/locationid={location_pattern}/year={year_pattern}/month=*/*.csv.gz"

    # Read with glob pattern - Spark will efficiently prune partitions
    print(f"Reading data with glob pattern: {glob_path}")

    try:
        df = (
            spark.read.format("csv")
            .option("header", "true")
            .option("inferSchema", "true")
            .option("compression", "gzip")
            .load(glob_path)
        )

        # Convert datetime string to timestamp
        df = df.withColumn("datetime", to_timestamp(col("datetime")))

        return df
    except Exception as e:
        print(f"Error reading data: {e}")
        # Return empty dataframe with expected schema
        return spark.createDataFrame(
            [],
            [
                "location_id",
                "sensors_id",
                "location",
                "datetime",
                "lat",
                "lon",
                "parameter",
                "units",
                "value",
            ],
        )


def create_feature_store_country_data_by_sensor(country, openaq_countries, openaq_locations, energy_engine, feature_store_engine):
    db_country_name = "energy_" + country
    country_words = list(map(lambda x: x.capitalize(), country.split("_")))

    if len(country_words) > 1:
        country = country_words[0].capitalize() + " " + country_words[1].capitalize()
    else:
        country = country_words[0].capitalize()

    openaq_country = openaq_countries[openaq_countries["name"].eq(country)]
    country_locations = openaq_locations.loc[
        openaq_locations["country_id"].isin(
            openaq_country["id"]
        )
    ]

    energy_data = pd.read_sql_table(db_country_name, engine)
    min_year = energy_data["year"].min()

    openaq_data = read_openaq_with_partition_pruning(
        location_ids=country_locations["id"].to_list(), start_year=min_year
    )

    share_data = [
        col
        for col in energy_data.columns
        if "share" in col
        and "demand" not in col
        and "energy" in col
        and "renewables" not in col
        and "fossil" not in col
    ]

    energy_columns = share_data + ["country", "population", "gdp", "per_capita_electricity", "year"]
    energy_data = energy_data[energy_columns]
    energy_data = spark.createDataFrame(energy_data)

    # First, extract the year as a separate column
    openaq_data_with_year = openaq_data.withColumn("year", F.year("datetime"))

    # Group by sensor_id and year, getting the last value
    # Create a window partitioned by sensor_id and year, ordered by datetime descending
    window_spec = Window.partitionBy("sensors_id", "year").orderBy(F.desc("datetime"))

    # Add row number within each partition
    with_row_number = openaq_data_with_year.withColumn("row_number", F.row_number().over(window_spec))

    # Filter to keep only the last row for each sensor_id and year combination
    yearly_last_values = with_row_number.filter(F.col("row_number") == 1).drop("row_number")

    # To get separate tables for each sensor_id, you can use this approach:
    # Get unique sensor IDs
    sensor_ids = [row.sensors_id for row in openaq_data.select("sensors_id").distinct().collect()]

    # Create a dictionary of dataframes, one for each sensor_id
    sensor_dataframes = {}
    for sensor_id in sensor_ids:
        sensor_df = yearly_last_values.filter(F.col("sensors_id") == sensor_id)
        sensor_df = sensor_df.join(energy_data, on="year", how="inner")
        df_head = sensor_df.select("parameter").head(1)
        if len(df_head) != 0:
            sensor_dataframes[sensor_id] = sensor_df
            parameter = df_head[0]["parameter"]

            feature_table_name = country.replace(" ", "_") + "_" + parameter + "_sensor_" + sensor_id

            check_and_recreate_table(feature_store_engine, feature_table_name, sensor_df.toPandas())

In [0]:
connection_string = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/energy"
energy_engine = create_engine(connection_string)

connection_string = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/feature-store"
feature_store_engine = create_engine(connection_string)

for country in table_names:
    create_feature_store_country_data_by_sensor(country, countries, locations, energy_engine, feature_store_engine)