In [None]:
%load_ext blackcellmagic 
# %black -l 120
%load_ext autoreload
%autoreload 2
%env XLA_PYTHON_CLIENT_PREALLOCATE=false

In [None]:
import subprocess as sp
from tests.flops_computation.dqn import DQN
from tests.flops_computation.tfdqn import TFDQN
from tests.flops_computation.isdqn import iSDQN
import jax
import jax.numpy as jnp
import time
from tests.utils import Generator


def count_params(params):
	return sum(x.size for x in jax.tree.leaves(params))

def get_memory_usage():
	return int(sp.check_output(["nvidia-smi", "--query-gpu=memory.used", "--format=csv"]).decode('ascii').split("\n")[1].split(" ")[0])

def measure_memory_params(params):
	memory = get_memory_usage()
	list_params = jax.block_until_ready([jax.tree.map(lambda w: jnp.ones_like(w).astype(jnp.float32) * i, params) for i in range(1000)])
	memory_params = (get_memory_usage() - memory) / 1000
	del list_params
	time.sleep(5)

	return memory_params

def count_flops(q, has_target_params=False):
	best_action_compiled = jax.jit(q.best_action).lower(q.params, sample_generator.state(jax.random.PRNGKey(0)), key=jax.random.PRNGKey(0)).compile()
	if not has_target_params:
		learn_on_batch_compiled = jax.jit(q.learn_on_batch).lower(q.params, q.optimizer_state, sample_generator.samples(jax.random.PRNGKey(0))).compile()
	else:
		learn_on_batch_compiled = jax.jit(q.learn_on_batch).lower(q.params, q.target_params, q.optimizer_state, sample_generator.samples(jax.random.PRNGKey(0))).compile()

	return best_action_compiled, learn_on_batch_compiled

sample_generator = Generator(32, (84, 84, 4), 10) 


for architecture, features in [["cnn", [32, 64, 64, 512]], ["impala", [32, 64, 64, 512]]]:
	print(f"--- {architecture} architecture ---")
	q_dqn = DQN(jax.random.PRNGKey(0), (84, 84, 4), 10, features, True, architecture, 0.001, 0.9, 1, 1, 100)
	print("TD-DQN", count_params(q_dqn.params) * 2)
	q_dqn_best_action_compiled, q_dqn_learn_on_batch_compiled = count_flops(q_dqn, has_target_params=True)
	print("FLOPs best action: ", q_dqn_best_action_compiled.cost_analysis()[0]["flops"])
	print("FLOPs to learn on a batch: ", q_dqn_learn_on_batch_compiled.cost_analysis()[0]["flops"], "\n")

	tfq_dqn = TFDQN(jax.random.PRNGKey(0), (84, 84, 4), 10, features, True, architecture, 0.001, 0.9, 1, 1, 100)
	print("TF-DQN", count_params(tfq_dqn.params))
	tfq_dqn_best_action_compiled, tfq_dqn_learn_on_batch_compiled = count_flops(tfq_dqn)
	print("FLOPs best action: ", tfq_dqn_best_action_compiled.cost_analysis()[0]["flops"])
	print("FLOPs to learn on a batch: ", tfq_dqn_learn_on_batch_compiled.cost_analysis()[0]["flops"], "\n")

	memory_tf = measure_memory_params(q_dqn.params)
	print(f"{memory_tf} Mb saved for TF")

	for K in [1, 4, 9, 49]:
		q_isdqn = iSDQN(jax.random.PRNGKey(0), (84, 84, 4), 10, K, features, True, architecture, 0, 0, 1, 1, 1, 1)
		
		print(f"iSDQN K={K}", count_params(q_isdqn.params))
		# print(f"{2 * memory_tf - measure_memory_params(q_isdqn.params)} Mb saved for iSDQN K={K}")
		q_isdqn_best_action_compiled, q_isdqn_learn_on_batch_compiled = count_flops(q_isdqn)
		print("FLOPs best action: ", q_isdqn_best_action_compiled.cost_analysis()[0]["flops"])
		print("FLOPs to learn on a batch: ", q_isdqn_learn_on_batch_compiled.cost_analysis()[0]["flops"], "\n")
	print("\n")