We are going to simulate the cutting of an RXX gate classically

To define the qubits we need 2x2 density matrices

In [1]:
import numpy as np
import scipy

In [2]:
#Basic definitions
ket0 = np.array([[1],[0]], dtype = complex)
ket1 = np.array([[0],[1]], dtype = complex)
ketplus = (1/np.sqrt(2))*(ket0+ket1)
ketminus = (1/np.sqrt(2))*(ket0+ket1)

X = np.array([[0,1],[1,0]])
S = np.array([[1,0],[0,1j]])
Z = np.array([[1,0],[0,-1]])
H = (1/np.sqrt(2))*np.array([[1,1],[1,-1]])

In [6]:
a00,a01,a10,a11 = 1,0,0,0
b00,b01,b10,b11 = 0,0,0,1
theta  = 0.3*np.pi
rhoA = a00 * ket0 @ ket0.conj().T + a01 * ket0 @ ket1.conj().T + a10 * ket1 @ ket0.conj().T + a11 * ket1 @ ket1.conj().T
rhoB = b00 * ket0 @ ket0.conj().T + b01 * ket0 @ ket1.conj().T + b10 * ket1 @ ket0.conj().T + b11 * ket1 @ ket1.conj().T

In [7]:
def rxx_channel(rho1, rho2, theta):
     
    identity = np.kron(rho1,rho2)
    Xchannel = np.kron(X @ rho1 @ X, X @ rho2 @ X)
    p1p = ketplus.conj().T @ rho1 @ ketplus # Projection of qubit 1 on state +
    m1m = ketminus.conj().T @ rho1 @ ketminus
    p2p = ketplus.conj().T @ rho2 @ ketplus # Projection of qubit 1 on state +
    m2m = ketminus.conj().T @ rho2 @ ketminus
    S1Sd = S @ rho1 @ S.conj().T
    S2Sd = S @ rho2 @ S.conj().T 

    channel = np.cos(theta/2)**2 * identity + np.sin(theta/2)**2 * Xchannel
    channel = channel + (1/2)*np.sin(theta)*(p1p*np.kron(ketplus @ ketplus.conj().T, S2Sd.conj().T) + p2p*np.kron(S1Sd.conj().T, ketplus @ ketplus.conj().T))
    channel = channel + (1/2)*np.sin(theta)*(-p1p*np.kron(ketplus @ ketplus.conj().T, S2Sd) - p2p*np.kron(S1Sd, ketplus @ ketplus.conj().T))
    channel = channel + (1/2)*np.sin(theta)*(-m1m*np.kron(ketminus @ ketminus.conj().T, S2Sd.conj().T) - m2m*np.kron(S1Sd.conj().T, ketminus @ ketminus.conj().T))
    channel = channel + (1/2)*np.sin(theta)*(m1m*np.kron(ketminus @ ketminus.conj().T, S2Sd) + m2m*np.kron(S1Sd, ketminus @ ketminus.conj().T))
    return channel

def ZZ_measurement(rho):
    ZZ = np.kron(Z,Z)
    expected_val = np.trace(ZZ @ rho)
    return expected_val

In [8]:
rho_channel = rxx_channel(rhoA,rhoB,theta)
ZZ_meas = ZZ_measurement(rho_channel)
print(np.real(ZZ_meas))

-1.0
