In [None]:
import sys 
import matplotlib.pyplot as plt
import numpy as np

from matplotlib.contour import QuadContourSet
from typing import Tuple

# Create a grid of complex values
real = np.linspace(-5, 5, 1000)
imag = np.linspace(-5, 5, 1000)
X, Y = np.meshgrid(real, imag)
Z = X + 1j*Y  # Complex grid

# Identify singular points
singular = np.any((np.abs(Z) < 1e-10) | (np.abs(Z) > 1e+10))
f_values = np.zeros_like(Z, dtype=complex)

epsilon = sys.float_info.epsilon


def f(z, branches_at_pi4 = True, remove_branch = False):
    # See: ... (link to research report)
    z4 = z**4
    if remove_branch:
        f0 = (1 - np.exp(-((np.linalg.norm(z) - 1) / epsilon)**2)) * -np.atanh((z / (1 + epsilon))**4)
    else:
        # The only conditional operator required is to deal with the sign
        f0 = -np.atanh(z4) if \
            np.any(np.isclose(z.real, 0) | np.isclose(z.imag, 0)) else np.atanh(z4)
    if branches_at_pi4:
        return f0 / (z4 + 1)
    return 1 + 0.5 * (z4 + 1) * (f0 - 1) + np.imag(f0) * 1j

f_values[~singular] = f(Z[~singular], remove_branch=False)
f_values[singular] = np.nan
magnitude = np.abs(f_values)
magnitude = np.clip(magnitude, -25, 25)


def contour_plot(
        M: np.ndarray,
        X: np.ndarray,
        Y: np.ndarray,
        sz: tuple = (10, 8)) -> Tuple[QuadContourSet, QuadContourSet]:
    
    ax, cf = plt.gca(), plt.gcf()
    contour = ax.contourf(X, Y, M, 20, cmap='inferno')
    contour_lines = ax.contour(X, Y, M, 50, colors='black', alpha=0.5)

    ax.set_xlabel('Re(z)')
    ax.set_ylabel('Im(z)')
    ax.grid(True)
    contour_lines.clabel(inline=True, fontsize=8, fmt='%.1f')
    cf.set_size_inches(sz)
    cf.colorbar(contour, label='|f(z)|')

    cf.tight_layout()

z4 = np.asarray([np.exp(1j * (np.pi/4) * k) for k in range(0, 8)], dtype=complex)
z4.imag[np.abs(z4.imag) <= 1e-6] = 0.0
z4.real[np.abs(z4.real) <= 1e-6] = 0.0
zpi4 = np.exp(1j * np.pi/4)
CRIT_VALUES = (*z4, 0j)

contour_plot(magnitude, X, Y)
plt.title(r'Contour Plot of $|f(z)| = -\ln(\|(z^4 + 4)\tanh^{-1}(z^2)\|)$')
plt.plot(0, 0, 'ro', markersize=5, label='z=0')
plt.plot(1, 0, 'ro', markersize=5, label='z=1')
plt.plot(-1, 0, 'ro', markersize=5, label='z=-1')
plt.plot(0, 1, 'ro', markersize=5, label='z=i')
plt.plot(0, -1, 'ro', markersize=5, label='z=-i')
plt.plot(np.real(zpi4), np.imag(zpi4), 'ro', markersize=5)
plt.plot(np.real(zpi4), -np.imag(zpi4), 'ro', markersize=5)
plt.plot(-np.real(zpi4), np.imag(zpi4), 'ro', markersize=5)
plt.plot(-np.real(zpi4), -np.imag(zpi4), 'ro', markersize=5)


In [None]:
import functools


def cast_to_inf(z, thresh=1e15):
    z      = np.asarray(z, dtype=complex)
    big    = np.abs(z) > thresh
    ang    = np.exp(1j*np.angle(z[big]))
    z[big] = np.inf * ang
    return z


