# Pulse Sequence Design Using Reinforcement Learning

Implementing deep deterministic policy gradient (DDPG) to learn pulse sequence design for spin systems.

TODO write up a more thorough background on the DDPG algorithm

For training, the following reward function was used
$$
r = -\log\left( 1- \left| \text{Tr}\left( \frac{U_\text{target}^\dagger U_\text{exp}}{D} \right) \right| \right)
= -\log\left( 1- \text{fidelity}(U_\text{target}^\dagger, U_\text{exp}) \right)
$$
For example, if the fidelity is $0.999$, then the reward $r = -\log(0.001) = 3$. 

Using the [OpenAI SpinningUp resource](https://spinningup.openai.com/en/latest/algorithms/ddpg.html#pseudocode) for the theoretical background on DDPG, and lots of TensorFlow documentation for how to write the algorithm below.

<!-- For the policy function, I need to perform gradient ascent with the following gradient
$$
\nabla_\theta 1/|B| \sum_{s \in B} Q_\phi (s, \pi_\theta(s))
$$

And for the Q-function, perform gradient descent with
$$
\nabla_\phi 1/|B| \sum_{(s,a,r,s',d) \in B} (Q_\phi(s,a) - y(r,s',d))^2
$$ -->

Other resources:

- https://keras.io/getting-started/sequential-model-guide/
- https://www.tensorflow.org/guide/keras/overview
- https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough#define_the_loss_and_gradient_function
- https://github.com/floodsung/DDPG/blob/master/actor_network.py
- https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/master/contents/9_Deep_Deterministic_Policy_Gradient_DDPG/DDPG.py

Also helpful: https://www.tensorflow.org/guide/migrate#customize_the_training_step

In [None]:
import spinSimulation as ss
import rlPulse as rlp
import numpy as np
import scipy.linalg as spla
import importlib

In [None]:
importlib.reload(ss)
importlib.reload(rlp)

# Initialize spin system

In [None]:
N = 4
dim = 2**N

pulse = .25e-6    # duration of pulse
delay = 3e-6      # duration of delay
f1 = 1/(4*pulse)  # for pi/2 pulses
coupling = 5e3    # coupling strength
delta = 500       # chemical shift strength (for identical spins)

(x,y,z) = (ss.x, ss.y, ss.z)
(X,Y,Z) = ss.getTotalSpin(N, dim)

Hdip, Hint = ss.getAllH(N, dim, coupling, delta)
HWHH0 = ss.getHWHH0(X,Y,Z,delta)

# Initialize RL algorithm

The target network parameters $\theta_\text{target}$ are updated by
$$
\theta_\text{target} = \rho \theta_\text{target} + (1-\rho)\theta
$$

TODO figure out if this buffer size makes sense

In [None]:
sDim = 3 # state represented by sequences of actions...
aDim = 3 # action = [phi, rot, time]

numExp = 2000 # how many experiences to "play" through and learn from
bufferSize = 500 # size of the replay buffer (i.e. how many experiences to keep in memory).
batchSize = 50 # size of batch (subset of replay buffer) to use as training for actor and critic.
p = 1 # action noise parameter
polyak = 0.75 # polyak averaging parameter
gamma = 0.5 # future reward discount rate

printEvery = 50
updateAfter = 500 # start updating actor/critic networks after this many episodes
updateEvery = 50  # update networks every __ episodes
numUpdates = 2 # how many training updates to perform on a random subset of experiences (s,a,r,s1,d)
randomizeDipolarEvery = 10
lowerNoiseAfter = 500


pDiff = (0-p)/(numExp-lowerNoiseAfter)

Initialize the actor and critic, as well as target actor and target critic. The actor learns the policy function
$$
\pi_\theta: S \to A, s \mapsto a
$$
that picks the optimal action $a$ for a given state $s$, with some set of parameters $\theta$ (in this case weights/biases in the neural network). The critic learns the Q-function
$$
Q_\phi: S \times A \to \mathbf{R}, (s,a) \mapsto q
$$
where $q$ is the total expected rewards by doing action $a$ on a state $s$, and $\phi$ is the parameter set for the Q-function model. The target actor/critic have different parameter sets $\theta_\text{target}$ and $\phi_\text{target}$.

In [None]:
actor = rlp.Actor(sDim,aDim,None)
actorTarget = rlp.Actor(sDim,aDim,None)
critic = rlp.Critic(sDim,aDim,None, gamma)
criticTarget = rlp.Critic(sDim,aDim,None, gamma)
env = rlp.Environment(N, dim, sDim, HWHH0, X, Y)

actorTarget.setParams(actor.getParams())
criticTarget.setParams(critic.getParams())

replayBuffer = rlp.ReplayBuffer(bufferSize)

## DDPG algorithm

In [None]:
rMat = np.zeros((numExp,))
aMat = np.zeros((numExp,aDim))
timeMat = np.zeros((numExp, 2)) # record length of episode so far and number of pulses
# keep track of when resets/updates happen
resetStateEps = []
updateEps = []

for i in tqdm(range(numExp)):
    # randomize dipolar coupling strengths for Hint
    if i > 0 and i % randomizeDipolarEvery == 0:
        Hdip, Hint = ss.getAllH(N, dim, coupling, delta)
    
    s = env.getState()
    # get action based on current state and some level of noise
    a = rlp.clipAction(actor.predict(env.state) + rlp.actionNoise(p))
    # evolve state based on action
    env.evolve(a, Hint)
    # get reward
    r = env.reward()
    
    # record episode data
    aMat[i,:] = a
    rMat[i] = r
    timeMat[i,:] = [env.t, np.sum(env.state[:,2] != 0)]
    
    # get updated state, and whether it's a terminal state
    s1 = env.getState()
    d = env.isDone()
    replayBuffer.add(s,a,r,s1,d)
    
    # update noise parameter
    if i > lowerNoiseAfter:
        p += pDiff
    
    # CHECK IF TERMINAL
    if d:
        env.reset()
        resetStateEps.append(i)
    # UPDATE NETWORKS
    if (i > updateAfter) and (i % updateEvery == 0):
#         print("updating actor/critic networks (episode {})".format(i))
        updateEps.append(i)
        for update in range(numUpdates):
            batch = replayBuffer.getSampleBatch(batchSize)
            # train critic
            critic.trainStep(batch, actorTarget, criticTarget)
            # train actor
            actor.trainStep(batch, critic)
            # update target networks
            criticTarget.updateParams(critic, polyak)
            actorTarget.updateParams(actor, polyak)
            

# Results

See how the rewards and actions change over time as the actor/critic (hopefully) learn how to control the spin system.

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.plot(rMat, color='black', label='rewards')
ymin, ymax = plt.ylim()
plt.vlines(updateEps, ymin, ymax, color='red', alpha=0.2, label='updates')
#plt.vlines(resetStateEps, ymin, ymax, color='blue', alpha=0.2, linestyles='dashed', label='state reset')
plt.title('Rewards for each episode')
plt.xlabel('Episode number')
plt.ylabel('Reward')
plt.legend()

In [None]:
%matplotlib inline

plt.hist(rMat, bins=20, color='black', label='rewards')
# ymin, ymax = plt.ylim()
# plt.vlines(updateEps, ymin, ymax, color='red', alpha=0.2, label='updates')
#plt.vlines(resetStateEps, ymin, ymax, color='blue', alpha=0.2, linestyles='dashed', label='state reset')
plt.title('Rewards histogram')
plt.legend()

In [None]:
%matplotlib inline

plt.plot(aMat[:,0], 'ok', label='phi')
plt.title('Phi action')
ymin, ymax = plt.ylim()
plt.vlines(updateEps, ymin, ymax, color='red', alpha=0.2, label='updates')
#plt.vlines(resetStateEps, ymin, ymax, color='blue', alpha=0.2, linestyles='dashed', label='state reset')
plt.xlabel('Episode number')
plt.ylabel('Phi action')
plt.legend()

In [None]:
%matplotlib inline

plt.plot(aMat[:,1], 'ok', label='rot')
plt.title('Rot action')
ymin, ymax = plt.ylim()
plt.vlines(updateEps, ymin, ymax, color='red', alpha=0.2, label='updates')
#plt.vlines(resetStateEps, ymin, ymax, color='blue', alpha=0.2, linestyles='dashed', label='state reset')
plt.xlabel('Episode number')
plt.ylabel('Rot action')
plt.legend()

In [None]:
%matplotlib inline

plt.plot(aMat[:,2], 'ok', label='time')
plt.title('Time action')
ymin, ymax = plt.ylim()
plt.vlines(updateEps, ymin, ymax, color='red', alpha=0.2, label='updates')
#plt.vlines(resetStateEps, ymin, ymax, color='blue', alpha=0.2, linestyles='dashed', label='state reset')
plt.xlabel('Episode number')
plt.ylabel('Time action')
plt.legend()

In [None]:
# print the state with highest reward in the buffer (might not be highest of all episodes because buffer forgets)
rBuffer = [_[2] for _  in replayBuffer.buffer]
rlp.printAction(replayBuffer.buffer[np.argmax(rBuffer)][3])
print("reward: ", np.max(rBuffer))

In [None]:
# evaluate average fidelity for a pulse sequence

fidelities = np.zeros((50,))
for i in range(50):
    Hdip, Hint = ss.getAllH(N, dim, coupling, delta)
    Uexp = spla.expm(-1j*(Hint*10.0e-6 - X*0*np.pi)) @ \
           spla.expm(-1j*(Hint*3.47e-6 - X*0*np.pi)) @ \
           spla.expm(-1j*(Hint*7.11e-6 + Y*0*np.pi))
    Utarget = ss.getPropagator(HWHH0, (7.11+3.47+10)*1e-6)
    fidelities[i] = ss.fidelity(np.power(Utarget,1), np.power(Uexp,1))
print("mean fidelity: ", np.mean(fidelities), "std dev: ", np.std(fidelities), "max: ", np.max(fidelities))