# Декомпозиция Клементса

In [1]:
"""
Clements unitary decomposition
"""

from typing import NamedTuple, List, Tuple
import torch
import numpy as np
from torch import Tensor


def mzi_matrix(
    dim: int,
    target: Tuple[int, int],
    theta: float,
    phi: float,
    dtype=torch.complex128,
    device='cpu'
) -> Tensor:
    """Matrix form of a Mach-Zehnder interferometer using PyTorch."""
    m, n = target
    if n - m != 1:
        raise ValueError("m and n must be consecutive integers.")

    u_mzi = torch.eye(dim, dtype=dtype, device=device)

    theta_t = torch.tensor(theta, dtype=torch.float32, device=device, requires_grad=False)
    phi_t = torch.tensor(phi, dtype=torch.float32, device=device, requires_grad=False)

    global_phase = torch.exp(1j * theta_t)
    cos_t = torch.cos(theta_t)
    sin_t = torch.sin(theta_t)
    exp_phi = torch.exp(1j * phi_t)

    u_mzi[m, m] = global_phase * exp_phi * cos_t
    u_mzi[m, n] = global_phase * 1j * sin_t
    u_mzi[n, m] = global_phase * exp_phi * 1j * sin_t
    u_mzi[n, n] = global_phase * cos_t

    return u_mzi


class MachZehnder(NamedTuple):
    theta: float
    phi: float
    target: Tuple[int, int]


class Decomposition(NamedTuple):
    D: Tensor
    circuit: List[MachZehnder]


def _is_unitary(U: Tensor, atol=1e-6) -> bool:
    I = torch.eye(U.shape[0], dtype=U.dtype, device=U.device)
    return torch.allclose(U @ U.conj().T, I, atol=atol)

def clements_decomposition(U: Tensor) -> Decomposition:
    if not isinstance(U, torch.Tensor):
        raise TypeError("U must be a torch.Tensor.")
    if not _is_unitary(U):
        raise ValueError("U must be a unitary matrix.")

    dim = U.shape[0]
    dtype = U.dtype
    device = U.device

    right_sequence: List[MachZehnder] = []
    left_sequence: List[MachZehnder] = []
    U = U.clone()

    for k in range(1, dim):
        if k % 2 == 1:
            for i, j in zip(range(dim - 1, dim - k - 1, -1), range(k - 1, -1, -1)):
                if torch.abs(U[i, j]) < 1e-12:
                    continue
                if torch.abs(U[i, j + 1]) < 1e-12:
                    phi = 0.0
                else:
                    phi = (torch.angle(U[i, j] / U[i, j + 1]) - torch.pi / 2) % (2 * torch.pi)
                theta = torch.arctan2(torch.abs(U[i, j]), torch.abs(U[i, j + 1]))
                right_sequence.append(MachZehnder(theta, phi, (j, j + 1)))
                mzi = mzi_matrix(dim, (j, j + 1), theta, phi, dtype=dtype, device=device)
                U = U @ mzi.conj().T
        else:
            for i, j in zip(range(dim - k, dim), range(0, k)):
                if torch.abs(U[i, j]) < 1e-12:
                    continue
                if torch.abs(U[i - 1, j]) < 1e-12:
                    phi = 0.0
                else:
                    phi = (torch.angle(U[i, j] / U[i - 1, j]) + torch.pi / 2) % (2 * torch.pi)
                theta = torch.arctan2(torch.abs(U[i, j]), torch.abs(U[i - 1, j]))
                left_sequence.insert(0, MachZehnder(theta, phi, (i - 1, i)))
                mzi = mzi_matrix(dim, (i - 1, i), theta, phi, dtype=dtype, device=device)
                U = mzi @ U

    D = torch.diag(U).clone()
    new_left_sequence: List[MachZehnder] = []

    for theta, phi, (m, n) in left_sequence:
        new_phi = (torch.pi + torch.angle(D[m]) - torch.angle(D[n])) % (2 * torch.pi)
        new_left_sequence.insert(0, MachZehnder(theta, new_phi, (m, n)))
        new_beta = torch.angle(D[n]) - 2 * theta
        new_alpha = torch.angle(D[n]) - torch.pi - phi - 2 * theta
        D[m] = torch.exp(1j * new_alpha)
        D[n] = torch.exp(1j * new_beta)

    return Decomposition(D=D, circuit=new_left_sequence + right_sequence[::-1])