def print_crit_values(fn, *args, cast_reals=False, cast_imags=False, cast_complex=True):
    # Split into real, imaginary and complex valued
    cr, ci, z = [], [], []
    for a in args:
        if not isinstance(a, complex):
            return ValueError('Must provide complex valued args only!')
        if a.imag == 0:
            cr.append(a)
        elif a.real == 0:
            ci.append(a)
        else:
            z.append(a)

    with np.errstate(divide='ignore', invalid='ignore'):
        reals = map(fn, cr)
        imags = map(fn, ci)
        comps = map(fn, z)
        
        if cast_reals:
            reals = cast_to_inf(list(reals))
        if cast_imags:
            imags = cast_to_inf(list(imags))
        if cast_complex:
            comps = cast_to_inf(list(comps))
        
        reals = [f'{_in} |-> {out}' for (_in, out) in zip(cr, reals)]
        imags = [f'{_in} |-> {out}' for (_in, out) in zip(ci, imags)]
        comps = [f'{_in} |-> {out}' for (_in, out) in zip(z, comps)]

        if reals:
            print(f'R: {" ".join(reals)}')
        if imags:
            print(f'I: {" ".join(imags)}')
        if comps:
            print(f'Z: {" ".join(comps)}')
        

# Let's check some critical values 
print_crit_values(functools.partial(f, remove_branch=False), *CRIT_VALUES)
print('\n')
print_crit_values(functools.partial(f, remove_branch=True), *CRIT_VALUES)

# Some random values just to get a sense
# random_values = np.random.uniform(-np.real(zpi4), np.imag(zpi4), (5))
# + 1.j * np.random.uniform(-np.real(zpi4), np.imag(zpi4), (5))
# print(f'\n{f(random_values)}') # A feel for how the function works

In [None]:
from matplotlib import cm


def freal(x, y):
    num = np.sqrt((x**2 - y**2 - 1)**2 + 4*x**2*y**2)
    denom = np.abs(np.atanh(1j*(x + 1j*y)))
    return np.log(num / denom)


U = f(X + 1j * Y)
# OR U = freal(X, Y)


def proj_plot(X, Y, U, set_aspect=True):
    ax = plt.axes(projection='3d')
    try:
        ax.plot_surface(X, Y, U, cmap=cm.jet)
    except ComplexWarning:
        pass
    if set_aspect:
        plt.xlim([-2, 2])
        ax.set_box_aspect( (np.diff(ax.get_xlim())[0],
                            np.diff(ax.get_ylim())[0],
                            np.diff(ax.get_zlim())[0]))
        ax.set_aspect('equal')


proj_plot(X, Y, U, set_aspect=False)

In [None]:
def G(r, theta):
    return r**4 - 2*r**2*np.cos(2 * theta) + 1


def H(r, theta):
    return np.log(np.abs(np.atanh(1j * r * np.exp(1j * theta))))


def fpolar(r, theta):
    with np.errstate(divide='ignore', invalid='ignore'):
        return 1/2 * np.log(G(r, theta)) - H(r, theta)


r = np.linspace(0.001, 1, 1000)
theta = np.linspace(0, 2*np.pi, 300)
R, Theta = np.meshgrid(r, theta)

Z2 = fpolar(R, Theta)
X2 = R * np.cos(Theta)
Y2 = R * np.sin(Theta)

plt.title(r'Contour Plot of polar function in unit circle')
contour_plot(Z2, X2, Y2)

In [None]:
# Now let's try applying the gaussian function
def f2(z, f=None):
    gauss = np.exp(-z * np.conjugate(z))
    if f is not None:
        return gauss + f(z)
    return gauss


# Let's check the critical values still hold
print_crit_values(functools.partial(f2, f=f), *CRIT_VALUES)

# What does the 'limit' look like around zero ?
print('\nLimit\'s around 0:')
print_crit_values(functools.partial(f2, f=f), epsilon + 0j, epsilon - 0j, epsilon * 1j, epsilon * -1j)

In [None]:
f_values = np.zeros_like(Z, dtype=complex)
f_values= f2(Z, f)
magnitude = np.abs(f_values)

contour_plot(magnitude, X, Y)
plt.title(r'Contour Plot of $|f(z)| = \mathcal{CN}(0, \sigma^2) - \ln\left(\left\|(z^4 + 4)\tanh^{-1}(z^2)\right\|\right)$')

