In [None]:
from databricks.sdk import WorkspaceClient
w = WorkspaceClient()

w.dbutils.library.restartPython()

In [None]:
from databricks.connect import DatabricksSession
from pathlib import Path
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window as W
from datetime import datetime, date
import pandas as pd
import polars as pl
import sys

spark = DatabricksSession.builder.getOrCreate()

root_dir = Path.cwd().parent
if str(root_dir) not in sys.path:
    sys.path.append(str(root_dir))

ppa_details_path = root_dir / "utils" / "ppa_details.csv"
ppa_details_df = pd.read_csv(ppa_details_path)
ppa_details = spark.createDataFrame(ppa_details_df)

In [None]:
from utils.extract import extract_deal_info

sim_complete_deal_info = extract_deal_info()

# Aurora prices

In [None]:
scenario_prices = (
    spark.table("exploration.denise_ng.latest_forward_price_trace")
    .filter(F.col("Scenario").isin(["Central", "Messy"]))
    .select(
        F.col("Scenario").alias("scenario"),
        F.concat(F.col("region"), F.lit(1)).alias("regionid"),
        F.col("DateTime").alias("interval_date_time"),
        F.col("Interval_date").alias("interval_date"),
        F.col("Nominal_$").alias("rrp")
    )
).cache()

# Retail volume

In [None]:
jurisdiction_region_id = spark.createDataFrame(pd.DataFrame({
    "jurisdiction": ["SA", "VIC", "NSW", "ACT", "QLD"],
    "regionid": ["SA1", "VIC1", "NSW1", "NSW1", "QLD1"]
}))

retail_load_volume = (
    spark.table("exploration.khai_chang.stlf_backcast_FY17")
    .withColumn("month_id", F.month("interval_datetime"))
    .withColumn("day_id", F.dayofmonth("interval_datetime"))
    .withColumn("period_id", F.hour("interval_datetime")*2 + F.minute("interval_datetime")/30)
    .join(
        scenario_prices
        .select("interval_date_time")
        .distinct()
        .select(
            "interval_date_time",
            F.month("interval_date_time").alias("month_id"),
            F.dayofmonth("interval_date_time").alias("day_id"),
            (F.hour("interval_date_time")*2 + F.minute("interval_date_time")/30).alias("period_id")
        ),
        ["month_id", "day_id", "period_id"]
    )
    .select(
        "interval_date_time",
        "product_id",
        "jurisdiction",
        (F.col("quantity")/1_000).alias("volume_mwh")
    )
).join(
    jurisdiction_region_id,
    "jurisdiction"
).alias("retail_volume")


# PPA volumes

In [None]:
from utils.aurora_functions import extract_rez_cf, prep_datetime_columns

rez_cf = extract_rez_cf(["Solar", "Wind"], ppa_details_df)
duid_cf = prep_datetime_columns(spark.table("dev.silver.aurora_duid_cf"), duid=True)

In [None]:
ppa_details = ppa_details.select(
    "deal_id",
    "product_id",
    "duid",
    "rez",
    "fuel",
    "size_mw",
    "quantity_factor",
    "floor",
    "turndown",
    F.to_date("product_start_date", format="d/M/y").alias("product_start_date"),
    F.to_date("product_end_date", format="d/M/y").alias("product_end_date"),
).join(
    scenario_prices.select("interval_date_time").distinct(),
    F.col("interval_date_time").between(
        F.col("product_start_date"), F.col("product_end_date")
    ),
)

ppa_volumes = (
    prep_datetime_columns(ppa_details)
    .join(duid_cf, ["month_id", "day_id", "period_id", "duid"], "left")
    .join(rez_cf, ["month_id", "day_id", "period_id", "rez", "fuel"], "left")
    .withColumn(
        "volume_mwh",
        F.col("size_mw")
        * F.col("quantity_factor")
        * F.coalesce(F.col("CF"), F.col("rez_cf"))
        / F.lit(2),
    )
    .select(
        "interval_date_time", "deal_id", "product_id", "floor", "turndown", "volume_mwh"
    )
)
ppa_volumes

In [None]:
sim_rate = (
    spark.table("prod.bronze.etrm_sim_profiles")
    .filter(
        (F.col("code") == "R") &
        (F.col("interval") == 5)
    )
    .withColumn(
        "todate",
        F.lead(
            F.date_add(F.col("fromdate"), -1),
            1,
            date(9999, 12, 13)
        ).over(
            W.partitionBy("productid", "code", "daytype", "period")
            .orderBy("fromdate")
        )
    )
    .select(
        F.col("productid").alias("product_id"),
        F.col("fromdate").alias("from_date"),
        F.col("todate").alias("to_date"),
        F.col("daytype").alias("day_type"),
        F.col("period").alias("period_id"),
        F.col("value").alias("rate")
    )
)

