In [None]:
import random

from etuples import etuple
from unification import unify, var

import pytensor.tensor as pt
from pytensor.graph import rewrite_graph
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import MergeOptimizer, PatternNodeRewriter, out2in

In [87]:
def find_optimal_P(P, Q, mc):
    pi = (Q * (P - mc)).sum()
    dpi_dP = pt.grad(pi, P)
    # P_star, success = root(dpi_dP, P, method="hybr", optimizer_kwargs=dict(tol=1e-8))
    # return P_star, success
    return dpi_dP

In [97]:
price_effect = pt.scalar("price_effect")
price = pt.vector("price")
trend = pt.vector("trend")
seasonality = pt.vector("seasonality")
mc = pt.scalar("marginal_cost")

price_term = price * price_effect
expected_sales = trend + price_term + seasonality

In [98]:
expr = find_optimal_P(price, expected_sales, mc=mc)

In [99]:
# Use existing rewrites to simplify expression
fgraph = FunctionGraph(outputs=[expr], clone=False)
rewrite_graph(fgraph, include=("canonicalize",))
fgraph.dprint()

Add [id A] 5
 ├─ Mul [id B] 4
 │  ├─ Sub [id C] 3
 │  │  ├─ price [id D]
 │  │  └─ ExpandDims{axis=0} [id E] 2
 │  │     └─ marginal_cost [id F]
 │  └─ ExpandDims{axis=0} [id G] 0
 │     └─ price_effect [id H]
 ├─ trend [id I]
 ├─ Mul [id J] 1
 │  ├─ price [id D]
 │  └─ ExpandDims{axis=0} [id G] 0
 │     └─ ···
 └─ seasonality [id K]


<ipykernel.iostream.OutStream at 0x7fcbccd613c0>

In [100]:
# distribute_mul_over_add = PatternNodeRewriter(
#     (pt.mul, (pt.add, "x", "y"), "z"),
#     (pt.add, (pt.mul, "x", "z"), (pt.mul, "y", "z")),
# )

distribute_mul_over_sub = PatternNodeRewriter(
    (pt.mul, (pt.sub, "x", "y"), "z"),
    (pt.add, (pt.mul, "x", "z"), (pt.mul, (pt.neg, "y"), "z")),
)

combine_addition_terms = PatternNodeRewriter(
    (pt.add, (pt.add, "x", "y"), "z", "x", "w"),
    (pt.add, (pt.mul, "x", 2), (pt.add, "y", "z", "w")),
)

# distribute_mul_over_add = out2in(distribute_mul_over_add, name="distribute_mul_add")
distribute_mul_over_sub = out2in(distribute_mul_over_sub, name="distribute_mul_sub")
combine_addition_terms = out2in(combine_addition_terms, name="combine_addition_terms")

# distribute
distribute_mul_over_sub.rewrite(fgraph)
# merge equivalent terms
MergeOptimizer().rewrite(fgraph)
# combine equivalent terms
combine_addition_terms.rewrite(fgraph)
# extract rewritten expression
expr = fgraph.outputs[0]

In [101]:
expr.dprint()

Add [id A]
 ├─ Mul [id B]
 │  ├─ Mul [id C]
 │  │  ├─ price [id D]
 │  │  └─ ExpandDims{axis=0} [id E]
 │  │     └─ price_effect [id F]
 │  └─ ExpandDims{axis=0} [id G]
 │     └─ 2 [id H]
 └─ Add [id I]
    ├─ Mul [id J]
    │  ├─ Neg [id K]
    │  │  └─ ExpandDims{axis=0} [id L]
    │  │     └─ marginal_cost [id M]
    │  └─ ExpandDims{axis=0} [id E]
    │     └─ ···
    ├─ trend [id N]
    └─ seasonality [id O]


<ipykernel.iostream.OutStream at 0x7fcbccd613c0>

In [102]:
# Create variations of a graph for pattern matching
rewrites = [
    out2in(
        PatternNodeRewriter((pt.add, "x", "y"), (pt.add, "y", "x")),
        name="commutative_add",
        ignore_newtrees=True,
    ),
    out2in(
        PatternNodeRewriter((pt.mul, "x", "y"), (pt.mul, "y", "x")),
        name="commutative_mul",
        ignore_newtrees=True,
    ),
    out2in(
        PatternNodeRewriter(
            (pt.mul, (pt.mul, "x", "y"), "z"), (pt.mul, "x", (pt.mul, "y", "z"))
        ),
        name="associative_mul",
        ignore_newtrees=True,
    ),
]


def yield_arithmetic_variants(expr, n):
    fgraph = FunctionGraph(outputs=[expr], clone=False)
    while n > 0:
        rewrite = random.choice(rewrites)
        res = rewrite.apply(fgraph)
        n -= 1
        if res:
            yield fgraph.outputs[0]
    yield fgraph.outputs[0]

In [103]:
# Rewrite graph randomly until we match price * a + b
a, b, price_ = var("a"), var("b"), var("price")
pattern = etuple(pt.add, etuple(pt.mul, price_, a), b)

for variant in yield_arithmetic_variants(expr, n=100):
    match_dict = unify(variant, pattern)
    if match_dict and match_dict[price_] is price:
        break
else:
    raise ValueError("No matching variant found")
match_dict

{~price: price, ~a: Mul.0, ~b: Add.0}

In [104]:
optimal_result = -match_dict[b] / match_dict[a]
optimal_result.dprint()

True_div [id A]
 ├─ Neg [id B]
 │  └─ Add [id C]
 │     ├─ Mul [id D]
 │     │  ├─ Neg [id E]
 │     │  │  └─ ExpandDims{axis=0} [id F]
 │     │  │     └─ marginal_cost [id G]
 │     │  └─ ExpandDims{axis=0} [id H]
 │     │     └─ price_effect [id I]
 │     ├─ trend [id J]
 │     └─ seasonality [id K]
 └─ Mul [id L]
    ├─ ExpandDims{axis=0} [id H]
    │  └─ ···
    └─ ExpandDims{axis=0} [id M]
       └─ 2 [id N]


<ipykernel.iostream.OutStream at 0x7fcbccd613c0>