In [None]:
"""
By converting to an auxillary analytic function, so that f(z) is no longer meromorphic, we can:

(1) Use Jensen's formula, which relates the definition of the complex logarithm to it's poisson expansion
(2) Use the harmonic mean-disc value theorem to compute 
(3) Re-normalize the distribution to its area-mean log
"""

def f4(z):
    return f(z) /  np.linalg.norm(z)


U4 = np.real(f2(X + 1j * Y, f4))
U5 = np.real(f4(X + 1j * Y))
proj_plot(X, Y, U4)
plt.show()
proj_plot(X, Y, U5)


print_crit_values(functools.partial(f2, f=f4), 1 + 0j, -1 + 0j, 1j, -1j, 1 + 1j, -1 + 1j, -1 - 1j, 1 - 1j, 0 + 0j)

In [None]:
import os
import scipy.signal as signal
import librosa
import matplotlib.pyplot as plt


EXAMPLE_RECORDING = os.path.join(os.path.dirname(os.getcwd()), 'resources', 'walrus.flac')
N_TAPS = 64
N_FREQS = 32 # Example number of frequencies in a bin (smaller = better). It is possible for this to be graded on an EQ function that gates certain frequencies more
N = 32 # Example number of bands
DB_UNDER = 10 * np.log10(2)
DB_OVER = 10 * np.log10(2)
# DB_UNDER = 0
# DB_OVER = 0
SIGMA = 0.5  # Bandwidth parameter
NORMALIZE_WEIGHTS = False
CAUSAL = False
DC_GAIN_TO_UNITY = False # Having substantial low-frequency (DC/near-DC) content => gain at ω=0 to 1 will keep overall levels reasonable
SCALE_USING_RMS = False
USE_LOG = True

MAX_INT = 2**31 - 1
MIN_INT = -2**31

audio, sr = librosa.load(EXAMPLE_RECORDING, sr=None, mono=True)


def f_z(z, ignore_log=False, clamp_to_32=False):
    with np.errstate(divide='ignore', invalid='ignore'):
        result = np.arctanh(z**2) if ignore_log else f4(z)
        if clamp_to_32:
            return np.nan_to_num(result, posinf=MAX_INT, neginf=MIN_INT) 
        return np.nan_to_num(result)
    

# Define poles from f(z) singularities
r = 1 - epsilon if CAUSAL else np.sqrt(2)
poles = np.array([1, -1, 1j, -1j, 1+1j, 1-1j, -1+1j, -1-1j])
poles = r * (poles / np.linalg.norm(poles))

# Frequency bin centers and Gaussian weights
omega_k = np.linspace(0, 2*np.pi, N, endpoint=False)
weights = np.array([np.sum(np.exp(-(np.angle(poles) - w)**2 / (2*SIGMA**2))) for w in omega_k])

if NORMALIZE_WEIGHTS:
    weights /= np.max(weights)

print(f'Weight\'s |-> {weights}\n')

# Construct numerator polynomial by sampling f(e^{jω}) * Gaussian weights
omega = np.linspace(0, 2*np.pi, N_FREQS, endpoint=False)
z = np.exp(1j * omega)

F_vals = f_z(z, ignore_log=not USE_LOG)
# F_vals = f_z(z, ignore_log=True)
# F_vals = np.ones(N_FREQS, dtype=float)

G_vals = np.zeros_like(F_vals, dtype=float)
for k in range(N):
    G_vals += weights[k] * np.exp(-(omega - omega_k[k])**2 / (2*SIGMA**2))

with np.errstate(divide='ignore', over='ignore'):
    H_freq = F_vals * G_vals
    freq = np.linspace(0, sr / 2, N_FREQS)
    gain = H_freq
    gain = np.clip(gain, np.min(audio) - DB_UNDER, np.max(audio) + DB_OVER)

# Fit a FIR numerator to approximate H_freq over frequencies
b = signal.firwin2(N_TAPS + 1, freq, gain, fs=sr)
a = np.poly(poles)  # coefficients of D(z) = ∏ (z - p)

