In [None]:
import os

from tf_agents.networks import actor_distribution_network
from tf_agents.utils import common
from tf_agents.environments import tf_py_environment
import tensorflow as tf

from tf_agents.agents.reinforce import reinforce_agent
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.metrics import tf_metrics
from tf_agents.drivers import dynamic_episode_driver
from tf_agents.policies import random_tf_policy
import ipywidgets as widgets
from ipywidgets import interact, interact_manual

%load_ext autoreload
%autoreload 
from QTransferEnv import *
from QTransferLib import compute_populations, dephase_factory
from AgentsTraining import *
from Plots import *

In [None]:
root_dir = 'logs/'
root_dir = os.path.expanduser(root_dir)
train_dir = os.path.join(root_dir, 'train')
eval_dir = os.path.join(root_dir, 'eval')

N = 5
Ωmax = 1
n_steps = 30
γ = 0.
t_max = 2
initial_state = qt.basis(N, 0)
target_state = qt.basis(N, N-1)
noise = Noise("gaussian", percentage=0.05)
deltas = np.zeros(N, dtype=complex)
deltas[1] = 0 + 0j

num_iterations = 100
reward_gain = 1.0
summaries_flush_secs = 1
num_eval_episodes = 10

fc_layer_params=(100, 50, 30)
optimizer_learning_rate=1e-3
replay_buffer_episodes_capacity=10
replay_buffer_capacity=replay_buffer_episodes_capacity * n_steps
collect_episodes_per_iteration=1
use_tf_functions=True
batch_size=256
train_steps_per_iteration=1
replay_buffer_percentage_used_as_experience=1
experience_episodes_per_train_step=int(replay_buffer_episodes_capacity*replay_buffer_percentage_used_as_experience)
initial_collect_episodes=experience_episodes_per_train_step

eval_interval=10
summary_interval=10

train_checkpoint_interval=num_iterations + 1
policy_checkpoint_interval=num_iterations + 1
rb_checkpoint_interval=num_iterations + 1
log_interval=num_iterations + 1

In [None]:
env_train_py = QTransferEnv(N=N,
                            t_max=t_max,
                            n_steps=n_steps,
                            initial_state=initial_state,
                            target_state=target_state,
                            reward_gain=reward_gain,
                            omega_min=0,
                            omega_max=Ωmax,
                            noise=noise,
                            deltas=deltas)

env_eval_py = QTransferEnv(N=N,
                           t_max=t_max,
                           n_steps=n_steps,
                           initial_state=initial_state,
                           target_state=target_state,
                           reward_gain=reward_gain,
                           omega_min=0,
                           omega_max=Ωmax,
                           noise=noise,
                           deltas=deltas)

In [None]:
global_step = tf.compat.v1.train.get_or_create_global_step()

In [None]:
tf_env = tf_py_environment.TFPyEnvironment(env_train_py)
eval_tf_env = tf_py_environment.TFPyEnvironment(env_eval_py)

time_step_spec = tf_env.time_step_spec()
observation_spec = time_step_spec.observation
action_spec = tf_env.action_spec()

# Actor Network
actor_net = actor_distribution_network.ActorDistributionNetwork(
    observation_spec,
    action_spec,
    fc_layer_params=fc_layer_params,
)

# Agent
tf_agent = reinforce_agent.ReinforceAgent(
    time_step_spec,
    action_spec,
    actor_network=actor_net,
    optimizer=tf.compat.v1.train.AdamOptimizer(
        learning_rate=optimizer_learning_rate),
    normalize_returns=True,
    train_step_counter=global_step,
)

tf_agent.initialize()

# define the optimizer
# optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # (learning_rate=learning_rate)

# Replay Buffer
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=tf_agent.collect_data_spec,
    batch_size=tf_env.batch_size,
    max_length=replay_buffer_capacity,
)
replay_observer = [replay_buffer.add_batch]

