In [0]:
# Bring in your ETL notebook
%run /Workspace/energyconsumption/db_ingestion






In [0]:
%run /Workspace/Repos/databricks/awsproject/db_refinement

In [0]:
#this path is working
%run /Workspace/Users/venkynani701@gmail.com/curated_data


In [0]:
#test_db_ingestion

# Use %run to include the db_ingestion notebook at the top of your test notebook
# %run ./db_ingestion

import pytest
from pyspark.sql import Row

# --------------------------
# Spark fixture
# --------------------------
@pytest.fixture(scope="session")
def spark():
    from pyspark.sql import SparkSession
    spark = (
        SparkSession.builder
        .master("local[2]")
        .appName("pytest-db-ingestion")
        .getOrCreate()
    )
    return spark

# --------------------------
# Test log_step()
# --------------------------
def test_log_step(spark):
    print("\n=== Running test_log_step ===")
    row = Row(
        job_id="12345",
        timestamp="2025-08-25T12:00:00",
        step="TEST_STEP",
        status="INFO",
        message="This is a test log"
    )
    df = spark.createDataFrame([row])
    display(df)

    assert "step" in df.columns
    assert "status" in df.columns
    assert "message" in df.columns
    print("✅ log_step test passed!")

# --------------------------
# Test row count
# --------------------------
def test_row_count(spark):
    print("\n=== Running test_row_count ===")
    sample_data = [
        Row(date="2025-08-01", consumption=10),
        Row(date="2025-08-01", consumption=20),
        Row(date="2025-08-02", consumption=30),
    ]
    df = spark.createDataFrame(sample_data)
    display(df)

    count = df.count()
    print("Row count:", count)

    assert count == 3
    print("✅ row_count test passed!")

# --------------------------
# Test ETL failure handling
# --------------------------
def test_etl_failure():
    print("\n=== Running test_etl_failure ===")
    try:
        raise Exception("S3 not reachable")
    except Exception as e:
        print("Captured Exception:", str(e))
        assert "S3 not reachable" in str(e)
        print("✅ etl_failure test passed!")

In [0]:
# At the top of your test notebook, add:
# %run /path/to/db_refinement

import pytest
from pyspark.sql import Row
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType

# Remove: import db_refinement as job

# --------------------------
# Spark fixture
# --------------------------
@pytest.fixture(scope="session")
def spark():
    # SparkSession is available in Databricks, so you can omit this fixture
    pass

# --------------------------
# Test log_step()
# --------------------------
def test_log_step():
    row = Row(
        job_id="test123",
        timestamp="2025-08-25T12:00:00",
        step="TEST_STEP",
        status="INFO",
        message="Testing log"
    )
    df = spark.createDataFrame([row])
    display(df)

    assert "step" in df.columns
    assert df.collect()[0]["status"] == "INFO"

# --------------------------
# Test numeric casting
# --------------------------
def test_numeric_casting():
    sample_data = [
        Row(Date="1/1/2007", Time="00:00:00", Global_active_power="1.23"),
        Row(Date="1/1/2007", Time="00:01:00", Global_active_power="?"),
    ]
    df = spark.createDataFrame(sample_data)

    df_casted = df.withColumn(
        "Global_active_power",
        F.when(F.col("Global_active_power") == "?", None)
         .otherwise(F.col("Global_active_power"))
         .cast(DoubleType())
    )

    display(df_casted)

    values = [row["Global_active_power"] for row in df_casted.collect()]
    assert values[0] == 1.23
    assert values[1] is None

# --------------------------
# Test timestamp transformation
# --------------------------
def test_timestamp_creation():
    sample_data = [
        Row(Date="1/1/2007", Time="00:00:00"),
        Row(Date="1/1/2007", Time="12:30:45"),
    ]
    df = spark.createDataFrame(sample_data)

    df_ts = df.withColumn(
        "timestamp",
        F.to_timestamp(F.concat_ws(" ", F.col("Date"), F.col("Time")), "d/M/yyyy H:m:s")
    )

    display(df_ts)

    timestamps = [row["timestamp"] for row in df_ts.collect()]
    assert timestamps[0] is not None
    assert str(timestamps[1]).endswith("12:30:45")