if USE_LOG:
    a[5] = 0 # The term of 4th degree is cancelled with f(z)

if DC_GAIN_TO_UNITY:
    H0 = np.polyval(b, 1) / np.polyval(a, 1)
    b /= H0

filtered = signal.lfilter(b, a, audio)

if SCALE_USING_RMS:
    rms_orig = np.sqrt(np.mean(audio**2))
    rms_filtered = np.sqrt(np.mean(filtered**2))
    peak_orig = np.max(np.abs(audio))
    peak_filt = np.max(np.abs(filtered))
    filtered /= (peak_orig / (peak_filt + epsilon))

# Print impulse response of filter
impulse = np.zeros(N_FREQS); impulse[0] = 1
h = signal.lfilter(b, a, impulse)
print(f'Impulse response |-> [{h[:10]} ... {h[-10:]}]')

# Find the diff
diff = audio - filtered
print(f'\n Some stats:\n\t sum of H |-> {np.sum(h)}')

plt.title('Reference signal')
plt.plot(audio, alpha=0.7, label='Original')
plt.show()
plt.title('Impulse response')
plt.plot(impulse, alpha=0.7, label='Impulse')
plt.show()
plt.title('Filtered response')
plt.plot(filtered, alpha=0.7, label='Filtered')
plt.show()
plt.title('Diff. between reference & filtered')
plt.plot(diff, alpha=0.7, label='Diff')
plt.plot(audio, alpha=0.7, label='|ref|')
plt.legend(loc='upper right')
plt.show()

In [None]:
# Let's play the filtered signal as a test

import soundfile as sf
import io
from IPython.display import Audio

def play_sig(sig, sr):
    buf = io.BytesIO()
    sf.write(buf, sig, sr, format='flac')
    buf.seek(0)

    y, sr = librosa.load(buf, sr=None)
    return y, sr

y, sr = play_sig(filtered, sr)
Audio(data=y, rate=sr) # filtered

In [None]:
Audio(data=audio, rate=sr) # original

In [None]:
"""
Now consider we have 2 'metrics' encoded in 2 unit vectors.

To compute how well the metrics align we can find their dot product i.e. m_1,1 . m_2,1
And because we are dealing in complex coordinates we can find another i.e. m_1,2 . m_2,2
Then we can embed the complex number z = k * [(m_1,1 . m_2,1) + (m_1,2 . m_2,2)i]

I.e. Z = k * M * [1, i]
where M = 
[
    m_1^T,1, m_2,1
    m_1^T,2, m_2,2
]

Consider this:
If m1,1 . m2,1 is = -1, 1 ~or~ m1,2 . m2,2 is = -1, 1 our number goes to infinity.

[*] This is because the combination of any 2 metrics that are strongly correlated (irrespective of sign)
should represent a global maxima

If m1,1 . m2,1 = 0 ~nand~ m1,2, . m2,2 = 0 our number goes to negative infinity

[*] This is because if either of the 2 metrics has no correlation at all this should represent a global minima

Finally, there is a special case at the origin. This is because the magnitude of the complex vector is 0 IFF
M = 0. I will return back to what this magnitude means, but just know for now it is chosen S.T values closer to the origin
are graded higher on a 3D-gaussian function.
"""


def construct_M(z1, z2, z3, z4, squash=False):
    v1 = np.asarray(z1).ravel()
    v2 = np.asarray(z2).ravel()
    v3 = np.asarray(z3).ravel()
    v4 = np.asarray(z4).ravel()
    if squash:
        M = np.concatenate((v1, 1j * v3)).reshape(1, -1)
        return M @ np.concatenate((v2, v4))
    else:
        M = np.stack((v1, v3))
        W = np.stack((v2, v4))
        # einsum 'ij,ij->i' = row-wise inner product
        return np.einsum('ij,ij->i', M, W)  


