# 3WM gain and TWPAnalysis

This tutorial demonstrates how to utilize the `twpasolver.TWPAnalysis` class to simulate the gain response of a TWPA by solving the Coupled Mode Equations (CMEs) with a minimal model.

### Overivew
`twpasolver.TWPAnalysis` automates the procedures required to compute the gain characteristics of a TWPA model.

Implemented features:

* Calculation of phase-matching profile, gain and bandwidth.
* Automatic caching of data from the respectve `phase_matching()`, `gain()` and `bandwidth()` methods.
* Sweeps over lists of input parameters for the analysis functions.
* Simple plots.

### Example - Setup

Let's start by importing the required libraries and initializing the `TWPAnalysis` instance. This object requires:
* A `TWPA` instance or the name of a json file containing the definition of the model
* A frequency span which determines the full region that will be considered in the analysis, which can be provided either as a list of values or a tuple that will be passed to `numpy.arange`. Some precautions in choosing the frequency span must be taken to correctly unwrap the phase response of the device. Namely, the frequency array should be dense enough and should start from frequencies much lower than the position of the stopband.

The computed response and results of the following analysis functions are stored in the internal `data` dictionary, which can be accessed and saved to an hdf5 file through the `save_data()` method.

In [None]:
import logging

import matplotlib.pyplot as plt
import numpy as np

from twpasolver import TWPAnalysis
from twpasolver.logger import log
from twpasolver.mathutils import dBm_to_I

log.setLevel(logging.WARNING)

plt.rcParams["font.size"] = 13.5
plt.rcParams["axes.axisbelow"] = True

twpa_file = "model_cpw_dartwars_13nm_Lk8_5.json"
a = TWPAnalysis(twpa=twpa_file, f_arange=(0.05, 10, 0.5e-3))
a.update_base_data()  # compute response, estimate stopband position and optimal pump frequency
ax = a.plot_response(pump_freq=a.data["optimal_pump_freq"])
a.twpa.model_dump()

### Phase matching

This is the first analysis function implemented by the class. It computes the phase matching condition as a function of pump and signal frequency. By default, the signal range is chosen from the start of the total frequency span to the beginning of the stopband, while the pump range is chosen from the end of the stopband to the maximum of the full span.

In [None]:
_ = a.phase_matching()
ax = a.plot_phase_matching(thin=5)

### Gain and bandwidth
Since `twpasolver` uses a numba-based implementation of the Runge-Kutta algorithm to solve the CMEs, it may take some seconds to compile all the functions when the `gain` method is called for the first time, becoming much faster afterwards.

In [None]:
s_arange = np.arange(1, 7, 0.05)
_ = a.gain(signal_freqs=s_arange, Is0=1e-6)
a.plot_gain()
plt.show()

In [None]:
"""
Small signal stability analysis for TWPAs.

Linearizes the coupled mode equations around operating point and analyzes eigenvalues
to determine stability margins and oscillation thresholds.
"""

from typing import Dict, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
from scipy.linalg import eigvals
from scipy.optimize import fsolve


