In [63]:
from __future__ import annotations
from typing import TYPE_CHECKING
import polars as pl

from common.jupyter import get_nodes, plot_node, get_datasets
from common import polars as ppl
from rich import print
import scipy
if TYPE_CHECKING:
    from nodes.context import Context
    from nodes.node import Node
    from nodes.actions.shift import ShiftParameterValue
    from nodes.metric import DimensionalMetric, DimensionalFlow


In [64]:
from typing import cast
from nodes.actions.action import ActionNode
from nodes.constants import YEAR_COLUMN
from nodes.datasets import DVCDataset

nodes = get_nodes('zuerich')
ctx: Context = nodes.context
act_nodes = cast(list[ActionNode], [
    nodes['fossil_fuel_heater_to_district_heat'],
    nodes['fossil_fuel_heater_to_heat_pumps'],
    nodes['fossil_fuel_heater_to_other'],
])
outcome_node: Node = nodes['building_heat_mix']
ds = DVCDataset('zuerich/building_heat_mix_goals', tags=[])
df = ds.get_copy(ctx)
outcome_cols = outcome_node.compute().paths.to_wide(only_category_names=True).metric_cols
goal_df = df.paths.to_wide(only_category_names=True)
# ensure the columns are in the same order
goal_df = goal_df.select(outcome_cols + [YEAR_COLUMN])


In [66]:
from functools import partial
from typing import Callable
import networkx as nx
from scipy import optimize
import yaml

from nodes.constants import FORECAST_COLUMN
if TYPE_CHECKING:
    from nodes.actions.shift import ShiftAmount


def compute_and_compare(vals: list[float], goal: pl.DataFrame, value_setters: list[Callable]):
    # set parameters
    for val, set_value in zip(vals, value_setters):
        set_value(val)
    df = outcome_node.compute().paths.to_wide(only_category_names=True).drop(FORECAST_COLUMN)
    outcome = df.filter(pl.col('Year') == goal['Year'][0]).drop('Year')
    diff = outcome - goal
    m = [val ** 2 for val in diff.to_numpy()[0]]
    return m


def run(start_year: int, target_year: int):
    target_year_goal = goal_df.filter(pl.col('Year') == target_year)

    # Disable output caching for all nodes in 
    path_nodes = set()

    x0 = []
    value_setters = []
    lower_bounds = []
    upper_bounds = []
    diff_steps = []

    def set_source_value(start: ShiftAmount | None, end: ShiftAmount, new_val: float):
        if start is not None:
            start.source_amount = new_val
        end.source_amount = new_val

    def set_dest_value(start: ShiftAmount | None, end: ShiftAmount, idx: int, new_val: float):
        if start is not None:
            start.dest_amounts[idx] = new_val
        end.dest_amounts[idx] = new_val

    for act in act_nodes:
        all_paths = list(nx.all_simple_paths(ctx.node_graph, source=act.id, target=outcome_node.id))
        assert len(all_paths)
        for path in all_paths:
            path_nodes.update(path)

        param = act.get_parameter('shift')
        val: ShiftParameterValue = param.value.model_copy(deep=True)
        param.value = val

        # remove all the values after our start year
        for entry in val.root:
            entry.amounts = list(sorted([a for a in entry.amounts if a.year <= start_year], key=lambda x: x.year))
            start = entry.amounts[-1]
            if start.year != start_year:
                start = entry.amounts[-1].model_copy()
                start.year = start_year
                entry.amounts.append(start)
            end = start.model_copy(update=dict(year=target_year, deep=True))
            entry.amounts.append(end)

            x0.append(start.source_amount)
            value_setters.append(partial(set_source_value, start, end))
            if start.source_amount < 0:
                lower_bounds.append(-100)
                upper_bounds.append(0)
            elif start.source_amount > 0:
                lower_bounds.append(0)
                upper_bounds.append(100)

            diff_steps.append(0.01)
            if len(start.dest_amounts) == 1:
                continue
            sum_amounts = sum(start.dest_amounts)
            for idx, amount in enumerate(start.dest_amounts):
                x0.append(amount)
                value_setters.append(partial(set_dest_value, start, end, idx))
                lower_bounds.append(0)
                upper_bounds.append(sum_amounts)
                diff_steps.append(0.1 * sum_amounts)

    for node_id in path_nodes:
        ctx.get_node(node_id).disable_cache = True

    df = outcome_node.compute().paths.to_wide(only_category_names=True)\
        .drop(FORECAST_COLUMN).filter(pl.col(YEAR_COLUMN) == target_year)\
        .drop(YEAR_COLUMN)

    print('start:\n%s\nlb:\n%s\nub:\n%s\n' % (x0, lower_bounds, upper_bounds))
    print(target_year_goal)
    with ctx.run():
        res = optimize.least_squares(
            compute_and_compare,
            x0, bounds=[lower_bounds, upper_bounds],
            max_nfev=500,
            kwargs=dict(
                goal=target_year_goal,
                value_setters=value_setters,
            )
        )
        #df = outcome_node.compute().paths.to_wide(only_category_names=True)\
        #    .drop(FORECAST_COLUMN).filter(pl.col(YEAR_COLUMN) == target_year)\
        #    .drop(YEAR_COLUMN)
        #print(df)
        #print(x0)
        print([round(x, 2) for x in res.x])
        for x, set_value in zip(res.x, value_setters):
            set_value(round(float(x), 2))

orig_values = []
for act in act_nodes:
    orig_values.append(act.get_parameter_value('shift'))

last_hist_year: int = list(outcome_node.get_output_pl().filter(~pl.col(FORECAST_COLUMN))[YEAR_COLUMN])[-1]
years = sorted(goal_df[YEAR_COLUMN])
start_year = last_hist_year + 1
try:
    for year in years:
        print(year)
        run(start_year, year)
        start_year = year + 1
finally:
    df = outcome_node.compute().paths.to_wide(only_category_names=True).filter(pl.col[FORECAST_COLUMN])

    for act, orig_val in zip(act_nodes, orig_values):
        val: ShiftParameterValue = act.get_parameter_value('shift')
        act.set_parameter_value('shift', orig_val)
        orig_data = orig_val.model_dump(exclude_none=False)
        new_data = val.model_dump(exclude_none=False)
        print('- id: %s\n%s\n->\n%s\n\n' % (act.id, yaml.safe_dump(orig_data), yaml.safe_dump(new_data)))

#ctx.node_graph