def construct_Mfield(v1, v2, v3, v4, fine=True):
    """
    Produces a cross-correlation field between <v1, v2>, <v3, v4>
    """
    field = np.zeros_like(v1)
    if fine:
        # The fine-structured method is where we simply compute the dot-product of 
        # All complex vector's pairwise - thus embedding more 'time-sensitive' information
        field = [
            construct_M(v1[i:i + 3], v2[i:i + 3],
                        v3[i:i + 3], v4[i:i + 3],
                        squash=False)
            for i in range(0, len(v1), 3)
        ]
        return np.asarray(field, dtype=complex)
    
    # The macro-structured method is where we simply return the complex conjugate (i.e dot product)
    return construct_M(v1, v2, v3, v4, squash=False)



In [None]:
import numpy as np
import librosa, librosa.display
from scipy import signal
import matplotlib.pyplot as plt


K_INDS = 10


def design_metrics(sig):
    # ------------------------------------------------------------------
    # 1.  feature-set A  – two STFT frames (complex spectra)
    # ------------------------------------------------------------------
    
    n_fft = 1024
    hop   = 512
    S     = librosa.stft(sig, n_fft=n_fft, hop_length=hop)
    vecA1 = S[:, 0]                                          # frame-0 spectrum
    vecA2 = S[:, 3]                                          # frame-3 spectrum

    # ------------------------------------------------------------------
    # 2.  feature-set B  – two analytic-signal snippets (complex time)
    # ------------------------------------------------------------------

    analytic = signal.hilbert(y)                              # complex analytic waveform
    chunk    = 1024                                           # same length as STFT column
    vecB1    = analytic[0: chunk]
    vecB2    = analytic[chunk: 2*chunk]

    # make all vectors the same length (trim to the shortest)
    L = min(map(len, (vecA1, vecA2, vecB1, vecB2)))
    vecA1, vecA2, vecB1, vecB2 = (v[:L] for v in (vecA1, vecA2, vecB1, vecB2))

    # ------------------------------------------------------------------
    # 3.  determine the correlation field between seperate metrics
    # ------------------------------------------------------------------

    # Fine-structured cross correlation field
    field_fine = construct_Mfield(vecA1, vecA2, vecB1, vecB2, fine=True)
    simV1V2f = np.mean(field_fine[:, 0])
    simV3V4f = np.mean(field_fine[:, 1])
    field_macro = construct_Mfield(vecA1, vecA2, vecB1, vecB2, fine=False)
    simV1V2m, simV3V4m = field_macro
    simX = np.vdot(np.hstack((vecA1, vecA2)),
                   np.hstack((vecB1, vecB2)))  # cross-similarity between sets

    # ------------------------------------------------------------------
    # 4.  similarity between fine & macro structures
    # ------------------------------------------------------------------

    def determine_fine_macro_similarity(z1, z2):
        x1, y1 = np.real(z1), np.imag(z1)
        x2, y2 = np.real(z2), np.imag(z2)

        # |v1||v2|sinθ
        det = x1*y2 - y1*x2
        # I.e.
        # det = np.linalg.det(
        #     np.column_stack((z1, z2))
        # )

        # |v1||v2|cosθ
        dot = x1*x2 + y1*y2
        # I.e.
        # dot = np.dot(z1, z2)

        d = np.hypot(x1, y1) * np.hypot(x2, y2)
        sin_theta = det / d
        cos_theta = dot / d
        
        # Compute the angle θ in radians
        theta = np.arctan2(sin_theta, cos_theta)
        
        return theta, cos_theta, sin_theta


    fine_structure = determine_fine_macro_similarity(simV1V2f, simV1V2m)
    macro_structure = determine_fine_macro_similarity(simV3V4f, simV3V4m)

    
    # ------------------------------------------------------------------
    # 5.  Upstream adjustment of fine vs macro weight vector
    # ------------------------------------------------------------------
    # TODO: use some sort of regressive model to adjust the structure_w
    structure_w = np.array([0.5, 0.5]).transpose()
    # structure = np.array([fine_structure, macro_structure], dtype=complex).transpose()
    # structure = np.vdot(structure_w, structure)
    structure = None

    evaluated_field = f(np.vstack(field_fine).ravel())
    evaluated_field_srted = np.argsort(evaluated_field)
    min_k_inds = evaluated_field_srted[:K_INDS]
    max_k_inds = evaluated_field_srted[-K_INDS:][::-1]

    print((
        'Some metric stats:'
        f'\n\tFirst metric (stft) fine similarity: {simV1V2f}'
        f'\n\tSecond metric (hilbert) fine similarity: {simV3V4f}'
        f'\n\tFirst metric (stft) macro similarity: {simV1V2m}'
        f'\n\tSecond metric (hilbert) macro similarity: {simV3V4m}'
        f'\n\tCross-set similarity: {simX}'
        f'\n\n\tField structure residuals:'
        f'\n\t\t *** fine (f): {fine_structure} ***'
        f'\n\t\t *** macro (m): {macro_structure} ***'
        f'\n\t\t *** overall (T): {structure} ***'
        f'\n\n\tEvaluated field for first {K_INDS} vals (ascending): {evaluated_field[min_k_inds]}'
        f'\n\n\tEvaluated field for first {K_INDS} vals (descending): {evaluated_field[max_k_inds]}'
    ))


