# AlphaZero implementation for pulse sequence design
_Will Kaufman, December 2020_

[Dalgaard et. al. (2020)](https://www.nature.com/articles/s41534-019-0241-0) applied this approach to constructing shaped pulses (as I understand it), but in theory this should be as applicable to pulse sequence design, if not more so. The original [AlphaZero paper](https://science.sciencemag.org/content/362/6419/1140.full) is here.

The general idea behind AlphaZero (as I understand it) is to do a "smart" tree search that balances previous knowledge (the policy), curiosity in unexplored branches, and high-value branches. My thought is that this can be improved with AHT (i.e. knowing that by the end of the pulse sequence, the pulse sequence must be cyclic (the overall frame transformation must be identity) and there must be equal times spent on each axis). This will provide a hard constraint that will (hopefully) speed up search.

## System installation

Make sure the following packages are installed

- `numpy`
- `scipy`
- `qutip`
- `pytorch`
- `tensorboard`

In [1]:
import qutip as qt
import numpy as np
import matplotlib.pyplot as plt
import sys, os
from random import sample
import multiprocessing as mp

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

In [2]:
sys.path.append(os.path.abspath('..'))
import pulse_sequences as ps
import alpha_zero as az

In [125]:
# import importlib
# importlib.reload(az)
# importlib.reload(ps)

## Define the spin system

In [18]:
delay = 1e-2  # time is relative to chemical shift strength
pulse_width = 1e-3
N = 3  # number of spins
ensemble_size = 5

In [19]:
X, Y, Z = ps.get_collective_spin(N)
Hsys_ensemble = [ps.get_Hsys(N) for _ in range(ensemble_size)]
pulses_ensemble = [
    ps.get_pulses(H, X, Y, Z, pulse_width, delay, rot_error=0.01) for H in Hsys_ensemble
]
Utarget = qt.identity(Hsys_ensemble[0].dims[0])

## Average Hamiltonian theory

To keep track of the average Hamiltonian (to lowest order), I'm defining a frame matrix and applying rotation matrices to the frame matrix, then determining how $I_z$ transforms during the pulse sequence. The last row in the frame matrix corresponds to the current transformed value of $I_z$.

**TODO** add link to F matrix paper from Harvard

In [20]:
ps.count_axes(ps.yxx48)

[8, 8, 8, 8, 8, 8]

In [21]:
ps.get_valid_time_suspension_pulses([0,1,1,], len(ps.pulse_names), 6)

[1, 3, 4]

## Tree search

Define nodes that can be used for tree search, with additional constraints that the lowest-order average Hamiltonian matches the desired Hamiltonian.

(deleted code that implemented tree search with constraints, see GitHub repo commits on 12/8 for code)

For 12-pulse sequences, calculated 16 branches at depth 4 in a minute, so about 1 every 4 seconds. At depth 4 there are $5^4 = 625$ branches, so that'll take $4 * 625 = 41$ hours to fully run. Alternatively, you can generate random pulse sequences until there's one that has the proper lowest-order average and cyclic property.

## Smarter search with MCTS

Following the [supplementary materials description under "Search"](https://science.sciencemag.org/content/sci/suppl/2018/12/05/362.6419.1140.DC1/aar6404-Silver-SM.pdf) to do rollouts and backpropagate information.

In [22]:
config = az.Config()
config.num_simulations = 100
ps_config = ps.PulseSequenceConfig(N=N, ensemble_size=ensemble_size,
                                   max_sequence_length=48, Utarget=Utarget,
                                   pulse_width=pulse_width, delay=delay)

In [68]:
# %load_ext snakeviz
# %snakeviz -t az.make_sequence(config, ps_config, None)

In [23]:
stats = az.make_sequence(config, ps_config, None)

In [24]:
ps_config.value()

0.060341634870554293

In [25]:
print(stats[-1][0])

[2, 1, 4, 2, 4, 0, 0, 4, 4, 2, 0, 1, 4, 3, 1, 2, 0, 2, 3, 0, 3, 3, 4, 2, 3, 4, 3, 1, 1, 1, 2, 3, 0, 1, 1, 1, 0, 3, 2, 0, 2, 3, 0, 4, 4, 4, 4]


In [26]:
print(stats[-1][1])

[0. 0. 0. 0. 1.]


In [27]:
ps.count_axes(stats[-1][0] + [4])

[8, 8, 8, 8, 8, 8]

In [28]:
ps_config.frame

array([[ 1.,  0.,  0.],
       [ 0.,  0., -1.],
       [ 0.,  1.,  0.]])

## Replay buffer

Inspired by [this pytorch tutorial](https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html).

In [7]:
rb = az.ReplayBuffer(int(1e5))

In [85]:
config = az.Config()
config.num_simulations = 100
ps_config = ps.PulseSequenceConfig(N=N, ensemble_size=ensemble_size,
                                   max_sequence_length=48, Utarget=Utarget,
                                   pulse_width=pulse_width, delay=delay)

For the purposes of saving data in a reasonable way (and using RNN), the state is represented by a sequence, where 0 indicates the start of sequence, and 1-5 are the possible pulses (1: delay, 2: x, etc...).

In [9]:
az.get_training_data(config, ps_config, rb, num_iter=5)

In [11]:
# rb.sample()

## MCTS with policy and value networks

In [105]:
p = az.Policy()
v = az.Value()
net = az.Network(p, v)

In [423]:
# net.save('models')

In [284]:
ps_config.reset()

In [278]:
# %load_ext snakeviz
# %snakeviz -t az.make_sequence(config, ps_config, net)

The snakeviz extension is already loaded. To reload it, use:
  %reload_ext snakeviz
 
*** Profile stats marshalled to file '/var/folders/_y/8j5pxws97d7clk4h9d6bjhtr0000gn/T/tmpoddonvv6'. 
Opening SnakeViz in a new tab...


In [285]:
stats = az.make_sequence(config, ps_config, net)

## Optimizing networks with replay buffer data

See [this doc](https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html) for writing training loss to tensorboard data, and [this doc](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference) for saving/loading models.

So the above code seems to run, but boy does it run slowly... I'll probably want to gather a lot of data with random policy (no network) first, then train from there.

TODO
- [ ] Look into multiprocessing, gathering training data and training continuously
- [ ] Figure out GPU utilization (if I can...)
- [ ] Set up on Discovery and run!

In [403]:
policy_optimizer = optim.Adam(p.parameters(), lr=1e-5)
value_optimizer = optim.Adam(v.parameters(), lr=1e-5)
writer = SummaryWriter()

In [404]:
global_step = 0

In [440]:
# global_step = az.train_step(rb, p, policy_optimizer, v, value_optimizer, writer, global_step=global_step, num_iters=500)

## Multiprocessing

Setting up this algorithm to run in parallel is quite important. I'm using [multiprocessing](https://docs.python.org/3/library/multiprocessing.html) to handle the parallelism, and it looks like pytorch also has a similar API for moving Tensors around. With 2 processors on my laptop, speedup is about 90% (not bad...).

Want to set random seed for each process, otherwise you end up getting all the same results...

In [110]:
def f(x):
    config = az.Config()
    config.num_simulations = 100
    ps_config = ps.PulseSequenceConfig(N=N, ensemble_size=ensemble_size,
                                       max_sequence_length=48, Utarget=Utarget,
                                       pulse_width=pulse_width, delay=delay)
    # load policy and value networks from memory
    policy = az.Policy()
    policy.load_state_dict(torch.load('network/policy'))
    policy.eval()
    value = az.Value()
    value.load_state_dict(torch.load('network/value'))
    value.eval()
    net = az.Network(policy, value)
    return az.make_sequence(config, ps_config, network=net, rng=ps_config.rng)

In [111]:
with mp.Pool(5) as p:
    output = p.map(f, range(5))

Can't select action: no child actions to perform!
Can't select action: no child actions to perform!


In [121]:
for stat in output:
    az.add_stats_to_buffer(stat, rb)

In [122]:
len(rb)

960