In [7]:
from __future__ import annotations
from typing import TYPE_CHECKING
import polars as pl

from rich import print
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedSeq, CommentedMap

# If the imports below fail and you're using VSCode, make the following change
# in your user JSON preferences:
#     "jupyter.notebookFileRoot": "${workspaceFolder}",

from common.jupyter import get_nodes, plot_node, get_datasets
from common import polars as ppl
import scipy
if TYPE_CHECKING:
    from nodes.context import Context
    from nodes.node import Node
    from nodes.actions.shift import ShiftParameterValue
    from nodes.metric import DimensionalMetric, DimensionalFlow

yaml = YAML()


- source:
    categories:
      heating_system: fuel_oil
  dests:
  - categories:
      heating_system: district_heat
  amounts:
  - {year: 2023, source_amount: -0.696, dest_amounts: [100.0]}
  - {year: 2040, source_amount: -0.696, dest_amounts: [100.0]}
- source

:
    categories:
      heating_system: natural_gas
  dests:
  - categories:
      heating_system: district_heat
  amounts:
  - {year: 2023, source_amount: -0.988, dest_amounts: [1.0]}
  - {year: 2040, source_amount: -0.988, dest_amounts: [1.0]}


In [2]:
from typing import cast
from nodes.actions.action import ActionNode
from nodes.constants import YEAR_COLUMN
from nodes.datasets import DVCDataset

nodes = get_nodes('zuerich')
ctx: Context = nodes.context
act_nodes = cast(list[ActionNode], [
    nodes['fossil_fuel_heater_to_district_heat'],
    nodes['fossil_fuel_heater_to_heat_pumps'],
    nodes['fossil_fuel_heater_to_other'],
])
outcome_node: Node = nodes['building_heat_mix']
ds = DVCDataset('zuerich/building_heat_mix_goals', tags=[])
df = ds.get_copy(ctx)
outcome_cols = outcome_node.compute().paths.to_wide(only_category_names=True).metric_cols
goal_df = df.paths.to_wide(only_category_names=True)
# ensure the columns are in the same order
goal_df = goal_df.select(outcome_cols + [YEAR_COLUMN])
goal_df

fuel_oil,natural_gas,district_heat,wood,solar_collectors,heat_pumps,Year
f64,f64,f64,f64,f64,f64,i64
15.6,46.0,24.0,5.0,0.4,9.0,2025
10.1,33.0,33.0,5.3,0.6,18.0,2030
4.6,21.0,41.0,5.6,0.8,27.0,2035
0.0,9.0,49.0,6.0,1.0,35.0,2040


In [8]:
import cProfile
import numpy as np
from functools import partial
from typing import Any, Callable, Concatenate, Self, Tuple, TypeAlias
import networkx as nx
from pydantic import BaseModel
from scipy import optimize


from nodes.constants import FORECAST_COLUMN
from nodes.simple import MixNode
from nodes.actions.shift import ShiftParameterValue
if TYPE_CHECKING:
    from nodes.actions.shift import ShiftAmount
    from params import Parameter


ParamSetter: TypeAlias = Callable[['OptimizeParameter', float], None]


class OptimizeParameter:
    def __init__(self, action: ActionNode, param: Parameter):
        self.action = action
        self.param = param
        self.original_value = param.value
        self.value = param.value
        if isinstance(self.value, BaseModel):
            param.value = self.value = self.value.model_copy(deep=True)
        self.reset()

    def restore(self):
        self.param.value = self.original_value

    def reset(self):
        self.x0: list[float] = []
        self.bounds: list[Tuple[float, float]] = []
        self.xstep: list[float] = []
        self.value_set_ctx: list[Any] = []
        self.value_setters: list[Callable] = []
        self.ids: list[str] = []

    def set_source_value(self, start: ShiftAmount | None, end: ShiftAmount, new_val: float):
        if start is not None:
            start.source_amount = new_val
            #print('%s: set source %d from %f to %f' % (self.action.id, id(start), start.source_amount, new_val))
        end.source_amount = new_val

    def set_dest_value(self, start: ShiftAmount | None, end: ShiftAmount, idx: int, new_val: float):
        if start is not None:
            start.dest_amounts[idx] = new_val
        end.dest_amounts[idx] = new_val

    def configure_for_shift(self, start_year: int, end_year: int):
        value = self.value
        assert isinstance(value, ShiftParameterValue)

        self.reset()

        # remove all the values after our start year
        for entry in value.root:
            entry.amounts = list(sorted([a for a in entry.amounts if a.year <= start_year], key=lambda x: x.year))
            start = entry.amounts[-1]
            if start.year != start_year:
                start = entry.amounts[-1].model_copy()
                start.year = start_year
                entry.amounts.append(start)
            end = start.model_copy(update=dict(year=end_year, deep=True))
            entry.amounts.append(end)

            self.x0.append(start.source_amount)
            setter = partial(self.set_source_value, start, end)
            self.value_setters.append(setter)
            if start.source_amount < 0:
                self.bounds.append((-100, 100))
            elif start.source_amount > 0:
                self.bounds.append((-100, 100))
            self.ids.append('source')

            self.xstep.append(0.01)
            if len(start.dest_amounts) == 1:
                continue
            sum_amounts = sum(start.dest_amounts)
            for idx, amount in enumerate(start.dest_amounts):
                self.x0.append(0.1)
                self.value_setters.append(partial(self.set_dest_value, start, end, idx))
                self.bounds.append((0, 1))
                self.xstep.append(0.01)
                self.ids.append('dest-%d' % idx)

    def set_in_yaml(self):
        ins = ctx.instance
        assert ins.yaml_file_path
        cfg = yaml.load(open(ins.yaml_file_path, 'r', encoding='utf8'))
        main = cfg
        if 'instance' in cfg:
            main = cfg['instance']
        acts = main['actions']
        for actcfg in acts:
            if acts['id'] == self.action.id:
                break
        else:
            raise Exception("Action %s not found in yaml" % self.action.id)
        pcfg = actcfg['params']
        for pc in pcfg:
            if pc['id'] == 'shift':
                break
        else:
            raise Exception("Action %s not found in yaml" % self.action.id)

        out: list[dict] = self.value.model_dump(exclude_none=True, exclude_unset=True)  # pyright: ignore
        pc['value'] = out
        for entry in out:
            amounts = entry['amounts']  # pyright: ignore
            for idx, amt in enumerate(list(amounts)):
                m = CommentedMap(amt)
                m.fa.set_flow_style()
                amounts[idx] = m
        with open(ins.yaml_file_path, 'w', encoding='utf8') as f:
            yaml.dump(cfg, f)

    def dump_yaml(self):
        import io
        import sys

        buf = io.BytesIO()
        val = self.value

        yaml.dump(out, buf)
        print(buf.getvalue().decode('utf8'))


