In [2]:
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import Image
from math import *
import qutip as qt
import qutip.qip
# set a parameter to see animations in line
from matplotlib import rc
rc('animation', html='jshtml')

# static image plots
%matplotlib inline
# interactive 3D plots
# %matplotlib widget

In [3]:
def photon_count_distribution(state: qt.Qobj):
    """
    Returns the photon number distribution P(n) for any state.
    """
    if state.isket:
        # For Kets: |Amplitude|^2
        # We flatten it to get a 1D array
        return np.abs(state.full().flatten())**2
    else:
        # For Density Matrices: Extract the diagonal
        # We take .real to discard any +0j imaginary artifacts
        return np.real(state.diag())

In [4]:
def apply_cat_state_encoding(input_states: qt.Qobj, qubit_position: int, cv_position: int, vertical_displacement=2.5, N=20) -> qt.Qobj:
    # 1. Prepare CV states
    vacuum = qt.basis(N, 0)
    alpha_coeff = (vertical_displacement / np.sqrt(2)) * 1j
    
    pos_disp = qt.displace(N, alpha_coeff)
    neg_disp = qt.displace(N, -alpha_coeff)
    
    logical_zero = (pos_disp * vacuum + neg_disp * vacuum).unit()
    logical_one  = (pos_disp * vacuum - neg_disp * vacuum).unit()
    
    # Define the mapping operators for the CV mode
    map_zero = logical_zero * vacuum.dag()
    map_one  = logical_one * vacuum.dag()

    # 2. Build the operator list for the tensor product
    dims = input_states.dims[0] 
    num_subsystems = len(dims)

    def build_gate(qubit_state_index: int) -> qt.Qobj:
        op_list = [qt.qeye(dims[i]) for i in range(num_subsystems)]
        
        if qubit_state_index == 0:
            # Map: |0>_q |vac>_cv  ->  |0>_q |cat+>_cv
            op_list[qubit_position] = qt.basis(2, 0).proj()
            op_list[cv_position] = map_zero
        else:
            # Map: |1>_q |vac>_cv  ->  |0>_q |cat->_cv
            # We use |0><1| to flip the qubit from 1 to 0 during the transfer
            op_list[qubit_position] = qt.basis(2, 0) * qt.basis(2, 1).dag()
            op_list[cv_position] = map_one
        
        return qt.tensor(op_list)

    # 3. Combine into the full encoding operator
    U_encode = build_gate(0) + build_gate(1)

    # return U_encode * input_states
    if input_states.isket:
        return U_encode * input_states
    else:
        return U_encode * input_states * U_encode.dag()

In [5]:
def beamsplitter_general(input_state: qt.Qobj, idx1: int, idx2: int, transmissivity: float) -> qt.Qobj:
    # 1. Get the global dimensions of the system
    # dims will be something like [2, 2, 20, 20] (Qubit0, Qubit1, CV0, CV1)
    dims = input_state.dims[0]
    num_subsystems = len(dims)
    
    # 2. Extract cutoff dimensions for the two target CV modes
    N1 = dims[idx1]
    N2 = dims[idx2]
    
    # 3. Create the annihilation operators in the full Hilbert space
    # We start with a list of identities for every subsystem
    op_list1 = [qt.qeye(d) for d in dims]
    op_list2 = [qt.qeye(d) for d in dims]
    
    # Replace the identities at the target indices with destroy operators
    op_list1[idx1] = qt.destroy(N1)
    op_list2[idx2] = qt.destroy(N2)
    
    # Tensor them together to get operators acting on the full system
    a1 = qt.tensor(op_list1)
    a2 = qt.tensor(op_list2)

    # 4. Calculate mixing angle
    theta = np.arcsin(np.sqrt(transmissivity))

    # 5. Build the Unitary for the full space
    # U = exp( theta * (a1^dag a2 - a1 a2^dag) )
    generator = theta * (a1.dag() * a2 - a1 * a2.dag())
    U_bs = generator.expm()

    # 6. Apply and return
    if input_state.isket:
        return U_bs * input_state
    else:
        return U_bs * input_state * U_bs.dag()