def design_Hz(N=8,              # number of Gaussian bins  (≥2)
              mu=0.0,           # Gaussian centre (rad)
              sigma=np.pi/3,    # Gaussian width  (rad)
              rho=0.9,         # pole radius (<1 ⇒ strictly stable)
              gain_at_mu=1.0,
              norm = 'none'):  # desired gain at Ω=mu

    """Return (b, a) in z⁻¹ form for lfilter."""
    i        = np.arange(1, N+1)
    omega_i  = (2*i - 1) * np.pi / (2*N)          # bin-centres
    w        = np.exp(-((omega_i - mu)**2) / (2*sigma**2))
    w       /= w.max()                            # max-normalise

    # ---- numerator  Σ w_i (1 − 2cosΩ z⁻¹ + z⁻²) ------------------
    b = np.zeros(3)
    b[0] = b[2] = w.sum()                         # z⁰ & z⁻²
    b[1] = (-2*w*np.cos(omega_i)).sum()           # z⁻¹

    # ---- denominator  ∏ (1 − 2ρcosΩ z⁻¹ + ρ² z⁻²) ----------------
    a = np.array([1.])
    scaffold = np.convolve([1., -np.sqrt(2), 1.],  # e^{±jπ/4}
                       [1.,  np.sqrt(2), 1.])  # e^{±j3π/4}
    scaffold = np.convolve(scaffold, [1., 0.])     #   z = 0  →  z⁻¹ factor
    a = np.convolve(scaffold, a) 
    for ω in omega_i:
        a = np.convolve(a, [1., -2*rho*np.cos(ω), rho**2])

    # ---- normalise overall gain at Ω = μ -------------------------
    z_mu  = np.exp(1j*mu)
    B_mu  = np.polyval(b[::-1], z_mu)             # z⁻¹→z route
    A_mu  = np.polyval(a[::-1], z_mu)

    match norm:
        case 'dc':
            z_mu = np.exp(1j*mu)
            k    = (np.polyval(a[::-1], z_mu) /
                    np.polyval(b[::-1], z_mu)).real
        case 'peak':
            w_grid, h_grid = signal.freqz(b, a, worN=4096)
            k = np.abs(h_grid).max()
        case 'none':
            k = gain_at_mu * (A_mu/B_mu).real
        case _:
            assert False, 'Unreachable'
    b *= k

    # pad b so len(b)==len(a) (not required for lfilter, just neat)
    if len(b) < len(a):
        b = np.pad(b, (0, len(a)-len(b)))
    return b.astype(float), a.astype(float)


def run_metrics():
    design_metrics(audio)


