In [0]:
%pip install polars pyomo

In [0]:
dbutils.library.restartPython()

In [0]:
from src.storage_model import StorageModel

In [0]:
# import polars as pl
# import plotly.express as px
# from datetime import date, timedelta
from pyspark.sql import functions as F, Window as W, types as T, DataFrame

from src import data, spark
# from src.mtm.scenario_price import generate_mtm_scenario_prices
# from src.earnings import (
#     update_price_model_deal_profiles,
#     update_deal_settlement_details_table,
#     update_retail_rate_calendar_table,
#     settlement_calculations as calcs
# )
from src.earnings.deal_info import deal_settlement_details

# model_id = "20250617"

In [0]:
spot_price = (
    spark.table("prod.silver_mms.tradingprice")
    .withColumn(
        "interval_date",
        F.when(
            (F.minute("settlementdate") == 0) &
            (F.hour("settlementdate") == 0),
            F.date_sub(F.to_date("settlementdate"), 1)
        )
        .otherwise(F.to_date("settlementdate"))
    )
    .filter(F.col("regionid") == "SA1")
    .filter(F.col("interval_date").between("2025-06-10", "2025-06-16"))
    .select(
        F.col("settlementdate").alias("interval_date_time"),
        F.col("rrp").cast("float")
    )
    .toPandas()
)

In [0]:
import os

In [0]:
exit_code = os.system("which cbc > /dev/null")

In [0]:
model = StorageModel(111, 2.6, 0.81)

In [0]:
result = model.solve(spot_price)

In [0]:
result

In [0]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Create subplots with secondary y-axis
fig = make_subplots(specs=[[{"secondary_y": True}]])

# Add traces for the main y-axis (left side)
fig.add_trace(
    go.Scatter(
        x=result['interval_date_time'], 
        y=result['charge_mw'],
        mode='lines',
        name='Charge (MW)',
        line=dict(color='blue')
    ),
    secondary_y=False,
)

fig.add_trace(
    go.Scatter(
        x=result['interval_date_time'], 
        y=result['discharge_mw'],
        mode='lines',
        name='Discharge (MW)',
        line=dict(color='red')
    ),
    secondary_y=False,
)

fig.add_trace(
    go.Scatter(
        x=result['interval_date_time'], 
        y=result['effective_storage_level_mwh'],
        mode='lines',
        name='Storage Level (MWh)',
        line=dict(color='green')
    ),
    secondary_y=False,
)

# Add trace for the secondary y-axis (right side)
fig.add_trace(
    go.Scatter(
        x=result['interval_date_time'], 
        y=result['rrp'],
        mode='lines',
        name='RRP',
        line=dict(color='orange', dash='dash')
    ),
    secondary_y=True,
)

# Set x-axis title
fig.update_xaxes(title_text="Date/Time")

# Set y-axes titles
fig.update_yaxes(title_text="Power (MW) / Storage (MWh)", secondary_y=False)
fig.update_yaxes(title_text="RRP", secondary_y=True)

# Update layout
fig.update_layout(
    title="Energy Storage System Performance vs RRP",
    hovermode='x unified',
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="right",
        x=1
    )
)

# Show the plot
fig.show()

In [0]:
import pyomo.environ as pyo

In [0]:
import pandas as pd
from functools import reduce

In [0]:
import plotly.express as px

In [0]:
px.line(
    pd.merge()
)

In [0]:
data.sim.storage_details().display()

In [0]:
data.sim.deal_factors().filter(F.col("deal_id") == 2996).display()

In [0]:
import polars as pl

from pyspark.sql import functions as F
from src.earnings import settlement_calculations as calcs
from src import data, spark


# MTM_EARNINGS_TABLE = "exploration.earnings_forecast.daily_mtm_scenario_earnings"

# def calculate_earnings_forecast() -> None:
"""
Performs earnings calcluations on the MtM earnings scenario writes out the result.
"""
# Deal details
deal_settlement_details = pl.from_arrow(data.earnings.deal_settlement_details().toArrow())

# Price data
spot_prices = pl.from_arrow(data.earnings.daily_mtm_scenario_prices().toArrow())
lgc_prices = pl.from_arrow(
    data.market_price.lgc_prices()
    .join(
        data.date_calendar,
        F.col("date").between(F.col("period_start"), F.col("period_end"))
    )
    .select(
        F.col("date").alias("interval_date"),
        F.col("price").alias("lgc_price")
    )
    .toArrow()
)