# Train Metrics
train_metrics = [
    tf_metrics.NumberOfEpisodes(),
    tf_metrics.EnvironmentSteps(),
    tf_metrics.AverageReturnMetric(
        buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
    tf_metrics.AverageEpisodeLengthMetric(
        buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
]

# Eval Metrics
eval_metrics = [
    tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
    tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
]

# Policies
eval_policy = tf_agent.policy
initial_collect_policy = random_tf_policy.RandomTFPolicy(time_step_spec, action_spec)
collect_policy = tf_agent.collect_policy

# Drivers
initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
    tf_env,
    initial_collect_policy,
    observers=replay_observer + train_metrics,
    num_episodes=initial_collect_episodes)

collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
    tf_env,
    collect_policy,
    observers=replay_observer + train_metrics,
    num_episodes=collect_episodes_per_iteration)

eval_driver = dynamic_episode_driver.DynamicEpisodeDriver(eval_tf_env,
                                                          eval_policy,
                                                          eval_metrics,
                                                          num_episodes=1)

if use_tf_functions:
    initial_collect_driver.run = common.function(initial_collect_driver.run)
    collect_driver.run = common.function(collect_driver.run)
    eval_driver.run = common.function(eval_driver.run)
    tf_agent.train = common.function(tf_agent.train)

if replay_buffer.num_frames() == 0:
    # Collect initial replay data.
    initial_collect_driver.run()

time_step = None
policy_state = collect_policy.get_initial_state(tf_env.batch_size)

timed_at_step = global_step.numpy()
time_acc = 0

def train_step():
        experience = replay_buffer.gather_all()
        return tf_agent.train(experience)

if use_tf_functions:
    train_step = common.function(train_step)

global_step_val = global_step.numpy()

In [None]:
# Evaluate the agent's policy once before training.
final_time_step, policy_state = eval_driver.run()
print("Initial Average Return: ", eval_metrics[0].result().numpy())

In [None]:
return_list = []
episode_list = []
iteration_list = []
with trange(num_iterations, dynamic_ncols=False) as t:
    for i in t:
        t.set_description(f'episode {i}')

        time_step, policy_state = collect_driver.run(
            time_step=time_step,
            policy_state=policy_state,
        )
        
        train_loss = train_step()

        if i % eval_interval == 0 or i == num_iterations - 1:
            eval_metrics[0].reset()
            _ = eval_driver.run()

            t.set_postfix({"return": eval_metrics[0].result().numpy()})

In [None]:
return_list = []
episode_list = []
iteration_list = []

In [None]:
plot_training_returns(return_list, 0.9)

In [None]:
times = env_eval_py.times
def interactive(Iteration):
    plot_episode(times, episode_list[Iteration])
interact(interactive, Iteration=widgets.IntSlider(min=0, max=len(episode_list)-1, step=1, value=0))

In [None]:
print("Population {}: {}".format(N, episode_list[-1].observation.numpy()[0, -1, N - 1]))
print("Max population: {}".format(np.max(episode_list[-1].observation.numpy()[0, :, N - 1])))
print("Total population: {}".format(np.sum(episode_list[-1].observation.numpy()[0, -1, :N])))

In [None]:
def interactive(Iteration):
    pulses = np.array(episode_list[Iteration].action.numpy()[0, :, :])
    ax = plot_pulses(times, pulses)
    ax.set_ylim((0, Ωmax*1.05))
interact(interactive, Iteration=widgets.IntSlider(min=0, max=len(episode_list)-1, step=1, value=0))

In [None]:
def interactive(Iteration):
    pulses = np.array(episode_list[Iteration].action.numpy()[0, :, :])
    dm = env_eval_py.run_qstepevolution(pulses)
    populations = np.diagonal(dm, axis1=1, axis2=2).real
    plot_populations(env_eval_py.times, populations)
    print("Population {}: {}".format(N, populations[-1][-1]))
interact(interactive, Iteration=widgets.IntSlider(min=0, max=len(episode_list)-1, step=1, value=0))

In [None]:
def interactive(Iteration):
    pulses = np.array(episode_list[Iteration].action.numpy()[0, :, :])
    times, populations = compute_populations(N, initial_state, t_max, pulses, constant_delta_factory(deltas, n_steps + 1), dephase_factory(N, γ))
    plot_populations(times, populations)
    print("Population {}: {}".format(N, populations[-1][-1]))
interact(interactive, Iteration=widgets.IntSlider(min=0, max=len(episode_list)-1, step=1, value=0))