In [1]:
from __future__ import annotations
from typing import TYPE_CHECKING
import polars as pl

from rich import print

# If the imports below fail and you're using VSCode, make the following change
# in your user JSON preferences:
#     "jupyter.notebookFileRoot": "${workspaceFolder}",

from common.jupyter import get_nodes, plot_node, get_datasets
from common import polars as ppl
import scipy
if TYPE_CHECKING:
    from nodes.context import Context
    from nodes.node import Node
    from nodes.actions.shift import ShiftParameterValue
    from nodes.metric import DimensionalMetric, DimensionalFlow


In [2]:
from typing import cast
from nodes.actions.action import ActionNode
from nodes.constants import YEAR_COLUMN
from nodes.datasets import DVCDataset

nodes = get_nodes('zuerich')
ctx: Context = nodes.context
act_nodes = cast(list[ActionNode], [
    nodes['fossil_fuel_heater_to_district_heat'],
    nodes['fossil_fuel_heater_to_heat_pumps'],
    nodes['fossil_fuel_heater_to_other'],
])
outcome_node: Node = nodes['building_heat_mix']
ds = DVCDataset('zuerich/building_heat_mix_goals', tags=[])
df = ds.get_copy(ctx)
outcome_cols = outcome_node.compute().paths.to_wide(only_category_names=True).metric_cols
goal_df = df.paths.to_wide(only_category_names=True)
# ensure the columns are in the same order
goal_df = goal_df.select(outcome_cols + [YEAR_COLUMN])


In [3]:
from functools import partial
from typing import Any, Callable, Concatenate, Self, Tuple, TypeAlias
import networkx as nx
from pydantic import BaseModel
from scipy import optimize
import yaml

from nodes.constants import FORECAST_COLUMN
from nodes.simple import MixNode
if TYPE_CHECKING:
    from nodes.actions.shift import ShiftAmount
    from params import Parameter


ParamSetter: TypeAlias = Callable[['OptimizeParameter', float], None]


class OptimizeParameter:
    def __init__(self, action: ActionNode, param: Parameter):
        self.action = action
        self.param = param
        self.original_value = param.value
        self.value = param.value
        if isinstance(self.value, BaseModel):
            self.value = self.value.model_copy(deep=True)
        self.reset()

    def restore(self):
        self.param.value = self.original_value

    def reset(self):
        self.x0: list[float] = []
        self.bounds: list[Tuple[float, float]] = []
        self.xstep: list[float] = []
        self.value_set_ctx: list[Any] = []
        self.value_setters: list[Callable] = []
        self.ids: list[str] = []

    def set_source_value(self, start: ShiftAmount | None, end: ShiftAmount, new_val: float):
        if start is not None:
            start.source_amount = new_val
        end.source_amount = new_val

    def set_dest_value(self, start: ShiftAmount | None, end: ShiftAmount, idx: int, new_val: float):
        if start is not None:
            start.dest_amounts[idx] = new_val
        end.dest_amounts[idx] = new_val

    def configure_for_shift(self, start_year: int, end_year: int):
        value = self.value
        assert isinstance(value, ShiftParameterValue)

        self.reset()
        #x0: list[float] = []
        #bounds: list[Tuple[float, float]] = []
        #step: list[float] = []

        # remove all the values after our start year
        for entry in value.root:
            entry.amounts = list(sorted([a for a in entry.amounts if a.year <= start_year], key=lambda x: x.year))
            start = entry.amounts[-1]
            if start.year != start_year:
                start = entry.amounts[-1].model_copy()
                start.year = start_year
                entry.amounts.append(start)
            end = start.model_copy(update=dict(year=end_year, deep=True))
            entry.amounts.append(end)

            self.x0.append(start.source_amount)
            setter = partial(self.set_source_value, start, end)
            self.value_setters.append(setter)
            if start.source_amount < 0:
                self.bounds.append((-100, 0))
            elif start.source_amount > 0:
                self.bounds.append((0, 100))
            self.ids.append('source')

            self.xstep.append(0.01)
            if len(start.dest_amounts) == 1:
                continue
            sum_amounts = sum(start.dest_amounts)
            for idx, amount in enumerate(start.dest_amounts):
                self.x0.append(amount)
                self.value_setters.append(partial(self.set_dest_value, start, end, idx))
                self.bounds.append((0, sum_amounts))
                self.xstep.append(0.1 * sum_amounts)
                self.ids.append('dest-%d' % idx)