# Generation data
generation_profiles = pl.from_arrow(data.earnings.daily_mtm_scenario_generation_profiles().toArrow())
generation_profiles = apply_generation_turndown(
    turndowns=(
        deal_settlement_details
        .filter(pl.col("turndown").is_not_null())
        .select(
            "product_id",
            "region_number",
            "start_date",
            "end_date",
            "turndown"
        )
    ),
    generation_profiles=generation_profiles,
    spot_prices=spot_prices
)

# Retail data
load_profiles = pl.from_arrow(data.earnings.daily_mtm_scenario_load_profiles().toArrow())
rate_calendar = pl.from_arrow(data.earnings.retail_rate_calendar().toArrow())

# Wholesale product profiles
product_profiles = pl.from_arrow(data.earnings.daily_mtm_scenario_product_profiles().toArrow())

In [0]:
(
    data.sim.complete_deal_info()
    .filter(F.col("product_name").ilike("%temp%"))
    .display()
)

In [0]:
deal_settlement_details = pl.from_arrow(data.earnings.deal_settlement_details().toArrow())
spot_prices = pl.from_arrow(data.earnings.daily_mtm_scenario_prices().toArrow())
lgc_prices = pl.from_arrow(
    data.market_price.lgc_prices()
    .join(
        data.date_calendar,
        F.col("date").between(F.col("period_start"), F.col("period_end"))
    )
    .select(
        F.col("date").alias("interval_date"),
        F.col("price").alias("lgc_price")
    )
    .toArrow()
)

generation_profiles = pl.from_arrow(data.earnings.daily_mtm_scenario_generation_profiles().toArrow())
generation_profiles = apply_generation_turndown(
    turndowns=(
        deal_settlement_details
        .filter(pl.col("turndown").is_not_null())
        .select(
            "product_id",
            "region_number",
            "start_date",
            "end_date",
            "turndown"
        )
    ),
    generation_profiles=generation_profiles,
    spot_prices=spot_prices
)

load_profiles = pl.from_arrow(data.earnings.daily_mtm_scenario_load_profiles().toArrow())
rate_calendar = pl.from_arrow(data.earnings.retail_rate_calendar().toArrow())


product_profiles = pl.from_arrow(data.earnings.daily_mtm_scenario_product_profiles().toArrow())

In [0]:
earnings_df = calcs.ppa_energy(
    deal_info=(
        deal_settlement_details
        .filter(
            pl.col("instrument") == "ppa_energy"
        )
        .select(
            "deal_id",
            "product_id",
            "instrument_id",
            "start_date",
            "end_date",
            "region_number",
            "floor",
            "turndown",
            "price",
            pl.col("quantity").alias("quantity_factor")   
        )
    ),
    spot_prices=spot_prices,
    generation_profiles=generation_profiles
)

In [0]:
earnings_df = pl.concat([
    earnings_df,
    calcs.asset_toll_energy(
        deal_info=(
            deal_settlement_details
            .filter(pl.col("instrument") == "asset_toll_energy")
            .select(
                "deal_id",
                "product_id",
                "instrument_id",
                "start_date",
                "end_date",
                "region_number",
                "tolling_fee",
                pl.col("quantity").alias("quantity_factor")
            )
        ),
        spot_prices=spot_prices,
        generation_profiles=generation_profiles
    )
])

In [0]:
earnings_df = pl.concat([
    earnings_df,
    calcs.generation_lgc(
        deal_info=(
            deal_settlement_details
            .filter(pl.col("instrument") == "generation_lgc")
            .select(
                "deal_id",
                "product_id",
                "instrument_id",
                "start_date",
                "end_date",
                "region_number",
                pl.col("lgc_price").alias("price"),
                pl.col("quantity").alias("quantity_factor"),
                "lgc_percentage"
            )
        ),
        generation_profiles=generation_profiles,
        lgc_prices=lgc_prices
    )
])

