# AlphaZero implementation for pulse sequence design
_Will Kaufman, December 2020_

[Dalgaard et. al. (2020)](https://www.nature.com/articles/s41534-019-0241-0) applied this approach to constructing shaped pulses (as I understand it), but in theory this should be as applicable to pulse sequence design, if not more so. The original [AlphaZero paper](https://science.sciencemag.org/content/362/6419/1140.full) and the [AlphaGo Zero paper](https://www.nature.com/articles/nature24270) are useful resources.

The general idea behind AlphaZero (as I understand it) is to do a "smart" tree search that balances previous knowledge (the policy), curiosity in unexplored branches, and high-value branches. My thought is that this can be improved with AHT (i.e. knowing that by the end of the pulse sequence, the pulse sequence must be cyclic (the overall frame transformation must be identity) and there must be equal times spent on each axis). This will provide a hard constraint that will (hopefully) speed up search.

## System installation

Make sure the following packages are installed

- `numpy`
- `scipy`
- `qutip`
- `pytorch`
- `tensorboard`

## TODO
- [ ] Collect all hyperparameters up top or in config (e.g. how many pulse sequences to collect data from)
- [ ] Speed up LSTM (save hidden state, batch parallel pulse sequences, other?)
- [ ] Figure out GPU utilization (if I can...)
- [ ] Look into collecting training data and training continuously
- [ ] Change dirichlet noise to match number of possible moves (5 for now, eventually 24)
- [ ] Dynamically figure out how many CPUs there are available, and set pool to use that
- [ ] Mess around with hyperparameters (e.g. in config object), see if performance improves

In [1]:
import qutip as qt
import numpy as np
import sys
import os
import multiprocessing as mp

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

In [2]:
sys.path.append(os.path.abspath('..'))
import pulse_sequences as ps
import alpha_zero as az

In [353]:
import importlib
importlib.reload(az)
importlib.reload(ps)

<module 'pulse_sequences' from '/Users/willkaufman/Projects/rl_pulse/rl_pulse/pulse_sequences.py'>

## Define hyperparameters

In [3]:
num_cores = 2  # 32
num_collect = 2  # 1000
num_collect_initial = 2  # 5000
batch_size = 64  #2048
num_iters = 100

max_sequence_length = 48

## Define the spin system

In [4]:
delay = 1e-2  # time is relative to chemical shift strength
pulse_width = 1e-3
N = 3  # number of spins
ensemble_size = 5

In [5]:
# X, Y, Z = ps.get_collective_spin(N)
# Hsys_ensemble = [ps.get_Hsys(N) for _ in range(ensemble_size)]
# pulses_ensemble = [
#     ps.get_pulses(H, X, Y, Z, pulse_width, delay, rot_error=0.01) for H in Hsys_ensemble
# ]
# Utarget = qt.identity(Hsys_ensemble[0].dims[0])

In [6]:
Utarget = qt.tensor([qt.identity(2)] * N)

## Smarter search with MCTS

Following the [supplementary materials description under "Search"](https://science.sciencemag.org/content/sci/suppl/2018/12/05/362.6419.1140.DC1/aar6404-Silver-SM.pdf) to do rollouts and backpropagate information. All of the relevant code for the alpha zero algorithm is in alpha_zero.py.

## Replay buffer

Inspired by [this pytorch tutorial](https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html). This will store data collected by MCTS rollouts. The data is the state (the pulse sequence applied so far), the empirical probability distribution based on visit counts, and the value from the end of the pulse sequence.

Based on a rough idea of what was done with AlphaGo Zero and AlphaZero, I think I need to collect a lot of data and train a lot with each iteration. In AlphaGo Zero, they collected 25,000 games per iteration and trained with minibatch sizes of 2048 (training continuously I think).

In [14]:
rb = az.ReplayBuffer(int(1e5))  # 1e6

For the purposes of saving data in a reasonable way (and using RNN), the state is represented by a sequence, where 0 indicates the start of sequence, and 1-5 are the possible pulses (1: delay, 2: x, etc...).

## Fill replay buffer with inital data

In [19]:
def collect_data_no_net(x):
#     print(f'collecting data without network ({x})')
    config = az.Config()
    ps_config = ps.PulseSequenceConfig(N=N, ensemble_size=ensemble_size,
                                       max_sequence_length=max_sequence_length,
                                       Utarget=Utarget,
                                       pulse_width=pulse_width, delay=delay)
    return az.make_sequence(config, ps_config, network=None, rng=ps_config.rng)

In [9]:
with mp.Pool(num_cores) as pool:
    output = pool.map(collect_data_no_net, range(num_collect_initial))
for stat in output:
    az.add_stats_to_buffer(stat, rb)

In [10]:
len(rb)

96

## MCTS with policy and value networks

In [10]:
policy = az.Policy()
value = az.Value()
net = az.Network(policy, value)

In [11]:
net.save('network')

## Optimizing networks with replay buffer data

See [this doc](https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html) for writing training loss to tensorboard data, and [this doc](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference) for saving/loading models.

In [12]:
policy_optimizer = optim.Adam(policy.parameters())  #, lr=1e-5
value_optimizer = optim.Adam(value.parameters())  #, lr=1e-5
writer = SummaryWriter()
global_step = 0  # how many minibatches the models have been trained

In [440]:
# global_step = az.train_step(rb, p, policy_optimizer, v, value_optimizer, writer, global_step=global_step, num_iters=500)

## Multiprocessing

Setting up this algorithm to run in parallel is quite important. I'm using [multiprocessing](https://docs.python.org/3/library/multiprocessing.html) to handle the parallelism, and it looks like pytorch also has a similar API for moving Tensors around. With 2 processors on my laptop, speedup is about 90% (not bad...).

Want to set random seed for each process, otherwise you end up getting all the same results...

In [13]:
def collect_data(x):
    print(f'collecting data ({x})')
    config = az.Config()
    config.num_simulations = 250
    ps_config = ps.PulseSequenceConfig(N=N, ensemble_size=ensemble_size,
                                       max_sequence_length=max_sequence_length,
                                       Utarget=Utarget,
                                       pulse_width=pulse_width, delay=delay)
    # load policy and value networks from memory
    policy = az.Policy()
    policy.load_state_dict(torch.load('network/policy'))
    policy.eval()
    value = az.Value()
    value.load_state_dict(torch.load('network/value'))
    value.eval()
    net = az.Network(policy, value)
    return az.make_sequence(config, ps_config, network=net, rng=ps_config.rng)

In [111]:
# with mp.Pool(num_cores) as pool:
#     output = pool.map(f, range(num_collect))
# for stat in output:
#     az.add_stats_to_buffer(stat, rb)

Can't select action: no child actions to perform!
Can't select action: no child actions to perform!


## Bringing it all together: training loop

In [19]:
for i in range(10):
    print(f'on iteration {i}')
    # collect data
    print('collecting data...')
    with mp.Pool(num_cores) as pool:
        output = pool.map(collect_data, range(num_collect))
    for stat in output:
        az.add_stats_to_buffer(stat, rb)
    mean_value = np.mean([o[-1][-1] for o in output])
    for o in output:
        if o[-1][-1] > 1:
            print('Candidate pulse sequence found! Value is ', o[-1][-1], '\n', o[-1][0])
    writer.add_scalar('mean_value', mean_value, global_step=global_step)
    # train models from replay buffer
    print('training model...')
    global_step = az.train_step(rb, policy, policy_optimizer,
                                value, value_optimizer,
                                writer, global_step=global_step,
                                num_iters=num_iters, batch_size=batch_size)
    # write updated weights to file
    net.save('network')