# AlphaZero implementation for pulse sequence design
_Will Kaufman, December 2020_

[Dalgaard et. al. (2020)](https://www.nature.com/articles/s41534-019-0241-0) applied this approach to constructing shaped pulses (as I understand it), but in theory this should be as applicable to pulse sequence design, if not more so. The original [AlphaZero paper](https://science.sciencemag.org/content/362/6419/1140.full) is here.

The general idea behind AlphaZero (as I understand it) is to do a "smart" tree search that balances previous knowledge (the policy), curiosity in unexplored branches, and high-value branches. My thought is that this can be improved with AHT (i.e. knowing that by the end of the pulse sequence, the pulse sequence must be cyclic (the overall frame transformation must be identity) and there must be equal times spent on each axis). This will provide a hard constraint that will (hopefully) speed up search.

In [1]:
import qutip as qt
import numpy as np
import matplotlib.pyplot as plt
import sys, os
from random import sample

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
sys.path.append(os.path.abspath('..'))
import pulse_sequences as ps
import alpha_zero as az

In [306]:
import importlib
importlib.reload(az)
importlib.reload(ps)

<module 'pulse_sequences' from '/Users/willkaufman/Projects/rl_pulse/rl_pulse/pulse_sequences.py'>

## Define the spin system

In [3]:
delay = 1e-2  # time is relative to chemical shift strength
pulse_width = 5e-3
N = 3  # number of spins
ensemble_size = 5

In [4]:
X, Y, Z = ps.get_collective_spin(N)

In [5]:
Hsys_ensemble = [ps.get_Hsys(N) for _ in range(ensemble_size)]
pulses_ensemble = [
    ps.get_pulses(H, X, Y, Z, pulse_width, delay, rot_error=0.01) for H in Hsys_ensemble
]

In [6]:
Utarget = qt.identity(Hsys_ensemble[0].dims[0])

## Average Hamiltonian theory

To keep track of the average Hamiltonian (to lowest order), I'm defining a frame matrix and applying rotation matrices to the frame matrix, then determining how $I_z$ transforms during the pulse sequence. The last row in the frame matrix corresponds to the current transformed value of $I_z$.

In [7]:
ps.count_axes(ps.yxx48)

[8, 8, 8, 8, 8, 8]

In [8]:
ps.get_valid_time_suspension_pulses([0,1,1,], len(ps.pulse_names), 6)

[1, 3, 4]

## Tree search

Define nodes that can be used for tree search, with additional constraints that the lowest-order average Hamiltonian matches the desired Hamiltonian.

(deleted code that implemented tree search with constraints, see GitHub repo commits on 12/8 for code)

For 12-pulse sequences, calculated 16 branches at depth 4 in a minute, so about 1 every 4 seconds. At depth 4 there are $5^4 = 625$ branches, so that'll take $4 * 625 = 41$ hours to fully run. Alternatively, you can generate random pulse sequences until there's one that has the proper lowest-order average and cyclic property.

## Smarter search with MCTS

Following the [supplementary materials description under "Search"](https://science.sciencemag.org/content/sci/suppl/2018/12/05/362.6419.1140.DC1/aar6404-Silver-SM.pdf) to do rollouts and backpropagate information.

In [11]:
config = az.Config()
config.num_simulations = 500
ps_config = ps.PulseSequenceConfig(N, ensemble_size, pulse_width, delay, 6, Utarget)

In [68]:
# %load_ext snakeviz
# %snakeviz -t az.make_sequence(config, ps_config, None)

In [467]:
# stats = az.make_sequence(config, ps_config, None)

In [468]:
# ps_config.value()

In [469]:
# stats

## Replay buffer

Inspired by [this pytorch tutorial](https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html).

In [307]:
rb = az.ReplayBuffer(100000)

In [293]:
config = az.Config()
config.num_simulations = 100

For the purposes of saving data in a reasonable way (and using RNN), the state is represented by a sequence, where 0 indicates the start of sequence, and 1-5 are the possible pulses (1: delay, 2: x, etc...).

In [308]:
def get_training_data(config, ps_config, replay_buffer, network=None, num_iter=1):
    """Makes sequence using MCTS, then saves to replay buffer
    """
    for _ in range(num_iter):
        ps_config.reset()
        stats = az.make_sequence(config, ps_config, network)
        az.add_stats_to_buffer(stats, replay_buffer)

In [387]:
# get_training_data(config, ps_config, rb, num_iter=9)

In [33]:
# rb.sample()

## MCTS with policy and value networks

In [275]:
config = az.Config()
config.num_simulations = 100
ps_config = ps.PulseSequenceConfig(N, ensemble_size=ensemble_size,
                                   max_sequence_length=48, Utarget=Utarget,
                                   pulse_width=pulse_width, delay=delay)

In [276]:
p = az.Policy()
v = az.Value()
net = az.Network(p, v)

In [284]:
ps_config.reset()

In [278]:
%load_ext snakeviz
%snakeviz -t az.make_sequence(config, ps_config, net)

The snakeviz extension is already loaded. To reload it, use:
  %reload_ext snakeviz
 
*** Profile stats marshalled to file '/var/folders/_y/8j5pxws97d7clk4h9d6bjhtr0000gn/T/tmpoddonvv6'. 
Opening SnakeViz in a new tab...


In [285]:
stats = az.make_sequence(config, ps_config, net)

## Optimizing networks with replay buffer data

So the above code seems to run, but boy does it run slowly... I'll probably want to gather a lot of data with random policy (no network) first, then train from there.

TODO

- [ ] Finish training loop: add stats to buffer, set up training functions
- [ ] Look into multiprocessing, gathering training data and training continuously
- [ ] Set up on Discovery and run!

In [385]:
policy_optimizer = optim.Adam(p.parameters(), lr=1e-5)
value_optimizer = optim.Adam(v.parameters(), lr=1e-5)

In [383]:
def optimize_step(replay_buffer, policy, policy_optimizer, value, value_optimizer, num_iters=1, batch_size=64):
    for i in range(num_iters):
        if i % 10 == 0:
            print(f'On iteration {i}...', end=' ')
        minibatch = rb.sample(batch_size=batch_size)
        states, probabilities, values = zip(*minibatch)
        probabilities = torch.cat(probabilities).view(batch_size, -1)
        values = torch.cat(values).view(batch_size, -1)
        packed_states = az.pad_and_pack(states)
        policy_outputs, __ = policy(packed_states)
        policy_loss = -1 / len(states) * torch.sum(probabilities * torch.log(policy_outputs))
        policy_loss.backward()
        policy_optimizer.step()
        # value optimization
        value_outputs, __ = value(packed_states)
        value_loss = F.mse_loss(value_outputs, values)
        value_loss.backward()
        value_optimizer.step()
        if i % 10 == 0:
            print(f'policy loss: {policy_loss:.05f}, value loss: {value_loss:.05f}')

In [386]:
optimize_step(rb, p, policy_optimizer, v, value_optimizer, num_iters=250)

On iteration 0... policy loss: 1.55268, value loss: 0.17206
On iteration 10... policy loss: 1.56332, value loss: 0.12074
On iteration 20... policy loss: 1.60301, value loss: 0.12086
On iteration 30... policy loss: 1.57278, value loss: 0.10971
On iteration 40... policy loss: 1.59487, value loss: 0.15085
On iteration 50... policy loss: 1.57941, value loss: 0.11003
On iteration 60... policy loss: 1.55052, value loss: 0.10016
On iteration 70... policy loss: 1.54150, value loss: 0.12018
On iteration 80... policy loss: 1.56801, value loss: 0.17147
On iteration 90... policy loss: 1.52813, value loss: 0.09037
On iteration 100... policy loss: 1.59216, value loss: 0.10013
On iteration 110... policy loss: 1.58188, value loss: 0.09911
On iteration 120... policy loss: 1.58390, value loss: 0.09990
On iteration 130... policy loss: 1.55856, value loss: 0.14168
On iteration 140... policy loss: 1.59692, value loss: 0.12039
On iteration 150... policy loss: 1.56287, value loss: 0.13064
On iteration 160...