In [0]:
earnings_df = pl.concat([
    earnings_df,
    calcs.retail_energy(
        deal_info=(
            deal_settlement_details
            .filter(pl.col("instrument") == "retail_energy")
            .select(
                "deal_id",
                "product_id",
                "instrument_id",
                "start_date",
                "end_date"
            )
        ),
        load_profiles=load_profiles,
        spot_prices=spot_prices,
        rate_calendar=rate_calendar
    )
])


In [0]:
earnings_df = pl.concat([
    earnings_df,
    calcs.retail_lgc(
        deal_info=(
            deal_settlement_details
            .filter(pl.col("instrument") == "retail_lgc")
            .select(
                "deal_id",
                "product_id",
                "instrument_id",
                "start_date",
                "end_date",
                "lgc_percentage",
                pl.col("lgc_price").alias("price")
            )
        ),
        load_profiles=load_profiles,
        lgc_prices=lgc_prices
    )
])


In [0]:
earnings_df = pl.concat([
    earnings_df,
    calcs.flat_energy_swap(
        deal_info=(
            deal_settlement_details
            .filter(pl.col("instrument") == "flat_energy_swap")
            .select(
                "deal_id",
                "product_id",
                "instrument_id",
                "start_date",
                "end_date",
                "region_number",
                pl.col("quantity").alias("quantity_mw"),
                "price"
            )
        ),
        spot_prices=spot_prices
    )
])


In [0]:
earnings_df = pl.concat([
    earnings_df,
    calcs.flat_energy_cap(
        deal_info=(
            deal_settlement_details
            .filter(pl.col("instrument") == "flat_energy_cap")
            .select(
                "deal_id",
                "product_id",
                "instrument_id",
                "start_date",
                "end_date",
                "region_number",
                pl.col("quantity").alias("quantity_mw"),
                "price",
                "strike"
        )
        ),
        spot_prices=spot_prices
    )
])


In [0]:
earnings_df = pl.concat([
    earnings_df,    
    calcs.profiled_energy_swap(
        deal_info=(
            deal_settlement_details
            .filter(pl.col("instrument") == "profiled_energy_swap")
            .select(
                "deal_id",
                "product_id",
                "instrument_id",
                "start_date",
                "end_date",
                "region_number",
                pl.col("quantity").alias("quantity_mw"),
                "price"
            )
        ),
        product_profiles=product_profiles,
        spot_prices=spot_prices
    )
])


In [0]:
earnings_df = pl.concat([
    earnings_df,
    calcs.profiled_energy_cap(
        deal_info=(
            deal_settlement_details
            .filter(pl.col("instrument") == "profiled_energy_cap")
            .select(
                "deal_id",
                "product_id",
                "instrument_id",
                "start_date",
                "end_date",
                "region_number",
                pl.col("quantity").alias("quantity_mw"),
                "price"
            )
        ),
        product_profiles=product_profiles,
        spot_prices=spot_prices
    )
])


In [0]:
(
    earnings_df
    .join(
        deal_settlement_details
        .select(
            "deal_id",
            "product_id",
            "buy"
        )
        .unique(),
        ["deal_id", "product_id"]
    )
    .select(
        "deal_id",
        "product_id",
        "instrument_id",
        "region_number",
        "buy",
        "interval_date",
        "period_id",
        "volume_mwh",
        pl.when(pl.col("buy"))
        .then(pl.col("buy_income"))
        .otherwise(pl.col("sell_income"))
        .alias("income"),
        pl.when(pl.col("buy"))
        .then(pl.col("sell_income"))
        .otherwise(pl.col("buy_income"))
        .alias("cost")
    )
    .filter(pl.col("volume_mwh").is_null() | pl.col("income").is_null() | pl.col("cost").is_null())
)

In [0]:
(
    deal_settlement_details
    .filter(pl.col("instrument").is_in(["generation_lgc", "retail_lgc"]))
)

In [0]:
(
    spark.createDataFrame(
        earnings_df
        .join(
            deal_settlement_details
            .select(
                "deal_id",
                "product_id",
                "buy"
            )
            .unique(),
            ["deal_id", "product_id"]
        )
        .select(
            "deal_id",
            "product_id",
            "instrument_id",
            "region_number",
            "buy",
            "interval_date",
            "period_id",
            "volume_mwh",
            pl.when(pl.col("buy"))
            .then(pl.col("buy_income"))
            .otherwise(pl.col("sell_income"))
            .alias("income"),
            pl.when(pl.col("buy"))
            .then(pl.col("sell_income"))
            .otherwise(pl.col("buy_income"))
            .alias("cost")
        )
        .to_arrow()
    )
    .write.mode("overwrite")
    .saveAsTable("exploration.earnings_forecast.daily_mtm_scenario_earnings")
)

