# AlphaZero implementation for pulse sequence design
_Will Kaufman, December 2020_

[Dalgaard et. al. (2020)](https://www.nature.com/articles/s41534-019-0241-0) applied this approach to constructing shaped pulses (as I understand it), but in theory this should be as applicable to pulse sequence design, if not more so. The original [AlphaZero paper](https://science.sciencemag.org/content/362/6419/1140.full) is here.

The general idea behind AlphaZero (as I understand it) is to do a "smart" tree search that balances previous knowledge (the policy), curiosity in unexplored branches, and high-value branches. My thought is that this can be improved with AHT (i.e. knowing that by the end of the pulse sequence, the pulse sequence must be cyclic (the overall frame transformation must be identity) and there must be equal times spent on each axis). This will provide a hard constraint that will (hopefully) speed up search.

In [None]:
import qutip as qt
import numpy as np
from scipy.spatial.transform import Rotation
import matplotlib.pyplot as plt

## Define the spin system

In [None]:
delay = 1e-2  # time is relative to chemical shift strength
pulse_width = 5e-3
N = 3  # number of spins
ensemble_size = 5

In [None]:
def get_Hsys(dipolar_strength=1e-2):
    chemical_shifts = 2*np.pi * np.random.normal(scale=1, size=(N,))
    Hcs = sum(
        [qt.tensor(
            [qt.identity(2)]*i
            + [chemical_shifts[i] * qt.sigmaz()]
            + [qt.identity(2)]*(N-i-1)
        ) for i in range(N)]
    )
    # dipolar interactions
    dipolar_matrix = 2*np.pi * np.random.normal(scale=dipolar_strength, size=(N, N))
    Hdip = sum([
        dipolar_matrix[i, j] * (
            2 * qt.tensor(
                [qt.identity(2)]*i
                + [qt.sigmaz()]
                + [qt.identity(2)]*(j-i-1)
                + [qt.sigmaz()]
                + [qt.identity(2)]*(N-j-1)
            )
            - qt.tensor(
                [qt.identity(2)]*i
                + [qt.sigmax()]
                + [qt.identity(2)]*(j-i-1)
                + [qt.sigmax()]
                + [qt.identity(2)]*(N-j-1)
            )
            - qt.tensor(
                [qt.identity(2)]*i
                + [qt.sigmay()]
                + [qt.identity(2)]*(j-i-1)
                + [qt.sigmay()]
                + [qt.identity(2)]*(N-j-1)
            )
        )
        for i in range(N) for j in range(i+1, N)
    ])
    return Hcs + Hdip

In [None]:
X = sum(
    [qt.tensor(
        [qt.identity(2)]*i
        + [qt.spin_Jx(1/2)]
        + [qt.identity(2)]*(N-i-1)
    ) for i in range(N)]
)
Y = sum(
    [qt.tensor(
        [qt.identity(2)]*i
        + [qt.spin_Jy(1/2)]
        + [qt.identity(2)]*(N-i-1)
    ) for i in range(N)]
)
Z = sum(
    [qt.tensor(
        [qt.identity(2)]*i
        + [qt.spin_Jz(1/2)]
        + [qt.identity(2)]*(N-i-1)
    ) for i in range(N)]
)

In [None]:
def get_pulses(Hsys, X, Y, Z, pulse_width, delay, rot_error=0):
    rot = np.random.normal(scale=rot_error)
    pulses = [
        qt.propagator(Hsys, pulse_width),
        qt.propagator(X * (np.pi/2) * (1 + rot) / pulse_width + Hsys, pulse_width),
        qt.propagator(-X * (np.pi/2) * (1 + rot) / pulse_width + Hsys, pulse_width),
        qt.propagator(Y * (np.pi/2) * (1 + rot) / pulse_width + Hsys, pulse_width),
        qt.propagator(-Y * (np.pi/2) * (1 + rot) / pulse_width + Hsys, pulse_width),
#         qt.propagator(Z * (np.pi/2) * (1 + rot) / pulse_width + Hsys, pulse_width),
#         qt.propagator(-Z * (np.pi/2) * (1 + rot) / pulse_width + Hsys, pulse_width),
    ]
    delay_propagator = qt.propagator(Hsys, delay)
    pulses = [delay_propagator * i for i in pulses]
    return pulses

In [None]:
pulse_names = [
    'd', 'x', '-x', 'y', '-y', #'z', '-z'
]

In [None]:
def pulse_sequence_string(pulse_sequence):
    """Return a string that correspond to pulse sequence
    """
    pulse_list = ','.join([pulse_names[i] for i in pulse_sequence])
    return pulse_list

In [None]:
def get_pulse_sequence(string):
    """Returns a list of integers for the pulse sequence
    """
    chars = string.split(',')
    pulse_sequence = [pulse_names.index(c) for c in chars]
    return pulse_sequence

In [None]:
def get_propagator(pulse_sequence, pulses):
    propagator = qt.identity(pulses[0].dims[0])
    for p in pulse_sequence:
        propagator = pulses[p] * propagator
    return propagator

In [None]:
Hsys_ensemble = [get_Hsys() for _ in range(ensemble_size)]
pulses_ensemble = [
    get_pulses(H, X, Y, Z, pulse_width, delay, rot_error=0.01) for H in Hsys_ensemble
]

In [None]:
Utarget = qt.identity(Hsys_ensemble[0].dims[0])

## Pulse sequences

In [None]:
ideal6 = [3, 1, 1, 3, 2, 2]
yxx24 = [4, 1, 2, 3, 2, 2, 3, 2, 1, 4, 1, 1, 3, 2, 1, 4, 1, 1, 4, 1, 2, 3, 2, 2]
yxx48 = [
    3, 2, 2, 3, 2, 2, 4, 1, 1, 3, 2, 2, 4, 1, 1, 4, 1, 1, 3, 2, 2, 3, 2, 2,
    4, 1, 1, 3, 2, 2, 4, 1, 1, 4, 1, 1, 3, 2, 2, 4, 1, 1, 3, 2, 2, 4, 1, 1
]

# brute-force search
bf6 = [1, 1, 3, 1, 1, 3]
bf12 = [1, 1, 4, 1, 1, 4, 2, 2, 4, 2, 2, 4]
bfr12 = [1, 4, 4, 1, 4, 4, 1, 3, 3, 1, 3, 3]

## Average Hamiltonian theory

To keep track of the average Hamiltonian (to lowest order), I'm defining a frame matrix and applying rotation matrices to the frame matrix, then determining how $I_z$ transforms during the pulse sequence. The last row in the frame matrix corresponds to the current transformed value of $I_z$.

In [None]:
rotations = [
    np.eye(3),
    np.round(Rotation.from_euler('x', 90, degrees=True).as_matrix()),
    np.round(Rotation.from_euler('x', -90, degrees=True).as_matrix()),
    np.round(Rotation.from_euler('y', 90, degrees=True).as_matrix()),
    np.round(Rotation.from_euler('y', -90, degrees=True).as_matrix()),
#     np.round(Rotation.from_euler('z', 90, degrees=True).as_matrix()),
#     np.round(Rotation.from_euler('z', -90, degrees=True).as_matrix()),
]

In [None]:
def get_rotation(pulse_sequence):
    frame = np.eye(3)
    for p in pulse_sequence:
        frame = rotations[p] @ frame
    return frame

In [None]:
def is_cyclic(pulse_sequence):
    frame = get_rotation(pulse_sequence)
    return (frame == np.eye(3)).all()

In [None]:
def count_axes(pulse_sequence):
    axes_counts = [0] * 6
    frame = np.eye(3)
    for p in pulse_sequence:
        frame = rotations[p] @ frame
        axis = np.where(frame[-1,:])[0][0]
        is_negative = np.sum(frame[-1,:]) < 0
        axes_counts[axis + 3*is_negative] += 1
    return axes_counts

In [None]:
count_axes(get_pulse_sequence('d,x,-y,d,y,-x'))  # WHH-4

In [None]:
count_axes(yxx48)

In [None]:
def is_valid_dd(subsequence, sequence_length):
    """Checks if the pulse subsequence allows for dynamical decoupling of
        dipolar interactions (i.e. equal time spent on each axis)
    """
    axes_counts = count_axes(subsequence)
    (x, y, z) = [axes_counts[i] + axes_counts[i+3] for i in range(3)]
    # time on each axis isn't more than is allowed for dd
    return (np.array([x, y, z]) <= sequence_length / 3).all()

In [None]:
def is_valid_time_suspension(subsequence, sequence_length):
    """Checks if the pulse subsequence allows for dynamical decoupling of
        all interactions (i.e. equal time spent on each ± axis)
    """
    axes_counts = count_axes(subsequence)
    # time on each axis isn't more than is allowed for dd
    return (np.array(axes_counts) <= sequence_length / 6).all()

In [None]:
def get_valid_time_suspension_pulses(subsequence, pulse_names, sequence_length):
    valid_pulses = []
    for p in range(len(pulse_names)):
        if is_valid_time_suspension(subsequence + [p], sequence_length):
            valid_pulses.append(p)
    return valid_pulses

In [None]:
get_valid_time_suspension_pulses([0,1,1,], pulse_names, 6)

## Tree search

Define nodes that can be used for tree search, with additional constraints that the lowest-order average Hamiltonian matches the desired Hamiltonian.

(deleted code that implemented tree search with constraints, see GitHub repo commits on 12/8 for code)

For 12-pulse sequences, calculated 16 branches at depth 4 in a minute, so about 1 every 4 seconds. At depth 4 there are $5^4 = 625$ branches, so that'll take $4 * 625 = 41$ hours to fully run. Alternatively, you can generate random pulse sequences until there's one that has the proper lowest-order average and cyclic property.

## Smarter search with MCTS

Following the [supplementary materials description under "Search"](https://science.sciencemag.org/content/sci/suppl/2018/12/05/362.6419.1140.DC1/aar6404-Silver-SM.pdf) to do rollouts and backpropagate information.

In [None]:
class Config(object):
    """All the config information for AlphaZero
    """
    
    def __init__(self):
        # self-"play"
        self.num_actors = 1
        self.num_sampling_moves = 30
        self.max_moves = 48
        # simulations for MCTS
        self.num_simulations = 100
        # root prior exploration noise
        self.root_dirichlet_alpha = 0.3
        self.root_exploration_fraction = 0.25
        # UCB formulat
        self.pb_c_base = 1e2
        self.pb_c_init = 1.25
        # training
        self.training_steps = int(700e3)
        self.checkpoint_interval = int(1e3)
        self.window_size = int(1e6)
        self.batch_size = 4096
        # TODO also weight_decay (1e-4), momentum (.9), learning rate schedule

In [None]:
class Node(object):
    """A node of the pulse sequence tree. Each node has a particular
    sequence of pulses applied so far.
    """
    
    def __init__(
            self,
            prior,
            sequence,
            propagators,
            ):
        """Create a node at a given point in the pulse sequence.
        
        Args:
            prior: Prior probability of selecting node.
            propagators: List of Qobj propagators for ensemble.
            sequence: Sequence of integers representing pulse sequence.
            frame: The current frame representation as a 3x3 rotation matrix.
        """
        self.prior = prior
        self.sequence = sequence
        # list of Qobj propagators for each element of ensemble
        self.propagators = propagators
        self.depth = len(self.sequence)
        self.children = {}
        self.max_value = -1  # maximum value it's seen at any point
        self.visit_count = 0
        self.total_value = 0
    
    def value(self):
        if self.visit_count > 0:
            return self.total_value / self.visit_count
        else:
            return 0
    
    def reward(self, Utarget):
        fidelities = [np.clip(
            qt.metrics.average_gate_fidelity(U, Utarget),
            0, 1) for U in self.propagators
        ]
        fidelity = np.nanmean(fidelities)
        reward = -1.0 * np.log10(1 - fidelity + 1e-300)
        return reward
    
    def sequence_string(self):
        return ','.join([str(a) for a in self.sequence])
    
    def has_children(self):
        return len(self.children) > 0

In [None]:
def run_mcts(config, propagators, sequence, sequence_length, Utarget, network=None):
    """Perform rollouts of pulse sequence and
    backpropagate values through nodes, then select
    action based on visit counts of child nodes.
    
    When looking at AlphaZero code, the game turns into
    the pulse sequence information (sequence, propagators)
    
    Args:
        propagators: List of Qobj propagators at root.
        sequence: List of ints, represents pulse sequence.
        sequence_length: Maximum length of pulse sequence.
    """
    root = Node(0, sequence, propagators)
    evaluate(root, sequence_length, Utarget)
    add_exploration_noise(config, root)
    
    for _ in range(config.num_simulations):
        node = root
        search_path = [node]
        
        while node.has_children():
            pulse, node = select_child(config, node)
            search_path.append(node)
            evaluate(node, sequence_length, Utarget)  # makes children nodes
            # TODO remove ^ when I implement NN, should only
            # explore nodes available
        
        value = evaluate(node, sequence_length, Utarget)
        backpropagate(search_path, value)
    
    return select_action(config, root), root

In [None]:
def evaluate(node, sequence_length, Utarget, network=None):
    """Calculate value and policy predictions from
    the network, add children to node, and return value.
    """
    if len(node.sequence) == sequence_length:
        value = node.reward(Utarget)
    else:
        value = 0  # replace with NN prediction
    valid_pulses = get_valid_time_suspension_pulses(
        node.sequence, pulse_names, sequence_length)
    policy = np.ones((len(valid_pulses),)) / len(valid_pulses)
    # TODO replace ^ with NN prediction
    for i, p in enumerate(valid_pulses):
        if p not in node.children:
            node.children[p] = Node(
                policy[i],
                node.sequence + [p],
                [pulses_ensemble[s][p] * U
                 for s, U in enumerate(node.propagators)]
            )
    return value

In [None]:
def add_exploration_noise(config, node):
    pulses = list(node.children.keys())
    noise = np.random.gamma(config.root_dirichlet_alpha, 1, len(pulses))
    frac = config.root_exploration_fraction
    for p, n in zip(pulses, noise):
        node.children[p].prior = node.children[p].prior * (1 - frac) + n * frac

In [None]:
def select_child(config, node):
    """
    """
    _, pulse, child = max(
        (ucb_score(config, node, node.children[pulse]),
         pulse, node.children[pulse])
        for pulse in node.children
    )
    return pulse, child

In [None]:
def ucb_score(config, parent, child):
    pb_c = np.log10((parent.visit_count + config.pb_c_base + 1)
                    / config.pb_c_base) + config.pb_c_init
    pb_c *= np.sqrt(parent.visit_count) / (child.visit_count + 1)
    prior_score = pb_c * child.prior
    value_score = child.value()
    return prior_score + value_score

In [None]:
def backpropagate(search_path, value):
    """Propagate value to each node in search path,
    and increment visit counts by 1.
    """
    for node in search_path:
        node.total_value += value
        if value > node.max_value:
            node.max_value = value
        node.visit_count += 1

In [None]:
def select_action(config, root):
    visit_counts = [
        root.children[p].visit_count
        for p in root.children
    ]
    probabilities = np.array(visit_counts) / np.sum(visit_counts)
    pulses = list(root.children.keys())
    return np.random.choice(pulses, p=probabilities)

In [None]:
config = Config()

In [None]:
config.num_simulations = 1000

In [None]:
pulse, root = run_mcts(config, [Utarget] * 5, [], 6, Utarget)

In [None]:
max(node.children.values(), key=lambda x: x.max_value)

In [None]:
max([(1,'a'), (2, 'b'), (3, 'c')])

In [None]:
node = root
while node.has_children():
    node = max(node.children.values(), key=lambda x: x.max_value)

In [None]:
node.sequence

In [None]:
node.max_value

In [None]:
[root.children[p].visit_count for p in root.children]

In [None]:
root.children[3].value()

In [None]:
def make_sequence(config, sequence_length, network=None):
    """Start with no pulses, do MCTS until a sequence of length
    sequence_length is made.
    """
    sequence = []
    propagators = [Utarget] * 5
    search_statistics = []
    while len(sequence) < sequence_length:
        pulse, root = run_mcts(config, propagators, sequence, sequence_length, Utarget)
        print(f'applying pulse {pulse}')
        sequence.append(pulse)
        propagators = root.children[pulse].propagators
        search_statistics.append(
            (root.sequence,
             [(p, root.children[p].visit_count) for p in root.children])
        )
    return sequence, search_statistics

In [None]:
config.num_simulations = 1000

In [None]:
sequence, search_statistics = make_sequence(config, 24)

In [None]:
sequence

## Old code (eventually delete)

In [None]:
root = Node(1, [Utarget] * 5, [], np.eye(3))

In [None]:
N = {}
W = {}
Q = {}
reward_dict = {}

In [None]:
root = MCTSNode([Utarget] * 5, N, W, Q)

In [None]:
for _ in range(1000):
    output = root.rollout(Utarget, pulse_names, pulses_ensemble, None, max_depth=12)
    if _ % 100 == 0:
        print(output)

In [None]:
sequence = []
for i in range(12):
    string = ','.join([str(p) for p in sequence])
    p = max(Q[string].keys(), key=lambda x: Q[string][x])
    sequence.append(p)
print(sequence)

In [None]:
count_axes(sequence)