In [None]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots


DAC_SR = 4e9          
buf_len = 2**14         
dt = 1.0 / DAC_SR
X_axis = (1.0 / DAC_SR) * np.arange(0, buf_len)
freq = np.fft.fftshift(np.fft.fftfreq(buf_len, d=dt)) 

E0 = 1.0
G0_t = E0 * np.ones(buf_len)

line_freqs = np.array([0.0, 840e6, -1330e6])   # Hz
power_ratios = np.array([5.0, 3.0, 1.0])       # corresponding to carrier:+840:-1330 as 5:3:1
field_amps = np.sqrt(power_ratios)   # if ratios are power. If ratios already amplitude, skip sqrt.


sigma_hz = 1.0e6   # 1 MHz width (tune as needed)
target_amp = np.zeros_like(freq)

for f0, A in zip(line_freqs, field_amps):
    target_amp += A * np.exp(-0.5 * ((freq - f0) / sigma_hz)**2)

initial_field_w = np.fft.fftshift(np.fft.fft(np.fft.ifftshift(G0_t)))
norm_time_energy = np.linalg.norm(initial_field_w)   # L2 in freq domain proportional to time energy
norm_target = np.linalg.norm(target_amp)
if norm_target == 0:
    raise ValueError("target amplitude is zero (check line frequencies and sigma).")
target_amp *= (norm_time_energy / norm_target)

# -----------------------
# GSA iterative loop
# -----------------------
phi_t = np.random.uniform(0, 2*np.pi, buf_len)  # initial random phase
max_iters = 1000
tol = 1e-6

for it in range(max_iters):
    # forward: time -> freq
    field_t = G0_t * np.exp(1j * phi_t)
    field_w = np.fft.fftshift(np.fft.fft(np.fft.ifftshift(field_t)))
    amp_w = np.abs(field_w)
    phase_w = np.angle(field_w)

    # impose frequency amplitude constraint
    new_field_w = target_amp * np.exp(1j * phase_w)

    # inverse: freq -> time
    new_field_t = np.fft.fftshift(np.fft.ifft(np.fft.ifftshift(new_field_w)))
    amp_t = np.abs(new_field_t)
    phi_t = np.angle(new_field_t)

    # impose time amplitude constraint
    # set amplitude back to known G0_t, keep phase phi_t
    field_t_next = G0_t * np.exp(1j * phi_t)

    # convergence metric: compare current frequency amplitude to target
    err = np.linalg.norm(amp_w - target_amp) / np.linalg.norm(target_amp)
    if (it % 50 == 0) or (err < 1e-4):
        print(f"iter {it:4d}  rel_err={err:.3e}")
    if err < tol:
        print(f"Converged at iter {it}, rel_err={err:.3e}")
        break
else:
    print(f"Finished {max_iters} iters; final rel_err={err:.3e}")

# final recovered objects
final_phi_t = phi_t.copy()
final_field_t = G0_t * np.exp(1j * final_phi_t)
final_field_w = np.fft.fftshift(np.fft.fft(np.fft.ifftshift(final_field_t)))

# -----------------------
# Plotly figures
# -----------------------
fig = make_subplots(rows=2, cols=2,
                    subplot_titles=(
                        "Time amplitude G0(t)",
                        "Frequency amplitude (target vs achieved)",
                        "Recovered phase φ(t)",
                        "Frequency-domain phase of final field"
                    ))

# (1,1) Time amplitude G0(t)
fig.add_trace(go.Scatter(x=t * 1e9, y=G0_t, name="G0(t)"), row=1, col=1)
fig.update_xaxes(title_text="Time (ns)", row=1, col=1)
fig.update_yaxes(title_text="Amplitude", row=1, col=1)

# (1,2) Frequency amplitude target vs achieved
fig.add_trace(go.Scatter(x=freq / 1e9, y=target_amp, name="target amp"), row=1, col=2)
fig.add_trace(go.Scatter(x=freq / 1e9, y=np.abs(final_field_w), name="achieved amp",
                          line=dict(dash='dash')), row=1, col=2)
