In [None]:
# Time different mk solves, and verify that they give the correct (same) answer.
# We originally used newton, but this notwbook confirms that brentq gives the same answer at shorter times.
# With appropriate initialization.
import numpy as np
from numpy import pi, inf
from scipy.optimize import root_scalar
import time

In [68]:
m0s = [0.001, 0.01, 0.1, 1, 10, 100]
hs = [1, 2, 4, 8, 16, 32, 64]
ks = list(range(500))
times = {}

In [None]:
def m_k_entry(m0, h, k, method):
    if k == 0: return m0
    elif m0 == inf:
        return ((k - 1/2) * pi)/h

    m_k_h_err = (lambda m_k_h: (m_k_h * np.tan(m_k_h) + m0 * h * np.tanh(m0 * h)))
    k_idx = k

    m_k_h_lower = np.nextafter(pi * (k_idx - 1/2), np.inf)
    m_k_h_upper = np.nextafter(pi * k_idx, np.inf)
    if method == "newton":
        m_k_initial_guess =  m_k_h_lower # could give wrong answers with a different starting pt.
    else:
        m_k_initial_guess =  (m_k_h_upper + m_k_h_lower) / 2

    result = root_scalar(m_k_h_err, x0=m_k_initial_guess, method=method, bracket=[m_k_h_lower, m_k_h_upper])
    m_k_val = result.root / h
    return m_k_val

In [61]:
start = time.perf_counter()
newton = {f"m0: {m0:2g}, h: {h}, k: {k}" : m_k_entry(m0, h, k, "newton") for k in ks for h in hs for m0 in m0s}
stop = time.perf_counter()
times["newton"] = stop - start

start = time.perf_counter()
brentq = {f"m0: {m0:2g}, h: {h}, k: {k}" : m_k_entry(m0, h, k, "brentq") for k in ks for h in hs for m0 in m0s}
stop = time.perf_counter()
times["brentq"] = stop - start

start = time.perf_counter()
brenth = {f"m0: {m0:2g}, h: {h}, k: {k}" : m_k_entry(m0, h, k, "brenth") for k in ks for h in hs for m0 in m0s}
stop = time.perf_counter()
times["brenth"] = stop - start

In [None]:
# good if no outputs
for m0 in m0s:
  for h in hs:
    for k in ks:
      label = f"m0: {m0:2g}, h: {h}, k: {k}"
      diff1, diff2 = brentq[label] - newton[label], brenth[label]- newton[label]
      if abs(diff1) > 1e-12 or abs(diff2) > 1e-12:
        print(diff1, diff2, m0, h, k)

In [None]:
print(times["newton"]) # slow
print(times["brentq"])
print(times["brenth"])

80.70836249506101
1.2815738030476496
1.3919195890193805