In [6]:
def apply_ideal_cat_state_decoding(input_states: qt.Qobj, qubit_position: int, cv_position: int, vertical_displacement=2.5, N=20) -> qt.Qobj:
    dims = input_states.dims[0]
    
    # 1. Define states
    vacuum = qt.basis(N, 0)
    alpha_coeff = (vertical_displacement / np.sqrt(2)) * 1j
    
    # Define logical states for parity detection
    pos_disp = qt.displace(N, alpha_coeff)
    neg_disp = qt.displace(N, -alpha_coeff)
    logical_zero_cv = (pos_disp * vacuum + neg_disp * vacuum).unit()
    logical_one_cv  = (pos_disp * vacuum - neg_disp * vacuum).unit()

    # 2. Step A: Parity-Controlled Qubit Flip (The "Decoding")
    # This maps: |cat+>|0> -> |cat+>|0>  AND  |cat->|0> -> |cat->|1>
    def build_flip():
        # Project CV onto Parity, apply corresponding gate to Qubit
        op_plus = [qt.qeye(d) for d in dims]
        op_plus[cv_position] = logical_zero_cv.proj()
        # Qubit stays same (Identity)
        
        op_minus = [qt.qeye(d) for d in dims]
        op_minus[cv_position] = logical_one_cv.proj()
        op_minus[qubit_position] = qt.sigmax() # Flip if odd parity
        
        return qt.tensor(op_plus) + qt.tensor(op_minus)

    # 3. Step B: Qubit-Controlled Un-displacement (The "Cleaning")
    # This returns the CV mode to vacuum: |cat+>|0> -> |vac>|0> AND |cat->|1> -> |vac>|1>
    # Note: This is essentially the inverse of your encoding function.
    def build_clean():
        # This part ensures the operation is unitary by resetting the CV mode
        op_zero = [qt.qeye(d) for d in dims]
        op_zero[qubit_position] = qt.basis(2, 0).proj()
        op_zero[cv_position] = vacuum * logical_zero_cv.dag()
        
        op_one = [qt.qeye(d) for d in dims]
        op_one[qubit_position] = qt.basis(2, 1).proj()
        op_one[cv_position] = vacuum * logical_one_cv.dag()
        
        return qt.tensor(op_zero) + qt.tensor(op_one)

    # Combined Unitary: First flip the qubit, then clean the CV mode
    U_total = build_clean() * build_flip()
    
    # Inside apply_ideal_cat_state_decoding:
    if input_states.isket:
        return U_total * input_states
    else:
        return U_total * input_states * U_total.dag()

In [7]:
def apply_hadamard(state: qt.Qobj, target_idx: int) -> qt.Qobj:
    dims = state.dims[0]
    op_list = [qt.qeye(d) for d in dims]
    
    # Place Hadamard at the target index
    op_list[target_idx] = qt.gates.snot()
    
    H_total = qt.tensor(op_list)
    if state.isket:
        return H_total * state
    else:
        return H_total * state * H_total.dag()

def apply_cnot(state: qt.Qobj, control_idx: int, target_idx: int) -> qt.Qobj:
    dims = state.dims[0]
    
    # Part 1: Control is in |0> (Identity on target)
    op_list_0 = [qt.qeye(d) for d in dims]
    op_list_0[control_idx] = qt.basis(2, 0).proj()
    # Target stays Identity, so no change needed to op_list_0
    
    # Part 2: Control is in |1> (X on target)
    op_list_1 = [qt.qeye(d) for d in dims]
    op_list_1[control_idx] = qt.basis(2, 1).proj()
    op_list_1[target_idx] = qt.sigmax()
    
    CNOT_total = qt.tensor(op_list_0) + qt.tensor(op_list_1)
    if state.isket:
        return CNOT_total * state
    else:
        return CNOT_total * state * CNOT_total.dag()
    
def apply_swap(state: qt.Qobj, idx1: int, idx2: int) -> qt.Qobj:
    # A SWAP is 3 CNOTs
    state = apply_cnot(state, idx1, idx2)
    state = apply_cnot(state, idx2, idx1)
    state = apply_cnot(state, idx1, idx2)
    return state

