In [None]:
import math
import numpy as np

import jax
from tests.networks.utils import Generator

from experiments.atari.utils import generate_keys
from idqn.environments.atari import AtariEnv
from idqn.sample_collection.replay_buffer import ReplayBuffer
from idqn.networks.architectures.dqn import AtariDQN
from idqn.networks.architectures.iqn import AtariIQN
from idqn.networks.architectures.idqn import AtariiDQN
from idqn.networks.architectures.iiqn import AtariiIQN
from idqn.utils.head_behaviorial_policy import head_behaviorial_policy

q_key, train_key = generate_keys(1)

env = AtariEnv("Breakout")

replay_buffer = ReplayBuffer(
    (env.state_height, env.state_width),
    1000000,
    32,
    1,
    0.99,
    lambda x: np.clip(x, -1, 1),
)

sample_generator = Generator(32, (env.state_height, env.state_width, env.n_stacked_frames)) 

q_dqn = AtariDQN(
    (env.state_height, env.state_width, env.n_stacked_frames),
    env.n_actions,
    math.pow(0.99, 1),
    q_key,
    0.001,
    0.001,
    4,
    6000,
)

q_iqn = AtariIQN(
    (env.state_height, env.state_width, env.n_stacked_frames),
    env.n_actions,
    math.pow(0.99, 1),
    q_key,
    0.001,
    0.001,
    4,
    6000,
)

q_idqn = AtariiDQN(
    5 + 1,
    (env.state_height, env.state_width, env.n_stacked_frames),
    env.n_actions,
    math.pow(0.99, 1),
    q_key,
    head_behaviorial_policy("uniform", 5 + 1),
    0.001,
    0.001,
    4,
    30,
    6000,
    True,
)

q_iiqn = AtariiIQN(
    3 + 1,
    (env.state_height, env.state_width, env.n_stacked_frames),
    env.n_actions,
    math.pow(0.99, 1),
    q_key,
    head_behaviorial_policy("uniform", 3 + 1),
    0.001,
    0.001,
    4,
    30,
    6000,
    32,
    64, 
    64,
    True
)

In [None]:
def count_flops(q):
    best_action_compiled = jax.jit(q.best_action, static_argnames="self").lower(q.params, sample_generator.generate_state(jax.random.PRNGKey(0)), jax.random.PRNGKey(0)).compile()
    learn_on_batch_compiled = jax.jit(q.learn_on_batch, static_argnames="self").lower(q.params, q.target_params, q.optimizer_state, sample_generator.generate_samples(jax.random.PRNGKey(0)), jax.random.PRNGKey(0)).compile()

    return best_action_compiled, learn_on_batch_compiled


print("DQN")
dqn_best_action_compiled, dqn_learn_on_batch_compiled = count_flops(q_dqn)
print("FLOPs best action: ", dqn_best_action_compiled.cost_analysis()[0]["flops"])
print("FLOPs to learn on a batch: ", dqn_learn_on_batch_compiled.cost_analysis()[0]["flops"])
print("\n")

print("i-DQN")
idqn_best_action_compiled, idqn_learn_on_batch_compiled = count_flops(q_idqn)
print("FLOPs best action: ", idqn_best_action_compiled.cost_analysis()[0]["flops"])
print("FLOPs to learn on batch: ", idqn_learn_on_batch_compiled.cost_analysis()[0]["flops"])
print("\n")

print("DQN vs i-DQN")
print("Best action FLOPs ratio:", idqn_best_action_compiled.cost_analysis()[0]["flops"] / dqn_best_action_compiled.cost_analysis()[0]["flops"])
print("Learn on batch FLOPs ratio:", idqn_learn_on_batch_compiled.cost_analysis()[0]["flops"] / dqn_learn_on_batch_compiled.cost_analysis()[0]["flops"])
print("\n")

print("IQN")
iqn_best_action_compiled, iqn_learn_on_batch_compiled = count_flops(q_iqn)
print("FLOPs best action: ", iqn_best_action_compiled.cost_analysis()[0]["flops"])
print("FLOPs to learn on a batch: ", iqn_learn_on_batch_compiled.cost_analysis()[0]["flops"])
print("\n")

print("i-IQN")
iiqn_best_action_compiled, iiqn_learn_on_batch_compiled = count_flops(q_iiqn)
print("FLOPs best action: ", iiqn_best_action_compiled.cost_analysis()[0]["flops"])
print("FLOPs to learn on a batch: ", iiqn_learn_on_batch_compiled.cost_analysis()[0]["flops"])
print("\n")

print("IQN vs i-IQN")
print("Best action FLOPs ratio:", iiqn_best_action_compiled.cost_analysis()[0]["flops"] / iqn_best_action_compiled.cost_analysis()[0]["flops"])
print("Learn on batch FLOPs ratio:", iiqn_learn_on_batch_compiled.cost_analysis()[0]["flops"] / iqn_learn_on_batch_compiled.cost_analysis()[0]["flops"])
print("\n")