fig.update_xaxes(title_text="Frequency (GHz)", range=[-3, 3], row=1, col=2)
fig.update_yaxes(title_text="Amplitude", row=1, col=2)

# (2,1) Recovered phase φ(t)
fig.add_trace(go.Scatter(x=t * 1e9, y=final_phi_t, name="φ(t)"), row=2, col=1)
fig.update_xaxes(title_text="Time (ns)", row=2, col=1)
fig.update_yaxes(title_text="Phase (rad)", row=2, col=1)

# (2,2) Frequency-domain phase of final field
fig.add_trace(go.Scatter(x=freq / 1e9, y=np.angle(final_field_w), name="phase"), row=2, col=2)
fig.update_xaxes(title_text="Frequency (GHz)", row=2, col=2)
fig.update_yaxes(title_text="Phase (rad)", row=2, col=2)

fig.update_layout(height=800, width=1100, title_text="Gerchberg–Saxton Results")
fig.show()

# -----------------------
# EOM drive waveform (what to send to the modulator)
# -----------------------
# The phase drive is proportional to φ(t). Without Vπ, show a zero-mean, unit-max waveform.
eom_drive = final_phi_t - np.mean(final_phi_t)
maxabs = np.max(np.abs(eom_drive))
eom_drive_norm = eom_drive / maxabs if maxabs > 0 else eom_drive

fig_eom = go.Figure()
fig_eom.add_trace(go.Scatter(x=t * 1e9, y=eom_drive_norm, name="Normalized EOM drive"))
fig_eom.update_layout(title="Final EOM Drive Waveform (normalized)",
                      xaxis_title="Time (ns)",
                      yaxis_title="Normalized voltage (arb.)")
fig_eom.show()

In [None]:
from __future__ import annotations

import time
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import sawtooth, square