In [8]:
ideal_phi_plus = (qt.tensor(qt.basis(2,0), qt.basis(2,0)) + qt.tensor(qt.basis(2,1), qt.basis(2,1))).unit()
ideal_rho = qt.ket2dm(ideal_phi_plus)

N = 16
vertical_displacement = 2
loss_prob = 0.2

all_states = qt.tensor(qt.basis(2, 0), qt.basis(2, 0), qt.basis(N, 0), qt.basis(N, 0), qt.basis(2, 0))
print(all_states.dims)
all_states = apply_hadamard(all_states, 0)
all_states = apply_cnot(all_states, 0, 1)
all_states = apply_cat_state_encoding(all_states, 1, 2, vertical_displacement, N)
all_states = all_states.ptrace([0, 2, 3, 4]) 
all_states = beamsplitter_general(all_states, 1, 2, loss_prob)
after_ideal_decoding = apply_ideal_cat_state_decoding(all_states, 3, 1, vertical_displacement, N)
after_ideal_decoding = after_ideal_decoding.ptrace([0, 2, 3])
edge_qubits_ideal = after_ideal_decoding.ptrace([0, 2])

print(edge_qubits_ideal)
fid = qt.fidelity(edge_qubits_ideal, ideal_rho)
print(f"Fidelity with Phi+ (ideal): {fid:.4f}")


[[2, 2, 16, 16, 2], [1]]
Quantum object: dims=[[2, 2], [2, 2]], shape=(4, 4), type='oper', dtype=Dense, isherm=True
Qobj data =
[[0.36107417 0.         0.         0.35418697]
 [0.         0.12725752 0.13457297 0.        ]
 [0.         0.13457297 0.14230896 0.        ]
 [0.35418697 0.         0.         0.34743114]]
Fidelity with Phi+ (ideal): 0.8417


In [9]:
def initial_channel_qubit_states() -> list[qt.Qobj]:
    return [qt.basis(2, 0), qt.basis(N, 0), qt.basis(N, 0), qt.basis(2, 0)]


In [10]:
ideal_phi_plus = (qt.tensor(qt.basis(2,0), qt.basis(2,0)) + qt.tensor(qt.basis(2,1), qt.basis(2,1))).unit()
ideal_rho = qt.ket2dm(ideal_phi_plus)

N = 16
vertical_displacement = 2
loss_prob = 1E-7


NUM_CHANNEL_QUBITS = 1
num_tx_qubits = 2
num_rx_qubits = 1

initial_channel_states = initial_channel_qubit_states()
states_per_channel_qubit = len(initial_channel_states)

all_states = qt.tensor([qt.basis(2,0)]*num_tx_qubits + initial_channel_states*NUM_CHANNEL_QUBITS + [qt.basis(2,0)]*num_rx_qubits)

def tx_qubit_positions() -> list[int]:
    return list(range(num_tx_qubits))
    
def channel_states_positions():
    starting_offset = len(tx_qubit_positions())
    return [starting_offset + num_channel_qubit*states_per_channel_qubit + i for i in range(states_per_channel_qubit) for num_channel_qubit in range(NUM_CHANNEL_QUBITS)]

def rx_qubit_positions() -> list[int]:
    # The RX block starts where the Channel block ends
    start_idx = num_tx_qubits + (NUM_CHANNEL_QUBITS * states_per_channel_qubit)
    return list(range(start_idx, start_idx + num_rx_qubits))

all_states = apply_hadamard(all_states, tx_qubit_positions()[0])
all_states = apply_cnot(all_states, tx_qubit_positions()[0], tx_qubit_positions()[1])

all_states = apply_swap(all_states, tx_qubit_positions()[1], channel_states_positions()[0])

all_states = apply_cat_state_encoding(all_states, channel_states_positions()[0], channel_states_positions()[1], vertical_displacement, N)

all_states = beamsplitter_general(all_states, channel_states_positions()[1], channel_states_positions()[2], loss_prob)
all_states = apply_ideal_cat_state_decoding(all_states, channel_states_positions()[3], channel_states_positions()[1], vertical_displacement, N)

all_states = apply_swap(all_states, channel_states_positions()[-1], rx_qubit_positions()[0])