def run_filter(draw_plots=True):
    b, a = design_Hz(N=16)
    filt = signal.lfilter(b, a, audio)

    if draw_plots:
        # Time-domain
        plt.figure(figsize=(14, 7))
        plt.subplot(2,1,1)
        t = np.arange(len(audio)) / sr
        plt.plot(t, audio,       alpha=.6, label='original')
        plt.plot(t, filt,  alpha=.6, label='filtered')
        plt.xlabel('time [s]')
        plt.legend()
        plt.title('Waveforms')

        # Magnitude response
        plt.subplot(2,1,2)
        w, h = signal.freqz(b, a, worN=4096)
        plt.semilogy(w * sr / (2*np.pi), np.abs(h))
        plt.xlabel('frequency [Hz]')
        plt.ylabel('|H(e^{jΩ})|')
        plt.title('Magnitude response')
        plt.tight_layout()
        plt.show()

        counts, bin_edges  = np.histogram(filt, bins=N)
        top_idx            = np.argsort(counts)[-N:]
        centres            = 0.5 * (bin_edges[1:] + bin_edges[:-1])

        plt.figure(figsize=(10,4))
        plt.bar(centres[top_idx], counts[top_idx],
                width=np.diff(bin_edges)[top_idx], color='crimson', label=f'top {N}')
        plt.xlabel('value')
        plt.ylabel('count')
        plt.title(f'{N} most-common histogram bins')
        plt.legend()
        plt.tight_layout()
        plt.show()

    return filt

In [None]:
run_metrics()

In [None]:
filtered2 = run_filter()

In [None]:
y, sr = play_sig(filtered2, sr)
Audio(data=y, rate=sr) # filtered2

In [None]:
import sympy as sp
from pylatex import Document, Math, NoEscape, NewLine


"""
Generate the latex in markdown format, and in raw-latex form, to embed elsewhere.
This is done-so dynamically (where possible) rather than statically so that the
evolution of this notebook / filter design over time is explicit.
"""


eqns = []

# Previously discussed function definitions
Z_symb = sp.Symbol('Z', complex=True)
indx_symb = sp.Symbol('i')
epsilon_symb = sp.Symbol('epsilon')
rho_symb = sp.Symbol('rho')
sgn = sp.Function('sgn')(Z_symb)
little_h_symb = sp.Function('h')(Z_symb, rho_symb)
little_h_rho1 = sp.Function('h')(Z_symb, 1)
poles = sp.Function('f')(Z_symb)
transfer = sp.Function('H')(Z_symb)

# Vector, matrix eqn's
k = sp.symbols('k')
m11 = sp.MatrixSymbol('m_{1,i}', 2, 1)
m21 = sp.MatrixSymbol('m_{2,i}', 2, 1)
m12 = sp.MatrixSymbol('m_{1,i+1}', 2, 1)
m22 = sp.MatrixSymbol('m_{2,i+1}', 2, 1)
comp_unit_vec = sp.Matrix([[1], [1j]])

M = sp.Matrix([
    [sp.Transpose(m11) * m21],
    [sp.Transpose(m12) * m22]
])

# Mention the old pole-filter design in the report
transfer_function_obsolete = sp.Lambda(Z_symb, -sp.ln(
        sp.Abs(
            (Z_symb**4 + 4) * sp.atanh(Z_symb**2)
        )
))

# atanh(z**4) part of transfer function
little_h = sp.Lambda((Z_symb, rho_symb), sp.atanh(Z_symb ** 4 / rho_symb))

# This is the final design of the first part (I) of the filter design
transfer_function_first_part = sp.Lambda(Z_symb, (sgn * little_h_rho1 / (Z_symb**4 + 1)))

# General form
transfer_function_general_first_part = sp.Lambda((Z_symb, rho_symb), 
                                 (1 - sp.exp(-((sp.Abs(Z_symb) - 1) / epsilon_symb) ** 2)) * \
                                  little_h_symb * sgn / (Z_symb**4 + 1))


# Latex-ify
latex_comp_unit_vec = sp.latex(sp.nsimplify(comp_unit_vec))

latex_M = sp.latex(M)

latex_M_comp = sp.latex(
    sp.nsimplify(M.dot(comp_unit_vec))
)
latex_M_comp = latex_M_comp.replace(r'\left[', r'\begin{pmatrix}') \
             .replace(r'\right]', r'\end{pmatrix}')

latex_transfer_function_obsolete = sp.latex(
    transfer_function_obsolete(Z_symb)
)

latex_transfer_function_first_part = sp.latex(
    transfer_function_first_part(Z_symb)
)

