# Relaxation Problem 2 for Branch Cut

Given integer $N$ and array $P$ of length $2^N$.

We want to evaluate the answer of the following problem:
$$
    \max_{Q \in \mathbb{F}_2^{n \times n}, c \in \mathbb{F}_2^{n}} (-1)^{x^\top Q x} i^{c^\top x} P_x
$$

## naive method

In [11]:
import math
from tqdm.auto import tqdm
from numba import njit


@njit(cache=True)
def xQc_to_coeff(k, x, Q, c):
    cx = 0
    xQx = 0
    Q_idx = 0
    for i in range(k):
        cx ^= ((x >> i) & 1) * ((c >> i) & 1)
        for j in range(i, k):
            if ((Q >> Q_idx) & 1) & ((x >> i) & 1) & ((x >> j) & 1):
                xQx ^= 1
            Q_idx += 1
    return (-1) ** xQx * (complex(0, 1) ** cx)


@njit(cache=True)
def calc_abs_from_bQc(k, b, Q, c):
    ret = complex(0.0, 0.0)
    for x in range(1 << k):
        ret += xQc_to_coeff(k, x, Q, c) * b[x]
    return abs(ret)


def solveSlow(N: int, Xs: list):
    assert len(Xs) == 1 << N
    k = N
    maxAbs = 0.0
    for Q in tqdm(range(1 << (k * (k + 1) // 2))):
        for c in range(1 << k):
            absVal = calc_abs_from_bQc(k, Xs, Q, c)
            maxAbs = max(maxAbs, absVal)
    return maxAbs

## upper bound 1

In [12]:
def threshold1(N: int, Xs: list):
    return sum(np.abs(Xs))

## upper bound 2

Please refer to `relaxation_problem_1_for_branch_cut.ipynb` for detail.

In [13]:
def rotate(x):
    if x.imag >= 0:
        if x.real >= 0:
            return x
        else:
            return x * -1j
    else:
        if x.real <= 0:
            return -x
        else:
            return x * 1j


def threshold2(N: int, Xs: list):
    assert len(Xs) == 1 << N
    rotated_Xs = [rotate(X) for X in Xs]
    Ys = rotated_Xs.copy()
    Ys.sort(key=lambda Y: (Y.imag / Y.real) if Y.real != 0 else 1e9)
    sumYsReal = sum(Y.real for Y in Ys)
    sumYsImag = sum(Y.imag for Y in Ys)
    maxAbs2 = sumYsReal**2 + sumYsImag**2
    for i in range((1 << N) - 1):
        sumYsReal += -Ys[i].real - Ys[i].imag
        sumYsImag += -Ys[i].imag + Ys[i].real
        absVal2 = sumYsReal**2 + sumYsImag**2
        maxAbs2 = max(maxAbs2, absVal2)
        Ys[i] *= 1j
    return maxAbs2**0.5

## upper bound 3

In [14]:
import matplotlib.pyplot as plt



def vis(N: int, Xs: list, title=""):
    for i in range(1 << N):
        plt.plot([0, Xs[i].real], [0, Xs[i].imag], label=f"{i}")
    plt.gca().set_aspect("equal", adjustable="box")
    if N <= 3:
        plt.legend()


    plt.title(title)
    plt.show()

In [21]:
from exputils.math.popcount import popcount


def solveFast(N: int, Xs: list):
    assert len(Xs) == 1 << N
    maxAbs2 = 0.0
    Ys = [0.0] * (1 << N)
    for c in range(1 << N):
        for i in range(1 << N):
            Ys[i] = (1j ** (popcount(i & c) & 1)) * Xs[i]
            if Ys[i].imag < 0:
                Ys[i] *= -1
        Ys.sort(key=lambda Y: -(Y.real / Y.imag) if Y.imag != 0 else 1e9)
        # vis(N, Ys)
        sumYsReal = sum(Y.real for Y in Ys)
        sumYsImag = sum(Y.imag for Y in Ys)
        maxAbs2 = max(maxAbs2, sumYsReal**2 + sumYsImag**2)
        for i in range((1 << N) - 1):
            sumYsReal += -Ys[i].real - Ys[i].imag
            sumYsImag += -Ys[i].imag + Ys[i].real
            absVal2 = sumYsReal**2 + sumYsImag**2
            maxAbs2 = max(maxAbs2, absVal2)
            Ys[i] *= -1
            # vis(N, Ys)
    return maxAbs2**0.5

## check performance

In [22]:
import numpy as np

N = 8
Xs = (np.random.normal(size=1 << N) + 1j * np.random.normal(size=1 << N)).tolist()

# slow = solveSlow(N, Xs)
# print(f"{slow=}")
t1 = threshold1(N, Xs)
print(f"{t1=}")
t2 = threshold2(N, Xs)
print(f"{t2=}")
fast = solveFast(N, Xs)
print(f"{fast=}")

t1=318.161949826417
t2=288.25320521966063
fast=286.9314519884549