In [0]:
spark.table("exploration.earnings_forecast.daily_mtm_scenario_earnings").display()

In [0]:
%sql
USE CATALOG exploration;
USE SCHEMA earnings_forecast;
CREATE OR REPLACE TABLE daily_mtm_scenario_earnings (
  product_id SMALLINT NOT NULL,
  deal_id SMALLINT NOT NULL,
  instrument_id SMALLINT NOT NULL,
  region_number TINYINT NOT NULL,
  buy BOOLEAN NOT NULL,
  interval_date DATE NOT NULL,
  period_id SMALLINT NOT NULL,
  volume_mwh FLOAT NOT NULL,
  income FLOAT NOT NULL,
  cost FLOAT NOT NULL,
  CONSTRAINT pk_mtm_earnings  PRIMARY KEY (product_id, deal_id, instrument_id, region_number, interval_date, period_id),
  CONSTRAINT fk_mtm_earnings_region FOREIGN KEY (region_number) REFERENCES scenario_modelling.region_numbers(region_number),
  CONSTRAINT fk_mtm_earnings_instrument FOREIGN KEY (instrument_id) REFERENCES instruments(instrument_id)
);

In [0]:
spark.table("exploration.earnings_forecast.retail_rate_calendar").groupBy("product_id", "jurisdiction_id").agg((F.count("rate")/288).alias("days")).orderBy(F.col("days").desc(), "product_id", "jurisdiction_id").display()

In [0]:
spark.table("exploration.earnings_forecast.deal_settlement_details").display()

In [0]:
deal_info = deal_settlement_details().cache()
deal_info.display()

In [0]:
complete_deal_info = data.sim.complete_deal_info().cache()
ppa_details = data.sim.ppa_details().cache()
deal_factors = data.sim.deal_factors().cache()
product_profiles = data.sim.product_profiles().cache()
date_calendar = (
    data.date_calendar.cache()
    .filter(F.col("date").between(F.lit(date.today()), F.lit(date.today() + timedelta(days=365*5))))
)
lgc_rates = data.sim.lgc_rates()

In [0]:
wholesale_deal_info = (
    date_calendar
    .select(F.col("date").alias("interval_date"))
    .join(
        complete_deal_info
        .join(data.region_numbers(), "regionid")
        .filter(
            F.col("instrument_calculation").isin(["Swap", "Cap"]) &
            ~F.col("product_name").like("%ACCU%") & # Is there a way to positively identify energy contracts instead of exluding non-energy?
            (F.col("optionality") == "Firm")
            # To be implemented:
            #   - Options 3? (No options currently past Q2 25(?), and earnings methodology not specified)
            #   - Quantos 2? (Wind quanto til Aug 25, but not properly captured in Sim)
            #   - Floor 1 (Easy)
        )
        .alias("deal_info"),
        F.col("interval_date").between(F.col("product_start_date"), F.col("product_end_date"))
    )
    .join(
        deal_factors
        .filter(F.col("type") == "R")
        .withColumnRenamed("factor", "escalation_factor")
        .alias("escalation_factors"),
        (F.col("deal_info.deal_id") == F.col("escalation_factors.deal_id")) &
        F.col("interval_date").between(
            F.col("escalation_factors.from_date"),
            F.col("escalation_factors.to_date")
        ),
        "left"
    )
    .fillna({"escalation_factor": 1.0})
    .groupBy(
        "product_id",
        "instrument_calculation",
        "deal_info.deal_id",
        "buy_sell",
        "region_number",
        "quantity",
        "price",
        "strike_price"
    )
    .agg(
        F.min("interval_date").alias("start_date"),
        F.max("interval_date").alias("end_date")
    )
    .join(
        product_profiles
        .select(
            "product_id",
            F.col("profile_mw").isNotNull().alias("shape_product")
        )
        .distinct(),
        "product_id",
        "left"
    )
    .fillna({"shape_product": False})
    .select(
        "deal_id",
        "product_id",
        F.when(
            F.col("shape_product") &
            (F.col("instrument_calculation") == "Swap"),
            F.lit("profiled_swap")
        ).when(
            F.col("shape_product") &
            (F.col("instrument_calculation") == "Cap"),
            F.lit("profiled_cap")
        ).when(
            (~F.col("shape_product")) &
            (F.col("instrument_calculation") == "Swap"),
            F.lit("flat_swap")
        ).when(
            (~F.col("shape_product")) &
            (F.col("instrument_calculation") == "Cap"),
            F.lit("flat_cap")
        ).alias("deal_type"),
        "buy_sell",
        "region_number",
        F.col("start_date"),
        F.col("end_date"),
        F.col("quantity").cast("double"),
        F.col("price").cast("double"),
        F.col("strike_price").cast("double").alias("strike"),
        F.lit(None).cast("double").alias("tolling_fee"),
        F.lit(None).cast("double").alias("floor"),
        F.lit(None).cast("double").alias("turndown"),
        F.lit(None).cast("double").alias("lgc_price"),
        F.lit(None).cast("double").alias("lgc_percentage"),
        F.col("shape_product")
    )
    .orderBy("deal_id", "product_id", "start_date", "end_date")
)

