In [2]:
import scipy
import numpy as np
from scipy.stats import unitary_group
import itertools

In [3]:
U = unitary_group.rvs(6)

In [4]:
def get_alpha(index: list, unitary: np.ndarray):

    input_mode_occupations, output_mode_occupations = index[0:3] + [1,1,1], index[3:6] + [1,1,1]

    n_input = sum(input_mode_occupations)
    n_output = sum(output_mode_occupations)

    if n_input != n_output:
        return 0

    occupied_input_modes = [index for index, occupation in enumerate(input_mode_occupations) if occupation==1]

    idk_how_to_name_this = []
    for mode, occupation in enumerate(output_mode_occupations):
        for _ in range(occupation):
            idk_how_to_name_this.append(mode)

    permutations = list(itertools.permutations(idk_how_to_name_this))

    alpha = 0

    for permutation in permutations:
        poly = 1
        for index, mode in enumerate(occupied_input_modes):
            poly *= unitary[mode, permutation[index]]

        alpha += poly
    
    return alpha

In [19]:
def partition_min_max(n, k, l, m):
    """
    n: The integer to partition
    k: The length of partitions
    l: The minimum partition element size
    m: The maximum partition element size
    """
    if k < 1:
        return []
    if k == 1:
        if l <= n <= m:
            return [(n,)]
        return []
    result = []
    for i in range(l, m + 1):
        sub_partitions = partition_min_max(n - i, k - 1, i, m)
        for sub_partition in sub_partitions:
            result.append(sub_partition + (i,))
    return result

def get_partitions_permutations(n, k):
    partitions = partition_min_max(n, k, 0, n)

    permutations = []
    for partition in partitions:
        permutations.extend(list(itertools.permutations(partition)))
    
    return map(list, list(set(permutations)))
            
# Example usage:
N = 3
M = 3
# partitions = list(partition_min_max(N, M, 0, N))
permutations = get_partitions_permutations(N, M)
# print(partitions)
print(permutations)

[[0, 2, 1], [0, 0, 3], [3, 0, 0], [2, 1, 0], [1, 2, 0], [0, 3, 0], [2, 0, 1], [0, 1, 2], [1, 0, 2], [1, 1, 1]]


In [20]:
def loss_function_ccz_dual_rail(U):
    desired_gate_loss = 0
    + np.abs(get_alpha([0,0,0,0,0,0], U) - get_alpha([0,1,1,0,1,1], U))**2 
    + np.abs(get_alpha([0,1,1,0,1,1], U) - get_alpha([1,0,1,1,0,1], U))**2
    + np.abs(get_alpha([1,0,1,1,0,1], U) - get_alpha([1,1,0,1,1,0], U))**2
    + np.abs(get_alpha([1,1,0,1,1,0], U) + get_alpha([0,0,1,0,0,1], U))**2
    + np.abs(get_alpha([0,0,1,0,0,1], U) - get_alpha([0,1,0,0,1,0], U))**2
    + np.abs(get_alpha([0,1,0,0,1,0], U) - get_alpha([1,0,0,1,0,0], U))**2
    + np.abs(get_alpha([1,0,0,1,0,0], U) - get_alpha([1,1,1,1,1,1], U))**2

    undesired_gate_loss = 0

    for input_state in [[0,0,0], [0,0,1], [0,1,0], [0,1,1], [1,0,0], [1,0,1], [1,1,0], [1,1,1]]:
        
        particle_number = np.sum(input_state)

        output_states = get_partitions_permutations(n=particle_number, k=3)

        for output_state in output_states:
            if input_state != output_state:
                
                index=input_state + list(output_state) 
                undesired_gate_loss += np.abs(get_alpha(index=index, unitary=U))**2

    loss = desired_gate_loss + undesired_gate_loss

    return loss

In [7]:
def get_success_prob(U):
    return np.abs(get_alpha(index=[0,0,0,0,0,0], unitary=U))**2

In [24]:
import time

best_loss = np.infty
best_prob = 0
best_U = None

timer = time.time()

try:
    for _ in range(int(1e8)):

        U = reunitary_group.rvs(6)

        loss = loss_function_ccz_dual_rail(U=U)
        prob = get_success_prob(U)

        if loss < best_loss:
            best_loss = loss
            best_prob = prob
            best_U = U

    print("Calculated in", time.time()-timer, "seconds")
    print(best_loss, best_prob)
    print(U)
except KeyboardInterrupt:
    print("Calculated in", time.time()-timer, "seconds")
    print(best_loss, best_prob)
    print(U)

Calculated in 1995.5432305335999 seconds
0.01824173501706873 0.0011346220744473752
[[ 0.01401236-0.26330385j  0.03953893+0.01694316j  0.03437104+0.18470981j
   0.39745623+0.19229393j -0.02613052+0.58432327j -0.276804  +0.52880989j]
 [-0.27928651+0.13308178j -0.10915402-0.14853236j  0.4194344 +0.44481936j
  -0.17934838-0.10352693j -0.03598381-0.42711186j -0.39133817+0.34172173j]
 [-0.25229345+0.33671837j -0.13855627+0.50099044j -0.18555919+0.08898744j
  -0.01356569-0.5732786j  -0.28901845+0.29148239j -0.11138537-0.02637775j]
 [-0.2667837 -0.40637399j -0.190603  +0.39918127j  0.52111967-0.05218217j
  -0.21943342+0.16685237j  0.15340581+0.21828107j -0.07127902-0.37612678j]
 [-0.08463353+0.17307007j -0.30701702-0.61752428j  0.05843474-0.34484659j
  -0.23341558-0.18963529j  0.12744026+0.41029072j -0.28186077-0.10239669j]
 [ 0.52468438-0.33016455j  0.05387608+0.14074934j  0.18913932-0.34082886j
   0.10190986-0.50647291j  0.05275128-0.21709058j -0.3402899 +0.09185864j]]


In [145]:
print(best_loss, best_prob)
print(U)

0.07625325382063168 0.004201985572322366
[[ 0.35386082+0.09370338j -0.11781769+0.1126784j  -0.59085513+0.12103865j
  -0.24731624-0.22348081j  0.34500912-0.437993j    0.23122927+0.01481521j]
 [ 0.05147172-0.34501342j -0.61937862-0.20981538j  0.15924725-0.19038506j
   0.40548284+0.04569049j -0.04232555-0.28167285j  0.36699666-0.08207511j]
 [-0.00536796-0.40540113j -0.1896524 -0.03166948j -0.47094997-0.07977378j
  -0.22194236+0.47062341j  0.08917585+0.51562436j  0.00107132-0.16101437j]
 [ 0.2145497 -0.32452371j -0.33890926-0.06335179j  0.16423531+0.05956836j
  -0.16663138-0.28107991j  0.1844601 -0.02745294j -0.73554127+0.12917484j]
 [ 0.52238862+0.32440679j -0.32911184+0.50325446j  0.21671374+0.18323404j
   0.11193591+0.20973073j -0.04669289+0.30659866j  0.07758255+0.1450259j ]
 [ 0.21748837-0.05622874j  0.16668428-0.02372708j -0.47853037-0.08014661j
   0.53096638+0.08017722j -0.43980321-0.09894416j -0.30082001+0.32204433j]]


In [24]:
permutations = list(itertools.permutations([1, 2, 3]))
print(permutations)

[(1, 2, 3), (1, 3, 2), (2, 1, 3), (2, 3, 1), (3, 1, 2), (3, 2, 1)]