class TWPAStabilityAnalyzer:
    """Small signal stability analysis for TWPA systems."""

    def __init__(self, twpa_analysis):
        """Initialize with existing TWPAnalysis instance."""
        self.twpa = twpa_analysis

    def linearize_cmes(
        self,
        operating_point: Dict[str, complex],
        pump_freq: float,
        mode_array_config: str = "basic_3wm",
    ) -> np.ndarray:
        """
        Linearize coupled mode equations around operating point.

        Args:
            operating_point: Steady-state current amplitudes {mode: amplitude}
            pump_freq: Pump frequency
            mode_array_config: Mode array configuration

        Returns:
            Jacobian matrix of linearized system
        """
        mode_array = self.twpa.get_mode_array(mode_array_config)
        mode_labels = list(mode_array.modes.keys())
        n_modes = len(mode_labels)

        # Update frequencies
        pump_idx = mode_labels.index("p")
        mode_array.update_frequencies({"p": pump_freq})

        # Get mixing terms
        terms_3wm = mode_array.get_rwa_terms(power=2)
        terms_4wm = mode_array.get_rwa_terms(power=3)

        # Build Jacobian matrix
        jacobian = np.zeros((n_modes, n_modes), dtype=complex)

        # Extract operating point currents
        I_op = np.array([operating_point.get(label, 0) for label in mode_labels])

        # Get mode parameters
        kappas = np.array([mode_array.get_mode(label).k for label in mode_labels])

        # Compute Jacobian elements from 3WM and 4WM terms
        epsilon = self.twpa.twpa.epsilon
        xi = self.twpa.twpa.xi

        for term in terms_3wm:
            mode_idx, comb, mode_name, rhs_terms, coeff = term

            # Derivative contributions from 3WM terms
            for i, idx in enumerate(comb):
                if idx < n_modes:  # Forward mode
                    jacobian[mode_idx, idx] += (
                        1j * epsilon / 4 * coeff * kappas[mode_idx]
                    )

        for term in terms_4wm:
            mode_idx, comb, mode_name, rhs_terms, coeff = term

            # Derivative contributions from 4WM terms
            for i, idx in enumerate(comb):
                if idx < n_modes:  # Forward mode
                    # Compute partial derivatives of nonlinear terms
                    other_amplitudes = np.prod(
                        [
                            I_op[j] if j < n_modes else np.conj(I_op[j - n_modes])
                            for k, j in enumerate(comb)
                            if k != i
                        ]
                    )
                    jacobian[mode_idx, idx] += (
                        1j * xi / 8 * coeff * kappas[mode_idx] * other_amplitudes
                    )

        return jacobian

    def find_steady_state(
        self,
        pump_current: float,
        pump_freq: float,
        signal_freq: Optional[float] = None,
        mode_array_config: str = "basic_3wm",
    ) -> Dict[str, complex]:
        """
        Find steady-state operating point by solving nonlinear equations.

        Args:
            pump_current: Pump current amplitude
            pump_freq: Pump frequency
            signal_freq: Signal frequency (optional)
            mode_array_config: Mode array configuration

        Returns:
            Steady-state current amplitudes
        """
        mode_array = self.twpa.get_mode_array(mode_array_config)
        mode_labels = list(mode_array.modes.keys())

        # Set up frequencies
        freq_dict = {"p": pump_freq}
        if signal_freq is not None:
            freq_dict["s"] = signal_freq

        mode_array.update_frequencies(freq_dict)

        def steady_state_equations(currents_real):
            """Convert to complex and evaluate steady-state equations."""
            n_modes = len(currents_real) // 2
            currents = currents_real[:n_modes] + 1j * currents_real[n_modes:]

            # Set pump current
            currents[0] = pump_current  # Assume pump is first mode

            # Evaluate derivatives (should be zero at steady state)
            # This would use the actual CME functions from the library
            # For now, simplified placeholder
            derivs = np.zeros(n_modes, dtype=complex)

            # Return real and imaginary parts separately
            return np.concatenate([derivs.real, derivs.imag])

        # Initial guess
        n_modes = len(mode_labels)
        initial_guess = np.zeros(2 * n_modes)
        initial_guess[0] = pump_current  # Real part of pump

        # Solve for steady state
        solution = fsolve(steady_state_equations, initial_guess)
        n_modes = len(solution) // 2
        steady_currents = solution[:n_modes] + 1j * solution[n_modes:]

        return {label: steady_currents[i] for i, label in enumerate(mode_labels)}

    def stability_vs_pump_power(
        self,
        pump_powers_dbm: np.ndarray,
        pump_freq: float,
        mode_array_config: str = "basic_3wm",
    ) -> Dict:
        """
        Analyze stability versus pump power.

        Args:
            pump_powers_dbm: Array of pump powers in dBm
            pump_freq: Pump frequency
            mode_array_config: Mode array configuration

        Returns:
            Dictionary with eigenvalues and stability metrics
        """
        from twpasolver.mathutils import dBm_to_I

        pump_currents = [dBm_to_I(p) for p in pump_powers_dbm]
        eigenvals = []
        max_real_parts = []

        for pump_current in pump_currents:
            # Find operating point
            try:
                op_point = self.find_steady_state(
                    pump_current, pump_freq, mode_array_config
                )

                # Linearize around operating point
                jacobian = self.linearize_cmes(op_point, pump_freq, mode_array_config)

                # Compute eigenvalues
                eigs = eigvals(jacobian)
                eigenvals.append(eigs)
                max_real_parts.append(np.max(eigs.real))

            except:
                # Handle convergence failures
                eigenvals.append(np.array([np.nan]))
                max_real_parts.append(np.nan)

        return {
            "pump_powers_dbm": pump_powers_dbm,
            "pump_currents": pump_currents,
            "eigenvalues": eigenvals,
            "max_real_parts": max_real_parts,
            "unstable_threshold": self._find_instability_threshold(
                pump_powers_dbm, max_real_parts
            ),
        }

    def _find_instability_threshold(self, powers, max_real_parts):
        """Find power where system becomes unstable (max real part > 0)."""
        stable_powers = [
            p for p, r in zip(powers, max_real_parts) if not np.isnan(r) and r < 0
        ]
        return max(stable_powers) if stable_powers else None

    def plot_stability_analysis(self, stability_data: Dict):
        """Plot stability analysis results."""
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))

        powers = stability_data["pump_powers_dbm"]
        max_reals = stability_data["max_real_parts"]

        # Plot maximum real part of eigenvalues
        ax1.plot(powers, max_reals, "b-", linewidth=2)
        ax1.axhline(y=0, color="r", linestyle="--", alpha=0.7)
        ax1.set_ylabel("Max Re(λ)")
        ax1.set_title("Stability vs Pump Power")
        ax1.grid(True, alpha=0.3)

        # Highlight unstable region
        unstable_mask = np.array(max_reals) > 0
        if np.any(unstable_mask):
            ax1.fill_between(
                powers,
                max_reals,
                0,
                where=unstable_mask,
                alpha=0.3,
                color="red",
                label="Unstable",
            )

        # Plot all eigenvalues for selected powers
        selected_powers = powers[:: len(powers) // 10]  # Sample subset
        for i, power in enumerate(selected_powers):
            idx = np.argmin(np.abs(powers - power))
            eigs = stability_data["eigenvalues"][idx]
            if not np.any(np.isnan(eigs)):
                ax2.scatter(
                    [power] * len(eigs), eigs.real, c=eigs.imag, cmap="viridis", s=30
                )

        ax2.axhline(y=0, color="r", linestyle="--", alpha=0.7)
        ax2.set_xlabel("Pump Power (dBm)")
        ax2.set_ylabel("Re(λ)")
        ax2.set_title("Eigenvalue Real Parts")
        ax2.grid(True, alpha=0.3)

        # Add colorbar for imaginary parts
        scatter = ax2.scatter([], [], c=[], cmap="viridis")
        plt.colorbar(scatter, ax=ax2, label="Im(λ)")

        plt.tight_layout()
        return fig

    def parametric_oscillation_threshold(
        self,
        pump_freq_range: np.ndarray,
        signal_freq: float,
        mode_array_config: str = "basic_3wm",
    ) -> Dict:
        """
        Find parametric oscillation threshold vs pump frequency.

        Args:
            pump_freq_range: Array of pump frequencies to test
            signal_freq: Signal frequency
            mode_array_config: Mode array configuration

        Returns:
            Dictionary with oscillation thresholds
        """
        thresholds = []

        for pump_freq in pump_freq_range:
            # Binary search for threshold
            power_range = np.linspace(-60, -10, 50)  # dBm

            for power_dbm in power_range:
                stability_data = self.stability_vs_pump_power(
                    np.array([power_dbm]), pump_freq, mode_array_config
                )

                if stability_data["max_real_parts"][0] > 0:
                    thresholds.append(power_dbm)
                    break
            else:
                thresholds.append(np.nan)  # No instability found

        return {
            "pump_frequencies": pump_freq_range,
            "oscillation_thresholds_dbm": thresholds,
        }


# Example usage
def analyze_twpa_stability(twpa_analysis):
    """Example stability analysis workflow."""

    analyzer = TWPAStabilityAnalyzer(twpa_analysis)

    # 1. Stability vs pump power at optimal frequency
    pump_powers = np.linspace(-50, -10, 30)
    optimal_pump = twpa_analysis.data.get("optimal_pump_freq", 7.5)

    stability_data = analyzer.stability_vs_pump_power(pump_powers, optimal_pump)
    print(stability_data)
    # 2. Plot results
    fig = analyzer.plot_stability_analysis(stability_data)

    # 3. Find oscillation threshold vs frequency
    pump_freqs = np.linspace(optimal_pump - 0.5, optimal_pump + 0.5, 20)
    oscillation_data = analyzer.parametric_oscillation_threshold(
        pump_freqs, optimal_pump / 2
    )

    print(f"Stability threshold: {stability_data['unstable_threshold']:.1f} dBm")

    return stability_data, oscillation_data

In [None]:
analyze_twpa_stability(a)

The `bandwidth` analysis function computes the bandwidth by finding all the regions between the maximum gain value and a certain threshold, by default defined as 3 dB lower than the maximum gain. It is designed to keep track of the asymmetry and formation of lobes in the gain profile due to depleted pump effects and high pump frequency choice, potentially computing the bandwidth over discontinuous regions.

In [None]:
_ = a.gain(signal_freqs=s_arange, Is0=3e-5, pump=a.data["optimal_pump_freq"] + 0.15)
a.plot_gain()
_ = a.bandwidth()
for edg in a.data["bandwidth"]["bandwidth_edges"]:
    plt.axvline(edg, color="black", ls="--")
plt.axhline(a.data["bandwidth"]["reduced_gain"], c="black", ls="--")
plt.axhline(a.data["bandwidth"]["mean_gain"], c="black", ls="--")
plt.show()

### Parameter sweeps

The `parameter_sweep` analysis method allows performing sweeps over an input variable for one of the other analysis functions. Basic usage involves passing as the first three positional arguments of a sweep the name of the target function, the name of the target variable and its list of values.

#### Gain as a function of pump frequency

In [None]:
optimal = a.data["optimal_pump_freq"]
pumps = np.arange(optimal - 0.2, optimal + 0.5, 0.02)
pumpsweep_res = a.parameter_sweep(
    "gain", "pump", pumps, signal_freqs=s_arange, save_name="pump_profile"
)

plt.pcolor(pumpsweep_res["signal_freqs"][0], pumps, pumpsweep_res["gain_db"])
plt.xlabel("Signal frequency [GHz]")
plt.ylabel("Pump frequency [GHz]")
c = plt.colorbar()
c.set_label("Gain [dB]")
plt.show()

#### Compression point

Since the default CMEs system considers pump depletion effects, it is possible to simulate power-dependent measurements such as the compression point.

In [None]:
signals_db = np.arange(-80, -40, 1)
signals = dBm_to_I(signals_db)
edges = a.data["bandwidth"]["bandwidth_edges"]
s_arange = a.data["bandwidth"]["bw_freqs"]  # (edges[0], edges[-1], 0.1)
compression_res = a.parameter_sweep(
    "gain", "Is0", signals, signal_freqs=s_arange, thin=1000, save_name="compression"
)

In [None]:
mean_gains = np.mean(compression_res["gain_db"], axis=1)
reduced_1db = mean_gains[0] - 1
cp_1db = np.interp(reduced_1db, mean_gains[::-1], signals_db[::-1])
plt.scatter(signals_db, mean_gains)
plt.xlabel("Signal power [dBm]")
plt.ylabel("Gain [dB]")
plt.axhline(reduced_1db, c="black", ls="--")
plt.axvline(cp_1db, c="black", ls="--")
plt.show()