In [None]:
using JuMP
using Ipopt   # Or your favorite NLP solver
using InfiniteOpt, Distributions, NLPModelsIpopt, Ipopt, Random, MadNLP


"""
A simple struct to hold parameters, analogous to p in pivo_simple.jl
"""
struct PIVOParams
    productAdvRate::Float64
    marketCap::Float64
    churnProb::Float64
    initWorkGen::Float64
    recurringWorkPerCustomer::Float64
    repMult::Float64
    acquireSpeed::Float64
    salary::Float64
    recurringFee::Float64
    onboardFee::Float64
    productProductivityCap::Float64
    compRate::Float64
end


# Instantiate some parameters (tweak as desired):
p = PIVOParams(
    0.1,    # productAdvRate
    100.0,  # marketCap
    0.3,    # churnProb
    1.0,    # initWorkGen
    0.5,    # recurringWorkPerCustomer
    0.2,    # repMult
    0.3,    # acquireSpeed
    2.0,    # salary
    0.3,    # recurringFee
    30.0,   # onboardFee
    50.0,   # productProductivityCap
    0.04    # compRate
)

# Time horizon
T = 30.0


# 1) Create an InfiniteOpt model
model = InfiniteModel(Ipopt.Optimizer)

# 2) Define the infinite parameter: t ∈ [0, T].
#    We'll pick some discretization resolution, e.g., 31 supports.
@infinite_parameter(model, t in [0, T], num_supports = 31,
                    derivative_method = OrthogonalCollocation(3))

# 3) Define your “states” as infinite variables x1(t), x2(t), ...
#    We match the state vector from pivo_simple.jl:
#       x[1] = product(t)
#       x[2] = nCustomersEverTried(t)
#       x[3] = totalLostCustomers(t)
#       x[4] = accumulatedPnL(t)
#       x[5] = valueAddedToCustomers(t)
@variable(model, x1 >= 0, Infinite(t), start = 0.1)  # product(t1
@variable(model, x2 >= 0, Infinite(t), start = 1.0)  # nEverTried(t)
@variable(model, x3 >= 0, Infinite(t), start = 0.1)  # lost(t)
@variable(model, x4>=0,       Infinite(t), start = 100.0) # accumPnL(t)
@variable(model, x5 >= 0, Infinite(t), start = 1.0)  # valueAdded(t)


# 4) Define your “control” variable: w(t) = fraction of team on product
#    We might constrain 0 ≤ w(t) ≤ 1. The remainder (1-w(t)) is on customers.
@variable(model, 0 <= w <= 1, Infinite(t), start = 0.5)

# 5) Initial conditions at t=0.  Use restricted syntax for infinite variables:
@constraint(model, x1(0) == 0.05)  # start product=0
@constraint(model, x2(0) == 1.0)  # ...
@constraint(model, x3(0) == 0.1)
@constraint(model, x4(0) == 100.0)
@constraint(model, x5(0) == 1.0)


# 6) Express some "auxiliary" expressions for convenience:
#    totalCurrent(t) = x2(t) - x3(t)
@expression(model, totalCurrent, x2 - x3)


#    reputation(t) = p.repMult + (1 - p.repMult)* totalCurrent(t)/(x2(t)+1e-6)
@expression(model, rep,
    p.repMult + (1.0 - p.repMult)*( totalCurrent )/( x2 + 1e-6 )
)


#    gain(t) = p.acquireSpeed * rep(t) * (p.marketCap - x2(t)) * x1(t)
@expression(model, gain,
    p.acquireSpeed * rep * (p.marketCap - x2) * x1
)

constant_constant_over_collocation(w,t)


#    churn(t) = p.churnProb * max(0, (some measure of unmet?))
#    For demonstration, let's do unmet = currentWork - x5'(t),
#    but to keep it simpler, define "currentWork" with a small expression:
@expression(model, currentWork,
    p.initWorkGen*gain + p.recurringWorkPerCustomer*totalCurrent
)


