In [38]:
import numpy as np
import matplotlib.pyplot as plt


In [39]:
import numpy as np
from qiskit.quantum_info import Statevector

def maximally_entangled_state(d):
    state_vector = np.zeros(d**2, dtype=complex)
    for i in range(d):
        state_vector[i * d + i] = 1  # |i>|i> components
    state = (state_vector / np.sqrt(d))
    return np.expand_dims(state, axis = 1)

d = 2
state = maximally_entangled_state(d)
# This is |\phi^+><\phi^+|
print(state @ state.T.conj())

[[0.5+0.j 0. +0.j 0. +0.j 0.5+0.j]
 [0. +0.j 0. +0.j 0. +0.j 0. +0.j]
 [0. +0.j 0. +0.j 0. +0.j 0. +0.j]
 [0.5+0.j 0. +0.j 0. +0.j 0.5+0.j]]


In [40]:
a = 0.2-0.3j
if np.abs(a) > 1:
	raise ValueError("a must be less than 1 in magnitude")
b = np.sqrt(1-a**2)
p = prob_matrix(a, b)

In [75]:
prob_matrix = lambda a, b: np.array([
    [(a)**2, np.conjugate(a)*b], 
    [b*np.conjugate(a), (b)**2]
])

pauliX = np.array([
	[0, 1],
	[1, 0]
])

pauliI = np.array([
	[1, 0],
	[0, 1]
])
bit_flip_chainnel = lambda p: pauliX @ p @ pauliX

a = 0.2-0.3j
if np.abs(a) > 1:
	raise ValueError("a must be less than 1 in magnitude")
b = np.sqrt(1-a**2)
# p = prob_matrix(a, b)
p = np.array([
	[0.5, 0.5],
	[0.5, 0.5]
])
print("Initial: ", p)
p1 = bit_flip_chainnel(p)
print("Channel: ", p1)
phi_plus = maximally_entangled_state(2)
choi_matrix = (np.kron(pauliI, p1)) @ (state @ state.T.conj())
print("Choi matrix:", np.round(choi_matrix, 2))


Initial:  [[0.5 0.5]
 [0.5 0.5]]
Channel:  [[0.5 0.5]
 [0.5 0.5]]
Choi matrix: [[0.25+0.j 0.  +0.j 0.  +0.j 0.25+0.j]
 [0.25+0.j 0.  +0.j 0.  +0.j 0.25+0.j]
 [0.25+0.j 0.  +0.j 0.  +0.j 0.25+0.j]
 [0.25+0.j 0.  +0.j 0.  +0.j 0.25+0.j]]


In [78]:
J = 1/2*np.array([
	[0, 0, 0, 0],
	[0, 1, 1, 0],
	[0, 1, 1, 0],
	[0, 0, 0, 0]
])

from pennylane import math
J_inverse = np.linalg.pinv(J)
p2 = math.partial_trace(np.kron(pauliI, p1.T) @ J_inverse, [0])

print("Recover:", np.round(p2, 2))

Recover: [[0.25+0.j 0.25+0.j]
 [0.25+0.j 0.25+0.j]]


In [79]:
print("Original:", np.round(p, 2))

Original: [[0.5 0.5]
 [0.5 0.5]]


In [None]:
matrix = np.array(
	[[1,2,3,4],
 [5,6,7,8],
 [9,10,11,12],
 [13,14,15,16]]
)

math.partial_trace(matrix, indices = [0])

array([[12.+0.j, 14.+0.j],
       [20.+0.j, 22.+0.j]])