In [0]:
wholesale_deal_info.display()

In [0]:
complete_deal_info.filter(F.col("deal_id") == 3032).display()

In [0]:
spark.table("prod.bronze.etrm_sim_products").printSchema()

In [0]:
(
    product_profiles
    .select(
        "product_id",
        F.col("profile_mw").isNotNull().alias("shape_product")
    )
    .distinct()
    .join(
        spark.table("prod.bronze.etrm_sim_products")
        .select(
            F.col("id").alias("product_id"),
            F.col("name"),
            F.col("schedule")
        ),
        "product_id"
    )
    .display()
)

In [0]:
retail_deal_info = (
    date_calendar
    .select(F.col("date").alias("interval_date"))
    .join(
        complete_deal_info
        .filter(F.col("instrument") == "Retail Customer"),
        F.col("interval_date").between(F.col("product_start_date"), F.col("product_end_date"))
    )
    .join(data.region_numbers(), "regionid")
    .join(lgc_rates, ["product_id", "interval_date"])
    .groupBy(
        "deal_id",
        "product_id",
        "buy_sell",
        "region_number",
        "lgc_price",
        "lgc_percentage"
    )
    .agg(
        F.min("interval_date").alias("start_date"),
        F.max("interval_date").alias("end_date")
    )
    .select(
        "deal_id",
        "product_id",
        F.lit("retail").alias("deal_type"),
        "buy_sell",
        "region_number",
        F.col("start_date"),
        F.col("end_date"),
        F.lit(None).cast("double").alias("quantity"),
        F.lit(None).cast("double").alias("price"),
        F.lit(None).cast("double").alias("strike"),
        F.lit(None).cast("double").alias("tolling_fee"),
        F.lit(None).cast("double").alias("floor"),
        F.lit(None).cast("double").alias("turndown"),
        "lgc_price",
        "lgc_percentage"
    )
    .orderBy("deal_id", "product_id", "start_date", "end_date")
)

In [0]:
retail_deal_info.display()

