Skip to content

Commit

Permalink
add options to run bnn and gp in gpu (#195)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and neerajprad committed Jun 7, 2019
1 parent c5091a5 commit 3291ea6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
3 changes: 3 additions & 0 deletions examples/bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import jax.numpy as np
import jax.random as random
from jax import vmap
from jax.config import config as jax_config

import numpyro.distributions as dist
from numpyro.handlers import sample, seed, substitute, trace
Expand Down Expand Up @@ -91,6 +92,7 @@ def get_data(N=50, D_X=3, sigma_obs=0.05, N_test=500):


def main(args):
jax_config.update('jax_platform_name', args.device)
N, D_X, D_H = args.num_data, 3, args.num_hidden
X, Y, X_test = get_data(N=N, D_X=D_X)

Expand Down Expand Up @@ -128,5 +130,6 @@ def main(args):
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
parser.add_argument("--num-data", nargs='?', default=100, type=int)
parser.add_argument("--num-hidden", nargs='?', default=5, type=int)
parser.add_argument("--device", default='cpu', type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()
main(args)
8 changes: 6 additions & 2 deletions examples/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import jax.numpy as np
import jax.random as random
from jax import vmap
from jax.config import config as jax_config

import numpyro.distributions as dist
from numpyro.handlers import sample
Expand Down Expand Up @@ -87,7 +88,8 @@ def get_data(N=30, sigma_obs=0.15, N_test=400):


def main(args):
X, Y, X_test = get_data(N=25)
jax_config.update('jax_platform_name', args.device)
X, Y, X_test = get_data(N=args.num_data)

# do inference
rng, rng_predict = random.split(random.PRNGKey(0))
Expand Down Expand Up @@ -118,8 +120,10 @@ def main(args):


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
parser.add_argument("--num-data", nargs='?', default=25, type=int)
parser.add_argument("--device", default='cpu', type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()
main(args)

0 comments on commit 3291ea6

Please sign in to comment.