#    We'll keep churn(t) as a piecewise expression.
#    For now, let's do a "softplus" approximation to max(0, x):
#    softplus(u) = log(1 + exp(u * large_k))/large_k for some large_k
#    We'll just do an if-else below for demonstration:
function softplus(u, k=50.0)
    # A naive big-k approach (avoid overflows in practice).
    return (1/k)*log(1 + exp(k*u))
end


@expression(model, unmet,
    currentWork - (deriv(x5, t))  # "approx" measure
)


@expression(model, churnRate,
    p.churnProb * softplus(unmet)
)


# 7) Now define the ODE constraints using derivative operators.
#    For example, d/dt of product(t) = productAdvRate*(1 - product)*w(t).
@constraint(model, deriv_x1, deriv(x1, t) == p.productAdvRate*(1 - x1)*w)


#    d/dt of nEver(t) = gain(t) - competitor effect
@constraint(model, deriv_x2, deriv(x2, t) == gain - p.compRate*(p.marketCap - x2))


#    d/dt of lost(t) = churnRate(t)
@constraint(model, deriv_x3, deriv(x3, t) == churnRate)


#    d/dt of accumPnL(t) = onboardFee*gains - costAcquire*gains + ...
#    For brevity, let's define costAcquire*gains or monthly fee, etc.
#    We'll skip a logistic cost function and just do something minimal:
@expression(model, signups, gain*softplus(gain))  # though in practice gain≥0 anyway


@expression(model, costAcquire, 0.4)  # or some logistic expression


@constraint(model, deriv_x4, deriv(x4, t) ==
    p.onboardFee*signups -
    costAcquire*p.salary*signups +
    p.recurringFee*totalCurrent -
    p.salary*( w + (1 - w) )    # i.e. pay for entire team
)



#    d/dt of valueAdded(t) = (some function)
#    from pivo_simple, d(valueAdded)/dt ~ (x1 / denom)* (1-w(t)) etc.
#    We'll do a simpler version:
@expression(model, denom,
    (1 - x1) + (1/p.productProductivityCap)
)


@constraint(model, deriv_x5, deriv(x5, t) == ( x1/denom )*(1 - w) )


# 8) We typically reduce degrees of freedom on control if using collocation:
#    Make w(t) constant between collocation nodes
# constant_over_collocation(w, t)
# InfiniteOpt.constant_over_collocation(w, t)
# constant_over_collocation



# 9) Objective: let's **maximize final accumPnL** = x4(T).
#    Since InfiniteOpt’s default is a minimization, we do Min of -x4(T).
#    We'll do it as a semi-infinite style: x4(T) at t=30.
#    Easiest is to define a “dummy” finite variable finalPnL and a constraint:
@variable(model, finalPnL)
@constraint(model, finalPnL_def, finalPnL == x4(T))  # restricted to t=30
@objective(model, Min, -finalPnL)


# 10) Solve:
optimize!(model)


println("Termination status: ", JuMP.termination_status(model))
println("Objective value (negative means we are maximizing x4(T)): ",
        JuMP.objective_value(model))
final_pnl_value = value(finalPnL)
println("Optimal final PnL ≈ ", final_pnl_value)

# Now let's look at the solution for w(t) at the discretized supports:
w_vals = value.(w)
w_ts   = supports(w)
println("\nControl w(t) at discrete points:")
for (t_supp, w_val) in zip(w_ts, w_vals)
    println(" t = $(t_supp[1]):  w(t) = $w_val")
end

# Similarly, we can inspect x4(t) at each support:
x4_vals = value.(x4)
x4_ts   = supports(x4)
println("\naccumPnL x4(t) at discrete points:")
for (t_supp, x4_val) in zip(x4_ts, x4_vals)
    println(" t = $(t_supp[1]):  x4(t) = $x4_val")
end

PIVOParams