In [None]:
calendar = (
    spark.table("prod.silver.calendar_datetime")
    .crossJoin(jurisdiction_region_id.select("jurisdiction").distinct())
    .join(
        spark.table("prod.silver.calendar_public_holiday").select(
            "date", "jurisdiction", "holiday_name"
        ),
        ["date", "jurisdiction"],
        "left",
    )
    .select(
        F.col("date").alias("interval_date"),
        F.col("jurisdiction"),
        F.col("datetime").alias("interval_date_time"),
        F.col("period").alias("period_id"),
        F.col("day_of_week"),
        F.col("holiday_name").isNotNull().alias("public_holiday"),
    )
    .withColumn(
        "day_type",
        F.when(F.col("public_holiday") | (F.col("day_of_week") == 7), F.lit("S"))
        .when(F.col("day_of_week") == 6, F.lit("6"))
        .otherwise(F.lit("W")),
    )
    .distinct()
)

rate_calendar = (
    calendar.alias("calendar")
    .join(
        sim_rate.alias("rate"),
        F.col("interval_date").between(F.col("from_date"), F.col("to_date"))
        & (F.col("rate.period_id") == F.col("calendar.period_id"))
        & (F.col("rate.day_type") == F.col("calendar.day_type")),
    )
    .select(
        "product_id",
        "jurisdiction",
        "interval_date_time",
        "interval_date",
        "calendar.period_id",
        "rate",
    )
).cache()

# Retail earnings

In [None]:
retail_earnings = (
    sim_complete_deal_info
    .withColumn("product_start_date", F.to_date("product_start_date", format="d/M/y"))
    .withColumn("product_end_date", F.to_date("product_end_date", format="d/M/y"))
    .alias("deal_info")
    .join(
        retail_load_volume,
        "product_id"
    )
    .join(
        scenario_prices
        .alias("scenario_prices"),
        (
            (F.col("scenario_prices.regionid") == F.col("retail_volume.regionid")) &
            (F.col("scenario_prices.interval_date_time") == F.col("retail_volume.interval_date_time")) &
            F.col("scenario_prices.interval_date_time").between(F.col("product_start_date"), F.col("product_end_date"))
        )
    )
    .join(
        rate_calendar,
        ["product_id", "jurisdiction", "interval_date_time"]
    )
    .select(
        F.col("scenario_prices.scenario"),
        F.col("retail_volume.product_id"),
        F.col("product_name"),
        F.col("deal_info.deal_id"),
        F.col("deal_info.deal_name"),
        F.col("deal_info.status"),
        F.col("deal_info.deal_date"),
        F.col("deal_info.strategy"),
        F.col("deal_info.region").alias("deal_region"),
        F.col("retail_volume.regionid").alias("load_regionid"),
        F.col("scenario_prices.interval_date"),
        F.col("period_id"),
        F.col("buy_sell"),
        F.col("retail_volume.interval_date_time"),
        F.col("retail_volume.volume_mwh"),
        F.col("rate"),
        F.col("scenario_prices.rrp")
    )
    .withColumn("income_amount", F.col("volume_mwh") * F.col("rate"))
    .withColumn("cost_amount", F.col("volume_mwh") * F.col("rrp"))
    .withColumn("earnings", F.col("income_amount") - F.col("cost_amount"))
).cache()


# PPA earnings

In [None]:
deal_factors = (
    spark.table("prod.bronze.etrm_sim_deal_factors")
    .select(
        F.col("dealid").alias("deal_id"),
        "type",
        F.col("fromdate").cast("date").alias("from_date"),
        F.lead(
            F.date_sub(F.col("fromdate"), 1),
            1,
            "9999-12-31"
        )
        .over(
            W.partitionBy("dealid", "type").orderBy("fromdate")
        ).alias("to_date"),
        "factor"
    )
).alias("deal_factors")