edge_qubits = all_states.ptrace([tx_qubit_positions()[0], rx_qubit_positions()[0]])

print(edge_qubits)
fid = qt.fidelity(edge_qubits, ideal_rho)
print(f"Fidelity with Phi+ (ideal): {fid:.4f}")

Quantum object: dims=[[2, 2], [2, 2]], shape=(4, 4), type='oper', dtype=Dense, isherm=True
Qobj data =
[[4.99999904e-01 0.00000000e+00 0.00000000e+00 4.99999900e-01]
 [0.00000000e+00 9.64027353e-08 9.99999794e-08 0.00000000e+00]
 [0.00000000e+00 9.99999794e-08 1.03731454e-07 0.00000000e+00]
 [4.99999900e-01 0.00000000e+00 0.00000000e+00 4.99999896e-01]]
Fidelity with Phi+ (ideal): 1.0000


In [None]:
ideal_phi_plus = (qt.tensor(qt.basis(2,0), qt.basis(2,0)) + qt.tensor(qt.basis(2,1), qt.basis(2,1))).unit()
ideal_rho = qt.ket2dm(ideal_phi_plus)

N = 4
vertical_displacement = 1
loss_prob = 1E-2

NUM_CHANNEL_QUBITS = 1
num_tx_qubits = 2
num_rx_qubits = 1

initial_channel_states = [qt.basis(2, 0), qt.basis(N, 0), qt.basis(N, 0), qt.basis(2, 0)]
states_per_channel_qubit = len(initial_channel_states)

all_states = qt.ket2dm(qt.tensor([qt.basis(2,0)]*num_tx_qubits + initial_channel_states*NUM_CHANNEL_QUBITS + [qt.basis(2,0)]*num_rx_qubits))

def tx_qubit_positions() -> list[int]:
    return list(range(num_tx_qubits))
    
def channel_states_positions():
    starting_offset = len(tx_qubit_positions())

    pos_list = []
    for num_channel_qubit in range(NUM_CHANNEL_QUBITS):
        for i in range(states_per_channel_qubit):
            pos_list+=[starting_offset + num_channel_qubit*states_per_channel_qubit + i]
    return pos_list

def rx_qubit_positions() -> list[int]:
    # The RX block starts where the Channel block ends
    start_idx = num_tx_qubits + (NUM_CHANNEL_QUBITS * states_per_channel_qubit)
    return list(range(start_idx, start_idx + num_rx_qubits))

def ptrace_away_positions(states: qt.Qobj, positions: int | list[int]) -> qt.Qobj:
    if type(positions) is int:
        positions = [positions]

    keep_indices = [i for i in range(len(all_states.dims[0])) if i not in positions]
    return states.ptrace(keep_indices)

all_states = apply_hadamard(all_states, tx_qubit_positions()[0])
all_states = apply_cnot(all_states, tx_qubit_positions()[0], tx_qubit_positions()[1])

all_states = apply_swap(all_states, tx_qubit_positions()[1], channel_states_positions()[0])

for channel_qubit_index in range(NUM_CHANNEL_QUBITS):
    all_states = apply_cat_state_encoding(all_states, channel_states_positions()[channel_qubit_index*states_per_channel_qubit+0], channel_states_positions()[channel_qubit_index*states_per_channel_qubit+1], vertical_displacement, N)


for channel_qubit_index in range(NUM_CHANNEL_QUBITS):
    print(f"beamsplitter {channel_qubit_index}")
    all_states = beamsplitter_general(all_states, channel_states_positions()[channel_qubit_index*states_per_channel_qubit+1], channel_states_positions()[channel_qubit_index*states_per_channel_qubit+2], loss_prob)

for channel_qubit_index in range(NUM_CHANNEL_QUBITS):
    all_states = apply_ideal_cat_state_decoding(all_states, channel_states_positions()[channel_qubit_index*states_per_channel_qubit+3], channel_states_positions()[channel_qubit_index*states_per_channel_qubit+1], vertical_displacement, N)

print(all_states.dims)
all_states = apply_swap(all_states, channel_states_positions()[states_per_channel_qubit-1], rx_qubit_positions()[0])
edge_qubits = all_states.ptrace([tx_qubit_positions()[0], rx_qubit_positions()[0]])

