In [1]:
import torch
from torch.nn.functional import conv1d, pad
from torch.fft import fft
from torchaudio.transforms import FFTConvolve
import time
import numpy as np

# Use CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def objective_torch(x, P):
    x.requires_grad = True

    # Compute loss using squared distance function
    loss = torch.norm(P - FFTConvolve("full").forward(x, torch.flip(x, dims=[0])))**2
    return loss

times = []
final_vals = []
num_iterations = []

for k in range(4, 20):
    N = 2 ** k
    poly = torch.randn(N, device=device)

    granularity = 2 ** 25
    P = pad(poly, (0, granularity - poly.shape[0]))
    ft = fft(P)

    # Normalize P
    P_norms = ft.abs()
    poly /= torch.max(P_norms)

    conv_p_negative = FFTConvolve("full").forward(poly, torch.flip(poly, dims=[0]))* -1
    conv_p_negative[poly.shape[0] - 1] = 1 - torch.norm(poly) ** 2

    # Initializing Q randomly to start with
    initial = torch.randn(poly.shape[0], device=device, requires_grad=True)
    initial = (initial / torch.norm(initial)).clone().detach().requires_grad_(True)

    optimizer = torch.optim.LBFGS([initial], max_iter=1000)

    t0 = time.time()

    def closure():
        optimizer.zero_grad()
        loss = objective_torch(initial, conv_p_negative)
        loss.backward()
        return loss

    optimizer.step(closure)

    t1 = time.time()

    total = t1-t0
    times.append(total)
    final_vals.append(closure().item())
    num_iterations.append(optimizer.state[optimizer._params[0]]['n_iter'])
    print(f'N: {N}')
    print(f'Time: {total}')
    print(f'Final: {closure().item()}')
    print(f"# Iterations: {optimizer.state[optimizer._params[0]]['n_iter']}")
    print("-----------------------------------------------------")

print(times)
print(final_vals)
print(num_iterations)


N: 16
Time: 0.3235609531402588
Final: 1.030339717544848e-06
# Iterations: 71
-----------------------------------------------------
N: 32
Time: 1.238006830215454
Final: 4.154373164055869e-05
# Iterations: 201
-----------------------------------------------------
N: 64
Time: 3.0356409549713135
Final: 3.266644199584334e-08
# Iterations: 435
-----------------------------------------------------
N: 128
Time: 4.42612099647522
Final: 6.429892891901545e-06
# Iterations: 641
-----------------------------------------------------
N: 256
Time: 6.6678760051727295
Final: 5.054324446973624e-06
# Iterations: 880
-----------------------------------------------------
N: 512
Time: 4.485189914703369
Final: 5.584213795373216e-06
# Iterations: 587
-----------------------------------------------------
N: 1024
Time: 9.548364162445068
Final: 2.3808004243619507e-06
# Iterations: 1000
-----------------------------------------------------
N: 2048
Time: 9.351141214370728
Final: 3.6368119253893383e-06
# Iterations:

KeyboardInterrupt: 