# Pulse Sequence Design using PPO
_Written by Will Kaufman_

This notebook walks through a reinforcement learning approach to pulse sequence design for spin systems. [TF-Agents](https://www.tensorflow.org/agents) is used as a reinforcement learning library that uses Tensorflow, a common machine learning framework.

In [None]:
import numpy as np
import os
import spin_simulation as ss
import time
import tensorflow as tf

from tf_agents.agents.ppo import ppo_clip_agent
from tf_agents.drivers import dynamic_episode_driver
from tf_agents.environments import tf_py_environment, parallel_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import actor_distribution_network, value_network
from tf_agents.policies import random_tf_policy, policy_saver
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.trajectories import time_step as ts
from tf_agents.utils import common

from environments import spin_sys_discrete

In [None]:
import importlib
importlib.reload(spin_sys_discrete)

## Define algorithm hyperparameters



In [None]:
num_iterations = 1000 # @param {type:"integer"}
episode_length = 5 # @param {type:"integer"}

# collect parameters
num_environment_steps = 5000  # @param {type:"integer"}
collect_episodes_per_iteration = 20 # @param {type:"integer"}
num_parallel_environments = 20 # @param {type:"integer"}
replay_buffer_max_length = 1000  # @param {type:"integer"}

#training parameters
num_epochs = 25
learning_rate = 1e-3  # @param {type:"number"}

# evaluation parameters
num_eval_episodes = 5  # @param {type:"integer"}
eval_interval = 200  # @param {type:"integer"}

batch_size = 12 #64  # @param {type:"integer"}

# summaries and logging parameters
train_checkpoint_interval=500
policy_checkpoint_interval=500
log_interval=50
summary_interval=50
summaries_flush_secs=1
use_tf_functions=True
debug_summaries=False
summarize_grads_and_vars=False

In [None]:
root_dir = "~/projects/rl_pulse/data/"

root_dir = os.path.expanduser(root_dir)
train_dir = os.path.join(root_dir, 'train')
eval_dir = os.path.join(root_dir, 'eval')
saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

In [None]:
train_summary_writer = tf.compat.v2.summary.create_file_writer(
    train_dir, flush_millis=summaries_flush_secs * 1000)
train_summary_writer.set_as_default()

eval_summary_writer = tf.compat.v2.summary.create_file_writer(
    eval_dir, flush_millis=summaries_flush_secs * 1000)
eval_metrics = [
    tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
    tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
]

## Initialize the spin system

This sets the parameters of the system ($N$ spin-1/2 particles, which corresponds to a Hilbert space with dimension $2^N$). For the purposes of simulation, $\hbar \equiv 1$.

The total internal Hamiltonian is given by
$$
H_\text{int} = C H_\text{dip} + \delta \sum_i^N I_z^{i}
$$
where $C$ is the coupling strength, $\delta$ is the chemical shift strength (each spin is assumed to be identical), and $H_\text{dip}$ is given by
$$
H_\text{dip} = \sum_{i,j}^N d_{i,j} \left(3I_z^{i}I_z^{j} - \mathbf{I}^{i} \cdot \mathbf{I}^{j}\right)
$$

The target Hamiltonian is set to be the 0th-order average Hamiltonian from the WHH-4 pulse sequence, which is designed to remove the dipolar interaction term from the internal Hamiltonian. The pulse sequence is $\tau, \overline{X}, \tau, Y, \tau, \tau, \overline{Y}, \tau, X, \tau$.
The zeroth-order average Hamiltonian for the WAHUHA pulse sequence is
$$
H_\text{WHH}^{(0)} = \delta / 3 \sum_i^N \left( I_x^{i} + I_y^{i} + I_z^{i} \right)
$$

In [None]:
N=4
dim = 2**N
coupling = 1e3
delta = 500
(X,Y,Z) = ss.get_total_spin(N=N, dim=dim)
H_target = ss.get_H_WHH_0(X, Y, Z, delta)

The `SpinSystemDiscreteEnv` class keeps track of the system dynamics, and implements methods that are necessary for RL:

- `action_spec`: Returns an `ArraySpec` that gives the shape and range of a valid action. For example, in a discrete action space, an action will be an integer scalar between 0 and `numActions - 1`. For a continuous action space, an action will be a 3-dimensional vector representing phase, amplitude, and duration of the pulse.
- `observation_spec`: Returns an `ArraySpec` that gives the shape and range of a valid observation. In this case, the observations are all the actions performed on the environment so far.
- `_reset`: Resets the environment. This means setting the propagator to the identity, and choosing a new random dipolar interaction matrix $(d_{i,j})$.
- `_step`: Evolves the environment according to the action. Returns a `TimeStep` which includes the step type (`FIRST`, `MID`, or `LAST`), the **reward**, the discount rate to apply to future rewards, and an **observation** of the environment.

The reward function $r(s,a)$ can in general depend on the environment state _and_ action performed. However, because the goal of pulse sequence design is to find high-fidelity pulse sequences, the reward only depends on the state. 
$$
r = -\log \left( 1-
    \left|
        \frac{\text{Tr} (U_\text{target}^\dagger U_\text{exp})}{\text{Tr}(\mathbb{1})}
    \right|
    \right)
% = -\log\left( 1- \text{fidelity}(U_\text{target}, U_\text{exp}) \right)
$$



In [None]:
env = spin_sys_discrete.SpinSystemDiscreteEnv(N=4, dim=16, coupling=1e3,
    delta=500, H_target=H_target, X=X, Y=Y, delay=5e-6, pulse_width=0,
    delay_after=True, state_size=episode_length)
# env.reset()

# train_py_env = spin_sys_discrete.SpinSystemDiscreteEnv(N=4, dim=16, coupling=1e3,
#     delta=500, H_target=H_target, X=X, Y=Y, delay=5e-6, pulse_width=0,
#     delay_after=True)
# eval_py_env = spin_sys_discrete.SpinSystemDiscreteEnv(N=4, dim=16, coupling=1e3,
#     delta=500, H_target=H_target, X=X, Y=Y, delay=5e-6, pulse_width=0,
#     delay_after=True)

print('Observation Spec:')
print(env.time_step_spec().observation)

print('Reward Spec:')
print(env.time_step_spec().reward)

print('Action Spec:')
print(env.action_spec())

train_env = tf_py_environment.TFPyEnvironment(env)
eval_env = tf_py_environment.TFPyEnvironment(env)

## Define actor and value networks

In PPO, there are two separate networks: the _actor_ network and the _value_ network. The actor network learns the policy function $\pi(a|s)$, while the value network learns $v_\pi(s)$.

In [None]:
actor_net = actor_distribution_network.ActorDistributionNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params= (50, 50),
    activation_fn=tf.keras.activations.tanh)

