In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.90

env: XLA_PYTHON_CLIENT_MEM_FRACTION=0.90


In [3]:
import time
import jax
import wandb

In [4]:
jax.devices(), jax.default_backend()

([CudaDevice(id=0)], 'gpu')

In [5]:
from grid_maze3 import Grid_Maze

In [6]:
def make_benchmark(config):
	env = Grid_Maze(**config["ENV_KWARGS"])
	config["NUM_ACTORS"] = env.num_agents * config["NUM_ENVS"]

	def benchmark(rng):
		def init_runner_state(rng):

			# INIT ENV
			rng, _rng = jax.random.split(rng)
			reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
			env_state, obsv, done = jax.vmap(env.reset)(reset_rng)

			return (env_state, obsv, rng)

		def env_step(runner_state, unused):
			env_state, last_obs, rng = runner_state

			# SELECT ACTION
			rng, _rng = jax.random.split(rng)
			rngs = jax.random.split(_rng, config["NUM_ENVS"]).reshape((config["NUM_ENVS"], -1))
			actions = jax.vmap(env.action_spaces.sample)(rngs)

			# STEP ENV
			rng, _rng = jax.random.split(rng)
			rng_step = jax.random.split(_rng, config["NUM_ENVS"])
			env_state, obsv, _, _ = jax.vmap(env.step)(
				rng_step, env_state, actions
			)
			runner_state = (env_state, obsv, rng)
			return runner_state, None

		rng, init_rng = jax.random.split(rng)
		runner_state = init_runner_state(init_rng)
		runner_state = jax.lax.scan(env_step, runner_state, None, config["NUM_STEPS"])
		return runner_state

	return benchmark