def gsa_demo():
    # Sampling parameters
    Fs = 4e6  # Sampling frequency
    T = 1.0 / Fs  # Sampling period
    L = 2**15  # Length of signal
    t = np.arange(L) * T  # Time vector

    # Define a source signal (amplitude of light field)
    S_o = np.exp(-1j * 0.00 * t)

    So_f = np.fft.fftshift(np.fft.fft(S_o))
    G_of = np.abs(So_f) / L

    # Generate a target signal spectrum using concatenated serrodynes
    f_mod1 = 250
    f_mod2 = 900
    f_mod3 = -100
    f_mod_square = 5

    # Duration fractions (sum to 1)
    p1 = 0.1
    p2 = 0.3
    p3 = 1 - p1 - p2

    # Modulation phase components (scipy.signal.square: duty in [0,1])
    phi_mod1 = sawtooth(2 * np.pi * f_mod1 * t, width=1) * (1 + square(2 * np.pi * f_mod_square * t, duty=p1)) / 2
    phi_mod2 = sawtooth(2 * np.pi * f_mod2 * t, width=1) * (
        1 + square(2 * np.pi * f_mod_square * t + 2 * np.pi * p2, duty=p2)
    ) / 2
    phi_mod3 = sawtooth(2 * np.pi * f_mod3 * t, width=1) * (
        1 + square(2 * np.pi * f_mod_square * t + 2 * np.pi * (p2 + p3), duty=p3)
    ) / 2

    phi_mod = phi_mod1 + phi_mod2 + phi_mod3

    t0 = time.time()

    # (concatenated) serrodyne modulation
    S_t = S_o * np.exp(1j * np.pi * 1 * phi_mod)

    S_tf = np.fft.fft(S_t)
    G_tf = np.abs(np.fft.fftshift(S_tf)) / L

    # Frequency axis for plotting (mirrors MATLAB Fs/L*(-L/2:L/2-1))
    freqs = (Fs / L) * (np.arange(-L // 2, L // 2))

    plt.figure(figsize=(10, 10))
    plt.subplot(4, 1, 1)
    plt.plot(freqs, 10 * np.log10(G_tf + 1e-20), linewidth=2.0, label="Target")
    plt.xlim([-1200, 1200])
    plt.xlabel("Frequency")
    plt.ylabel("FFT magnitude (dBc)")
    plt.title("Spectrum of modulated light field")
    plt.tight_layout()

    # Gerchberg–Saxton algorithm
    i = 1
    np.random.seed(0)
    phi_i = 2 * np.pi * np.random.randn(L)  # random seed phase

    IterationNum = 150
    Err = np.zeros(IterationNum)

    # Frequency-domain decay constraints (optional)
    offset = 1.0
    # scipy.signal.square uses radians-like argument; this is a slow envelope as in MATLAB
    coef = (offset - square((1.0 / (16 * L * T)) * t, duty=0.5)) / (offset + 1.0)

    plt.ion()
    for i in range(1, IterationNum + 1):
        At = S_o * np.exp(1j * phi_i)
        Bf = np.fft.fft(At)
        phi_jf = np.angle(Bf)
        G_jf = np.abs(Bf) / L
        # Mirror MATLAB: compare shifted target magnitude with unshifted current magnitude
        Err[i - 1] = np.sum(np.abs(G_tf - G_jf))

        # Mirror MATLAB: use shifted target magnitude with unshifted phase
        Cf = G_tf * np.exp(1j * phi_jf)
        # Optional suppression of negative-frequency components
        # Cf = Cf * coef

        Dt = np.fft.ifft(Cf)

        # Update plots
        plt.subplot(4, 1, 2)
        # Normalize by max(G_of) to follow MATLAB style (no fftshift on Bf per MATLAB)
        g_jf_norm = (np.abs(Bf) / L) / (np.max(G_of) + 1e-20)
        plt.cla()
        plt.plot(freqs, 10 * np.log10(g_jf_norm + 1e-20), linewidth=2)
        plt.xlim([-1200, 1200])
        plt.title(f"IterationNum = {i}")
        plt.xlabel("Frequency")
        plt.ylabel("FFT magnitude (dBc)")

        plt.subplot(4, 1, 3)
        plt.cla()
        plt.plot(np.arange(1, i + 1), Err[:i], linewidth=1.0)
        plt.xlim([1, i + 1])
        plt.xlabel("Iteration number")
        plt.ylabel("Error")

        plt.pause(0.05)

        # Update time-domain phase for next iteration
        phi_i = np.angle(Dt)

    # Final overlay on first subplot
    plt.subplot(4, 1, 1)
    # Final retrieval spectrum (no fftshift per MATLAB plot)
    g_jf_final = (np.abs(Bf) / L) / (np.max(G_of) + 1e-20)
    plt.plot(freqs, 10 * np.log10(g_jf_final + 1e-20), linewidth=1.0, label="Retrieval")
    plt.legend()
    plt.xlim([-1200, 1200])

    # Residual spectrum
    plt.subplot(4, 1, 4)
    target_norm = G_tf / (np.max(G_of) + 1e-20)
    resid = np.abs(target_norm - g_jf_final)
    plt.plot(freqs, 10 * np.log10(resid + 1e-20), linewidth=1.0)
    plt.xlabel("Frequency")
    plt.ylabel("Residual FFT magnitude (dBc)")
    plt.xlim([-1200, 1200])

    elapsed = time.time() - t0
    print(f"GSA completed in {elapsed:.2f} s")

    # Phase plots
    plt.figure()
    plt.plot(t, np.angle(S_t), label="Target phase")
    PHI0 = 0.0
    # Recompute At with final phase to be explicit
    At_final = S_o * np.exp(1j * phi_i)
    plt.plot(t, np.angle(At_final * np.exp(1j * 2 * np.pi * (Fs / 2 * t - PHI0))))
    plt.title("Retrieval phase")
    plt.xlabel("Time")
    plt.ylabel("phi(t) (rad)")
    plt.legend()

    plt.ioff()
    plt.show()


def main():
    gsa_demo()


if __name__ == "__main__":
    main()
