Skip to content

Commit

Permalink
Add local global trend notebook (#119)
Browse files Browse the repository at this point in the history
* add comment on how to implement categorical dist effectively for large number of categories

* resolve the instability of StickBreaking logdet

* use upstream versions

* return wa_update instead of warmup_update

* add comment on categorical sampler

* upstream notebook

* run the experiment in gpu

* use vectorized version for supervised data

* add data and replicated notebook

* expose max_tree_depth

* add halfcauchy, trunccauchy

* add Pyglt notebook

* nonlocal max_tree_depth

* run isort

* expose max_tree_depth, heuristic step_size

* add option use_prims

* nit

* use prims by default for fori_collect

* avoid two time compilings in warmup too

* update logistic regression notebook

* fix typo

* remove tscan test

* expose warmup_update again because we return jitted kernel

* revert changes in constraints.py

* simpify notebook

* use adapt_step_size=False instead of heuristic=False

* update notebook

* fixed point error

* add a working version

* add forcasting

* fix nan issue

* implement seasonal model

* clean up

* revise formula

* update the content for the notebook

* tune forecasting algorithm

* add median plot

* use generic prior for coef_trend

* use generalized seasonality and seasonal avg method

* minor text edits

* remove data file in favor of URL, remove legacy info, and update seasonality=38 causality
  • Loading branch information
fehiepsi committed May 28, 2019
1 parent d11b911 commit 77b7042
Show file tree
Hide file tree
Showing 5 changed files with 571 additions and 4 deletions.
560 changes: 560 additions & 0 deletions notebooks/time_series_forecasting.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
multinomial,
poisson,
promote_shapes,
softmax,
xlog1py,
xlogy
)
Expand All @@ -55,8 +56,7 @@ def _to_logits_bernoulli(probs):


def _to_probs_multinom(logits):
x = np.exp(logits - np.max(logits, -1, keepdims=True))
return x / x.sum(-1, keepdims=True)
return softmax(logits, axis=-1)


def _to_logits_multinom(probs):
Expand Down
5 changes: 5 additions & 0 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,11 @@ def binary_cross_entropy_with_logits(x, y):
return np.clip(x, 0) + np.log1p(np.exp(-np.abs(x))) - x * y


def softmax(x, axis=-1):
unnormalized = np.exp(x - np.max(x, axis, keepdims=True))
return unnormalized / np.sum(unnormalized, axis, keepdims=True)


@custom_transforms
def cumsum(x):
return np.cumsum(x, axis=-1)
Expand Down
3 changes: 2 additions & 1 deletion numpyro/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def init_kernel(init_params,
with tqdm.trange(num_warmup, desc='warmup') as t:
for i in t:
hmc_state, wa_state = warmup_update(i, (hmc_state, wa_state))
t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=True)
# TODO: set refresh=True when its performance issue is resolved
t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=False)
# Reset `i` and `mean_accept_prob` for fresh diagnostics.
hmc_state.update(i=0, mean_accept_prob=0)
return hmc_state
Expand Down
3 changes: 2 additions & 1 deletion numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def _body_fn(i, vals):
val = body_fun(val)
collection.append(jit(ravel_fn)(val))
if diagnostics_fn:
t.set_postfix_str(diagnostics_fn(val), refresh=True)
# TODO: set refresh=True when its performance issue is resolved
t.set_postfix_str(diagnostics_fn(val), refresh=False)

# XXX: jax.numpy.stack/concatenate is currently so slow
collection = onp.stack(collection)
Expand Down

0 comments on commit 77b7042

Please sign in to comment.