# Reproduction of First-principles analysis of cross-resonance gate operation

Malekakhlagh et al. PRA 102, 042605

Note that the drive frequency is set to the undressed target qubit frequency, which results in $\mathcal{O}(J^2)$ differences in the calculation results compared to ideal experimental observations.

In [None]:
import sys
import os
sys.path.append(os.path.dirname(os.getcwd()))
import time
import pickle
import copy
import logging
from collections import defaultdict
from multiprocessing import Pool
from sympy import Add, Expr, I, Mul, Pow, S, Symbol, diff, pi, simplify
from sympy.physics.quantum import Commutator, IdentityOperator, HermitianOperator, TensorProduct
from symqudit.two_transmon_hamiltonian import (TwoTransmonHamiltonian, sort_block_diagonal,
                                               to_dict, from_dict, dict_product)
from symqudit.schrieffer_wolff_expansion import SWExpansion, integrate_exp_term, integrate_expr
from symqudit.common import ketbra, get_expr_at_order, organize_by_denom

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()

In [None]:
Id = IdentityOperator()
t = Symbol('t', real=True)

def expand_opprod(lhs, rhs, blkdiag_only=False) -> dict[tuple, Expr]:
    start = time.time()
    op_prod = dict_product(op_dicts[lhs], op_dicts[rhs],
                           blkdiag_only=blkdiag_only, expand=False)

    if blkdiag_only:
        filtered = {}
        for ket, bra_dict in op_prod.items():
            filtered[ket] = {}
            for bra, coeff in bra_dict.items():
                coeff = coeff.expand()
                if isinstance(coeff, Add):
                    cterms = coeff.args
                else:
                    cterms = [coeff]
                cterms = [cterm for cterm in cterms if diff(cterm, t) is S.Zero]
                if cterms:
                    filtered[ket][bra] = Add(*cterms)

            if not filtered[ket]:
                filtered.pop(ket)

        op_prod = filtered
        
    end = time.time()
    logger.info('dict_product(%s, %s) in %f seconds', lhs, rhs, end - start)

    return op_prod


def calculate_commutator(fore, back):
    result = copy.deepcopy(fore)
    for ket, bra_dict in back.items():
        if ket not in result:
            result[ket] = defaultdict(lambda: S.Zero)
        for bra, coeff in bra_dict.items():
            result[ket][bra] -= coeff

    return result


def expand_expr(expr):
    return expr.expand()


def sort_static(coeff):
    coeff = coeff.expand()
    if isinstance(coeff, Add):
        terms = coeff.args
    else:
        terms = [coeff]

    s_terms = []
    d_terms = []
    
    for term in terms:
        if diff(term, t) is S.Zero:
            s_terms.append(term)
        else:
            d_terms.append(term)

    if s_terms:
        s_terms = Add(*s_terms)
    else:
        s_terms = None
    if d_terms:
        d_terms = Add(*d_terms)
    else:
        d_terms = None

    return s_terms, d_terms

In [None]:
def from_dict_parallel(terms):
    ops = []
    coeffs = []
    
    for ket, bra_dict in terms.items():
        for bra, coeff in bra_dict.items():
            coeffs.append(coeff)
            if ket[1] is None:
                ops.append(TensorProduct(ketbra(ket[0], bra[0]), Id))
            else:
                ops.append(TensorProduct(ketbra(ket[0], bra[0]), ketbra(ket[1], bra[1])))
                
    with Pool() as pool:
        coeffs = pool.map(expand_expr, coeffs)

    return Add(*[c * o for c, o in zip(coeffs, ops)])


def sort_block_diagonal_parallel(all_terms):
    maybe_blkdiag = []
    nonblkdiag = defaultdict(dict)

    for ket, bra_dict in all_terms.items():
        for bra, coeff in bra_dict.items():
            if ket[0] == bra[0]:
                maybe_blkdiag.append((ket, bra, coeff))
            else:
                nonblkdiag[ket][bra] = coeff

    blkdiag = defaultdict(dict)
    if maybe_blkdiag:
        with Pool() as pool:
            sorted_coeffs = pool.map(sort_static, [c for _, _, c in maybe_blkdiag])

        for (s_terms, d_terms), (ket, bra, _) in zip(sorted_coeffs, maybe_blkdiag):
            if s_terms is not None:
                blkdiag[ket][bra] = s_terms
            if d_terms is not None:
                nonblkdiag[ket][bra] = d_terms

    return blkdiag, nonblkdiag


def integrate_expr_parallel(nonblkdiag):
    flattened = [(ket, bra, coeff) for ket, bra_dict in nonblkdiag.items() for bra, coeff in bra_dict.items()]
    with Pool() as pool:
        coeffs = pool.map(expand_expr, [coeff for _, _, coeff in flattened])
    
    coeff_terms = []
    term_bounds = [0]
    for coeff in coeffs:
        if isinstance(coeff, Add):
            coeff_terms += list(coeff.args)
        else:
            coeff_terms.append(coeff)
        term_bounds.append(len(coeff_terms))

    with Pool() as pool:
        int_terms = pool.map(integrate_exp_term, coeff_terms)
    
    integrated = defaultdict(dict)
    for iterm, (ket, bra, _) in enumerate(flattened):
        integrated[ket][bra] = Add(*int_terms[term_bounds[iterm]:term_bounds[iterm + 1]])

    return integrated