In [None]:
for num_agents in [40]:
	config = {
		"NUM_STEPS": 5,
		"NUM_ENVS": 1000,
		"ACTIVATION": "relu",
		"ENV_NAME": "grid_maze",
		"NUM_SEEDS": 1,
		"SEED": 0,
	}

	config["ENV_KWARGS"] = {
		"width": 20,
		"height": 20,
		"obstacle_density": 0.5,
		"num_agents": num_agents,
		"grain_factor": 4,
		"obstacle_size": 0.4,
		"contact_force": 500,
		"contact_margin": 1e-3,
		"dt": 0.01,
		"max_steps": 500,
		"frameskip": 4,
	}

	wandb.init(
		project="env_comparisons",
		config={
			"n_runs": 1,
			"rollout_length": config["NUM_STEPS"],
			"device": str(jax.devices()[0]),
			"benchamrk_config": config
		},
		name=f"rl2/myenv_jax"
	)

	### JAXMARL BENCHMARK
	num_envs = [100, 600, 1100, 1600, 2100, 2600, 3100, 3600, 4100, 4600]
	for num in num_envs:
		config["NUM_ENVS"] = num

		total_time = 0.
		for run in range(wandb.config.n_runs):
			jax.clear_caches()
			benchmark_fn = jax.jit(make_benchmark(config))
			rng = jax.random.PRNGKey(config["SEED"])
			rng, _rng = jax.random.split(rng)\

			benchmark_jit = jax.jit(benchmark_fn).lower(_rng).compile()

			before = time.perf_counter_ns()
			runner_state = jax.block_until_ready(benchmark_jit(_rng))
			after = time.perf_counter_ns()

			total_time += (after - before) / 1e9
			
		env = Grid_Maze(**config["ENV_KWARGS"])
		# env = jaxmarl.make(config["ENV_NAME"], **config["ENV_KWARGS"])

		sps = wandb.config.n_runs * config['NUM_STEPS'] * config['NUM_ENVS'] / total_time
		ops = sps * env.num_agents

		wandb.log({"num_envs": config["NUM_ENVS"], "SPS": sps, "OPS": ops, "n_agents": env.num_agents, "n_objects": env.num_entities})

	wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mapshenitsyn[0m ([33mapshenitsyn-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


0,1
OPS,▁▆▇█▇▇▇▆▆▆
SPS,▁▆▇█▇▇▇▆▆▆
n_agents,▁▁▁▁▁▁▁▁▁▁
n_objects,▁▁▁▁▁▁▁▁▁▁
num_envs,▁▂▃▃▄▅▆▆▇█

0,1
OPS,430301.20829
SPS,10757.53021
n_agents,40.0
n_objects,2680.0
num_envs,4600.0


: 

In [7]:
for obstacle_density in range(5, 60, 5):
	obstacle_density = obstacle_density / 100
	config = {
	"NUM_STEPS": 100,
	"NUM_ENVS": 1000,
	"ACTIVATION": "relu",
	"ENV_NAME": "grid_maze",
	"NUM_SEEDS": 1,
	"SEED": 0,
	}

	config["ENV_KWARGS"] = {
	"width": 18,
	"height": 18,
	"obstacle_density": obstacle_density,
	"num_agents": 32,
	"grain_factor": 4,
	"obstacle_size": 0.4,
	"contact_force": 500,
	"contact_margin": 1e-3,
	"dt": 0.01,
	"max_steps": 100,
	"frameskip": 6,
	}

	wandb.init(
	project="jaxmarl_fps",
	config={
		"n_runs": 5,
		"rollout_length": config["NUM_STEPS"],
		"device": str(jax.devices()[0]),
		"benchamrk_config": config
	},
	name=f"rl2/grid_maze3_a{config['ENV_KWARGS']['num_agents']}_od{config['ENV_KWARGS']['obstacle_density']}_jupyter"
	)

	### JAXMARL BENCHMARK
	num_envs = [400, 1000, 2000, 3000, 4000, 5000, 10000]
	for num in num_envs:
		try:
			config["NUM_ENVS"] = num

			total_time = 0.
			for run in range(wandb.config.n_runs):
				jax.clear_caches()
				benchmark_fn = jax.jit(make_benchmark(config))
				rng = jax.random.PRNGKey(config["SEED"])
				rng, _rng = jax.random.split(rng)\

				benchmark_jit = jax.jit(benchmark_fn).lower(_rng).compile()

				before = time.perf_counter_ns()
				runner_state = jax.block_until_ready(benchmark_jit(_rng))
				after = time.perf_counter_ns()

				total_time += (after - before) / 1e9
				
			env = Grid_Maze(**config["ENV_KWARGS"])
			# env = jaxmarl.make(config["ENV_NAME"], **config["ENV_KWARGS"])

			sps = wandb.config.n_runs * config['NUM_STEPS'] * config['NUM_ENVS'] / total_time
			ops = sps * env.num_agents

			wandb.log({"num_envs": config["NUM_ENVS"], "SPS": sps, "OPS": ops, "n_agents": env.num_agents, "n_objects": env.num_entities, "n_obstacles": env.num_obstacles})
		except:
			break

	wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mapshenitsyn[0m ([33mapshenitsyn-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


0,1
OPS,▃█▅▅▅▄▁
SPS,▃█▅▅▅▄▁
n_agents,▁▁▁▁▁▁▁
n_objects,▁▁▁▁▁▁▁
n_obstacles,▁▁▁▁▁▁▁
num_envs,▁▁▂▃▄▄█

0,1
OPS,2936909.06759
SPS,91778.40836
n_agents,32.0
n_objects,440.0
n_obstacles,16.0
num_envs,10000.0


0,1
OPS,▃█▅▅▄▃▁
SPS,▃█▅▅▄▃▁
n_agents,▁▁▁▁▁▁▁
n_objects,▁▁▁▁▁▁▁
n_obstacles,▁▁▁▁▁▁▁
num_envs,▁▁▂▃▄▄█

0,1
OPS,1883562.99396
SPS,58861.34356
n_agents,32.0
n_objects,632.0
n_obstacles,32.0
num_envs,10000.0


0,1
OPS,▄▇█▇▆▄▁
SPS,▄▇█▇▆▄▁
n_agents,▁▁▁▁▁▁▁
n_objects,▁▁▁▁▁▁▁
n_obstacles,▁▁▁▁▁▁▁
num_envs,▁▁▂▃▄▄█

0,1
OPS,1376658.83931
SPS,43020.58873
n_agents,32.0
n_objects,824.0
n_obstacles,48.0
num_envs,10000.0


0,1
OPS,▄▆█▆▅▄▁
SPS,▄▆█▆▅▄▁
n_agents,▁▁▁▁▁▁▁
n_objects,▁▁▁▁▁▁▁
n_obstacles,▁▁▁▁▁▁▁
num_envs,▁▁▂▃▄▄█

0,1
OPS,1074962.53109
SPS,33592.5791
n_agents,32.0
n_objects,1016.0
n_obstacles,64.0
num_envs,10000.0


0,1
OPS,▄▇█▆▅▄▁
SPS,▄▇█▆▅▄▁
n_agents,▁▁▁▁▁▁▁
n_objects,▁▁▁▁▁▁▁
n_obstacles,▁▁▁▁▁▁▁
num_envs,▁▁▂▃▄▄█

0,1
OPS,874931.89184
SPS,27341.62162
n_agents,32.0
n_objects,1220.0
n_obstacles,81.0
num_envs,10000.0


0,1
OPS,▅▇█▆▅▄▁
SPS,▅▇█▆▅▄▁
n_agents,▁▁▁▁▁▁▁
n_objects,▁▁▁▁▁▁▁
n_obstacles,▁▁▁▁▁▁▁
num_envs,▁▁▂▃▄▄█

0,1
OPS,740695.35668
SPS,23146.7299
n_agents,32.0
n_objects,1412.0
n_obstacles,97.0
num_envs,10000.0


0,1
OPS,▅▇█▆▅▃▁
SPS,▅▇█▆▅▃▁
n_agents,▁▁▁▁▁▁▁
n_objects,▁▁▁▁▁▁▁
n_obstacles,▁▁▁▁▁▁▁
num_envs,▁▁▂▃▄▄█

0,1
OPS,592799.75373
SPS,18524.9923
n_agents,32.0
n_objects,1604.0
n_obstacles,113.0
num_envs,10000.0


0,1
OPS,▅▇█▆▅▃▁
SPS,▅▇█▆▅▃▁
n_agents,▁▁▁▁▁▁▁
n_objects,▁▁▁▁▁▁▁
n_obstacles,▁▁▁▁▁▁▁
num_envs,▁▁▂▃▄▄█

0,1
OPS,546557.5781
SPS,17079.92432
n_agents,32.0
n_objects,1796.0
n_obstacles,129.0
num_envs,10000.0


0,1
OPS,▃▇█▆▅▃▁
SPS,▃▇█▆▅▃▁
n_agents,▁▁▁▁▁▁▁
n_objects,▁▁▁▁▁▁▁
n_obstacles,▁▁▁▁▁▁▁
num_envs,▁▁▂▃▄▄█

0,1
OPS,475394.92466
SPS,14856.0914
n_agents,32.0
n_objects,1988.0
n_obstacles,145.0
num_envs,10000.0


0,1
OPS,▃██▆▅▄▁
SPS,▃██▆▅▄▁
n_agents,▁▁▁▁▁▁▁
n_objects,▁▁▁▁▁▁▁
n_obstacles,▁▁▁▁▁▁▁
num_envs,▁▁▂▃▄▄█

0,1
OPS,454759.00673
SPS,14211.21896
n_agents,32.0
n_objects,2192.0
n_obstacles,162.0
num_envs,10000.0


0,1
OPS,▃█▇▆▅▄▁
SPS,▃█▇▆▅▄▁
n_agents,▁▁▁▁▁▁▁
n_objects,▁▁▁▁▁▁▁
n_obstacles,▁▁▁▁▁▁▁
num_envs,▁▁▂▃▄▄█

0,1
OPS,403024.14994
SPS,12594.50469
n_agents,32.0
n_objects,2384.0
n_obstacles,178.0
num_envs,10000.0


In [13]:
runner_state[0][1].shape

(10000, 32, 178, 2)

In [8]:
config = {
"NUM_STEPS": 100,
"NUM_ENVS": 1000,
"ACTIVATION": "relu",
"ENV_NAME": "grid_maze",
"NUM_SEEDS": 1,
"SEED": 0,
}

config["ENV_KWARGS"] = {
"width": 18,
"height": 18,
"obstacle_density": 0.55,
"num_agents": 32,
"grain_factor": 4,
"obstacle_size": 0.4,
"contact_force": 500,
"contact_margin": 1e-3,
"dt": 0.015,
"max_steps": 100,
}

jax.clear_caches()
benchmark_fn = jax.jit(make_benchmark(config))
rng = jax.random.PRNGKey(config["SEED"])
rng, _rng = jax.random.split(rng)\

benchmark_jit = jax.jit(benchmark_fn).lower(_rng).compile()

runner_state = jax.block_until_ready(benchmark_jit(_rng))

In [12]:
runner_state[0][0].agent_pos.shape

(1000, 32, 2)