In [None]:
ppa_earnings = (
    sim_complete_deal_info.alias("deal_info")
    .join(
        ppa_volumes,
        ["deal_id", "product_id"]
    )
    .join(
        scenario_prices,
        ["interval_date_time", "regionid"]
    )
    .join(
        deal_factors
        .filter(F.col("type") == "R")
        .withColumnRenamed("factor", "escalation_factor")
        .alias("escalation_factors"),
        (F.col("deal_info.deal_id") == F.col("escalation_factors.deal_id")) &
        F.col("interval_date").between(
            F.col("escalation_factors.from_date"),
            F.col("escalation_factors.to_date")
        ),
        "left"
    )
    .join(
        deal_factors
        .filter(F.col("type") == "M")
        .withColumn(
            "days_in_period",
            F.date_diff(F.col("to_date"), F.col("from_date")) + 1
        )
        .select(
            "deal_id",
            "from_date",
            "to_date",
            (F.col("factor") / F.col("days_in_period") / 48).alias("tolling_fee")
        )
        .alias("tolling_fees"),
        (F.col("deal_info.deal_id") == F.col("tolling_fees.deal_id")) &
        F.col("interval_date").between(
            F.col("tolling_fees.from_date"),
            F.col("tolling_fees.to_date")
        ),
        "left"
    )
    .fillna({"escalation_factor": 1})
    .select(
        "scenario",
        "product_id",
        "product_name",
        "deal_info.deal_id",
        "deal_name",
        "status",
        "deal_date",
        "strategy",
        "regionid",
        "interval_date_time",
        "interval_date",
        "buy_sell",
        "rrp",
        "floor",
        "tolling_fee",
        (F.col("price") * F.col("escalation_factor")).alias("ppa_price"),
        "volume_mwh"
    )
    .withColumn(
        "income_amount",
        F.when(
            (F.col("floor").isNotNull()) & (F.col("rrp") < F.col("floor")), 
            F.lit(0)
        )
        .otherwise(
            F.col("volume_mwh") * F.col("rrp")
        )
    )
    .withColumn(
        "cost_amount",
        F.when(
            F.col("tolling_fee").isNotNull(), 
            F.col("tolling_fee")
        )
        .otherwise(
            F.col("ppa_price") * F.col("volume_mwh")
        )
    )
    .withColumn("earnings", F.col("income_amount") - F.col("cost_amount"))
).cache()


# Wholesale volume? + earnings

In [None]:
from utils.aurora_functions import calculate_amount

wholesale_profiles = (
    spark.table("exploration.denise_ng.sim_daytypes_profile")
    .select(
        F.col("PRODUCT_ID").alias("product_id"),
        F.col("SCHEDULE").alias("schedule"),
        F.col("PERIOD_30MIN").alias("period_id"),
        F.col("INTERVAL_DATE_TIME").alias("interval_date_time"),
        F.col("INTERVAL_DATE").alias("interval_date"),
        F.col("PROFILE_MW").alias("profile_mw"),
        F.col("PROFILE_STRIKE").alias("profile_strike"),
    )
    .groupBy("interval_date", "period_id", "product_id", "schedule")
    .agg(
        F.max("interval_date_time").alias("interval_date_time"),
        F.mean("profile_mw").alias("profile_mw"),
        F.mean("profile_strike").alias("profile_strike"),
    )
    .drop("interval_date")
)

wholesale_earnings = (
    sim_complete_deal_info.alias("deal_info")
    .filter(
        F.col("transfer_date").isNull()
        & ~F.col("book").isin("Novated", "Counterparty Default")
        &
        # ~F.col("status").isin("Cancelled", "Hypothetical") &
        (
            F.col("instrument_calculation").isin("Swap", "Cap")
            | (
                (F.col("instrument_calculation") == "Floor")
                & (F.col("product_schedule") == "Flat")
            )
        )
        & ~(
            F.col("counterparty").ilike("zen %")
            | F.col("counterparty").ilike("sunshot%")
        )
        & ~F.col("product_name").like("%ACCU%")
        & (F.col("optionality") == "Firm")
    )
    .join(scenario_prices, "regionid")
    .filter(
        F.col("interval_date").between(
            F.col("product_start_date"), F.col("product_end_date")
        )
    )
    .join(
        deal_factors.filter(F.col("type") == "R")
        .withColumnRenamed("factor", "escalation_factor")
        .alias("escalation_factors"),
        (F.col("deal_info.deal_id") == F.col("escalation_factors.deal_id"))
        & F.col("interval_date").between(
            F.col("escalation_factors.from_date"), F.col("escalation_factors.to_date")
        ),
        "left",
    )
    .join(
        wholesale_profiles.alias("profiles"),
        ["product_id", "interval_date_time"],
        "left",
    )
    .withColumn("shaped_product", F.col("profiles.interval_date_time").isNotNull())
    .withColumn(
        "strike_price",
        F.when(F.col("shaped_product"), F.col("profile_strike")).otherwise(
            F.col("strike_price")
        ),
    )
    .fillna(1, ["profile_mw", "escalation_factor"])
    .withColumn("volume_mwh", F.col("quantity") * F.col("profile_mw") / 2)
    .select(
        "scenario",
        "product_id",
        "deal_info.deal_id",
        "deal_name",
        "product_name",
        "status",
        "deal_date",
        "strategy",
        "regionid",
        "interval_date_time",
        "interval_date",
        "shaped_product",
        "buy_sell",
        (F.col("price") * F.col("escalation_factor")).alias("contract_price"),
        "instrument_calculation",
        "rrp",
        "volume_mwh",
        "strike_price",
    )
)