In [0]:
# def ppa_deal_info():
ppa_deal_info = (
    date_calendar
    .select(F.col("date").alias("interval_date"))
    .join(
        ppa_details.alias("ppa_details"),
        (F.col("interval_date").between(F.col("start_date"), F.col("end_date")))
    )
    .join(
        complete_deal_info.select("product_id", "deal_id", "price"),
        ["product_id", "deal_id"]
    )
    .join(
        deal_factors
        .filter(F.col("type") == "R")
        .withColumnRenamed("factor", "escalation_factor")
        .alias("escalation_factors"),
        (F.col("ppa_details.deal_id") == F.col("escalation_factors.deal_id")) &
        F.col("interval_date").between(
            F.col("escalation_factors.from_date"),
            F.col("escalation_factors.to_date")
        ),
        "left"
    )
    .join(
        deal_factors
        .filter(F.col("type") == "M")
        .withColumn(
            "days_in_period",
            F.date_diff(F.col("to_date"), F.col("from_date")) + 1
        )
        .select(
            "deal_id",
            "from_date",
            "to_date",
            (F.col("factor") / F.col("days_in_period")).alias("tolling_fee")
        )
        .alias("tolling_fees"),
        (F.col("ppa_details.deal_id") == F.col("tolling_fees.deal_id")) &
        F.col("interval_date").between(
            F.col("tolling_fees.from_date"),
            F.col("tolling_fees.to_date")
        ),
        "left"
    )
    .join(
        lgc_rates,
        ["product_id", "interval_date"],
        "left"
    )
    .fillna({"escalation_factor": 1})
    .groupBy(
        "ppa_details.deal_id",
        "ppa_details.product_id",
        "region_number",
        (F.col("price")*F.col("escalation_factor")).cast("double").alias("price"),
        F.col("quantity_factor").alias("quantity"),
        "floor",
        "turndown",
        "lgc_price",
        "lgc_percentage",
        "tolling_fee"
    )
    .agg(
        F.min("interval_date").alias("start_date"),
        F.max("interval_date").alias("end_date")
    )
    .select(
        "deal_id",
        "product_id",
        F.when(
            F.col("tolling_fee").isNotNull(),
            F.lit("asset_toll")
        ).otherwise(F.lit("ppa")).alias("settlement_calculation"),
        "region_number",
        "start_date",
        "end_date",
        "quantity",
        "price",
        F.lit(None).cast("double").alias("strike"),
        "tolling_fee",
        "floor",
        "turndown",
        F.when(
            F.col("tolling_fee").isNotNull(),
            F.lit(0) # Hack for TB2, would be better to have this reflected in the source table
        ).otherwise(F.col("lgc_price")).alias("lgc_price"),
        "lgc_percentage"
    )
    .orderBy("deal_id", "product_id", "start_date", "end_date")
)

In [0]:
ppa_deal_info.display()

In [0]:
generate_mtm_scenario_prices(model_id)

In [0]:
%sql
USE CATALOG exploration;
USE SCHEMA earnings_forecast;
CREATE OR REPLACE TABLE scenario_generation_profiles (
  model_id SMALLINT NOT NULL,
  product_id INTEGER NOT NULL,
  year SMALLINT NOT NULL,
  month_id SMALLINT NOT NULL,
  day_id SMALLINT NOT NULL,
  period_id SMALLINT NOT NULL,
  generation_mwh FLOAT NOT NULL,
  CONSTRAINT pk_generation_profile PRIMARY KEY (product_id, year, month_id, day_id, period_id),
  CONSTRAINT fk_generation_model FOREIGN KEY (model_id) REFERENCES scenario_modelling.price_models(model_id)
);

In [0]:
update_price_model_deal_profiles(model_id)

In [0]:
price_sample = pl.from_arrow(spark.table("exploration.earnings_forecast.daily_mtm_scenario_prices").toArrow())

(
    price_sample
    .group_by([
        pl.col("interval_date").dt.year().alias("year"),
        pl.col("interval_date").dt.quarter().alias("quarter"),
        pl.col("region_number")
    ])
    .agg(
        pl.col("rrp").mean().alias("mean_rrp"),
        pl.when(pl.col("rrp") > 300)
        .then(pl.col("rrp") - 300)
        .otherwise(0).mean().alias("cap_payout")
    )
    .join(
        pl.from_arrow(
            data.market_price.quarterly_futures_prices()
            .join(
                data.price_model_scenario.region_numbers(),
                "regionid"
            )
            .select(
                F.year("period_start").alias("year"),
                F.quarter("period_start").alias("quarter"),
                "region_number",
                "regionid",
                F.col("cap_boq").alias("futures_cap"),
                F.col("swap_boq").alias("futures_swap")
            )
            .toArrow()
        ),
        ["year", "quarter", "region_number"]
    )
    .select(
        pl.col("year"),
        pl.col("quarter"),
        pl.col("regionid"),
        pl.col("mean_rrp").cast(pl.Float64()).round(2),
        pl.col("futures_swap").cast(pl.Float64()).round(2),
        pl.col("cap_payout").cast(pl.Float64()).round(2),
        pl.col("futures_cap").cast(pl.Float64()).round(2)
    )
    .sort("year", "quarter")
)