class OptimizeParameterSet:
    def __init__(self):
        self.params: list[OptimizeParameter] = []

    def add(self, optp: OptimizeParameter):
        self.params.append(optp)

    def restore(self):
        for param in self.params:
            param.restore()

    @property
    def x0(self) -> list[float]:
        x0 = []
        for param in self.params:
            x0 += param.x0
        return x0

    @property
    def bounds(self) -> Tuple[list[float], list[float]]:
        lower = []
        upper = []
        for param in self.params:
            lower += [bounds[0] for bounds in param.bounds]
            upper += [bounds[1] for bounds in param.bounds]
        return (lower, upper)

    @property
    def value_setters(self) -> list[Callable]:
        setters = []
        for param in self.params:
            setters += param.value_setters
        return setters

    @property
    def xstep(self) -> list[float]:
        xstep = []
        for param in self.params:
            xstep += param.xstep
        return xstep

    def print(self):
        from rich.table import Table
        from rich.console import Console

        table = Table()
        for col in ('Action', 'x0', 'bounds', 'step'):
            table.add_column(col)
        for param in self.params:
            for row in zip(param.x0, param.bounds, param.xstep):
                table.add_row(param.action.id, *[str(f) for f in row])
        console = Console()
        console.print(table)



def compute_and_compare(vals: list[float], goal: pl.DataFrame, value_setters: list[Callable]):
    # set parameters
    for val, set_value in zip(vals, value_setters):
        set_value(val)
    df = outcome_node.compute().paths.to_wide(only_category_names=True).drop(FORECAST_COLUMN)
    outcome = df.filter(pl.col('Year') == goal['Year'][0]).drop('Year')
    diff = outcome - goal
    mults = [10 if x < 0 else 1 for x in outcome.to_numpy()[0]]
    ndiffs = np.abs(diff.to_numpy()[0])
    m = ndiffs * mults
    return m


def run(params: OptimizeParameterSet, start_year: int, target_year: int):
    target_year_goal = goal_df.filter(pl.col('Year') == target_year)

    for opt in params.params:
        opt.configure_for_shift(start_year, target_year)

    print(params.x0)
    #print(params.bounds)
    #print(params.value_setters)

    with ctx.run():
        res = optimize.least_squares(
            compute_and_compare,
            params.x0, bounds=params.bounds, diff_step=params.xstep,
            max_nfev=500,
            method='trf',
            kwargs=dict(
                goal=target_year_goal,
                value_setters=params.value_setters,
            )
        )
        print(res)
        print(res.jac)
        #df = outcome_node.compute().paths.to_wide(only_category_names=True)\
        #    .drop(FORECAST_COLUMN).filter(pl.col(YEAR_COLUMN) == target_year)\
        #    .drop(YEAR_COLUMN)
        #print(df)
        #print(x0)
        #print([round(x, 2) for x in res.x])
        for x, set_value in zip(res.x, params.value_setters):
            set_value(round(float(x), 3))


def doit():
    path_nodes = set()
    params = OptimizeParameterSet()
    for act in act_nodes:
        all_paths = list(nx.all_simple_paths(ctx.node_graph, source=act.id, target=outcome_node.id))
        assert len(all_paths)
        for path in all_paths:
            path_nodes.update(path)

        param = act.get_parameter('shift')
        opt = OptimizeParameter(act, param)
        params.add(opt)

    for node_id in list(path_nodes):
        ctx.nodes[node_id].disable_cache = True
    if isinstance(outcome_node, MixNode):
        outcome_node.skip_normalize = True

    last_hist_year: int = list(outcome_node.get_output_pl().filter(~pl.col(FORECAST_COLUMN))[YEAR_COLUMN])[-1]
    years = sorted(goal_df[YEAR_COLUMN])
    start_year = last_hist_year + 1
    end_year = 0
    try:
        for year in years:
            print(start_year, years[-1])
            #with cProfile.Profile() as pr:
            run(params, start_year, years[-1])
            break
            #    pr.dump_stats('/tmp/opt.profile')
            start_year = year + 1
            end_year = year
    except:
        raise
    else:
        df = outcome_node.compute().paths.to_wide(only_category_names=True)\
            .filter(pl.col(FORECAST_COLUMN))\
            .filter(pl.col(YEAR_COLUMN).is_in(goal_df[YEAR_COLUMN]))
        print(goal_df)
        print(df.with_columns(pl.sum_horizontal(df.metric_cols).alias('Sum')))
        for param in params.params:
            param.dump_yaml()
    finally:
        params.restore()


doit()
#ctx.node_graph