wholesale_earnings = calculate_amount(wholesale_earnings, amount_col_name="cost_amount")
wholesale_earnings = calculate_amount(wholesale_earnings, amount_col_name="income_amount")
wholesale_earnings = wholesale_earnings.withColumn("earnings", F.col("income_amount") - F.col("cost_amount"))

# Templers

In [None]:
from datetime import date

# Reference dates
pre_refi_cutoff = date(2026, 1, 1)
post_refi_start = date(2026, 1, 1)
escalation_start = date(2031, 1, 1)

pre_refi_product_name = "Pre-Refi $27/MWh"
post_refi_product_name = "Post-Refi $23.5/MWh"
post_refi_esc_product_name = "Post-Refi $22.5/MWh + 2.5%"


templers_earnings = (
    spark.table("exploration.denise_ng.bess_allcurves_20250314")
    .withColumnRenamed("settlementdate", "interval_date_time")
    .filter(F.col("scenario").isin(["central", "messy"]))
    .join(
        scenario_prices.select("interval_date_time", "interval_date").distinct(),
        "interval_date_time",
    )
    .withColumn("scenario", F.initcap(F.col("scenario")))
    .withColumn(
        "cost_adjustment_ratio",
        F.when(F.col("interval_date") < F.lit(pre_refi_cutoff), 27.0 / 21.0)
        .when(
            F.col("interval_date").between(post_refi_start, escalation_start),
            23.5 / 21.0,
        )
        .otherwise((22.5 * (1.025 ** (F.year("interval_date") - 2031))) / 21.0),
    )
    .withColumn(
        "COST_CONTRACT_ENERGY",
        F.col("COST_CONTRACT_ENERGY") * F.col("cost_adjustment_ratio"),
    )
    .withColumn(
        "product_name",
        F.when(
            F.col("interval_date") < F.lit(pre_refi_cutoff),
            F.lit(pre_refi_product_name),
        )
        .when(
            F.col("interval_date").between(post_refi_start, escalation_start),
            F.lit(post_refi_product_name),
        )
        .otherwise(F.lit(post_refi_esc_product_name)),
    )
    .select(
        F.col("scenario"),
        F.lit("BESS").alias("group"),
        F.lit(888_888).alias("product_id"),
        F.col("product_name"),
        F.lit(888_888).alias("deal_id"),
        F.lit("Templers BESS Energy").alias("deal_name"),
        F.lit("Confirmed").alias("status"),
        F.lit(date(2025, 3, 13)).alias("deal_date"),
        F.lit("Value Solar").alias("strategy"),
        F.lit("SA1").alias("regionid"),
        F.col("interval_date_time"),
        F.col("interval_date"),
        F.lit("Buy").alias("buy_sell"),
        (F.col("GEN_METERED") - F.col("LOAD_METERED")).alias("volume_mwh"),
        F.col("rrp"),
        F.col("COST_CONTRACT_ENERGY").alias("cost_amount"),
        (F.col("REVENUE_ENERGY_METERED") - F.col("COST_ENERGY_METERED")).alias(
            "income_amount"
        ),
        (
            F.col("REVENUE_ENERGY_METERED")
            - F.col("COST_ENERGY_METERED")
            - F.col("COST_CONTRACT_ENERGY")
        ).alias("earnings"),
    )
)

display(templers_earnings)

# Output

In [None]:
(
    retail_earnings
    .select(
        "scenario",
        F.lit("retail").alias("group"),
        "product_id",
        "product_name",
        "deal_id",
        "deal_name",
        "status",
        "deal_date",
        "strategy",
        F.col("load_regionid").alias("regionid"),
        "interval_date_time",
        "interval_date",
        "buy_sell",
        F.col("volume_mwh"),
        "rrp",
        "cost_amount",
        "income_amount",
        "earnings"
    )
    .unionByName(
        ppa_earnings
        .select(
            "scenario",
            F.lit("ppa").alias("group"),
            "product_id",
            "product_name",
            "deal_info.deal_id",
            "deal_name",
            "status",
            "deal_date",
            "strategy",
            "regionid",
            "interval_date_time",
            "interval_date",
            "buy_sell",
            F.col("volume_mwh"),
            "rrp",
            "cost_amount",
            "income_amount",
            "earnings"
        )
    )
    .unionByName(
        wholesale_earnings
        .select(
            "scenario",
            F.lit("wholesale").alias("group"),
            "product_id",
            "product_name",
            "deal_info.deal_id",
            "deal_name",
            "status",
            "deal_date",
            "strategy",
            "regionid",
            "interval_date_time",
            "interval_date",
            "buy_sell",
            F.col("volume_mwh"),
            "rrp",
            "cost_amount",
            "income_amount",
            "earnings"
        )
    )
    .unionByName(
        templers_earnings
    )
    .write.mode("overwrite").saveAsTable("exploration.victor_goh.earnings_aurora_scenarios")
)
