In [2]:
from utils import *


In [None]:
import numpy as np
import numpy.linalg as la
import scipy.signal as sig
from dataclasses import dataclass
from typing import List, Tuple


# -----------------------------
# 0) 유틸: QAM / OFDM
# -----------------------------
def qam_constellation(M: int) -> np.ndarray:
    m = int(np.sqrt(M))
    if m*m != M:
        raise ValueError("Square QAM만 지원합니다. (예: 16, 64, 256)")
    levels = np.arange(-(m-1), m, 2)
    xv, yv = np.meshgrid(levels, levels)
    const = xv.flatten() + 1j*yv.flatten()
    const = const / np.sqrt(np.mean(np.abs(const)**2))  # 평균전력 1 정규화
    return const

def gen_qam_symbols(M: int, n: int, rng=None) -> np.ndarray:
    rng = np.random.default_rng() if rng is None else rng
    const = qam_constellation(M)
    idx = rng.integers(0, len(const), size=n)
    return const[idx]

def make_active_bins(nfft: int, nsc: int) -> np.ndarray:
    """fftshift 기준으로 DC 주변에 nsc개 서브캐리어 활성화(DC 제외)."""
    half = nsc // 2
    rel = np.concatenate([np.arange(-half, 0), np.arange(1, half+1)])  # DC(0) 제외
    bins = (rel + nfft//2) % nfft
    return bins

def ofdm_modulate(symbols: np.ndarray, nfft: int, cp_len: int, active_bins: np.ndarray) -> np.ndarray:
    """symbols: (n_sym, n_active) in fftshifted-bin order"""
    n_sym, n_act = symbols.shape
    if n_act != len(active_bins):
        raise ValueError("symbols.shape[1] != len(active_bins)")
    out = []
    for i in range(n_sym):
        Xs = np.zeros(nfft, dtype=complex)
        Xs[active_bins] = symbols[i]
        X = np.fft.ifftshift(Xs)
        x = np.fft.ifft(X, n=nfft)
        xcp = np.concatenate([x[-cp_len:], x])
        out.append(xcp)
    return np.concatenate(out)

def ofdm_demodulate(wave: np.ndarray, nfft: int, cp_len: int, active_bins: np.ndarray, n_sym: int) -> np.ndarray:
    sym_len = nfft + cp_len
    if len(wave) < n_sym * sym_len:
        raise ValueError("wave가 너무 짧습니다.")
    rx = np.zeros((n_sym, len(active_bins)), dtype=complex)
    for i in range(n_sym):
        seg = wave[i*sym_len:(i+1)*sym_len]
        seg = seg[cp_len:]
        X = np.fft.fft(seg, n=nfft)
        Xs = np.fft.fftshift(X)
        rx[i] = Xs[active_bins]
    return rx


# -----------------------------
# 1) MP / GMP 모델 (linear-in-coeff)
# -----------------------------
@dataclass
class MPConfig:
    p_list: List[int]   # polynomial orders, 예: [1,3,5,7,9]
    mem_depth: int      # M taps

@dataclass
class GMPConfig:
    # aligned MP-like
    p_list_align: List[int]
    L: int
    # lagging envelope cross terms
    p_list_lag: List[int]
    m_list: List[int]   # envelope lag taps (positive)


def build_mp_basis(x: np.ndarray, cfg: MPConfig) -> Tuple[np.ndarray, int]:
    M = cfg.mem_depth
    p_list = cfg.p_list
    N = len(x)
    start = M - 1
    rows = N - start
    cols = M * len(p_list)
    X = np.zeros((rows, cols), dtype=complex)

    col = 0
    for m in range(M):
        xm = x[start-m:N-m]
        am = np.abs(xm)
        for p in p_list:
            X[:, col] = xm * (am ** (p-1))
            col += 1
    return X, start

def apply_mp(x: np.ndarray, w: np.ndarray, cfg: MPConfig) -> np.ndarray:
    X, start = build_mp_basis(x, cfg)
    yv = X @ w
    y = np.zeros_like(x, dtype=complex)
    y[start:] = yv
    return y

def build_gmp_basis(x: np.ndarray, cfg: GMPConfig) -> Tuple[np.ndarray, int]:
    N = len(x)
    L = cfg.L
    m_list = cfg.m_list
    max_m = max(m_list) if len(m_list) else 0
    start = (L - 1) + max_m
    rows = N - start
    cols = (L * len(cfg.p_list_align)) + (L * len(m_list) * len(cfg.p_list_lag))
    X = np.zeros((rows, cols), dtype=complex)

    col = 0
    # aligned terms
    for l in range(L):
        xl = x[start-l:N-l]
        al = np.abs(xl)
        for p in cfg.p_list_align:
            X[:, col] = xl * (al ** (p-1))
            col += 1

    # lagging envelope cross terms
    for l in range(L):
        xl = x[start-l:N-l]
        for m in m_list:
            env = np.abs(x[start-l-m:N-l-m])
            for p in cfg.p_list_lag:
                X[:, col] = xl * (env ** (p-1))
                col += 1

    return X, start

def apply_gmp(x: np.ndarray, w: np.ndarray, cfg: GMPConfig) -> np.ndarray:
    X, start = build_gmp_basis(x, cfg)
    yv = X @ w
    y = np.zeros_like(x, dtype=complex)
    y[start:] = yv
    return y


# -----------------------------
# 2) LS / ILA 학습
# -----------------------------
def ls_solve(X: np.ndarray, d: np.ndarray, ridge: float = 0.0) -> np.ndarray:
    """min ||Xw-d||^2 + ridge||w||^2"""
    if ridge <= 0:
        w, *_ = la.lstsq(X, d, rcond=None)
        return w
    XtX = X.conj().T @ X + ridge*np.eye(X.shape[1], dtype=complex)
    Xtd = X.conj().T @ d
    return la.solve(XtX, Xtd)

def ila_train_with_gain_subset(
    u: np.ndarray,
    pa_fn,
    model_type: str,
    cfg,
    n_iter: int = 4,
    ridge: float = 1e-4,
    train_slice: slice = slice(None),
) -> np.ndarray:
    """
    ILA: post-inverse 추정 후 predistorter로 복사.
    y를 G로 나눠서(y/G) 학습하는 형태(안정성↑)  (Ding 2004 흐름) 
    """
    u_tr = u[train_slice]
    x = u_tr.copy()
    w = None

    for _ in range(n_iter):
        y = pa_fn(x)
        G = (np.vdot(x, y)) / (np.vdot(x, x) + 1e-30)
        y_s = y / (G + 1e-30)

        if model_type == "mp":
            Y, start = build_mp_basis(y_s, cfg)
            d = x[start:]
            w = ls_solve(Y, d, ridge=ridge)
            x = apply_mp(u_tr, w, cfg)
        elif model_type == "gmp":
            Y, start = build_gmp_basis(y_s, cfg)
            d = x[start:]
            w = ls_solve(Y, d, ridge=ridge)
            x = apply_gmp(u_tr, w, cfg)
        else:
            raise ValueError("model_type은 'mp' 또는 'gmp'")

    return w


# -----------------------------
# 3) PSD / ACLR(=ACPR) / EVM
# -----------------------------
def compute_psd_welch(x: np.ndarray, fs: float, nperseg: int = 8192):
    nperseg = min(nperseg, len(x))
    f, Pxx = sig.welch(
        x, fs=fs, window="hann",
        nperseg=nperseg, noverlap=nperseg//2,
        return_onesided=False, scaling="density"
    )
    idx = np.argsort(f)
    return f[idx], Pxx[idx]

def bandpower_from_psd(f, Pxx, f1, f2):
    if f1 > f2:
        f1, f2 = f2, f1
    mask = (f >= f1) & (f <= f2)
    if np.count_nonzero(mask) < 2:
        return 0.0
    return np.trapz(Pxx[mask], f[mask])

def compute_aclr(f, Pxx, B, off):
    """
    Xia 2024 식처럼: 메인대역 전력 대비 인접대역 전력비(PSD 적분)  (ACPR 정의) 
    main: [-B/2, B/2]
    adj : [±off-B/2, ±off+B/2]
    """
    main = bandpower_from_psd(f, Pxx, -B/2, B/2) + 1e-30
    lower = bandpower_from_psd(f, Pxx, -off - B/2, -off + B/2) + 1e-30
    upper = bandpower_from_psd(f, Pxx,  off - B/2,  off + B/2) + 1e-30
    return 10*np.log10(lower/main), 10*np.log10(upper/main)

def compute_evm(tx_syms: np.ndarray, rx_syms: np.ndarray):
    """best-fit complex gain 1tap equalizer 후 RMS EVM"""
    s = tx_syms.reshape(-1)
    r = rx_syms.reshape(-1)
    g = (np.vdot(s, r)) / (np.vdot(s, s) + 1e-30)
    e = r - g*s
    evm_rms = np.sqrt(np.mean(np.abs(e)**2) / (np.mean(np.abs(g*s)**2) + 1e-30))
    evm_db = 20*np.log10(evm_rms + 1e-30)
    return float(evm_rms), float(evm_db)

def estimate_integer_delay(x_ref: np.ndarray, y: np.ndarray, max_lag: int = 2000) -> int:
    x = x_ref / (np.sqrt(np.mean(np.abs(x_ref)**2)) + 1e-12)
    y = y / (np.sqrt(np.mean(np.abs(y)**2)) + 1e-12)
    corr = sig.correlate(y, x, mode="full", method="fft")
    lags = np.arange(-len(x)+1, len(y))
    mask = (lags >= -max_lag) & (lags <= max_lag)
    lags = lags[mask]
    corr = corr[mask]
    return int(lags[np.argmax(np.abs(corr))])

def align_by_delay(y: np.ndarray, d: int) -> np.ndarray:
    if d > 0:
        return np.concatenate([y[d:], np.zeros(d, dtype=complex)])
    elif d < 0:
        d2 = -d
        return np.concatenate([np.zeros(d2, dtype=complex), y[:-d2]])
    return y


# -----------------------------
# 4) "실험": OFDM → (PA) → DPD(MP/GMP) → Metric
# -----------------------------
def main():
    # (A) 신호 파라미터: Xia 2024 실험과 유사하게 122.88MSPS, 4096FFT, 600subc(≈18MHz)
    fs = 122.88e6
    nfft = 4096
    cp_len = 288
    nsc = 600
    M_qam = 64
    n_sym = 50

    B = 18.02e6   # 메인대역 폭(예: LTE 20MHz E-TM 3.1의 OBW 근사)
    off = 20e6    # 인접대역 오프셋(±20MHz)

    rng = np.random.default_rng(1)
    active_bins = make_active_bins(nfft, nsc)
    tx_syms = gen_qam_symbols(M_qam, n_sym*len(active_bins), rng=rng).reshape(n_sym, -1)
    u = ofdm_modulate(tx_syms, nfft, cp_len, active_bins)
    u = u / np.sqrt(np.mean(np.abs(u)**2))  # RMS=1
    drive = 0.6
    u = drive * u

    # (B) PA: 일부러 "cross-term(=GMP)" 성분이 강하게 들어가게 만든 예제 PA
    #     (Morgan 2006의 GMP 형태를 흉내) 
    cfg_pa = GMPConfig(
        p_list_align=[1,3,5],
        L=3,
        p_list_lag=[3,5,7,9],
        m_list=[1,2,3,4],
    )
    # 계수 벡터(순서: aligned(l,p) → lag(l,m,p))
    w_pa = []
    for l in range(cfg_pa.L):
        for p in cfg_pa.p_list_align:
            if p == 1:
                w_pa.append([1.0+0j, 0.03*np.exp(1j*0.2), 0.01*np.exp(1j*0.5)][l])
            elif p == 3:
                w_pa.append((-0.05+0.02j) * (0.5**l))
            elif p == 5:
                w_pa.append((0.01-0.005j) * (0.5**l))
    for l in range(cfg_pa.L):
        for m in cfg_pa.m_list:
            for p in cfg_pa.p_list_lag:
                if p == 3:   base = (-0.25+0.10j)
                elif p == 5: base = (0.12-0.08j)
                elif p == 7: base = (-0.03+0.02j)
                elif p == 9: base = (0.01-0.005j)
                w_pa.append(base * (0.7**l) * (0.85**(m-1)))
    w_pa = np.array(w_pa, dtype=complex)

    def pa_fn(x):  # PA block
        return apply_gmp(x, w_pa, cfg_pa)

    # (C) baseline (no DPD)
    y_no = pa_fn(u)
    f_no, P_no = compute_psd_welch(y_no, fs)
    aclr_no = compute_aclr(f_no, P_no, B, off)

    # (D) DPD-MP 학습 (ILA + LS)
    cfg_dpd_mp = MPConfig(p_list=[1,3,5,7,9], mem_depth=7)
    train_slice = slice(20000, 20000+60000)  # 너무 길면 행렬이 커져서 느려짐 → 일부만 학습
    w_mp = ila_train_with_gain_subset(u, pa_fn, "mp", cfg_dpd_mp, n_iter=5, ridge=1e-4, train_slice=train_slice)
    x_mp = apply_mp(u, w_mp, cfg_dpd_mp)
    y_mp = pa_fn(x_mp)
    f_mp, P_mp = compute_psd_welch(y_mp, fs)
    aclr_mp = compute_aclr(f_mp, P_mp, B, off)

    # (E) DPD-GMP 학습 (cross-term 포함)
    cfg_dpd_gmp = GMPConfig(
        p_list_align=[1,3,5,7,9],
        L=5,
        p_list_lag=[3,5,7,9],
        m_list=[1,2,3,4],
    )
    train_slice2 = slice(20000, 20000+30000)
    w_gmp = ila_train_with_gain_subset(u, pa_fn, "gmp", cfg_dpd_gmp, n_iter=4, ridge=1e-4, train_slice=train_slice2)
    x_gmp = apply_gmp(u, w_gmp, cfg_dpd_gmp)
    y_gmp = pa_fn(x_gmp)
    f_gmp, P_gmp = compute_psd_welch(y_gmp, fs)
    aclr_gmp = compute_aclr(f_gmp, P_gmp, B, off)

    # (F) EVM(OFDM demod 후 best-fit gain)
    burn_sym = 5
    sym_len = nfft + cp_len
    eval_syms = n_sym - burn_sym
    start = burn_sym * sym_len

    def evm_from(y):
        u_seg = u[start:start+eval_syms*sym_len]
        y_seg = y[start:start+eval_syms*sym_len]
        d = estimate_integer_delay(u_seg, y_seg, max_lag=200)
        y_al = align_by_delay(y_seg, d)
        rx = ofdm_demodulate(y_al, nfft, cp_len, active_bins, eval_syms)
        tx = tx_syms[burn_sym:burn_sym+eval_syms]
        return compute_evm(tx, rx)

    evm_no = evm_from(y_no)
    evm_mp = evm_from(y_mp)
    evm_gmp = evm_from(y_gmp)

    print("=== ACLR(dBc) lower/upper ===")
    print("No DPD :", aclr_no)
    print("MP DPD :", aclr_mp)
    print("GMP DPD:", aclr_gmp)
    print("\n=== EVM ===")
    print("No DPD : evm_rms=%.4f  evm_dB=%.2f dB" % evm_no)
    print("MP DPD : evm_rms=%.4f  evm_dB=%.2f d