print(edge_qubits)
fid = qt.fidelity(edge_qubits, ideal_rho)
print(f"Fidelity with Phi+ (ideal): {fid:.4f}")

beamsplitter 0
[[2, 2, 2, 4, 4, 2, 2], [2, 2, 2, 4, 4, 2, 2]]
Quantum object: dims=[[2, 2], [2, 2]], shape=(4, 4), type='oper', dtype=Dense, isherm=True
Qobj data =
[[0.4989119  0.         0.         0.49671574]
 [0.         0.00103041 0.00237249 0.        ]
 [0.         0.00237249 0.00546258 0.        ]
 [0.49671574 0.         0.         0.49452943]]
Fidelity with Phi+ (ideal): 0.9967


In [4]:
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import Image
from math import *
import qutip as qt
import qutip.qip
# set a parameter to see animations in line
from matplotlib import rc
rc('animation', html='jshtml')

# static image plots
%matplotlib inline
# interactive 3D plots
# %matplotlib widget



def apply_cat_state_encoding(input_states: qt.Qobj, qubit_position: int, cv_position: int, vertical_displacement=2.5, N=20) -> qt.Qobj:
    # 1. Prepare CV states
    vacuum = qt.basis(N, 0)
    alpha_coeff = (vertical_displacement / np.sqrt(2)) * 1j
    
    pos_disp = qt.displace(N, alpha_coeff)
    neg_disp = qt.displace(N, -alpha_coeff)
    
    logical_zero = (pos_disp * vacuum + neg_disp * vacuum).unit()
    logical_one  = (pos_disp * vacuum - neg_disp * vacuum).unit()
    
    # Define the mapping operators for the CV mode
    map_zero = logical_zero * vacuum.dag()
    map_one  = logical_one * vacuum.dag()

    # 2. Build the operator list for the tensor product
    dims = input_states.dims[0] 
    num_subsystems = len(dims)

    def build_gate(qubit_state_index: int) -> qt.Qobj:
        op_list = [qt.qeye(dims[i]) for i in range(num_subsystems)]
        
        if qubit_state_index == 0:
            # Map: |0>_q |vac>_cv  ->  |0>_q |cat+>_cv
            op_list[qubit_position] = qt.basis(2, 0).proj()
            op_list[cv_position] = map_zero
        else:
            # Map: |1>_q |vac>_cv  ->  |0>_q |cat->_cv
            # We use |0><1| to flip the qubit from 1 to 0 during the transfer
            op_list[qubit_position] = qt.basis(2, 0) * qt.basis(2, 1).dag()
            op_list[cv_position] = map_one
        
        return qt.tensor(op_list)

    # 3. Combine into the full encoding operator
    U_encode = build_gate(0) + build_gate(1)

    # return U_encode * input_states
    if input_states.isket:
        return U_encode * input_states
    else:
        return U_encode * input_states * U_encode.dag()