value_net = value_network.ValueNetwork(
    train_env.observation_spec(),
    fc_layer_params= (50, 50),
    activation_fn=tf.keras.activations.tanh)

See what the initial Q-values are for the network.

In [None]:
value_net(train_env.current_time_step().observation)[0].numpy()

In [None]:
value_net.summary()

In [None]:
value_net.get_layer("EncodingNetwork").summary()

## Create agent

In RL, the "agent" has a policy that determines its behavior. For DQN, the agent will act greedily during evaluation (i.e. it picks the action with the maximal Q-value) and epsilon-greedily during data collection. These policies are accessed with `agent.policy` (for evaluation) and `agent.collect_policy` (for data collection).

According to [the docs](https://www.tensorflow.org/agents/api_docs/python/tf_agents/agents/tf_agent/TFAgent?hl=fa#args), I can adjust `train_sequence_length=None` for RNN-based agents. When using non-RNN DQN, though, I don't have that option. 

In [None]:
# is there a v2 optimizer I could use?
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

In [None]:
global_step = tf.Variable(0, name="global_step", dtype=tf.int64)

agent = ppo_clip_agent.PPOClipAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    optimizer,
    actor_net=actor_net,
    value_net=value_net,
    entropy_regularization=0.0,
    importance_ratio_clipping=0.2,
    normalize_observations=False,
    normalize_rewards=False,
    use_gae=True,
    num_epochs=num_epochs,
    debug_summaries=debug_summaries,
    summarize_grads_and_vars=summarize_grads_and_vars,
    train_step_counter=global_step)

