# AlphaZero implementation for pulse sequence design
_Will Kaufman, December 2020_

[Dalgaard et. al. (2020)](https://www.nature.com/articles/s41534-019-0241-0) applied this approach to constructing shaped pulses (as I understand it), but in theory this should be as applicable to pulse sequence design, if not more so. The original [AlphaZero paper](https://science.sciencemag.org/content/362/6419/1140.full) is here.

The general idea behind AlphaZero (as I understand it) is to do a "smart" tree search that balances previous knowledge (the policy), curiosity in unexplored branches, and high-value branches. My thought is that this can be improved with AHT (i.e. knowing that by the end of the pulse sequence, the pulse sequence must be cyclic (the overall frame transformation must be identity) and there must be equal times spent on each axis). This will provide a hard constraint that will (hopefully) speed up search.

In [1]:
import qutip as qt
import numpy as np
import matplotlib.pyplot as plt
import sys, os
from random import sample

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
sys.path.append(os.path.abspath('..'))
import pulse_sequences as ps
import alpha_zero as az

In [475]:
import importlib
importlib.reload(az)
importlib.reload(ps)

<module 'pulse_sequences' from '/Users/willkaufman/Projects/rl_pulse/rl_pulse/pulse_sequences.py'>

## Define the spin system

In [3]:
delay = 1e-2  # time is relative to chemical shift strength
pulse_width = 5e-3
N = 3  # number of spins
ensemble_size = 5

In [4]:
X, Y, Z = ps.get_collective_spin(N)

In [5]:
Hsys_ensemble = [ps.get_Hsys(N) for _ in range(ensemble_size)]
pulses_ensemble = [
    ps.get_pulses(H, X, Y, Z, pulse_width, delay, rot_error=0.01) for H in Hsys_ensemble
]

In [6]:
Utarget = qt.identity(Hsys_ensemble[0].dims[0])

## Average Hamiltonian theory

To keep track of the average Hamiltonian (to lowest order), I'm defining a frame matrix and applying rotation matrices to the frame matrix, then determining how $I_z$ transforms during the pulse sequence. The last row in the frame matrix corresponds to the current transformed value of $I_z$.

In [7]:
ps.count_axes(ps.yxx48)

[8, 8, 8, 8, 8, 8]

In [8]:
ps.get_valid_time_suspension_pulses([0,1,1,], len(ps.pulse_names), 6)

[1, 3, 4]

## Tree search

Define nodes that can be used for tree search, with additional constraints that the lowest-order average Hamiltonian matches the desired Hamiltonian.

(deleted code that implemented tree search with constraints, see GitHub repo commits on 12/8 for code)

For 12-pulse sequences, calculated 16 branches at depth 4 in a minute, so about 1 every 4 seconds. At depth 4 there are $5^4 = 625$ branches, so that'll take $4 * 625 = 41$ hours to fully run. Alternatively, you can generate random pulse sequences until there's one that has the proper lowest-order average and cyclic property.

## Smarter search with MCTS

Following the [supplementary materials description under "Search"](https://science.sciencemag.org/content/sci/suppl/2018/12/05/362.6419.1140.DC1/aar6404-Silver-SM.pdf) to do rollouts and backpropagate information.

In [466]:
config = az.Config()
config.num_simulations = 500
ps_config = ps.PulseSequenceConfig(N, ensemble_size, pulse_width, delay, 6, Utarget)

In [68]:
# %load_ext snakeviz
# %snakeviz -t az.make_sequence(config, ps_config, None)

In [467]:
# stats = az.make_sequence(config, ps_config, None)

In [468]:
# ps_config.value()

In [469]:
# stats

## Replay buffer

Inspired by [this pytorch tutorial](https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html).

In [516]:
class ReplayBuffer(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0
    
    def add(self, data):
        """Save to replay buffer
        """
        if len(self) < self.capacity:
            self.buffer.append(data)
        else:
            self.buffer[self.position] = data
        self.position = (self.position + 1) % self.capacity
    
    def sample(self, batch_size=1):
        if batch_size > len(self):
            raise ValueError(f'batch_size of {batch_size} should be'
                             + f'less than buffer size of {len(self)}')
        return sample(self.buffer, batch_size)
    
    def __len__(self):
        return(len(self.buffer))

In [643]:
rb = ReplayBuffer(100000)

In [644]:
config = az.Config()
config.num_simulations = 100

For the purposes of saving data in a reasonable way (and using RNN), the state is represented by a sequence, where 0 indicates the start of sequence, and 1-5 are the possible pulses (1: delay, 2: x, etc...).

In [645]:
for _ in range(5):
    print(f'creating pulse sequence {_}')
    ps_config = ps.PulseSequenceConfig(N, ensemble_size, pulse_width, delay, 48, Utarget)
    stats = az.make_sequence(config, ps_config, None)
    for s in stats:
        state = torch.Tensor(s[0]) + 1
        state = torch.cat([torch.Tensor([0]), state])
        state = F.one_hot(state.long(), 6).float()
        probs = torch.Tensor(s[1])
        value = torch.Tensor([s[2]])
        rb.add((state,
                probs,
                value))

creating pulse sequence 0
creating pulse sequence 1
creating pulse sequence 2
creating pulse sequence 3
creating pulse sequence 4


In [634]:
len(rb)

240

In [927]:
# rb.sample()

## Neural networks for policy, value estimation

Batched tensors have shape `B * T * ...` where `B` is batch size and `T` is the timestep. Different from default behavior, but more intuitive to me.

In [892]:
class Policy(nn.Module):
    def __init__(self, input_size=6, lstm_size=16, output_size=6):
        super(Policy, self).__init__()
        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=lstm_size,
                            num_layers=1,
                            batch_first=True)
        self.fc1 = nn.Linear(lstm_size, output_size)
    
    def forward(self, x, h0=None, c0=None):
        """Calculates the policy from state x
        
        Args:
            x: The state of the pulse sequence. Either a tensor with 
                shape B*T*(num_actions + 1), or a packed sequence of states.
        """
        if h0 is None or c0 is None:
            x, (h, c) = self.lstm(x)
        else:
            x, (h, c) = self.lstm(x, (h0, c0))
        if type(x) is torch.Tensor:
            x = x[:, -1, :]
        elif type(x) is nn.utils.rnn.PackedSequence:
            # x is PackedSequence, need to get last timestep from each
            x, lengths = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
            idx = (
                lengths.long() - 1
            ).view(
                -1, 1
            ).expand(
                len(lengths), x.size(2)
            ).unsqueeze(1)
            x = x.gather(1, idx).squeeze(1)
        x = F.softmax(self.fc1(x), dim=1)
        return x, (h, c)

In [893]:
p = Policy()

In [666]:
optimizer = optim.Adam(p.parameters())

In [707]:
batch_size = 30
seq_length = 48

In [708]:
inputs = rb.sample(batch_size)

In [894]:
states = [i[0] for i in inputs]

In [895]:
lengths = [s.size(0) for s in states]

In [896]:
packed_states = nn.utils.rnn.pack_padded_sequence(
    nn.utils.rnn.pad_sequence(states, batch_first=True),
    lengths, enforce_sorted=False, batch_first=True
)

In [915]:
output, (h, c) = p(packed_states)

Confirm that the output is the same for packed and individual inputs.

In [916]:
output_individual = torch.cat([p(s.unsqueeze(0))[0] for s in states])

In [918]:
torch.norm(output - output_individual)

tensor(1.2197e-07, grad_fn=<CopyBackwards>)

Check that the hidden and cell states work properly.

In [919]:
output1, (h1, c1) = p(packed_states, h0=h, c0=c)

In [920]:
doubled_states = [
    torch.cat([s, s])
    for s in states
]

In [921]:
doubled_lengths = [s.size(0) for s in doubled_states]

In [922]:
packed_doubles = nn.utils.rnn.pack_padded_sequence(
    nn.utils.rnn.pad_sequence(doubled_states, batch_first=True),
    doubled_lengths, enforce_sorted=False, batch_first=True
)

In [923]:
output2, (h2, c2) = p(packed_doubles)

In [925]:
torch.norm(output2 - output1)

tensor(2.5810e-08, grad_fn=<CopyBackwards>)

## Optimize policy based on target distribution

In [None]:
targets = 0 # TODO...

In [346]:
for _ in range(1000):
    outputs, (h, c) = p(inputs)
    loss = -torch.sum(targets * torch.log(outputs))
    if _ % 100 == 0:
        print(loss)
    loss.backward()
    optimizer.step()

tensor(80.8050, grad_fn=<NegBackward>)
tensor(77.2775, grad_fn=<NegBackward>)
tensor(73.8556, grad_fn=<NegBackward>)
tensor(71.4682, grad_fn=<NegBackward>)
tensor(69.1231, grad_fn=<NegBackward>)
tensor(67.9690, grad_fn=<NegBackward>)
tensor(66.3810, grad_fn=<NegBackward>)
tensor(65.5521, grad_fn=<NegBackward>)
tensor(65.0454, grad_fn=<NegBackward>)
tensor(65.2894, grad_fn=<NegBackward>)


## Value network

## TODO

- [ ] Value network (below)
- [ ] Bring it all together with MCTS...
- [ ] Set up Discovery environment
- [ ] Run it and (hopefully) rejoice!

In [None]:
class Value(nn.Module):
    """TODO change all below, just copied from Policy network
    """
    def __init__(self, input_size=6, lstm_size=16, output_size=6):
        super(Policy, self).__init__()
        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=lstm_size,
                            num_layers=1,
                            batch_first=True)
        self.fc1 = nn.Linear(lstm_size, output_size)
    
    def forward(self, x, h0=None, c0=None):
        """Calculates the policy from state x
        
        Args:
            x: The state of the pulse sequence. Either a tensor with 
                shape B*T*(num_actions + 1), or a packed sequence of states.
        """
        if h0 is None or c0 is None:
            x, (h, c) = self.lstm(x)
        else:
            x, (h, c) = self.lstm(x, (h0, c0))
        if type(x) is torch.Tensor:
            x = x[:, -1, :]
        elif type(x) is nn.utils.rnn.PackedSequence:
            # x is PackedSequence, need to get last timestep from each
            x, lengths = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
            idx = (
                lengths.long() - 1
            ).view(
                -1, 1
            ).expand(
                len(lengths), x.size(2)
            ).unsqueeze(1)
            x = x.gather(1, idx).squeeze(1)
        x = F.softmax(self.fc1(x), dim=1)
        return x, (h, c)