class OptimizeParameterSet:
    def __init__(self):
        self.params: list[OptimizeParameter] = []

    def add(self, optp: OptimizeParameter):
        self.params.append(optp)

    def restore(self):
        for param in self.params:
            param.restore()

    @property
    def x0(self) -> list[float]:
        x0 = []
        for param in self.params:
            x0 += param.x0
        return x0

    @property
    def bounds(self) -> Tuple[list[float], list[float]]:
        lower = []
        upper = []
        for param in self.params:
            lower += [bounds[0] for bounds in param.bounds]
            upper += [bounds[1] for bounds in param.bounds]
        return (lower, upper)

    def print(self):
        from rich.table import Table
        from rich.console import Console

        table = Table()
        for col in ('Action', 'x0', 'bounds', 'step', 'x'):
            table.add_column(col)
        for param in self.params:
            for x0 in zip(param.x0, param.bounds, param.xstep)
            table.add_row

def compute_and_compare(vals: list[float], goal: pl.DataFrame, value_setters: list[Callable]):
    # set parameters
    for val, set_value in zip(vals, value_setters):
        set_value(val)
    df = outcome_node.compute().paths.to_wide(only_category_names=True).drop(FORECAST_COLUMN)
    outcome = df.filter(pl.col('Year') == goal['Year'][0]).drop('Year')
    diff = outcome - goal
    m = [val ** 2 for val in diff.to_numpy()[0]]
    return m


def run(start_year: int, target_year: int):
    target_year_goal = goal_df.filter(pl.col('Year') == target_year)

    # Disable output caching for all nodes in 
    path_nodes = set()

    params = OptimizeParameterSet()
    for act in act_nodes:
        param = act.get_parameter('shift')
        params.add(OptimizeParameter(act, param))

    for node_id in path_nodes:
        ctx.get_node(node_id).disable_cache = True

    params.__repr__
    print(target_year_goal)
    with ctx.run():
        res = optimize.least_squares(
            compute_and_compare,
            x0, bounds=[lower_bounds, upper_bounds],
            max_nfev=500,
            kwargs=dict(
                goal=target_year_goal,
                value_setters=value_setters,
            )
        )
        #df = outcome_node.compute().paths.to_wide(only_category_names=True)\
        #    .drop(FORECAST_COLUMN).filter(pl.col(YEAR_COLUMN) == target_year)\
        #    .drop(YEAR_COLUMN)
        #print(df)
        #print(x0)
        print([round(x, 2) for x in res.x])
        for x, set_value in zip(res.x, value_setters):
            set_value(round(float(x), 2))


def doit():
    path_nodes = set()
    orig_values = []
    for act in act_nodes:
        all_paths = list(nx.all_simple_paths(ctx.node_graph, source=act.id, target=outcome_node.id))
        assert len(all_paths)
        for path in all_paths:
            path_nodes.update(path)
        orig_values.append(act.get_parameter_value('shift'))

    for node_id in list(path_nodes):
        ctx.nodes[node_id].disable_cache = True
    if isinstance(outcome_node, MixNode):
        outcome_node.skip_normalize = True

    last_hist_year: int = list(outcome_node.get_output_pl().filter(~pl.col(FORECAST_COLUMN))[YEAR_COLUMN])[-1]
    years = sorted(goal_df[YEAR_COLUMN])
    start_year = last_hist_year + 1
    try:
        for year in years:
            print(start_year, year)
            continue
            run(start_year, year)
            start_year = year + 1
    finally:
        df = outcome_node.compute().paths.to_wide(only_category_names=True).filter(pl.col(FORECAST_COLUMN))
        print(goal_df)
        print(df)
        for act, orig_val in zip(act_nodes, orig_values):
            val: ShiftParameterValue = act.get_parameter_value('shift')
            act.parameters['shift'] = orig_val
            orig_data = orig_val.model_dump(exclude_none=False)
            new_data = val.model_dump(exclude_none=False)
            print('- id: %s\n%s\n->\n%s\n\n' % (act.id, yaml.safe_dump(orig_data), yaml.safe_dump(new_data)))

doit()
#ctx.node_graph