# Pulse Sequence Design Using Reinforcement Learning

Implementing deep deterministic policy gradient (DDPG) to learn pulse sequence design for spin systems. The [OpenAI SpinningUp resource](https://spinningup.openai.com/en/latest/algorithms/ddpg.html#pseudocode) has a good theoretical background on DDPG which I used to implement the algorithm below.

DDPG is designed for _continuous_ action spaces, which is the ultimate goal for this project (to apply pulses with arbitrary axes of rotation, rotation angles, and times, instead of limiting to pi/2 pulses along X or Y). However, that means the algorithm is less suited to constrained versions of the problem, such as only applying pi/2 pulses of a certain length about X or Y.

For training, the following reward function was used
$$
r = -\log\left( 1- \left| \text{Tr}\left( \frac{U_\text{target}^\dagger U_\text{exp}}{2^N} \right) \right| \right)
= -\log\left( 1- \text{fidelity}(U_\text{target}, U_\text{exp}) \right)
$$
For example, if the fidelity is $0.999$, then the reward $r = -\log(0.001) = 3$. 

<!-- For the policy function, I need to perform gradient ascent with the following gradient
$$
\nabla_\theta 1/|B| \sum_{s \in B} Q_\phi (s, \pi_\theta(s))
$$

And for the Q-function, perform gradient descent with
$$
\nabla_\phi 1/|B| \sum_{(s,a,r,s',d) \in B} (Q_\phi(s,a) - y(r,s',d))^2
$$ -->

Other resources:

- https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough#define_the_loss_and_gradient_function
- https://www.tensorflow.org/guide/migrate#customize_the_training_step

In [None]:
import spinSimulation as ss
import rlPulse as rlp
import numpy as np
import scipy.linalg as spla
import importlib
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
importlib.reload(ss)
importlib.reload(rlp)

# Initialize spin system

This sets the parameters of the system ($N$ spin-1/2 particles, which corresponds to a Hilbert space with dimension $2^N$). For the purposes of simulation, $\hbar \equiv 1$.

The total internal Hamiltonian is given by
$$
H_\text{int} = C H_\text{dip} + \Delta \sum_i^N I_z^{(i)}
$$
where $C$ is the coupling strength, $\Delta$ is the chemical shift strength (each spin is assumed to be identical), and $H_\text{dip}$ is given by
$$
H_\text{dip} = \sum_{i,j}^N d_{i,j} \left(3I_z^{(i)}I_z^{(j)} - \mathbf{I}^{(i)} \cdot \mathbf{I}^{(j)}\right)
$$

The WAHUHA pulse sequence is designed to remove the dipolar interaction term from the internal Hamiltonian. The pulse sequence is $\tau, P_{-x}, \tau, P_{y}, \tau, \tau, P_{-y}, \tau, P_{x}, \tau$.
The zeroth-order average Hamiltonian for the WAHUHA pulse sequence is
$$
H_\text{WHH} = \Delta / 3 \sum_i^N I_x^{(i)} + I_y^{(i)} + I_z^{(i)}
$$

In [None]:
N = 4
dim = 2**N

# pulse = .25e-6    # duration of pulse
# delay = 3e-6      # duration of delay
coupling = 5e3    # coupling strength
delta = 500       # chemical shift strength (for identical spins)

(x,y,z) = (ss.x, ss.y, ss.z)
(X,Y,Z) = ss.getTotalSpin(N, dim)

Hdip, Hint = ss.getAllH(N, dim, coupling, delta)
HWHH0 = ss.getHWHH0(X,Y,Z,delta)

# Initialize RL algorithm

An "action" performed on the system corresponds to an RF-pulse applied to the system. A pulse can be parametrized by the axis of rotation (e.g. $(\theta, \phi)$, but for now $\theta = \pi/2$ is assumed so the axis of rotation lies in the xy-plane), the rotation angle, and the duration of the pulse.

The state of the system can correspond to the propagator, but because the propagator grows exponentially (it has $4^N$ elements for an $N$-spin system) and the pulse sequence determines the propagator, the state is represented by the pulse sequence instead.

The target network parameters $\theta_\text{target}$ are updated by
$$
\theta_\text{target} = (1-\rho) \theta_\text{target} + \rho\theta
$$

TODO figure out if this buffer size makes sense

In [None]:
sDim = 3 # state represented by sequences of actions...
aDim = 3 # action = [phi, rot, time]
learningRate = 0.001 # learning rate for optimizer

numExp = 1000 # how many experiences to "play" through and learn from
bufferSize = 1000 # size of the replay buffer (i.e. how many experiences to keep in memory).
batchSize = 100 # size of batch for training
polyak = 0.01 # polyak averaging parameter
gamma = .5 # future reward discount rate

updateAfter = 100 # wait until updating actor/critic networks to fill buffer
updateEvery = 10 # wait between updates (faster)
numUpdates = 1 # how many training updates to perform on a random subset of experiences (s,a,r,s1,d)
testEvery = 250 # how often to evaluate performance on an initial state

p = 1 # action noise parameter
dp = -p/numExp / (3/4)

Initialize the actor and critic, as well as target actor and target critic. The actor learns the policy function
$$
\pi_\theta: S \to A, s \mapsto a
$$
that picks the optimal action $a$ for a given state $s$, with some set of parameters $\theta$ (in this case weights/biases in the neural network). The critic learns the Q-function
$$
Q_\phi: S \times A \to \mathbf{R}, (s,a) \mapsto q
$$
where $q$ is the total expected rewards by doing action $a$ on a state $s$, and $\phi$ is the parameter set for the Q-function model. The target actor/critic have different parameter sets $\theta_\text{target}$ and $\phi_\text{target}$.

The "environment" keeps track of the system state, and calculates rewards after each episode.

The replay buffer keeps track of the most recent episodes.

In [None]:
actor = rlp.Actor(sDim,aDim, learningRate, 4, 4)
actorTarget = rlp.Actor(sDim,aDim, learningRate, 4, 4)
critic = rlp.Critic(sDim, aDim, gamma, learningRate, 4, 4)
criticTarget = rlp.Critic(sDim, aDim, gamma, learningRate, 4, 4)
env = rlp.Environment(N, dim, sDim, HWHH0, X, Y)

actorTarget.setParams(actor.getParams())
criticTarget.setParams(critic.getParams())

replayBuffer = rlp.ReplayBuffer(bufferSize)

## DDPG algorithm

In [None]:
# record actions and rewards from learning
actorAMat = np.zeros((numExp,aDim))
aMat = np.zeros((numExp,aDim))
timeMat = np.zeros((numExp, 2)) # duration of sequence and number of pulses
rMat = np.zeros((numExp,))
# record when resets/updates happen
resetStateEps = []
updateEps = [] # TODO remove this
# and record parameter differences between networks and targets (episode #, actor, critic)
paramDistance = []

# record test results: episode, final pulse sequence (to terminal state), rewards at each episode
testResults = []
isTesting = False

numActions = 0

for i in tqdm(range(numExp)):
    s = env.getState()
    # get action based on current state and some level of noise
    actorA = actor.predict(env.state)
    if not isTesting:
        aNoise = rlp.actionNoise(p)
        a = rlp.clipAction(actorA + aNoise)
    else:
        a = rlp.clipAction(actorA)
    
    # update noise parameter
    p = np.maximum(p + dp, .2)
    
    numActions += 1
    
    # evolve state based on action
    env.evolve(a, Hint)
    # get reward
    r = env.reward2()
    
    # get updated state, and whether it's a terminal state
    s1 = env.getState()
    d = env.isDone()
    replayBuffer.add(s,a,r,s1,d)
    
    # record episode data
    aMat[i,:] = a
    actorAMat[i,:] = actorA
    rMat[i] = r
    timeMat[i,:] = [env.t, numActions]
    
    if i % int(numExp/25) == 0:
        # calculate distance between parameters for actors/critics
        paramDistance.append((i, actor.paramDistance(actorTarget), \
                                 critic.paramDistance(criticTarget)))
    
    # if the state is terminal
    if d:
        if isTesting:
            # record results from the test and go back to learning
            testResults.append((i, s1, rMat[(i-numActions+1):(i+1)]))
            isTesting = not isTesting
        else:
            # check if it's time to test performance
            if len(testResults)*testEvery < i:
                isTesting = True
        
        # randomize dipolar coupling strengths for Hint
        Hdip, Hint = ss.getAllH(N, dim, coupling, delta)
        # reset environment
        env.reset()
        resetStateEps.append(i)
        numActions = 0
    
    # update networks
    if i > updateAfter and i % updateEvery == 0:
        updateEps.append(i)
        for update in range(numUpdates):
            batch = replayBuffer.getSampleBatch(batchSize)
            # train critic
            critic.trainStep(batch, actorTarget, criticTarget)
            # train actor
            actor.trainStep(batch, critic)
            # update target networks
            criticTarget.copyParams(critic, polyak)
            actorTarget.copyParams(actor, polyak)

# Results

See how the rewards and actions change over time as the actor/critic (hopefully) learn how to control the spin system.

Looking at the histogram below, most rewards are very small (which makes sense, considering a large subset of rewards are from random actions to begin with). There is a slight bump of rewards centered around 3 (corresponding to fidelities of around 0.999).

In [None]:
%matplotlib inline

plt.hist(rMat, bins=20, color='black', label='rewards')
plt.title('Rewards histogram')
plt.legend()

The rewards increase as the algorithm learns (which is a good thing!), but seem to plateau near 4 once the algorithm converges.

In [None]:
%matplotlib inline

plt.plot(rMat, 'ok', label='rewards')
ymin, ymax = plt.ylim()
# plt.vlines(updateEps, ymin, ymax, color='red', alpha=0.2, label='updates')
#plt.vlines(resetStateEps, ymin, ymax, color='blue', alpha=0.2, linestyles='dashed', label='state reset')
plt.title('Rewards for each episode')
plt.xlabel('Episode number')
# plt.yscale('log')
plt.ylabel('Reward')
plt.legend()

Below are the actions performed on the system. Random noise is added to actions for early episodes (to explore the action space), and gradually less noise is added as the algorithm converges.

In [None]:
%matplotlib inline

plt.plot(aMat[:,0], 'ok', label='phi', zorder=1)
plt.plot(actorAMat[:,0], '.b', label='phi (actor)', zorder=2)
plt.title('Phi action')
ymin, ymax = plt.ylim()
# plt.vlines(updateEps, ymin, ymax, color='red', alpha=0.2, label='updates')
#plt.vlines(resetStateEps, ymin, ymax, color='blue', alpha=0.2, linestyles='dashed', label='state reset')
plt.xlabel('Episode number')
plt.ylabel('Phi action')
plt.legend()

In [None]:
%matplotlib inline

plt.plot(aMat[:,1], 'ok', label='rot')
plt.plot(actorAMat[:,1], '.b', label='rot (actor)', zorder=2)
plt.title('Rot action')
ymin, ymax = plt.ylim()
# plt.vlines(updateEps, ymin, ymax, color='red', alpha=0.2, label='updates')
#plt.vlines(resetStateEps, ymin, ymax, color='blue', alpha=0.2, linestyles='dashed', label='state reset')
plt.xlabel('Episode number')
plt.ylabel('Rot action')
plt.legend()

In [None]:
%matplotlib inline

plt.plot(aMat[:,2], 'ok', label='time')
plt.plot(actorAMat[:,2], '.b', label='time (actor)', zorder=2)
plt.title('Time action')
ymin, ymax = plt.ylim()
# plt.vlines(updateEps, ymin, ymax, color='red', alpha=0.2, label='updates')
# plt.vlines(resetStateEps, ymin, ymax, color='blue', alpha=0.2, linestyles='dashed', label='state reset')
plt.xlabel('Episode number')
plt.ylabel('Time action')
plt.legend()

In [None]:
%matplotlib inline

plt.plot(timeMat[:,0], 'ok', label='time')
plt.title('Pulse sequence length (time)')
ymin, ymax = plt.ylim()
# plt.vlines(updateEps, ymin, ymax, color='red', alpha=0.2, label='updates')
#plt.vlines(resetStateEps, ymin, ymax, color='blue', alpha=0.2, linestyles='dashed', label='state reset')
plt.xlabel('Episode number')
plt.ylabel('Pulse sequence length (s)')
plt.legend()

In [None]:
%matplotlib inline

plt.plot(timeMat[:,1], 'ok', label='time')
plt.title('Pulse sequence length (number of pulses)')
ymin, ymax = plt.ylim()
# plt.vlines(updateEps, ymin, ymax, color='red', alpha=0.2, label='updates')
#plt.vlines(resetStateEps, ymin, ymax, color='blue', alpha=0.2, linestyles='dashed', label='state reset')
plt.xlabel('Episode number')
plt.ylabel('Number of pulses')
plt.legend()

Below prints the pulse sequences with the highest rewards in the replay buffer. It seems like the algorithm is only learning to apply a single pulse that rotates $2\pi$ about the x-axis. This pulse "sequence" has reasonably high fidelity ($0.99998$), but is clearly not the WAHUHA sequence.

In [None]:
# rBuffer = np.array([_[2] for _  in replayBuffer.buffer])
# indSorted = rBuffer.argsort()
# for i in range(1,5):
#     print("Highest rewards in buffer (#{})\n".format(i))
#     print("Index in buffer: {}\n".format(indSorted[-i]))
#     sequence = replayBuffer.buffer[indSorted[-i]][3] # sequence of actions
#     print(rlp.formatAction(sequence) + "\n")
#     # calculate mean fidelity from ensemble of dipolar couplings
#     fidelities = np.zeros((10,))
#     t = np.sum(rlp.getTimeFromAction(sequence))
#     for i in range(10):
#         Hdip, Hint = ss.getAllH(N, dim, coupling, delta)
#         Uexp = rlp.getPropagatorFromAction(N, dim, sequence, Hint, X, Y)
#         Utarget = ss.getPropagator(HWHH0, t)
#         fidelities[i] = ss.fidelity(Utarget, Uexp)
#     fMean = np.mean(fidelities)
#     print(f"Mean fidelity: {fMean}")
#     r = -1*np.log10(1+1e-12-fMean**(20e-6/t))
#     print(f"Reward: {r}")

Display the test results (no noise added to the actions).

In [None]:
for result in testResults:
    print(f"Test result from episode {result[0]}\n\nChosen pulse sequence:")
    rlp.printAction(result[1])
    print(f"Rewards from the pulse sequence:\n{result[2]}")
    print("\n")

### Analysis of networks

In [None]:
for i in paramDistance:
    print(f"episode {i[0]}:\tactor diff={i[1]:0.2},\tcritic diff={i[2]:0.2}")

In [None]:
# s = np.zeros((16,3), dtype="float32")
# # s[0,:] = [0,0,1]
# # s[1,:] = [0,.25,.3]
# print(actor.predict(s))
# a = np.array([0,0,1], dtype="float32")
# print(critic.predict(s,a))

In [None]:
# # evaluate average fidelity for a pulse sequence
# num = 25
# fidelities = np.zeros((num,))
# for i in range(num):
#     Hdip, Hint = ss.getAllH(N, dim, coupling, delta)
#     Uexp = rlp.getPropagatorFromAction(N, dim, np.array([0,1,.3]), Hint, X, Y)
#     Utarget = ss.getPropagator(HWHH0, (.1)*1e-6)
#     fidelities[i] = ss.fidelity(np.power(Utarget,1), np.power(Uexp,1))
# print("mean fidelity: ", np.mean(fidelities), "std dev: ", np.std(fidelities), "max: ", np.max(fidelities))