agent.initialize()

In [None]:
eval_policy = agent.policy
collect_policy = agent.collect_policy

random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())

In [None]:
train_env.time_step_spec()

## Metrics for training/evaluation

In [None]:
environment_steps_metric = tf_metrics.EnvironmentSteps()
step_metrics = [
    tf_metrics.NumberOfEpisodes(),
    environment_steps_metric,
]

train_metrics = step_metrics + [
    tf_metrics.AverageReturnMetric(
        batch_size=1), # TODO replace with num_parallel_environments
    tf_metrics.AverageEpisodeLengthMetric(
        batch_size=1), # TODO replace with num_parallel_environments
]

In [None]:
def compute_avg_return(environment, policy, num_episodes=10, print_actions=False):

    total_return = 0.0
    for _ in range(num_episodes):

        time_step = environment.reset()
        policy_state = policy.get_initial_state(environment.batch_size)
        episode_return = 0.0

        while not time_step.is_last():
            action_step = policy.action(time_step, policy_state = policy_state)
            policy_state = action_step.state
            time_step = environment.step(action_step.action)
            episode_return += time_step.reward
            if print_actions:
                print(f"action: {action_step.action}, reward: {time_step.reward}, return: {episode_return}")
        total_return += episode_return

    avg_return = total_return / num_episodes
    return avg_return.numpy()[0]

In [None]:
compute_avg_return(eval_env, random_policy, num_eval_episodes)

## Create the replay buffer

A replay buffer stores trajectories (sequences of states and actions) from data collection, and then samples those trajectories to train the agent. This increases data-efficiency and decreases bias.

In [None]:
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_max_length,
)

replay_buffer

## Add checkpoints and policy saver

In [None]:
train_checkpointer = common.Checkpointer(
    ckpt_dir=train_dir,
    agent=agent,
    global_step=global_step,
    metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
policy_checkpointer = common.Checkpointer(
    ckpt_dir=os.path.join(train_dir, 'policy'),
    policy=eval_policy,
    global_step=global_step)
saved_model = policy_saver.PolicySaver(
    eval_policy, train_step=global_step)

## Save some trajectories to the replay buffer

In [None]:
def collect_step(environment, policy, buffer):
    time_step = environment.current_time_step()
    if time_step.is_last():
        time_step = environment.reset()
    action_step = policy.action(time_step)
    next_time_step = environment.step(action_step.action)
    traj = trajectory.from_transition(time_step, action_step, next_time_step)
    
    # Add trajectory to the replay buffer
    buffer.add_batch(traj)

def collect_data(env, policy, buffer, steps):
    for _ in range(steps):
        collect_step(env, policy, buffer)

Collect 64 episodes from a random policy and store to the replay buffer.

In [None]:
collect_step(train_env, collect_policy, replay_buffer)

In [None]:
train_env.reset()

collect_data(env=train_env,
    policy=collect_policy,
    buffer=replay_buffer,
    steps=episode_length*64)

A Tensorflow `Dataset` takes care of sampling the replay buffer and generating trajectories quite nicely. The replay buffer can be converted to a `Dataset` which is then used for training.

In [None]:
# # Dataset generates trajectories with shape [Bx2x...]
# dataset = replay_buffer.as_dataset(
#     num_parallel_calls=2,
#     sample_batch_size=batch_size, 
#     num_steps=2).prefetch(3)


# dataset

In [None]:
# iterator = iter(dataset)

# print(iterator)

In [None]:
#iterator.next()

## Create the driver

TODO add writeup to this section

In [None]:
collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
    train_env,
    collect_policy,
    observers=[replay_buffer.add_batch] + train_metrics,
    num_episodes=collect_episodes_per_iteration)

In [None]:
def train_step():
    trajectories = replay_buffer.gather_all()
    return agent.train(experience=trajectories)

Convert functions to `tf_function`s for speedup.

In [None]:
collect_driver.run = common.function(collect_driver.run, autograph=False)
agent.train = common.function(agent.train, autograph=False)
train_step = common.function(train_step)
#agent.collect_policy.action = common.function(agent.collect_policy.action)

## Train the agent

