Skip to content

Commit

Permalink
fix insufficient of prediction accuracy due to boundary point (szcf-w…
Browse files Browse the repository at this point in the history
  • Loading branch information
szcf-weiya committed Sep 9, 2023
1 parent 332b6d9 commit a295dc8
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions src/mono_decomp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ function gurobi(nthread::Int = 0)
# and inspired by https://support.gurobi.com/hc/en-us/community/posts/4412624836753-Do-not-print-Set-parameter-Username-in-console
GRBsetintparam(GRB_ENV, GRB_INT_PAR_OUTPUTFLAG, 0)
GRBsetintparam(GRB_ENV, GRB_INT_PAR_THREADS, nthread)
# GRBsetdblparam(GRB_ENV, GRB_DBL_PAR_OPTIMALITYTOL, 1e-9)
# GRBsetdblparam(GRB_ENV, GRB_DBL_PAR_FEASIBILITYTOL, 1e-9)
global OPTIMIZER = () -> Gurobi.Optimizer(GRB_ENV)
end

Expand Down Expand Up @@ -496,7 +498,7 @@ function cv_mono_decomp_ss(x::AbstractVector{T}, y::AbstractVector{T}; figname =
# workspace has been defined, so J would be inherited regardless of prop_nknots
D = mono_decomp_ss(D.workspace, x, y, λopt, μopt)
end
return D, μmin, μs, errs, σerrs, yhat, yhatnew
return D, μmin, μs, errs, σerrs, yhat, yhatnew, γss
end

function summary_res(σs = 0.2:0.2:1.0)
Expand Down Expand Up @@ -842,7 +844,7 @@ end
strict && error(status)
γhats[:, i] .= mean(y) / 2
else
if !(status in [MOI.OPTIMAL, MOI.ALMOST_OPTIMAL])
if !(status in [MOI.OPTIMAL, MOI.ALMOST_OPTIMAL, MOI.LOCALLY_SOLVED])
@debug "$status when λ=, μ=: direct take the solution"
strict && error(status)
end
Expand Down Expand Up @@ -877,7 +879,7 @@ end
strict && error(status)
γhats[:, i] .= mean(y) / 2
else
if !(status in [MOI.OPTIMAL, MOI.ALMOST_OPTIMAL])
if !(status in [MOI.OPTIMAL, MOI.ALMOST_OPTIMAL, MOI.LOCALLY_SOLVED])
@debug "$status"
strict && error(status)
end
Expand Down Expand Up @@ -912,7 +914,7 @@ end
strict && error(status)
γhats[:, i] .= mean(y) / 2
else
if !(status in [MOI.OPTIMAL, MOI.ALMOST_OPTIMAL])
if !(status in [MOI.OPTIMAL, MOI.ALMOST_OPTIMAL, MOI.LOCALLY_SOLVED])
@debug "$status"
strict && error(status)
end
Expand Down Expand Up @@ -986,7 +988,7 @@ function _optim!(y::AbstractVector{T}, J::Int, B::AbstractMatrix{T}, s::Union{No
γhat .= mean(y) / 2
end
else
if !(status in [MOI.OPTIMAL, MOI.ALMOST_OPTIMAL])
if !(status in [MOI.OPTIMAL, MOI.ALMOST_OPTIMAL, MOI.LOCALLY_SOLVED])
@debug "$status"
strict && error(status)
end
Expand Down Expand Up @@ -1096,7 +1098,7 @@ function predict(W::WorkSpaceSS, xnew::AbstractVector, γup::AbstractVector, γd
xm = (xnew .- W.mx) ./ W.rx
# evaluate on the whole dataset, so all should be in the middle
@assert all(0 .<= xm .<= 1)
Bnew = rcopy(R"splines::bs($xm, intercept = TRUE, knots=$(W.knots[2:end-1]))")
Bnew = rcopy(R"splines::bs($xm, intercept = TRUE, knots=$(W.knots[2:end-1]), Boundary.knots = c(0, 1))")
return Bnew * γup, Bnew * γdown
end

Expand All @@ -1122,7 +1124,7 @@ function predict(W::WorkSpaceSS, xnew::AbstractVector, γhat::AbstractVecOrMat)
Bnew = zeros(0, W.J)
else
# xm cannot be empty
Bnew = rcopy(R"splines::bs($xm, intercept = TRUE, knots=$(W.knots[2:end-1]))")
Bnew = rcopy(R"splines::bs($xm, intercept = TRUE, knots=$(W.knots[2:end-1]), Boundary.knots = c(0, 1))")
end
if isa(γhat, AbstractVector)
yhat = zeros(n)
Expand Down Expand Up @@ -1390,15 +1392,23 @@ function cvplot(μerr::AbstractMatrix{T}, σerr::AbstractMatrix{T}, para1::Abstr
n, m = size(μerr)
@assert n == length(para1)
@assert m == length(para2)
if para1[2] - para1[1] para1[3] - para1[2]
if length(para1) < 3
f = x -> x
else
f = x -> log10(x)
if para1[2] - para1[1] para1[3] - para1[2]
f = x -> x
else
f = x -> log10(x)
end
end
if para2[2] - para2[1] para2[3] - para2[2]
if length(para2) < 3
g = x -> x
else
g = x -> log10(x)
if para2[2] - para2[1] para2[3] - para2[2]
g = x -> x
else
g = x -> log10(x)
end
end
p = heatmap(g.(para2), f.(para1), μerr, xlab = lbl[2], ylab = lbl[1], title = title)
ind = argmin(μerr)
Expand Down

0 comments on commit a295dc8

Please sign in to comment.