# AlphaZero implementation for pulse sequence design
_Will Kaufman, December 2020_

[Dalgaard et. al. (2020)](https://www.nature.com/articles/s41534-019-0241-0) applied this approach to constructing shaped pulses (as I understand it), but in theory this should be as applicable to pulse sequence design, if not more so. The original [AlphaZero paper](https://science.sciencemag.org/content/362/6419/1140.full) and the [AlphaGo Zero paper](https://www.nature.com/articles/nature24270) are useful resources.

The general idea behind AlphaZero (as I understand it) is to do a "smart" tree search that balances previous knowledge (the policy), curiosity in unexplored branches, and high-value branches. My thought is that this can be improved with AHT (i.e. knowing that by the end of the pulse sequence, the pulse sequence must be cyclic (the overall frame transformation must be identity) and there must be equal times spent on each axis). This will provide a hard constraint that will (hopefully) speed up search.

## System installation

Make sure the following packages are installed

- `numpy`
- `scipy`
- `qutip`
- `pytorch`
- `tensorboard`

## TODO
- [ ] Collect all hyperparameters up top or in config (e.g. how many pulse sequences to collect data from)
- [ ] Speed up LSTM (save hidden state, batch parallel pulse sequences, other?)
- [ ] Figure out GPU utilization (if I can...)
- [ ] Look into collecting training data and training continuously
- [ ] Change dirichlet noise to match number of possible moves (5 for now, eventually 24)
- [ ] Dynamically figure out how many CPUs there are available, and set pool to use that
- [ ] Mess around with hyperparameters (e.g. in config object), see if performance improves

# TODO

- [ ] Add other changes on github project page (lots of documenting algo run)
- [ ] Run it on Discovery, hope it works!
- [ ] Clean up code, add tests, make sure everything is working as expected

In [1]:
import qutip as qt
import numpy as np
import sys
import os
import torch.multiprocessing as mp
import importlib

from datetime import datetime
import random
from time import sleep
import torch.nn.functional as F

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

In [2]:
sys.path.append(os.path.abspath('..'))
import pulse_sequences as ps
import alpha_zero as az

In [97]:
# importlib.reload(az)
# importlib.reload(ps)

## Define hyperparameters

In [3]:
num_cores = 2  # 32
num_collect = 2  # 2000
num_collect_initial = 2  # 5000
batch_size = 64  #2048
num_iters = 100  # 1000

max_sequence_length = 48

## Define the spin system

In [4]:
delay = 1e-2  # time is relative to chemical shift strength
pulse_width = 1e-3
N = 3  # number of spins
ensemble_size = 5

In [5]:
# X, Y, Z = ps.get_collective_spin(N)
# Hsys_ensemble = [ps.get_Hsys(N) for _ in range(ensemble_size)]
# pulses_ensemble = [
#     ps.get_pulses(H, X, Y, Z, pulse_width, delay, rot_error=0.01) for H in Hsys_ensemble
# ]
# Utarget = qt.identity(Hsys_ensemble[0].dims[0])

In [5]:
Utarget = qt.tensor([qt.identity(2)] * N)

## Smarter search with MCTS

Following the [supplementary materials description under "Search"](https://science.sciencemag.org/content/sci/suppl/2018/12/05/362.6419.1140.DC1/aar6404-Silver-SM.pdf) to do rollouts and backpropagate information. All of the relevant code for the alpha zero algorithm is in alpha_zero.py.

## Replay buffer

Inspired by [this pytorch tutorial](https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html). This will store data collected by MCTS rollouts. The data is the state (the pulse sequence applied so far), the empirical probability distribution based on visit counts, and the value from the end of the pulse sequence.

Based on a rough idea of what was done with AlphaGo Zero and AlphaZero, I think I need to collect a lot of data and train a lot with each iteration. In AlphaGo Zero, they collected 25,000 games per iteration and trained with minibatch sizes of 2048 (training continuously I think).

In [6]:
# rb = az.ReplayBuffer(int(1e5))  # 1e6
buffer_size = int(1e5)  # 1e6

For the purposes of saving data in a reasonable way (and using RNN), the state is represented by a sequence, where 0 indicates the start of sequence, and 1-5 are the possible pulses (1: delay, 2: x, etc...).

## Fill replay buffer with inital data

In [93]:
def collect_data_no_net(proc_num, buffer, index, lock, buffer_size, ps_count):
    """
    Args:
        proc_num: Which process number this is (for debug purposes)
        buffer (mp.managers.List): A shared replay buffer
        index (mp.managers.Value): The current index for the buffer
        lock (mp.managers.RLock): Lock object to prevent overwriting
            data from different threads
        buffer_size (int): The maximum size of the buffer
        ps_count (Value): Shared count of how many pulse sequences have
            been constructed
    """
    print(datetime.now(), f'collecting data without network ({proc_num})')
    config = az.Config()
    ps_config = ps.PulseSequenceConfig(N=N, ensemble_size=ensemble_size,
                                       max_sequence_length=max_sequence_length,
                                       Utarget=Utarget,
                                       pulse_width=pulse_width, delay=delay)
    output = az.make_sequence(config, ps_config, network=None, rng=ps_config.rng)
    output_tensors = az.convert_stats_to_tensors(output)
    with lock:
        ps_count.value += 1
    for obs in output_tensors:
        with lock:
            if len(buffer) < buffer_size:
                buffer.append(obs)
            else:
                buffer[index.value] = obs
            index.value += 1
            if index.value >= buffer_size:
                index.value = 0

In [30]:
if __name__ == '__main__':
    with mp.Manager() as manager:
        buffer = manager.list()  #[None] * buffer_size
        index = manager.Value(typecode='i', value=0)
        lock = manager.RLock()
        workers = []
        for i in range(4):
            workers.append(mp.Process(target=collect_data_no_net,
                                      args=(i, buffer, index, lock, buffer_size)))
            workers[-1].start()
        for w in workers:
            w.join()
        print('done gathering initial data!')
        l = list(buffer)  # to save a non-shared copy...

collecting data without network (0)
collecting data without network (1)
collecting data without network (2)
collecting data without network (3)
done gathering initial data!


## MCTS with policy and value networks

In [15]:
net = az.Network()

## Optimizing networks with replay buffer data

See [this doc](https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html) for writing training loss to tensorboard data, and [this doc](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference) for saving/loading models.

## Multiprocessing

Setting up this algorithm to run in parallel is quite important. I'm using [multiprocessing](https://docs.python.org/3/library/multiprocessing.html) to handle the parallelism, and it looks like pytorch also has a similar API for moving Tensors around. With 2 processors on my laptop, speedup is about 90% (not bad...).

Want to set random seed for each process, otherwise you end up getting all the same results...

In [92]:
def collect_data(proc_num, buffer, index, lock, buffer_size, net, ps_count):
    """
    Args:
        ps_count (Value): A shared count of how many pulse sequences have been
            constructed so far
    """
    print(datetime.now(), f'collecting data ({proc_num})')
    config = az.Config()
    config.num_simulations = 250
    ps_config = ps.PulseSequenceConfig(Utarget=Utarget, N=N, ensemble_size=ensemble_size,
                                       max_sequence_length=max_sequence_length,
                                       pulse_width=pulse_width, delay=delay)
    for _ in range(4):
        ps_config.reset()
        output = az.make_sequence(config, ps_config, network=net, rng=ps_config.rng)
        output_tensors = az.convert_stats_to_tensors(output)
        with lock:
            ps_count.value += 1
        for obs in output_tensors:
            with lock:
                if len(buffer) < buffer_size:
                    buffer.append(obs)
                else:
                    buffer[index.value] = obs
                index.value += 1
                if index.value >= buffer_size:
                    index.value = 0

In [106]:
if __name__ == '__main__':
    with mp.Manager() as manager:
        buffer = manager.list()  #[None] * 500
        index = manager.Value(typecode='i', value=0)
        lock = manager.RLock()
        # get network
        net = az.Network()
        net.share_memory()
        workers = []
        for i in range(4):
            workers.append(mp.Process(target=collect_data,
                                      args=(i, buffer, index, lock, buffer_size, net)))
            workers[-1].start()
        for w in workers:
            w.join()
        print('done gathering data!')
        l = list(buffer)  # save a non-shared copy

collecting data (0)
collecting data (1)
collecting data (2)
collecting data (3)
done gathering initial data!


## Training process

In [94]:
def train_process(proc_num, buffer, net, global_step, ps_count):
    """
    Args:
        buffer (mp.managers.list): Replay buffer, list of (state, probability, value).
        global_step (mp.managers.Value): Counter to keep track of training iterations
        writer (SummaryWriter): Write losses to log
    """
    print(datetime.now(), f'started training process ({proc_num})')
    writer = SummaryWriter()
    net_optimizer = optim.Adam(net.parameters(),)
    for i in range(num_iters):  # number of training iterations
        if i % 100 == 0:
            print(datetime.now(), 'saving network...')
            # save network parameters to file
            if not os.path.exists('network'):
                os.makedirs('network')
            torch.save(net.state_dict(), f'network/{i:07.0f}-network')
        if len(buffer) < batch_size:
            print(datetime.now(), 'not enough data yet, sleeping...')
            sleep(5)
            continue
        net_optimizer.zero_grad()
        # sample minibatch from replay buffer
        minibatch = random.sample(list(buffer), batch_size)
        states, probabilities, values = zip(*minibatch)
        probabilities = torch.stack(probabilities)
        values = torch.stack(values)
        packed_states = az.pad_and_pack(states)
        # evaluate network
        policy_outputs, value_outputs, _ = net(packed_states)
        policy_loss = -1 / \
            len(states) * torch.sum(probabilities * torch.log(policy_outputs))
        value_loss = F.mse_loss(value_outputs, values)
        loss = policy_loss + value_loss
        loss.backward()
        net_optimizer.step()
        # write losses to log
        writer.add_scalar('training_policy_loss',
                          policy_loss, global_step=global_step.value)
        writer.add_scalar('training_value_loss',
                          value_loss, global_step=global_step.value)
        # every 10 iterations, add histogram of replay buffer values
        # and save network to file...
        if i % 10 == 0:
            print(datetime.now(), f'updated network (iteration {i})',
                  f'pulse_sequence_count: {ps_count.value}')
            _, _, values = zip(*list(buffer))
            values = torch.stack(values).squeeze()
            writer.add_histogram('buffer_values', values, global_step=global_step.value)
            writer.add_scalar('pulse_sequence_count', ps_count.value, global_step.value)
        global_step.value += 1
        sleep(.5)  # TODO remove this

In [96]:
if __name__ == '__main__':
    with mp.Manager() as manager:
        buffer = manager.list()
        index = manager.Value(typecode='i', value=0)
        global_step = manager.Value('i', 0)
        ps_count = manager.Value('i', 0)
        lock = manager.RLock()
        # get network
        net = az.Network()
        net.share_memory()
        collectors = []
        for i in range(4):
            c = mp.Process(target=collect_data_no_net,
                           args=(i, buffer, index, lock, buffer_size, ps_count))
            c.start()
            collectors.append(c)
        trainer = mp.Process(target=train_process,
                             args=(4, buffer, net, global_step, ps_count))
        trainer.start()
        # start data collectors with network
        for i in range(5, 9):
            c = mp.Process(target=collect_data,
                           args=(i, buffer, index, lock, buffer_size, net, ps_count))
            c.start()
            collectors.append(c)
        for p in collectors:
            p.join()
        print('all collectors are joined')
        trainer.join()
        print('trainer is joined')
        l = list(buffer)  # save a non-shared copy
        print('done gathering data!')

It appears that sharing the neural network behaves as expected! Training updates the weights, and those updated weights are reflected in each of the data collection processes. Neat!