In [None]:
# Reset the train step
agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
# avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
# returns = [avg_return]
# print(returns)

In [None]:
collect_time = 0
train_time = 0
timed_at_step = global_step.numpy()

In [None]:
#%load_ext line_profiler
# define some code
#%lprun -f train_agent train_agent()

In [None]:
# def train_agent():
#     train_env.reset()
#     policy_state = agent.collect_policy.get_initial_state(train_env.batch_size)

#     for _ in range(num_iterations):

#         # Collect a few steps using collect_policy and save to the replay buffer.
# #         final_time_step, policy_state = driver.run()
#         for _ in range(collect_steps_per_iteration):
#             #print(policy_state)
#             collect_step(train_env,
#                          agent.collect_policy,
#                          replay_buffer)

#         # Sample a batch of data from the buffer and update the agent's network.
#         experience, unused_info = next(iterator)
#         train_loss = agent.train(experience).loss

#         step = agent.train_step_counter.numpy()

#         if step % log_interval == 0:
#             # print(q_net(np.zeros((1,5,5), dtype="float32"))[0].numpy())
#             print(f'step = {step}: loss = {train_loss}')

#         if step % eval_interval == 0:
#             avg_return = compute_avg_return(eval_env, agent.policy)
#             print(f'step = {step}: Average Return = {avg_return}')
#             if avg_return > 50:
#                 break
#             returns.append(avg_return)

## TODO

- [x] Include eval [like this](https://github.com/tensorflow/agents/blob/v0.5.0/tf_agents/agents/ppo/examples/v2/train_eval_clip_agent.py#L238)
- [ ] Continue debugging code below (lots of things I failed to define above...)
- [ ] See what result is, if it works well

In [None]:
while environment_steps_metric.result() < num_environment_steps:
    global_step_val = global_step.numpy()
    if global_step_val % eval_interval == 0:
        metric_utils.eager_compute(
            eval_metrics,
            eval_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
    
    start_time = time.time()
    collect_driver.run()
    collect_time += time.time() - start_time

    start_time = time.time()
    total_loss, _ = train_step()
    replay_buffer.clear()
    train_time += time.time() - start_time
    
    for train_metric in train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=step_metrics)

    if global_step_val % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step_val, total_loss)
        steps_per_sec = (
            (global_step_val - timed_at_step) / (collect_time + train_time))
        logging.info('%.3f steps/sec', steps_per_sec)
        logging.info('collect_time = %.3f, train_time = %.3f', collect_time,
                     train_time)
    with tf.compat.v2.summary.record_if(True):
        tf.compat.v2.summary.scalar(
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)

    if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)

    if global_step_val % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=global_step_val)
        saved_model_path = os.path.join(
            saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9))
        saved_model.save(saved_model_path)

    timed_at_step = global_step_val
    collect_time = 0
    train_time = 0

## Evaluate the agent

See what pulse sequences it's performing

In [None]:
compute_avg_return(eval_env, agent.policy, num_episodes=1, print_actions=True)

Look at the Q-network structure (including the encoding network, LSTM, and final dense layers).

In [None]:
q_rnn_net.summary()

In [None]:
w = q_net.get_layer("EncodingNetwork").get_weights()
for weight in w:
    print(weight.shape)

And see what the Q-function returns for a play-through

In [None]:
ts = train_env.reset()
print(q_net(ts.observation, step_type=ts.step_type)[0].numpy())
ts = train_env.step(1)
print(q_net(ts.observation, step_type=ts.step_type)[0].numpy())
ts = train_env.step(2)
print(q_net(ts.observation, step_type=ts.step_type)[0].numpy())
ts = train_env.step(4)
print(q_net(ts.observation, step_type=ts.step_type)[0].numpy())
ts = train_env.step(3)
print(q_net(ts.observation, step_type=ts.step_type)[0].numpy())
ts = train_env.step(0)
print(q_net(ts.observation, step_type=ts.step_type)[0].numpy())
print(ts.reward.numpy())

## Manually interact with the environment

In [None]:
eval_env.reset()
# run the WHH-4 sequence
eval_env.step(1)
eval_env.step(2)
eval_env.step(4)
eval_env.step(3)
eval_env.step(0)