Skip to content

Commit

Permalink
ruff format notebooks (#1765)
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy committed Mar 17, 2024
1 parent b81e2d1 commit 013e54c
Show file tree
Hide file tree
Showing 14 changed files with 412 additions and 404 deletions.
11 changes: 4 additions & 7 deletions notebooks/source/bad_posterior_geometry.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,12 @@
"\n",
"import numpy as np\n",
"\n",
"import jax.numpy as jnp\n",
"from jax import random\n",
"import jax.numpy as jnp\n",
"\n",
"import numpyro\n",
"import numpyro.distributions as dist\n",
"from numpyro.diagnostics import summary\n",
"\n",
"import numpyro.distributions as dist\n",
"from numpyro.infer import MCMC, NUTS\n",
"\n",
"assert numpyro.__version__.startswith(\"0.14.0\")\n",
Expand Down Expand Up @@ -403,7 +402,7 @@
" x2 = numpyro.sample(\n",
" \"x2\", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=cov)\n",
" )\n",
" y = numpyro.sample(\"y\", dist.Normal(jnp.zeros(100), 1.0))\n",
" numpyro.sample(\"y\", dist.Normal(jnp.zeros(100), 1.0))\n",
" numpyro.sample(\"obs\", dist.Normal(x1 - x2, 0.1), jnp.ones(2))"
]
},
Expand Down Expand Up @@ -498,9 +497,7 @@
"\n",
"\n",
"def mvn_model():\n",
" x = numpyro.sample(\n",
" \"x\", dist.MultivariateNormal(jnp.zeros(dim), covariance_matrix=cov)\n",
" )\n",
" numpyro.sample(\"x\", dist.MultivariateNormal(jnp.zeros(dim), covariance_matrix=cov))\n",
"\n",
"\n",
"print(\"max_tree_depth = 5 (bad r_hat)\")\n",
Expand Down
11 changes: 6 additions & 5 deletions notebooks/source/bayesian_hierarchical_linear_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt"
"import pandas as pd\n",
"import seaborn as sns"
]
},
{
Expand Down Expand Up @@ -240,10 +240,11 @@
"metadata": {},
"outputs": [],
"source": [
"from jax import random\n",
"\n",
"import numpyro\n",
"from numpyro.infer import MCMC, NUTS, Predictive\n",
"import numpyro.distributions as dist\n",
"from jax import random\n",
"from numpyro.infer import MCMC, NUTS, Predictive\n",
"\n",
"assert numpyro.__version__.startswith(\"0.14.0\")"
]
Expand Down
10 changes: 5 additions & 5 deletions notebooks/source/bayesian_hierarchical_stacking.ipynb

Large diffs are not rendered by default.

5 changes: 1 addition & 4 deletions notebooks/source/bayesian_imputation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,10 @@
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from jax import numpy as jnp\n",
"from jax import random\n",
"from jax.scipy.special import expit\n",
"from jax import numpy as jnp, random\n",
"\n",
"import numpyro\n",
"from numpyro import distributions as dist\n",
"from numpyro.distributions import constraints\n",
"from numpyro.infer import MCMC, NUTS, Predictive\n",
"\n",
"plt.style.use(\"seaborn\")\n",
Expand Down
682 changes: 342 additions & 340 deletions notebooks/source/bayesian_regression.ipynb

Large diffs are not rendered by default.

13 changes: 8 additions & 5 deletions notebooks/source/discrete_imputation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,17 @@
"metadata": {},
"outputs": [],
"source": [
"import numpyro\n",
"from jax import numpy as jnp, random, ops\n",
"from math import inf\n",
"\n",
"from graphviz import Digraph\n",
"\n",
"from jax import numpy as jnp, random\n",
"from jax.scipy.special import expit\n",
"\n",
"import numpyro\n",
"from numpyro import distributions as dist, sample\n",
"from numpyro.infer.mcmc import MCMC\n",
"from numpyro.infer.hmc import NUTS\n",
"from math import inf\n",
"from graphviz import Digraph\n",
"from numpyro.infer.mcmc import MCMC\n",
"\n",
"simkeys = random.split(random.PRNGKey(0), 10)\n",
"nsim = 5000\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/logistic_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
"\n",
"import numpy as np\n",
"\n",
"import jax.numpy as jnp\n",
"from jax import random\n",
"import jax.numpy as jnp\n",
"\n",
"import numpyro\n",
"import numpyro.distributions as dist\n",
Expand Down
7 changes: 4 additions & 3 deletions notebooks/source/model_rendering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
"metadata": {},
"outputs": [],
"source": [
"import flax\n",
"import flax.linen as flax_nn\n",
"from numpyro.contrib.module import flax_module\n",
"import numpy as np\n",
"\n",
"import flax.linen as flax_nn\n",
"from jax import nn\n",
"import jax.numpy as jnp\n",
"\n",
"import numpyro\n",
"from numpyro.contrib.module import flax_module\n",
"import numpyro.distributions as dist\n",
"import numpyro.distributions.constraints as constraints\n",
"\n",
Expand Down
14 changes: 8 additions & 6 deletions notebooks/source/ordinal_regression.ipynb

Large diffs are not rendered by default.

15 changes: 8 additions & 7 deletions notebooks/source/tbip.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,11 @@
},
"outputs": [],
"source": [
"import numpy as np\n",
"from scipy import sparse\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"\n",
"dataPath = \"tbip/data/senate-speeches-114/clean/\"\n",
"\n",
Expand Down Expand Up @@ -273,7 +274,7 @@
},
"outputs": [],
"source": [
"from numpyro import plate, sample, param\n",
"from numpyro import param, plate, sample\n",
"import numpyro.distributions as dist\n",
"from numpyro.distributions import constraints\n",
"\n",
Expand Down Expand Up @@ -410,9 +411,10 @@
"outputs": [],
"source": [
"# Initialize the model\n",
"from jax import jit\n",
"from optax import adam, exponential_decay\n",
"\n",
"from numpyro.infer import SVI, TraceMeanField_ELBO\n",
"from jax import jit\n",
"\n",
"num_steps = 50000\n",
"batch_size = 512 # Large batches are recommended\n",
Expand Down Expand Up @@ -543,8 +545,8 @@
],
"source": [
"# Run SVI\n",
"from tqdm import tqdm\n",
"import pandas as pd\n",
"from tqdm import tqdm\n",
"\n",
"print_steps = 100\n",
"print_intermediate_results = False\n",
Expand Down Expand Up @@ -623,7 +625,6 @@
},
"outputs": [],
"source": [
"import os\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"\n",
Expand Down Expand Up @@ -726,7 +727,7 @@
"\n",
"\n",
"def create_svi_object(guide):\n",
" svi_object = SVI(\n",
" SVI(\n",
" model=tbip.model,\n",
" guide=guide,\n",
" optim=adam(exponential_decay(learning_rate, num_steps, decay_rate)),\n",
Expand Down Expand Up @@ -789,7 +790,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
"version": "3.10.13"
},
"vscode": {
"interpreter": {
Expand Down
6 changes: 3 additions & 3 deletions notebooks/source/time_series_forecasting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@
"source": [
"import os\n",
"\n",
"from IPython.display import set_matplotlib_formats\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"from IPython.display import set_matplotlib_formats\n",
"\n",
"import jax.numpy as jnp\n",
"from jax import random\n",
"import jax.numpy as jnp\n",
"\n",
"import numpyro\n",
"import numpyro.distributions as dist\n",
"from numpyro.contrib.control_flow import scan\n",
"from numpyro.diagnostics import autocorrelation, hpdi\n",
"import numpyro.distributions as dist\n",
"from numpyro.infer import MCMC, NUTS, Predictive\n",
"\n",
"if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n",
Expand Down
18 changes: 10 additions & 8 deletions notebooks/source/truncated_distributions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,29 @@
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from scipy.stats import poisson as sp_poisson\n",
"\n",
"import jax\n",
"from jax import lax, random\n",
"import jax.numpy as jnp\n",
"from jax.scipy.special import ndtri\n",
"from jax.scipy.stats import norm, poisson\n",
"\n",
"import numpyro\n",
"import numpyro.distributions as dist\n",
"from jax import lax, random\n",
"from jax.scipy.special import ndtr, ndtri\n",
"from jax.scipy.stats import poisson, norm\n",
"from numpyro.distributions import (\n",
" constraints,\n",
" Distribution,\n",
" FoldedDistribution,\n",
" SoftLaplace,\n",
" StudentT,\n",
" TruncatedDistribution,\n",
" TruncatedNormal,\n",
" constraints,\n",
")\n",
"from numpyro.distributions.util import promote_shapes\n",
"from numpyro.infer import DiscreteHMCGibbs, MCMC, NUTS, Predictive\n",
"from scipy.stats import poisson as sp_poisson\n",
"from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs, Predictive\n",
"\n",
"numpyro.enable_x64()\n",
"RNG = random.PRNGKey(0)\n",
Expand Down
21 changes: 11 additions & 10 deletions notebooks/source/variationally_inferred_parameterization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,20 @@
},
"outputs": [],
"source": [
"import jax\n",
"import numpyro\n",
"import arviz as az\n",
"import numpy as np\n",
"import pandas as pd\n",
"import jax.numpy as jnp\n",
"from numpyro.infer import MCMC, NUTS\n",
"import numpyro.distributions as dist\n",
"from ucimlrepo import fetch_ucirepo\n",
"\n",
"rng_key = jax.random.PRNGKey(0)\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import numpyro\n",
"import numpyro.distributions as dist\n",
"from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO\n",
"from numpyro.infer.autoguide import AutoDiagonalNormal\n",
"from numpyro.infer.reparam import LocScaleReparam\n",
"from numpyro.infer import SVI, Trace_ELBO\n",
"from numpyro.infer.autoguide import AutoDiagonalNormal"
"\n",
"rng_key = jax.random.PRNGKey(0)"
]
},
{
Expand Down Expand Up @@ -1338,7 +1338,8 @@
"name": "python3"
},
"language_info": {
"name": "python"
"name": "python",
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ exclude = [
# Same as Black.
line-length = 88
indent-width = 4
extend-include = ["*.ipynb"]

[tool.ruff.lint]
select = ["E", "F", "I", "W"]
Expand Down

0 comments on commit 013e54c

Please sign in to comment.