def apply_ideal_cat_state_decoding(input_states: qt.Qobj, qubit_position: int, cv_position: int, vertical_displacement=2.5, N=20) -> qt.Qobj:
    dims = input_states.dims[0]
    
    # 1. Define states
    vacuum = qt.basis(N, 0)
    alpha_coeff = (vertical_displacement / np.sqrt(2)) * 1j
    
    # Define logical states for parity detection
    pos_disp = qt.displace(N, alpha_coeff)
    neg_disp = qt.displace(N, -alpha_coeff)
    logical_zero_cv = (pos_disp * vacuum + neg_disp * vacuum).unit()
    logical_one_cv  = (pos_disp * vacuum - neg_disp * vacuum).unit()

    # 2. Step A: Parity-Controlled Qubit Flip (The "Decoding")
    # This maps: |cat+>|0> -> |cat+>|0>  AND  |cat->|0> -> |cat->|1>
    def build_flip():
        # Project CV onto Parity, apply corresponding gate to Qubit
        op_plus = [qt.qeye(d) for d in dims]
        op_plus[cv_position] = logical_zero_cv.proj()
        # Qubit stays same (Identity)
        
        op_minus = [qt.qeye(d) for d in dims]
        op_minus[cv_position] = logical_one_cv.proj()
        op_minus[qubit_position] = qt.sigmax() # Flip if odd parity
        
        return qt.tensor(op_plus) + qt.tensor(op_minus)

    # 3. Step B: Qubit-Controlled Un-displacement (The "Cleaning")
    # This returns the CV mode to vacuum: |cat+>|0> -> |vac>|0> AND |cat->|1> -> |vac>|1>
    # Note: This is essentially the inverse of your encoding function.
    def build_clean():
        # This part ensures the operation is unitary by resetting the CV mode
        op_zero = [qt.qeye(d) for d in dims]
        op_zero[qubit_position] = qt.basis(2, 0).proj()
        op_zero[cv_position] = vacuum * logical_zero_cv.dag()
        
        op_one = [qt.qeye(d) for d in dims]
        op_one[qubit_position] = qt.basis(2, 1).proj()
        op_one[cv_position] = vacuum * logical_one_cv.dag()
        
        return qt.tensor(op_zero) + qt.tensor(op_one)

    # Combined Unitary: First flip the qubit, then clean the CV mode
    U_total = build_clean() * build_flip()
    
    # Inside apply_ideal_cat_state_decoding:
    if input_states.isket:
        return U_total * input_states
    else:
        return U_total * input_states * U_total.dag()

def beamsplitter_general(input_state: qt.Qobj, idx1: int, idx2: int, transmissivity: float) -> qt.Qobj:
    # 1. Get the global dimensions of the system
    # dims will be something like [2, 2, 20, 20] (Qubit0, Qubit1, CV0, CV1)
    dims = input_state.dims[0]
    num_subsystems = len(dims)
    
    # 2. Extract cutoff dimensions for the two target CV modes
    N1 = dims[idx1]
    N2 = dims[idx2]
    
    # 3. Create the annihilation operators in the full Hilbert space
    # We start with a list of identities for every subsystem
    op_list1 = [qt.qeye(d) for d in dims]
    op_list2 = [qt.qeye(d) for d in dims]
    
    # Replace the identities at the target indices with destroy operators
    op_list1[idx1] = qt.destroy(N1)
    op_list2[idx2] = qt.destroy(N2)
    
    # Tensor them together to get operators acting on the full system
    a1 = qt.tensor(op_list1)
    a2 = qt.tensor(op_list2)

    # 4. Calculate mixing angle
    theta = np.arcsin(np.sqrt(transmissivity))

    # 5. Build the Unitary for the full space
    # U = exp( theta * (a1^dag a2 - a1 a2^dag) )
    generator = theta * (a1.dag() * a2 - a1 * a2.dag())
    U_bs = generator.expm()

    # 6. Apply and return
    if input_state.isket:
        return U_bs * input_state
    else:
        return U_bs * input_state * U_bs.dag()

def apply_hadamard(state: qt.Qobj, target_idx: int) -> qt.Qobj:
    dims = state.dims[0]
    op_list = [qt.qeye(d) for d in dims]
    
    # Place Hadamard at the target index
    op_list[target_idx] = qt.gates.snot()
    
    H_total = qt.tensor(op_list)
    if state.isket:
        return H_total * state
    else:
        return H_total * state * H_total.dag()

def apply_cnot(state: qt.Qobj, control_idx: int, target_idx: int) -> qt.Qobj:
    dims = state.dims[0]
    
    # Part 1: Control is in |0> (Identity on target)
    op_list_0 = [qt.qeye(d) for d in dims]
    op_list_0[control_idx] = qt.basis(2, 0).proj()
    # Target stays Identity, so no change needed to op_list_0
    
    # Part 2: Control is in |1> (X on target)
    op_list_1 = [qt.qeye(d) for d in dims]
    op_list_1[control_idx] = qt.basis(2, 1).proj()
    op_list_1[target_idx] = qt.sigmax()
    
    CNOT_total = qt.tensor(op_list_0) + qt.tensor(op_list_1)
    if state.isket:
        return CNOT_total * state
    else:
        return CNOT_total * state * CNOT_total.dag()
    