def circuit_reconstruction(decomposition: Decomposition) -> Tensor:
    if not isinstance(decomposition, Decomposition):
        raise TypeError("decomposition must be a Decomposition object.")

    D = decomposition.D
    dim = D.shape[0]
    dtype = D.dtype
    device = D.device

    reconstructed_matrix = torch.diag(D)

    for theta, phi, (m, n) in decomposition.circuit:
        mzi = mzi_matrix(dim, (m, n), theta, phi, dtype=dtype, device=device)
        reconstructed_matrix = reconstructed_matrix @ mzi

    return reconstructed_matrix

# Разложение Река

# Проверка

In [3]:
torch.manual_seed(42)

def fidellity(U_exp: Tensor, U: Tensor):
    return torch.abs(torch.einsum('ii', U_exp.adjoint() @ U)).item()/U.shape[0]

def haar_measure(N):
    """Generate a Haar-random matrix using the QR decomposition."""
    A, B = torch.randn(size=(N, N), dtype=torch.float64), torch.randn(size=(N, N), dtype=torch.float64)
    Z = A + 1j * B
    Q, R = torch.linalg.qr(Z)
    Lambda = torch.diag(torch.tensor([R[i, i] / torch.abs(R[i, i]) for i in range(N)]))

    return Q @ Lambda

N = torch.randint(low=2, high=20, size = ())
U = haar_measure(N)

print("Initial Unitary:\n", U.numpy().round(4))

# Apply Clements decomposition
decomposition = clements_decomposition(U)

# Print the decomposition parameters
print("\nCircuit:\n")
num_layer = 1
temp = 0
print(f'Layer {num_layer}\n')
for theta, phi, target in decomposition.circuit:
    if temp > target[0]:
        num_layer += 1
        print(f'\nLayer {num_layer}\n')
    print(f"theta: {theta:.3f}, phi: {phi:.2f}, target: {target}")
    temp = target[0]

# Reconstruct the unitary from the decomposition
reconstructed_unitary = circuit_reconstruction(decomposition)

# Print the reconstructed unitary and assert initial matrix
print("\nReconstructed Unitary:\n", reconstructed_unitary.numpy().round(4))
print("\nFidellity:\n", fidellity(reconstructed_unitary, U))
assert torch.allclose(U, reconstructed_unitary, atol=1e-06), "Reconstructed unitary does not match original."

Initial Unitary:
 [[-0.2006+0.0945j  0.5269-0.099j  -0.2327+0.2757j  0.008 +0.1147j
  -0.0506+0.092j  -0.1041-0.2304j -0.3972+0.4805j -0.0954-0.2177j]
 [-0.1295-0.1587j  0.1815+0.1986j -0.0996-0.1983j -0.3697+0.4263j
  -0.1729-0.019j  -0.4346+0.3077j  0.2055+0.1765j  0.1376+0.3344j]
 [-0.3724-0.5003j -0.1125-0.2725j -0.0402+0.2994j -0.3849-0.0267j
   0.393 +0.2562j  0.1423+0.1155j -0.0635-0.1336j -0.0753+0.0523j]
 [-0.1528+0.2489j -0.0655+0.3764j  0.1486+0.1059j -0.2019-0.1234j
  -0.3107+0.3058j  0.3503+0.3919j  0.0303+0.1637j -0.4256-0.0662j]
 [-0.0238-0.1727j -0.3071+0.078j   0.1853-0.2987j  0.0207+0.0415j
   0.0188+0.4263j  0.1027-0.1532j  0.0292+0.4863j  0.4608-0.2788j]
 [ 0.1624+0.2351j  0.2311-0.0187j  0.1415-0.227j  -0.1423-0.3131j
   0.5125-0.1013j -0.0614+0.5134j -0.2805+0.1268j  0.1911-0.0567j]
 [-0.0476+0.3983j -0.4043-0.1464j -0.6899-0.1172j -0.0803-0.118j
   0.0564+0.1781j -0.0128-0.0217j -0.0923+0.1006j  0.0269+0.2987j]
 [ 0.3913+0.092j  -0.2228-0.1365j  0.0504-0.0933j -0

  theta_t = torch.tensor(theta, dtype=torch.float32, device=device, requires_grad=False)
  phi_t = torch.tensor(phi, dtype=torch.float32, device=device, requires_grad=False)
