Skip to content

Commit

Permalink
Merge pull request #416 from oughtinc/fix-scaling-for-logistic-fit
Browse files Browse the repository at this point in the history
Fix scaling for logistic fit
  • Loading branch information
brachbach committed Sep 29, 2020
2 parents 5de9199 + 10a6f0a commit e174863
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
5 changes: 4 additions & 1 deletion ergo/platforms/metaculus/question/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,10 @@ def get_submission_from_samples(

normalized_samples = self.scale.normalize_points(samples)
_dist = dist.LogisticMixture.from_samples(
normalized_samples, fixed_params={"num_components": 3}, verbose=verbose
normalized_samples,
fixed_params={"num_components": 3},
verbose=verbose,
scale=Scale(0, 1),
)
return self.prepare_logistic_mixture(_dist)

Expand Down
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ def logistic_mixture():
)


@pytest.fixture(scope="module")
def smooth_logistic_mixture():
xscale = Scale(1, 1000000.0)
return LogisticMixture(
components=[
Logistic(loc=400000, s=100000, scale=xscale),
Logistic(loc=700000, s=50000, scale=xscale),
],
probs=[0.8, 0.2],
)


@pytest.fixture(scope="module")
def logistic_mixture10():
xscale = Scale(-20, 40)
Expand Down
17 changes: 17 additions & 0 deletions tests/test_metaculus.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,23 @@ def test_get_questions_question_status(metaculus):
).all()


def test_submission_from_samples_smooth(metaculus_questions, smooth_logistic_mixture):
samples = np.array([smooth_logistic_mixture.sample() for _ in range(5000)])
fit_mixture = metaculus_questions.continuous_linear_open_question.get_submission_from_samples(
samples
)
normalized_samples_from_fit_mixture = [fit_mixture.sample() for _ in range(5000)]
mixture_samples = metaculus_questions.continuous_linear_open_question.denormalize_samples(
normalized_samples_from_fit_mixture
)
assert float(np.mean(samples)) == pytest.approx(
float(np.mean(mixture_samples)), rel=0.1
)
assert float(np.var(samples)) == pytest.approx(
float(np.var(mixture_samples)), rel=0.2
)


@pytest.mark.xfail(reason="Fitting doesn't reliably work yet #219")
def test_submission_from_samples_linear(metaculus_questions, logistic_mixture_samples):
normalized_mixture = metaculus_questions.continuous_linear_open_question.get_submission_from_samples(
Expand Down

0 comments on commit e174863

Please sign in to comment.