latex_transfer_function_general_first_part = sp.latex(
    transfer_function_general_first_part(Z_symb, rho_symb)
)

latex_little_h = sp.latex(
    little_h(Z_symb, rho_symb)
)


# ---- Eqn's start here -------------------------

eqns.append(('$$', 'Basic representation of transfer function (part I)'))

# Eqn for M
eqns.append((False, NewLine()))
eqns.append('Z = ')
eqns.append(latex_M_comp)
eqns.append(' = ')
eqns.append(latex_M)
eqns.append(latex_comp_unit_vec)

# Obsolete first part
eqns.append((False, NewLine()))
eqns.append(r'\text{An original proposal was the following: }')
eqns.append(r'\mathcal{H_{old}}(Z) = ')
eqns.append(latex_transfer_function_obsolete)

# Eqn for atanh part of function
eqns.append((False, NewLine()))
eqns.append('\mathcal{h}(Z) = ')
eqns.append(latex_little_h)

# Eqn for transfer function
eqns.append((False, NewLine()))
eqns.append('\mathcal{H}(Z) = ')
eqns.append(latex_transfer_function_first_part)

# Eqn for general transfer function form
eqns.append((False, NewLine()))
eqns.append(r'\mathcal{H_g}(Z, \rho) = ')
eqns.append(r'\lim_{\epsilon \to 0}')
eqns.append(latex_transfer_function_general_first_part)
eqns.append(r'\text{Where the system becomes non-causal for any $\rho \geq 1$}')

eqns.append('$$')

# ---- Eqn's end here --------------------------


def format_eqns(eqns: list, markdown: bool = False):
    eqns_formatted = []
    for eqn in eqns:

        if isinstance(eqn, tuple):
            match eqn[0]:
                case '$$':
                    if markdown:
                        eqns_formatted.append(NoEscape(NewLine()))
                        eqns_formatted.append(NoEscape(f'## {eqn[1]}'))
                    else:
                        eqns_formatted.append(NoEscape(f'\section {{{eqn[1]}}}'))
                    eqns_formatted.append(NoEscape('$$'))
                    continue
                case False:
                    eqns_formatted.append(eqn[1])
                    continue
                case _:
                    assert False, 'Unreachable case'

        eqns_formatted.append(NoEscape(eqn))
        
    return eqns_formatted


eqns = format_eqns(eqns, markdown=True)
doc = Document(documentclass='standalone')
doc.append(Math(data=eqns))

In [None]:
import re
from IPython.display import display, Markdown
from pathlib import Path


"""
Run some utility functions to create the readup
"""


NB_NAME = 'filter_design'


def add_section_title(title: str, content: list):
        content.insert(0, f'# {title}\n')
        return content


def join_markdown(content: list):
    return ''.join(
        (f'{s}\n\n' if s.lstrip().endswith('$$') else s
        for s in content
    )).lstrip()


def replace_operatorname_macro(content_str: str):
    """
    Github complains about 'The following macros are not allowed: operatorname.'
    Not sure why that is, but this function mitigates that by wrapping in mathop & text
    instead of the operatorname macro's.
    """
    result = re.sub(
         r'\\operatorname\{(.*?)\}',
         r'\\mathop{\\text{\1}}',
         content_str
    )
    return result


def display_equations(depth=None, fp = None, title = ''):
    latex_str = doc.dumps()
    if depth is None:
        matches = re.finditer(r'(#{1,2}.+?)\s*\$\$(.*?)\$\$',
                               latex_str, re.DOTALL)
    
    content = []
    for match in matches:
        section_title = match.group(1)
        equation = match.group(2)
        markdown_content = f'{section_title}\n\n$$\n{equation}\n$$'
        display(Markdown(markdown_content))
        content.append(markdown_content)
    
    if fp is not None:
        content = add_section_title(title, content)
        to_write = replace_operatorname_macro(
             join_markdown(content)
        )
        Path(fp).write_text(to_write, encoding='utf-8')


display_equations(fp=os.path.join(os.getcwd(), f'{NB_NAME}.md'),
                   title='Designing the filter function')