# --------------------------
# Test deduplication
# --------------------------
def test_deduplication():
    sample_data = [
        Row(timestamp="2025-08-25 00:00:00", value=10),
        Row(timestamp="2025-08-25 00:00:00", value=10),
        Row(timestamp="2025-08-25 00:01:00", value=20),
    ]
    df = spark.createDataFrame(sample_data)

    df_dedup = df.dropDuplicates(["timestamp"])
    display(df_dedup)

    count = df_dedup.count()
    assert count == 2

In [0]:
# At the top of your test notebook, add:
# %run /path/to/db_curateddata

import pytest
from pyspark.sql import Row
from pyspark.sql import functions as F
from pyspark.sql.window import Window

# Remove: import db_curateddata as job

# --------------------------
# Test log_step()
# --------------------------
def test_log_step():
    row = Row(
        job_id="pytest123",
        timestamp="2025-08-25T12:00:00",
        step="CURATED_TEST",
        status="INFO",
        message="Testing curated log"
    )
    df = spark.createDataFrame([row])
    display(df)

    assert "step" in df.columns
    assert df.collect()[0]["status"] == "INFO"

# --------------------------
# Test daily aggregation
# --------------------------
def test_daily_aggregation():
    sample_data = [
        Row(timestamp="2025-08-25 00:00:00", Global_active_power=1.2, Voltage=220.0,
            Sub_metering_1=10, Sub_metering_2=20, Sub_metering_3=5),
        Row(timestamp="2025-08-25 01:00:00", Global_active_power=0.8, Voltage=221.0,
            Sub_metering_1=15, Sub_metering_2=25, Sub_metering_3=10),
    ]
    df = spark.createDataFrame(sample_data)

    daily_agg_df = df.groupBy(
        F.to_date("timestamp").alias("date")
    ).agg(
        F.sum("Global_active_power").alias("total_active_power_kw"),
        F.avg("Voltage").alias("avg_voltage"),
        F.sum("Sub_metering_1").alias("total_sub_metering_1_wh"),
        F.sum("Sub_metering_2").alias("total_sub_metering_2_wh"),
        F.sum("Sub_metering_3").alias("total_sub_metering_3_wh"),
    )

    display(daily_agg_df)

    result = daily_agg_df.collect()[0]
    assert round(result["total_active_power_kw"], 1) == 2.0
    assert result["total_sub_metering_1_wh"] == 25

# --------------------------
# Test peak hour calculation
# --------------------------
def test_peak_hour():
    sample_data = [
        Row(timestamp="2025-08-25 00:00:00", Global_active_power=1.2),
        Row(timestamp="2025-08-25 01:00:00", Global_active_power=5.0),
        Row(timestamp="2025-08-25 02:00:00", Global_active_power=2.0),
    ]
    df = spark.createDataFrame(sample_data)

    hourly_consumption = df.groupBy(
        F.to_date("timestamp").alias("date"),
        F.hour("timestamp").alias("hour")
    ).agg(
        F.sum("Global_active_power").alias("hourly_active_power_kw")
    )

    window_spec = Window.partitionBy("date").orderBy(F.col("hourly_active_power_kw").desc())
    peak_hour_df = hourly_consumption.withColumn(
        "rank", F.row_number().over(window_spec)
    ).filter(
        F.col("rank") == 1
    ).select(
        "date", F.col("hour").alias("peak_consumption_hour")
    )

    display(peak_hour_df)

    peak_hour = peak_hour_df.collect()[0]["peak_consumption_hour"]
    assert peak_hour == 1  # Hour 01:00 has highest power

# --------------------------
# Test final gold dataframe enrichment
# --------------------------
def test_final_gold_features():
    sample_data = [
        Row(date="2025-08-25", total_active_power_kw=10.5, avg_voltage=220.5,
            total_sub_metering_1_wh=25, total_sub_metering_2_wh=30,
            total_sub_metering_3_wh=40, peak_consumption_hour=12)
    ]
    df = spark.createDataFrame(sample_data)

    enriched_df = df.withColumn(
        "day_name", F.date_format("date", "EEEE")
    ).withColumn(
        "day_of_week", F.dayofweek("date")
    ).withColumn(
        "month", F.month("date")
    )

    display(enriched_df)

    row = enriched_df.collect()[0]
    assert "day_of_week" in row.asDict()
    assert row["month"] == 8