# AlphaZero implementation for pulse sequence design
_Will Kaufman, December 2020_

[Dalgaard et. al. (2020)](https://www.nature.com/articles/s41534-019-0241-0) applied this approach to constructing shaped pulses (as I understand it), but in theory this should be as applicable to pulse sequence design, if not more so. The original [AlphaZero paper](https://science.sciencemag.org/content/362/6419/1140.full) is here.

The general idea behind AlphaZero (as I understand it) is to do a "smart" tree search that balances previous knowledge (the policy), curiosity in unexplored branches, and high-value branches. My thought is that this can be improved with AHT (i.e. knowing that by the end of the pulse sequence, the pulse sequence must be cyclic (the overall frame transformation must be identity) and there must be equal times spent on each axis). This will provide a hard constraint that will (hopefully) speed up search.

In [1]:
import qutip as qt
import numpy as np
import matplotlib.pyplot as plt
import sys, os

In [2]:
sys.path.append(os.path.abspath('..'))
import pulse_sequences as ps
import alpha_zero as az

In [56]:
import importlib
importlib.reload(az)
importlib.reload(ps)

<module 'pulse_sequences' from '/Users/willkaufman/Projects/rl_pulse/rl_pulse/pulse_sequences.py'>

## Define the spin system

In [3]:
delay = 1e-2  # time is relative to chemical shift strength
pulse_width = 5e-3
N = 3  # number of spins
ensemble_size = 5

In [4]:
X, Y, Z = ps.get_collective_spin(N)

In [5]:
Hsys_ensemble = [ps.get_Hsys(N) for _ in range(ensemble_size)]
pulses_ensemble = [
    ps.get_pulses(H, X, Y, Z, pulse_width, delay, rot_error=0.01) for H in Hsys_ensemble
]

In [6]:
Utarget = qt.identity(Hsys_ensemble[0].dims[0])

## Average Hamiltonian theory

To keep track of the average Hamiltonian (to lowest order), I'm defining a frame matrix and applying rotation matrices to the frame matrix, then determining how $I_z$ transforms during the pulse sequence. The last row in the frame matrix corresponds to the current transformed value of $I_z$.

In [7]:
ps.count_axes(ps.yxx48)

[8, 8, 8, 8, 8, 8]

In [8]:
ps.get_valid_time_suspension_pulses([0,1,1,], len(ps.pulse_names), 6)

[1, 3, 4]

## Tree search

Define nodes that can be used for tree search, with additional constraints that the lowest-order average Hamiltonian matches the desired Hamiltonian.

(deleted code that implemented tree search with constraints, see GitHub repo commits on 12/8 for code)

For 12-pulse sequences, calculated 16 branches at depth 4 in a minute, so about 1 every 4 seconds. At depth 4 there are $5^4 = 625$ branches, so that'll take $4 * 625 = 41$ hours to fully run. Alternatively, you can generate random pulse sequences until there's one that has the proper lowest-order average and cyclic property.

## Smarter search with MCTS

Following the [supplementary materials description under "Search"](https://science.sciencemag.org/content/sci/suppl/2018/12/05/362.6419.1140.DC1/aar6404-Silver-SM.pdf) to do rollouts and backpropagate information.

In [63]:
config = az.Config()
config.num_simulations = 500

In [64]:
ps_config = ps.PulseSequenceConfig(N, ensemble_size, pulse_width, delay, 6, Utarget)

In [68]:
# %load_ext snakeviz
# %snakeviz -t az.make_sequence(config, ps_config, None)

In [65]:
stats = az.make_sequence(config, ps_config, None)

applying pulse 3
applying pulse 2
applying pulse 3
applying pulse 3
applying pulse 3
applying pulse 2


In [66]:
ps_config.value()

0.06625695298727205

In [67]:
stats

[([],
  [(0, 0.198), (1, 0.2), (2, 0.198), (3, 0.198), (4, 0.206)],
  0.06625695298727205),
 ([3], [(1, 0.25), (2, 0.252), (3, 0.248), (4, 0.25)], 0.06625695298727205),
 ([3, 2], [(2, 0.35), (3, 0.326), (4, 0.324)], 0.06625695298727205),
 ([3, 2, 3], [(2, 0.498), (3, 0.502)], 0.06625695298727205),
 ([3, 2, 3, 3], [(2, 0.506), (3, 0.494)], 0.06625695298727205),
 ([3, 2, 3, 3, 3], [(2, 1.0)], 0.06625695298727205)]

In [None]:
pulse, root = run_mcts(config, [Utarget] * 5, [], 6, Utarget)

In [None]:
max(node.children.values(), key=lambda x: x.max_value)

In [None]:
node = root
while node.has_children():
    node = max(node.children.values(), key=lambda x: x.max_value)

In [None]:
node.sequence

In [None]:
node.max_value

In [None]:
[root.children[p].visit_count for p in root.children]

In [None]:
root.children[3].value()