# AlphaZero implementation for pulse sequence design
_Will Kaufman, December 2020_

[Dalgaard et. al. (2020)](https://www.nature.com/articles/s41534-019-0241-0) applied this approach to constructing shaped pulses (as I understand it), but in theory this should be as applicable to pulse sequence design, if not more so. The original [AlphaZero paper](https://science.sciencemag.org/content/362/6419/1140.full) is here.

The general idea behind AlphaZero (as I understand it) is to do a "smart" tree search that balances previous knowledge (the policy), curiosity in unexplored branches, and high-value branches. My thought is that this can be improved with AHT (i.e. knowing that by the end of the pulse sequence, the pulse sequence must be cyclic (the overall frame transformation must be identity) and there must be equal times spent on each axis). This will provide a hard constraint that will (hopefully) speed up search.

In [None]:
import qutip as qt
import numpy as np
from scipy.spatial.transform import Rotation
import matplotlib.pyplot as plt

## Define the spin system

In [None]:
delay = 1e-2  # time is relative to chemical shift strength
pulse_width = 5e-3
N = 3  # number of spins

In [None]:
def get_Hsys(dipolar_strength=1e-2):
    chemical_shifts = 2*np.pi * np.random.normal(scale=1, size=(N,))
    Hcs = sum(
        [qt.tensor(
            [qt.identity(2)]*i
            + [chemical_shifts[i] * qt.sigmaz()]
            + [qt.identity(2)]*(N-i-1)
        ) for i in range(N)]
    )
    # dipolar interactions
    dipolar_matrix = 2*np.pi * np.random.normal(scale=dipolar_strength, size=(N, N))
    Hdip = sum([
        dipolar_matrix[i, j] * (
            2 * qt.tensor(
                [qt.identity(2)]*i
                + [qt.sigmaz()]
                + [qt.identity(2)]*(j-i-1)
                + [qt.sigmaz()]
                + [qt.identity(2)]*(N-j-1)
            )
            - qt.tensor(
                [qt.identity(2)]*i
                + [qt.sigmax()]
                + [qt.identity(2)]*(j-i-1)
                + [qt.sigmax()]
                + [qt.identity(2)]*(N-j-1)
            )
            - qt.tensor(
                [qt.identity(2)]*i
                + [qt.sigmay()]
                + [qt.identity(2)]*(j-i-1)
                + [qt.sigmay()]
                + [qt.identity(2)]*(N-j-1)
            )
        )
        for i in range(N) for j in range(i+1, N)
    ])
    return Hcs + Hdip

In [None]:
def get_pulses(Hsys, X, Y, Z, pulse_width, delay, rot_error=0):
    rot = np.random.normal(scale=rot_error)
    pulses = [
        qt.propagator(Hsys, pulse_width),
        qt.propagator(X * (np.pi/2) * (1 + rot) / pulse_width + Hsys, pulse_width),
        qt.propagator(-X * (np.pi/2) * (1 + rot) / pulse_width + Hsys, pulse_width),
        qt.propagator(Y * (np.pi/2) * (1 + rot) / pulse_width + Hsys, pulse_width),
        qt.propagator(-Y * (np.pi/2) * (1 + rot) / pulse_width + Hsys, pulse_width),
#         qt.propagator(Z * (np.pi/2) * (1 + rot) / pulse_width + Hsys, pulse_width),
#         qt.propagator(-Z * (np.pi/2) * (1 + rot) / pulse_width + Hsys, pulse_width),
    ]
    delay_propagator = qt.propagator(Hsys, delay)
    pulses = [delay_propagator * i for i in pulses]
    return pulses

In [None]:
pulse_names = [
    'd', 'x', '-x', 'y', '-y', #'z', '-z'
]

In [None]:
def pulse_sequence_string(pulse_sequence):
    """Return a string that correspond to pulse sequence
    """
    pulse_list = ','.join([pulse_names[i] for i in pulse_sequence])
    return pulse_list

In [None]:
def get_pulse_sequence(string):
    """Returns a list of integers for the pulse sequence
    """
    chars = string.split(',')
    pulse_sequence = [pulse_names.index(c) for c in chars]
    return pulse_sequence

In [None]:
def get_propagator(pulse_sequence, pulses):
    propagator = qt.identity(pulses[0].dims[0])
    for p in pulse_sequence:
        propagator = pulses[p] * propagator
    return propagator

In [None]:
X = sum(
    [qt.tensor(
        [qt.identity(2)]*i
        + [qt.spin_Jx(1/2)]
        + [qt.identity(2)]*(N-i-1)
    ) for i in range(N)]
)
Y = sum(
    [qt.tensor(
        [qt.identity(2)]*i
        + [qt.spin_Jy(1/2)]
        + [qt.identity(2)]*(N-i-1)
    ) for i in range(N)]
)
Z = sum(
    [qt.tensor(
        [qt.identity(2)]*i
        + [qt.spin_Jz(1/2)]
        + [qt.identity(2)]*(N-i-1)
    ) for i in range(N)]
)

In [None]:
Hsys_ensemble = [get_Hsys() for _ in range(5)]
pulses_ensemble = [
    get_pulses(H, X, Y, Z, pulse_width, delay, rot_error=0.01) for H in Hsys_ensemble
]

In [None]:
Utarget = qt.identity(Hsys_ensemble[0].dims[0])

## Pulse sequences

In [None]:
ideal6 = [3, 1, 1, 3, 2, 2]
yxx24 = [4, 1, 2, 3, 2, 2, 3, 2, 1, 4, 1, 1, 3, 2, 1, 4, 1, 1, 4, 1, 2, 3, 2, 2]
yxx48 = [
    3, 2, 2, 3, 2, 2, 4, 1, 1, 3, 2, 2, 4, 1, 1, 4, 1, 1, 3, 2, 2, 3, 2, 2,
    4, 1, 1, 3, 2, 2, 4, 1, 1, 4, 1, 1, 3, 2, 2, 4, 1, 1, 3, 2, 2, 4, 1, 1
]

# brute-force search
bf6 = [1, 1, 3, 1, 1, 3]
bf12 = [1, 1, 4, 1, 1, 4, 2, 2, 4, 2, 2, 4]
bfr12 = [1, 4, 4, 1, 4, 4, 1, 3, 3, 1, 3, 3]

## Average Hamiltonian theory

To keep track of the average Hamiltonian (to lowest order), I'm defining a frame matrix and applying rotation matrices to the frame matrix, then determining how $I_z$ transforms during the pulse sequence. The last row in the frame matrix corresponds to the current transformed value of $I_z$.

In [None]:
rots = [
    np.eye(3),
    np.round(Rotation.from_euler('x', 90, degrees=True).as_matrix()),
    np.round(Rotation.from_euler('x', -90, degrees=True).as_matrix()),
    np.round(Rotation.from_euler('y', 90, degrees=True).as_matrix()),
    np.round(Rotation.from_euler('y', -90, degrees=True).as_matrix()),
#     np.round(Rotation.from_euler('z', 90, degrees=True).as_matrix()),
#     np.round(Rotation.from_euler('z', -90, degrees=True).as_matrix()),
]

In [None]:
def get_rotation(pulse_sequence):
    frame = np.eye(3)
    for p in pulse_sequence:
        frame = rots[p] @ frame
    return frame

In [None]:
def is_cyclic(pulse_sequence):
    frame = get_rotation(pulse_sequence)
    return (frame == np.eye(3)).all()

In [None]:
def count_axes(pulse_sequence):
    axes_counts = [0] * 6
    frame = np.eye(3)
    for p in pulse_sequence:
        frame = rots[p] @ frame
        axis = np.where(frame[-1,:])[0][0]
        is_negative = np.sum(frame[-1,:]) < 0
        axes_counts[axis + 3*is_negative] += 1
    return axes_counts

In [None]:
count_axes([0,1,4,0,3,2])

In [None]:
count_axes(get_pulse_sequence('d,x,-y,d,y,-x'))  # WHH-4

In [None]:
count_axes(yxx48)

In [None]:
def is_valid_dd(subsequence, sequence_length):
    """Checks if the pulse subsequence allows for dynamical decoupling of
        dipolar interactions (i.e. equal time spent on each axis)
    """
    axes_counts = count_axes(subsequence)
    (x, y, z) = [axes_counts[i] + axes_counts[i+3] for i in range(3)]
    # time on each axis isn't more than is allowed for dd
    return (np.array([x, y, z]) <= sequence_length / 3).all()

In [None]:
def is_valid_time_suspension(subsequence, sequence_length):
    """Checks if the pulse subsequence allows for dynamical decoupling of
        all interactions (i.e. equal time spent on each ± axis)
    """
    axes_counts = count_axes(subsequence)
    # time on each axis isn't more than is allowed for dd
    return (np.array(axes_counts) <= sequence_length / 6).all()

In [None]:
def get_valid_time_suspension_pulses(subsequence, pulse_names, sequence_length):
    valid_pulses = []
    for p in range(len(pulse_names)):
        if is_valid_time_suspension(subsequence + [p], sequence_length):
            valid_pulses.append(p)
    return valid_pulses

In [None]:
a = get_valid_time_suspension_pulses([0,1,1,], pulse_names, 6)

## Tree search

Define nodes that can be used for tree search, with additional constraints that the lowest-order average Hamiltonian matches the desired Hamiltonian.

In [None]:
# class Node(object):
#     """A node of the pulse sequence tree. Each node has a particular
#     sequence of pulses applied so far.
#     """
    
#     def __init__(
#             self,
#             propagators,
#             sequence=[],
#             depth=0):
#         """Create a new node with a given propagator
#         """
#         self.propagators = propagators  # list of Qobj propagators for each element of ensemble
#         self.sequence = sequence
#         self.depth = depth
        
#         self.children = {}
    
#     def has_children(self):
#         return len(self.children) > 0
    
#     def evaluate(
#             self,
#             Utarget,
#             pulse_names,
#             pulses_ensemble,
#             reward_dict,
#             print_depth=None,
#             max_depth=6):
#         """If the node isn't at max_depth, then create children and
#         evaluate each individually. If the node is at max_depth, then
#         calculate the reward and add the sequence/reward pair to
#         reward_dict.
        
#         Arguments:
#             pulses: An array of unitary operators representing all actions
#                 that can be applied to the system.
#         Returns: The maximum reward seen by the node or its children, and
#             the corresponding sequence.
            
#         """
#         if self.depth is not None:
#             if self.depth == print_depth:
#                 print(f'At depth {self.depth}, ({pulse_sequence_string(self.sequence)})')
#         if self.depth < max_depth:
#             max_reward = 0
#             max_reward_sequence = None
#             valid_pulses = get_valid_time_suspension_pulses(
#                 self.sequence,
#                 pulse_names,
#                 max_depth
#             )
#             for p in valid_pulses:
#                 sequence = self.sequence + [p]
#                 depth = self.depth + 1
#                 propagators = []
#                 for j in range(len(pulses_ensemble)):
#                     propagators.append(
#                         pulses_ensemble[j][p]
#                         * self.propagators[j])
#                 child = Node(
#                     propagators,
#                     sequence,
#                     depth=depth)
#                 r, s = child.evaluate(
#                     Utarget,
#                     pulse_names,
#                     pulses_ensemble,
#                     reward_dict,
#                     print_depth=print_depth,
#                     max_depth=max_depth
#                 )
#                 if r > max_reward:
#                     max_reward = r
#                     max_reward_sequence = s
#             return max_reward, max_reward_sequence
#         else:
#             fidelities = [np.clip(
#                 qt.metrics.average_gate_fidelity(p, Utarget),
#                 0, 1
#             ) for p in self.propagators]
#             fidelity = np.nanmean(fidelities)
#             reward = - np.log10(1.0 - fidelity + 1e-100)
#             sequence_str = ','.join([str(a) for a in self.sequence])
#             reward_dict[sequence_str] = reward
#             return reward, self.sequence

In [None]:
# root = Node([qt.identity(Hsys_ensemble[0].dims[0])] * len(Hsys_ensemble))
# reward_dict = {}
# root.evaluate(Utarget, pulse_names, pulses_ensemble, reward_dict, print_depth=1, max_depth=6)

In [None]:
# primitives = sorted(reward_dict.keys(), key=lambda x: reward_dict[x], reverse=True)
# primitives = [i.split(',') for i in primitives]
# primitives = [[int(j) for j in i] for i in primitives]

In [None]:
# rewards = sorted(reward_dict.values(), reverse=True)

For 12-pulse sequences, calculated 16 branches at depth 4 in a minute, so about 1 every 4 seconds. At depth 4 there are $5^4 = 625$ branches, so that'll take $4 * 625 = 41$ hours to fully run. Alternatively, you can generate random pulse sequences until there's one that has the proper lowest-order average and cyclic property.

In [None]:
rand = []
while True:
    rand = np.random.choice(len(pulse_names), 12)
    if is_valid_time_suspension(rand, 12) and is_cyclic(rand):
        break
count_axes(rand)

In [None]:
pulse_sequence_string(rand)

## Smarter search with MCTS

Following the [supplementary materials description under "Search"](https://science.sciencemag.org/content/sci/suppl/2018/12/05/362.6419.1140.DC1/aar6404-Silver-SM.pdf) to do rollouts and backpropagate information.


## TODO

- [ ] Vary the exploration parameter for calculating U
- [ ] Add neural networks for policy/value estimations
- [ ] Save propagators to speed up MCTS? Or too memory intensive?
- [ ] Are there smarter ways to constrain search? Need cyclic and equal time on each axis.

In [None]:
class MCTSNode(object):
    """A node of the pulse sequence tree. Each node has a particular
    sequence of pulses applied so far.
    """
    
    def __init__(
            self,
            propagators,
            N,
            W,
            Q,
            #P,  # prior probability pi(a|s)
            sequence=[],
            depth=0):
        """Create a new node with a given propagator
        
        The state-action statistics (N, W, Q) are stored as dictionaries
        of dictionaries, where the first key is the state, and the second
        key is the action.
        
        Args:
            N: Visit counts
            W: Total action-value
            Q: Mean action-value
        """
        self.propagators = propagators  # list of Qobj propagators for each element of ensemble
        self.N = N
        self.W = W
        self.Q = Q
        self.sequence = sequence
        # the state is just a string-version of the pulse sequence
        self.state = self.sequence_string()
        self.depth = depth
        
        self.children = {}
    
    
    def sequence_string(self):
        return ','.join([str(a) for a in self.sequence])
    
#     def has_children(self):
#         return len(self.children) > 0
    
    def get_N(self, action):
        """Get the state-action visit count, and
        initialize it to zero if the keys don't exist
        """
        if self.state not in self.N:
            self.N[self.state] = {}
        if action not in self.N[self.state]:
            self.N[self.state][action] = 0
        return self.N[self.state][action]
    
    def increment_N(self, action, increment=1):
        old_value = self.get_N(action)
        self.N[self.state][action] = old_value + increment
    
    def get_W(self, action):
        if self.state not in self.W:
            self.W[self.state] = {}
        if action not in self.W[self.state]:
            self.W[self.state][action] = 0
        return self.W[self.state][action]
    
    def update_W(self, action, new_value):
        old_value = self.get_W(action)
        self.W[self.state][action] = old_value + new_value
    
    def get_Q(self, action):
        if self.state not in self.Q:
            self.Q[self.state] = {}
        if action not in self.Q[self.state]:
            self.Q[self.state][action] = 0
        return self.Q[self.state][action]
    
    def update_Q(self, action):
        W = self.get_W(action)
        N = self.get_N(action)
        old_Q = self.get_Q(action)  # just ensures that the keys exist
        self.Q[self.state][action] = 1.0 * W / N
    
    def get_U(self, action):
        """Exploration function, allegedly a variant of PUCT.
        
        U(s, a) = C(s) P(s, a) \sqrt{N(s)} / (1 + N(s, a))
        
        """
        N = self.get_N(action)
        N_total = sum(self.N[self.state].values())
        C = np.log10((1 + N_total + 1e2)/1e2) + 1
        P = 1  # TODO will need to change this eventually
        U = (C * 1 * np.sqrt(N_total) / (1 + N))
        return U + np.random.normal(scale=.1)
    
    def rollout(
            self,
            Utarget,
            pulse_names,
            pulses_ensemble,
            reward_dict=None,
            max_depth=6):
        """Sample a rollout by choosing a random pulse
            
        """
        if self.depth < max_depth:
            valid_pulses = get_valid_time_suspension_pulses(
                self.sequence,
                pulse_names,
                max_depth
            )
            if len(valid_pulses) == 0:
                return -1, self.sequence
            # select a valid pulse that maximizes Q+U
            Q = np.array([self.get_Q(a) for a in valid_pulses])
            U = np.array([self.get_U(a) for a in valid_pulses])
            p = valid_pulses[np.argmax(Q + U)]
#             print('valid pulses:\t', valid_pulses)
#             print('N:\t', [self.get_N(a) for a in valid_pulses])
#             print('Q:\t', Q)
#             print('U:\t', U)
#             print('Q + U:\t', Q + U)
            sequence = self.sequence + [p]
            depth = self.depth + 1
            propagators = []
            # calculate propagators for child node
            for j in range(len(pulses_ensemble)):
                propagators.append(
                    pulses_ensemble[j][p]
                    * self.propagators[j])
            child = MCTSNode(
                propagators,
                self.N, self.W, self.Q,
                sequence=sequence,
                depth=depth)
            reward, sequence = child.rollout(
                Utarget,
                pulse_names,
                pulses_ensemble,
                reward_dict=reward_dict,
                max_depth=max_depth
            )
            # update statistics for N, W, Q
            self.increment_N(p)
            self.update_W(p, reward)
            self.update_Q(p)
            return reward, sequence
        else:
            fidelities = [np.clip(
                qt.metrics.average_gate_fidelity(p, Utarget),
                0, 1
            ) for p in self.propagators]
            fidelity = np.nanmean(fidelities)
            reward = - np.log10(1.0 - fidelity + 1e-100)
            sequence_str = self.sequence_string()
            if reward_dict is not None:
                reward_dict[sequence_str] = reward
            return reward, self.sequence

In [None]:
N = {}
W = {}
Q = {}
reward_dict = {}

In [None]:
root = MCTSNode([Utarget] * 5, N, W, Q)

In [None]:
for _ in range(500):
    output = root.rollout(Utarget, pulse_names, pulses_ensemble, None, max_depth=48)
    if _ % 50 == 0:
        print(output)