def apply_swap(state: qt.Qobj, idx1: int, idx2: int) -> qt.Qobj:
    # A SWAP is 3 CNOTs
    state = apply_cnot(state, idx1, idx2)
    state = apply_cnot(state, idx2, idx1)
    state = apply_cnot(state, idx1, idx2)
    return state





ideal_phi_plus = (qt.tensor(qt.basis(2,0), qt.basis(2,0)) + qt.tensor(qt.basis(2,1), qt.basis(2,1))).unit()
ideal_rho = qt.ket2dm(ideal_phi_plus)

N = 2
vertical_displacement = 1
loss_prob = 1E-2

NUM_CHANNEL_QUBITS = 2
num_tx_qubits = 2
num_rx_qubits = 1

initial_channel_states = [qt.basis(2, 0), qt.basis(N, 0), qt.basis(N, 0), qt.basis(2, 0)]
states_per_channel_qubit = len(initial_channel_states)

all_states = qt.ket2dm(qt.tensor([qt.basis(2,0)]*num_tx_qubits + initial_channel_states*NUM_CHANNEL_QUBITS + [qt.basis(2,0)]*num_rx_qubits))

def tx_qubit_positions() -> list[int]:
    return list(range(num_tx_qubits))
    
def channel_states_positions():
    starting_offset = len(tx_qubit_positions())

    pos_list = []
    for num_channel_qubit in range(NUM_CHANNEL_QUBITS):
        for i in range(states_per_channel_qubit):
            pos_list+=[starting_offset + num_channel_qubit*states_per_channel_qubit + i]
    return pos_list

def rx_qubit_positions() -> list[int]:
    # The RX block starts where the Channel block ends
    start_idx = num_tx_qubits + (NUM_CHANNEL_QUBITS * states_per_channel_qubit)
    return list(range(start_idx, start_idx + num_rx_qubits))

def ptrace_away_positions(states: qt.Qobj, positions: int | list[int]) -> qt.Qobj:
    if type(positions) is int:
        positions = [positions]

    keep_indices = [i for i in range(len(all_states.dims[0])) if i not in positions]
    return states.ptrace(keep_indices)

all_states = apply_hadamard(all_states, tx_qubit_positions()[0])
all_states = apply_cnot(all_states, tx_qubit_positions()[0], tx_qubit_positions()[1])

all_states = apply_swap(all_states, tx_qubit_positions()[1], channel_states_positions()[0])

for channel_qubit_index in range(NUM_CHANNEL_QUBITS):
    all_states = apply_cat_state_encoding(all_states, channel_states_positions()[channel_qubit_index*states_per_channel_qubit+0], channel_states_positions()[channel_qubit_index*states_per_channel_qubit+1], vertical_displacement, N)


for channel_qubit_index in range(NUM_CHANNEL_QUBITS):
    all_states = beamsplitter_general(all_states, channel_states_positions()[channel_qubit_index*states_per_channel_qubit+1], channel_states_positions()[channel_qubit_index*states_per_channel_qubit+2], loss_prob)

for channel_qubit_index in range(NUM_CHANNEL_QUBITS):
    all_states = apply_ideal_cat_state_decoding(all_states, channel_states_positions()[channel_qubit_index*states_per_channel_qubit+3], channel_states_positions()[channel_qubit_index*states_per_channel_qubit+1], vertical_displacement, N)

all_states = apply_swap(all_states, channel_states_positions()[states_per_channel_qubit-1], rx_qubit_positions()[0])
edge_qubits = all_states.ptrace([tx_qubit_positions()[0], rx_qubit_positions()[0]])

print(edge_qubits)
fid = qt.fidelity(edge_qubits, ideal_rho)
print(f"Fidelity with Phi+ (ideal): {fid:.4f}")

Quantum object: dims=[[2, 2], [2, 2]], shape=(4, 4), type='oper', dtype=Dense, isherm=True
Qobj data =
[[0.5        0.         0.         0.49749372]
 [0.         0.         0.         0.        ]
 [0.         0.         0.005      0.        ]
 [0.49749372 0.         0.         0.495     ]]
Fidelity with Phi+ (ideal): 0.9975
