# AlphaZero implementation for pulse sequence design
_Will Kaufman, December 2020_

[Dalgaard et. al. (2020)](https://www.nature.com/articles/s41534-019-0241-0) applied this approach to constructing shaped pulses (as I understand it), but in theory this should be as applicable to pulse sequence design, if not more so. The original [AlphaZero paper](https://science.sciencemag.org/content/362/6419/1140.full) and the [AlphaGo Zero paper](https://www.nature.com/articles/nature24270) are useful resources.

The general idea behind AlphaZero (as I understand it) is to do a "smart" tree search that balances previous knowledge (the policy), curiosity in unexplored branches, and high-value branches. My thought is that this can be improved with AHT (i.e. knowing that by the end of the pulse sequence, the pulse sequence must be cyclic (the overall frame transformation must be identity) and there must be equal times spent on each axis). This will provide a hard constraint that will (hopefully) speed up search.

## System installation

Make sure the following packages are installed

- `numpy`
- `scipy`
- `qutip`
- `pytorch`
- `tensorboard`

## TODO
- [ ] Collect all hyperparameters up top or in config (e.g. how many pulse sequences to collect data from)
- [ ] Speed up LSTM (save hidden state, batch parallel pulse sequences, other?)
- [ ] Figure out GPU utilization (if I can...)
- [ ] Look into collecting training data and training continuously
- [ ] Change dirichlet noise to match number of possible moves (5 for now, eventually 24)
- [ ] Dynamically figure out how many CPUs there are available, and set pool to use that
- [ ] Mess around with hyperparameters (e.g. in config object), see if performance improves

# TODO

- [ ] Add other changes on github project page (lots of documenting algo run)
- [ ] Run it on Discovery, hope it works!
- [ ] Clean up code, add tests, make sure everything is working as expected

In [1]:
from datetime import datetime
import random
from time import sleep
import qutip as qt
import sys
import os
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import importlib

In [2]:
sys.path.append(os.path.abspath('..'))
import pulse_sequences as ps
import alpha_zero as az

In [3]:
# importlib.reload(az)
# importlib.reload(ps)

## Define hyperparameters

In [4]:
collect_no_net_procs = 2  # 15
collect_no_net_count = 5  # 100
collect_procs = 2  # 15

buffer_size = int(1e6)  # 1e6
batch_size = 64  # 2048
num_iters = int(1e2)  # 800e3

max_sequence_length = 48

print_every = 1  # 100
save_every = 10  # 1000

reward_threshold = 2.5

## Define the spin system

In [5]:
dipolar_strength = 1e2
pulse_width = 1e-5  # time is relative to chemical shift strength
delay = 1e-4
N = 3  # number of spins
ensemble_size = 50
rot_error = 0.01

In [6]:
# X, Y, Z = ps.get_collective_spin(N)
# Hsys_ensemble = [ps.get_Hsys(N) for _ in range(ensemble_size)]
# pulses_ensemble = [
#     ps.get_pulses(H, X, Y, Z, pulse_width, delay, rot_error=0.01) for H in Hsys_ensemble
# ]
# Utarget = qt.identity(Hsys_ensemble[0].dims[0])
Utarget = qt.tensor([qt.identity(2)] * N)

## Smarter search with MCTS

Following the [supplementary materials description under "Search"](https://science.sciencemag.org/content/sci/suppl/2018/12/05/362.6419.1140.DC1/aar6404-Silver-SM.pdf) to do rollouts and backpropagate information. All of the relevant code for the alpha zero algorithm is in alpha_zero.py.

For the purposes of saving data in a reasonable way (and using RNN), the state is represented by a sequence, where 0 indicates the start of sequence, and 1-5 are the possible pulses (1: delay, 2: x, etc...).

In [7]:
# output = az.make_sequence(az.Config(), ps.PulseSequenceConfig(Utarget))

## Fill replay buffer with inital data

In [7]:
def collect_data_no_net(proc_num, queue, ps_count, global_step, lock):
    """
    Args:
        proc_num: Which process number this is (for debug purposes)
        queue (Queue): A queue to add the statistics gathered
            from the MCTS rollouts.
        ps_count (Value): Shared count of how many pulse sequences have
            been constructed
    """
    print(datetime.now(), f'collecting data without network ({proc_num})')
    config = az.Config()
    ps_config = ps.PulseSequenceConfig(N=N, ensemble_size=ensemble_size,
                                       max_sequence_length=max_sequence_length,
                                       Utarget=Utarget,
                                       dipolar_strength=dipolar_strength,
                                       pulse_width=pulse_width, delay=delay,
                                       rot_error=rot_error)
    for i in range(collect_no_net_count):
        ps_config.reset()
        output = az.make_sequence(config, ps_config, network=None,
                                  rng=ps_config.rng)
        if output[-1][2] > reward_threshold:
            print(datetime.now(),
                  f'candidate pulse sequence from {proc_num}',
                  output[-1])
        with lock:
            queue.put(output)
            ps_count.value += 1

In [9]:
# if __name__ == '__main__':
#     with mp.Manager() as manager:
#         buffer = manager.list()  #[None] * buffer_size
#         index = manager.Value(typecode='i', value=0)
#         ps_count = manager.Value(typecode='i', value=0)
#         lock = manager.RLock()
#         workers = []
#         for i in range(4):
#             workers.append(mp.Process(target=collect_data_no_net,
#                                       args=(i, buffer, index, lock,
#                                             buffer_size, ps_count)))
#             workers[-1].start()
#         for w in workers:
#             w.join()
#         print('done gathering initial data!')
#         l = list(buffer)  # to save a non-shared copy...

## Optimizing networks with replay buffer data

See [this doc](https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html) for writing training loss to tensorboard data, and [this doc](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference) for saving/loading models.

## Multiprocessing

Setting up this algorithm to run in parallel is quite important. I'm using [multiprocessing](https://docs.python.org/3/library/multiprocessing.html) to handle the parallelism, and it looks like pytorch also has a similar API for moving Tensors around. With 2 processors on my laptop, speedup is about 90% (not bad...).

Want to set random seed for each process, otherwise you end up getting all the same results...

In [8]:
def collect_data(proc_num, queue, net, ps_count, global_step, lock):
    """
    Args:
        queue (Queue): A queue to add the statistics gathered
            from the MCTS rollouts.
        ps_count (Value): A shared count of how many pulse sequences have been
            constructed so far
    """
    print(datetime.now(), f'collecting data ({proc_num})')
    config = az.Config()
    config.num_simulations = 250
    ps_config = ps.PulseSequenceConfig(Utarget=Utarget, N=N,
                                       ensemble_size=ensemble_size,
                                       max_sequence_length=max_sequence_length,
                                       dipolar_strength=dipolar_strength,
                                       pulse_width=pulse_width, delay=delay,
                                       rot_error=rot_error)
    while global_step.value < num_iters:
        ps_config.reset()
        output = az.make_sequence(config, ps_config, network=net,
                                  rng=ps_config.rng)
        if output[-1][2] > reward_threshold:
            print(datetime.now(),
                  f'candidate pulse sequence from {proc_num}',
                  output[-1])
        with lock:
            queue.put(output)
            ps_count.value += 1

In [11]:
# if __name__ == '__main__':
#     with mp.Manager() as manager:
#         buffer = manager.list()  #[None] * 500
#         index = manager.Value(typecode='i', value=0)
#         lock = manager.RLock()
#         # get network
#         net = az.Network()
#         net.share_memory()
#         workers = []
#         for i in range(4):
#             workers.append(mp.Process(target=collect_data,
#                                       args=(i, buffer, index, lock, buffer_size, net)))
#             workers[-1].start()
#         for w in workers:
#             w.join()
#         print('done gathering data!')
#         l = list(buffer)  # save a non-shared copy

## Training process

In [9]:
def train_process(queue, net, global_step, ps_count, lock,
                  c_value=1e0, c_l2=1e-6):
    """
    Args:
        queue (Queue): A queue to add the statistics gathered
            from the MCTS rollouts.
        global_step (mp.managers.Value): Counter to keep track
            of training iterations
        writer (SummaryWriter): Write losses to log
    """
    writer = SummaryWriter()
    start_time = datetime.now().strftime('%Y%m%d-%H%M%S')
    net_optimizer = optim.Adam(net.parameters(),)
    # construct replay buffer locally
    buffer = []
    index = 0
    i = 0
    # write network structure to tensorboard file
    tmp = torch.zeros((1, 10, 6))
    writer.add_graph(net, tmp)
    del tmp
    while global_step.value < num_iters:  # number of training iterations
        if i % save_every == 0:
            print(datetime.now(), 'saving network...')
            # save network parameters to file
            if not os.path.exists(f'{start_time}-network'):
                os.makedirs(f'{start_time}-network')
            torch.save(net.state_dict(), f'{start_time}-network/{i:07.0f}-network')
        # check if queue has new data to add to replay buffer
        with lock:
            while not queue.empty():
                new_stats = queue.get()
                new_stats = az.convert_stats_to_tensors(new_stats)
                for stat in new_stats:
                    if len(buffer) < buffer_size:
                        buffer.append(stat)
                    else:
                        buffer[index] = stat
                    index = index + 1 if index < buffer_size - 1 else 0
        # carry on with training
        if len(buffer) < batch_size:
            print(datetime.now(), 'not enough data yet, sleeping...')
            sleep(5)
            continue
#         elif len(buffer) < 1e4:
#             sleep(.5)  # put on the brakes a bit, don't tear through the data
        net_optimizer.zero_grad()
        # sample minibatch from replay buffer
        minibatch = random.sample(buffer, batch_size)
        states, probabilities, values = zip(*minibatch)
        probabilities = torch.stack(probabilities)
        values = torch.stack(values)
        packed_states = az.pad_and_pack(states)
        # evaluate network
        policy_outputs, value_outputs, _ = net(packed_states)
        policy_loss = -1 / \
            len(states) * torch.sum(probabilities * torch.log(policy_outputs))
        value_loss = F.mse_loss(value_outputs, values)
        l2_reg = torch.tensor(0.)
        for param in net.parameters():
            l2_reg += torch.norm(param)
        loss = policy_loss + c_value * value_loss + c_l2 * l2_reg
        loss.backward()
        net_optimizer.step()
        # write losses to log
        writer.add_scalar('training_policy_loss',
                          policy_loss, global_step=global_step.value)
        writer.add_scalar('training_value_loss',
                          c_value * value_loss, global_step=global_step.value)
        writer.add_scalar('training_l2_reg',
                          c_l2 * l2_reg, global_step=global_step.value)
        
        # every 10 iterations, add histogram of replay buffer values
        # and save network to file...
        if i % print_every == 0:
            print(datetime.now(), f'updated network (iteration {i})',
                  f'pulse_sequence_count: {ps_count.value}')
            _, _, values = zip(*list(buffer))
            values = torch.stack(values).squeeze()
            writer.add_histogram('buffer_values', values,
                                 global_step=global_step.value)
            writer.add_scalar('pulse_sequence_count', ps_count.value,
                              global_step.value)
        with lock:
            global_step.value += 1
        i += 1
        sleep(.1)

In [18]:
if __name__ == '__main__':
    with mp.Manager() as manager:
        queue = manager.Queue()
        global_step = manager.Value('i', 0)
        ps_count = manager.Value('i', 0)
        lock = manager.Lock()
        # get network
        net = az.Network()
        # optionally load state dict
        # net.load_state_dict(torch.load('network_state'))
        net.share_memory()
        collectors = []
        for i in range(collect_no_net_procs):
            c = mp.Process(target=collect_data_no_net,
                           args=(i, queue, ps_count, global_step, lock))
            c.start()
            collectors.append(c)
        trainer = mp.Process(target=train_process,
                             args=(queue, net,
                                   global_step, ps_count, lock))
        trainer.start()
        # join collectors before starting more
        for c in collectors:
            c.join()
        collectors.clear()
        # start data collectors with network
        for i in range(collect_procs):
            c = mp.Process(target=collect_data,
                           args=(i, queue, net, ps_count, global_step, lock))
            c.start()
            collectors.append(c)
        for c in collectors:
            c.join()
        print('all collectors are joined')
        trainer.join()
        print('trainer is joined')
        print('done!')

It appears that sharing the neural network behaves as expected! Training updates the weights, and those updated weights are reflected in each of the data collection processes. Neat!