Skip to content

Commit

Permalink
Merge pull request #3243 from stan-dev/fix/psis-weights-overflow
Browse files Browse the repository at this point in the history
update psis to not add back the max and return the unnormalized resample ratios
  • Loading branch information
SteveBronder committed Dec 8, 2023
2 parents 519a52a + 3c31d68 commit 4637f94
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 42 deletions.
9 changes: 1 addition & 8 deletions src/stan/services/pathfinder/psis.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,14 +270,7 @@ inline Eigen::Array<double, Eigen::Dynamic, 1> psis_weights(
}

// truncate at max of raw wts (i.e., 0 since max has been subtracted)
for (Eigen::Index i = 0; i < llr_weights.size(); ++i) {
if (llr_weights.coeff(i) > 0) {
llr_weights.coeffRef(i) = 0.0;
}
}
auto max_adj = (llr_weights + max_log_ratio).eval();
auto max_adj_exp = max_adj.exp();
return max_adj_exp / max_adj_exp.sum();
return (llr_weights.array() < 0.0).select(llr_weights, 0.0).exp().eval();
}

} // namespace psis
Expand Down
68 changes: 34 additions & 34 deletions src/test/unit/services/pathfinder/psis_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,40 +120,40 @@ TEST(ServicesPSIS, get_psis_weights) {
4.91652175788515, 5.59472544249795, 7.16126247369561, 6.15854810500356,
6.62555174418028, 9.103992907802, 6.7399306070758, 6.04794961458687;
Eigen::Array<double, -1, 1> answer(100);
answer << 0.00614128190501742, 0.00259883000263199, 0.00142240988349195,
0.0274531755800624, 0.0157476498780024, 0.00391823907901988,
8.26414513956511e-14, 0.00322003256443345, 0.00619646954412031,
0.00346793510111242, 0.0444046803091539, 0.004029171121401,
0.00366480923122832, 0.00175331655026224, 0.00491192297921304,
0.000810758279198907, 0.00588439834963157, 0.17822256163288,
0.00297698115204899, 0.00580171797853455, 0.0121325368865641,
0.00408651931997244, 0.00457647470707981, 0.00999374265076167,
4.75420585761083e-06, 0.00250429809109908, 0.00273193115270838,
0.00357946965550952, 0.00255728407443698, 0.00345942783466891,
0.00585416468904412, 0.00866717090821407, 0.00223521120433322,
0.00335369108675144, 0.000278240478565439, 0.00841871317684748,
0.00303339482203412, 0.0108274835712062, 0.00635298901334288,
0.0366679814601552, 0.0128890587651439, 0.0313596935307308,
0.00212538535536358, 0.00646666659496334, 0.00787540323350999,
0.00513279331867044, 0.0114488824095472, 0.00943345186908113,
0.00744772541269413, 0.00262296832653951, 0.00786635558506987,
0.00149440115404536, 0.00693444017895142, 0.00452127888795521,
0.0137317162483214, 0.001172190208605, 0.00470337325067493,
0.00880970502513788, 0.00388287423800577, 0.00615587072760303,
0.0220257764512452, 0.000965367587466615, 0.00414255659344651,
0.00869536963179888, 0.00266583380629748, 0.000131838303785288,
4.379914619164e-16, 0.00226285858615159, 0.0183863334607642,
0.00209481287761514, 2.69620598333185e-11, 2.54440757139945e-05,
0.00400503003258219, 0.0200460546415302, 0.00559698462905835,
0.00716790107562062, 0.00329971992325907, 0.0169711871963573,
0.00473505622866302, 0.057026147481862, 0.00651433804563218,
0.0102597432232274, 0.00575355555045384, 0.00379438966189875,
0.0244370937648389, 0.00151257245292768, 0.00487470007189976,
0.0054709853543924, 0.0044516176713708, 0.00952362373631346,
0.00464287019866733, 9.6288493325318e-09, 0.00147240135547797,
0.00290112423098392, 0.0146773422920254, 0.00509837055791414,
0.00813295746634018, 0.082538932616576, 0.00911848336266506,
0.00456456171607177;
answer << 0.0292354130264213, 0.0123716627387399, 0.00677134531192521,
0.130690129419765, 0.0749662782948843, 0.0186526753832652,
3.93412483260354e-13, 0.0153288814022406, 0.0294981323492076,
0.0165090149903415, 0.211387327470457, 0.019180764490236,
0.0174462291741511, 0.00834661790579885, 0.0233830817852443,
0.00385959373361536, 0.0280125255319936, 0.848423876407167,
0.0141718414653653, 0.0276189277725056, 0.0577566267796964,
0.019453769100265, 0.0217861889969809, 0.0475749523623197,
2.26322735235991e-05, 0.0119216460287785, 0.0130052873071864,
0.0170399723401178, 0.0121738844264701, 0.0164685163693616,
0.0278685989759023, 0.0412598420315477, 0.0106406648922235,
0.0159651593267689, 0.00132455657260805, 0.0400770654535543,
0.0144403972763279, 0.0515439543627507, 0.0302432392178012,
0.174556973513457, 0.0613580295366247, 0.149287006675898,
0.0101178417902614, 0.0307843983930222, 0.0374906525774646,
0.0244345293004686, 0.0545021073956205, 0.0449077026462719,
0.0354547034177988, 0.0124865726028556, 0.037447581482023,
0.0071140578858597, 0.0330112223920261, 0.0215234307334137,
0.0653694786023926, 0.00558018104755882, 0.0223902862183483,
0.0419383719937075, 0.0184843219760041, 0.0293048627374842,
0.104853136810685, 0.00459560733058814, 0.0197205330854113,
0.041394081322883, 0.0126906326060736, 0.000627612821472451,
2.08504698029095e-15, 0.0107722795353974, 0.0875276629668783,
0.00997230230387398, 1.28352185661132e-10, 0.000121125861681488,
0.0190658414638319, 0.0954287225467601, 0.0266442999790266,
0.0341226069278941, 0.0157082310045359, 0.0807908959249787,
0.0225411122122881, 0.271471494180354, 0.0310113370959555,
0.0488412411797974, 0.0273896517843532, 0.0180630934492948,
0.116332150262759, 0.00720056188230011, 0.0232058831016178,
0.0260444837040436, 0.0211918103208897, 0.0453369634785661,
0.0221022630104646, 4.58378871967692e-08, 0.00700932841609052,
0.0138107265625568, 0.0698711068265654, 0.02427066064181,
0.0387167328144493, 0.392924445274868, 0.0434083032508679,
0.0217294775126546;
stan::test::test_logger warner;
auto blah = stan::services::psis::psis_weights(lrms, 20, warner);
for (Eigen::Index i = 0; i < answer.size(); ++i) {
Expand Down

0 comments on commit 4637f94

Please sign in to comment.