In [None]:
cutoff = 4
c_params = (100., 0.1)
t_params = (104., 0.1)

In [None]:
tth = TwoTransmonHamiltonian((c_params[0], t_params[0]), (c_params[1], t_params[1]))

h_dirac = tth.h_dirac(cutoff=cutoff).expand()
h_dirac = h_dirac.subs({tth.drive_freq: tth.qt.symbolic_eigenvalue(1).doit(), tth.drive_phase: pi})
h_dirac = tth.subs_delta(h_dirac).expand()
subs = {}
for transmon in [tth.qc, tth.qt]:
    for level in range(cutoff):
        nu = transmon._transition_amp(level, level + 1)
        subs[nu] = Symbol(nu._latex(None), real=True)
h_dirac = h_dirac.subs(subs)

int_scale = Symbol('lambda', real=True, nonnegative=True)
h_i = HermitianOperator('H_I')

In [None]:
def calc_heff_at_order(order, blkdiag_only=False):
    comm_expr = get_expr_at_order(expansion, int_scale, order) + sw.gdot_n[order]

    start = time.time()
    if order == 1:
        # No commutator for Order λ -> just substitute H_I
        states_expr[order] = comm_expr.subs({h_i: h_dirac})
        all_terms = to_dict(states_expr[order])

    else:
        while True:
            if isinstance(comm_expr, Add):
                terms = comm_expr.args
            else:
                terms = [comm_expr]
            for term in terms:
                comm = term.args_cnc()[1][0]
                if isinstance(comm.args[0], Commutator) or isinstance(comm.args[1], Commutator):
                    break
            else:
                break
            new_expr = comm_expr.subs(commutator_subs)
            if new_expr == comm_expr:
                raise RuntimeError('Cannot reduce commutator %s', comm_expr)
            comm_expr = new_expr

        coeffs = []
        commutators = []
        opprod_terms = []
        for term in terms:
            c, nc = term.args_cnc()
            coeff = Mul(*c)
            comm = nc[0]
            if isinstance(comm, Mul):
                coeff *= comm.args[0]
                comm = comm.args[1]
            coeffs.append(coeff)
            commutators.append(comm)
            opprod_terms.append((comm.args[0], comm.args[1]))
            opprod_terms.append((comm.args[1], comm.args[0]))
        
        with Pool() as pool:
            if blkdiag_only:
                args = [arg + (True,) for arg in opprod_terms]
            else:
                args = opprod_terms
            opprods = pool.starmap(expand_opprod, args)
            
        end = time.time()
        logger.info('expand_opprod for order %d in %f seconds', order, end - start)
        start = end
        
        comm_dicts_source = list(zip(opprods[0:-1:2], opprods[1::2]))
        with Pool() as pool:
            comm_dicts = pool.starmap(calculate_commutator, comm_dicts_source)
            
        end = time.time()
        logger.info('calculate_commutator for order %d in %f seconds', order, end - start)
        start = end
        
        all_terms = defaultdict(lambda: defaultdict(lambda: S.Zero))
        for coeff, comm, comm_dict in zip(coeffs, commutators, comm_dicts):
            if not blkdiag_only:
                placeholder = HermitianOperator(f'C_{len(commutator_subs)}')
                commutator_subs[comm] = placeholder
                op_dicts[placeholder] = comm_dict

            for ket, bra_dict in comm_dict.items():
                for bra, c in bra_dict.items():
                    all_terms[ket][bra] += coeff * c
                
        end = time.time()
        logger.info('all_terms compilation for order %d in %f seconds', order, end - start)
        start = end

        # Expression in terms of TensorProduct(OuterProduct, OuterProduct or Identity)
        states_expr[order] = from_dict_parallel(all_terms)

    if blkdiag_only:
        hieff[order] = states_expr[order]
        return

    blkdiag, nonblkdiag = sort_block_diagonal_parallel(all_terms)
    hieff[order] = from_dict_parallel(blkdiag)
    gdot[order] = from_dict_parallel(nonblkdiag)
    
    end = time.time()
    logger.info('sort_block_diagonal(%d) in %f seconds', order, end - start)
    start = end
    
    integrated = integrate_expr_parallel(nonblkdiag)
    g[order] = from_dict_parallel(integrated)
    
    end = time.time()
    logger.info('integrate_expr(%d) in %f seconds', order, end - start)

    op_dicts.update({sw.g_n[order]: to_dict(g[order]), sw.gdot_n[order]: to_dict(gdot[order])})

In [None]:
op_dicts = {h_i: to_dict(h_dirac)}
states_expr = {}
commutator_subs = {}
g = {}
gdot = {}
hieff = {}

sw = SWExpansion()
expansion = sw.expand(int_scale * h_i, int_scale, 4)

## Order λ
calc_heff_at_order(1)

# Special substitution
expansion = expansion.subs({sw.gdot_n[1]: h_i})

## Order λ^2
calc_heff_at_order(2)

In [None]:
compos = tth.pauli_components(hieff[2], 2, 2)

for idx in [(0, 1), (3, 0), (3, 1), (3, 3)]:
    display(compos[idx])

In [None]:
# Order λ^3
calc_heff_at_order(3)

In [None]:
calc_heff_at_order(4, blkdiag_only=True)

In [None]:
compos = tth.pauli_components(hieff[4], 2, 2)

In [None]:
with open('first_principle_qutrit.pkl', 'wb') as out:
    pickle.dump((hieff, g, gdot), out)