In [1]:
version = 24880

In [2]:
import torch
import numpy as np
import spock_reg_model

  "torch was imported before juliacall. This may cause a segfault. "


Detected Jupyter notebook. Loading juliacall extension. Set `PYSR_AUTOLOAD_EXTENSIONS=no` to disable.


In [3]:
labels = ['time', 'e+_near', 'e-_near', 'max_strength_mmr_near', 'e+_far', 'e-_far', 'max_strength_mmr_far', 'megno', 'a1', 'e1', 'i1', 'cos_Omega1', 'sin_Omega1', 'cos_pomega1', 'sin_pomega1', 'cos_theta1', 'sin_theta1', 'a2', 'e2', 'i2', 'cos_Omega2', 'sin_Omega2', 'cos_pomega2', 'sin_pomega2', 'cos_theta2', 'sin_theta2', 'a3', 'e3', 'i3', 'cos_Omega3', 'sin_Omega3', 'cos_pomega3', 'sin_pomega3', 'cos_theta3', 'sin_theta3', 'm1', 'm2', 'm3', 'nan_mmr_near', 'nan_mmr_far', 'nan_megno']

# not all of these labels are actually used. for training, these inputs are zeroed out, but still passed in as zeroes.
# ideally, the linear layer ignores them, which does happen if i do l1 regularization to it
skipped = ['nan_mmr_near', 'nan_mmr_far', 'nan_megno', 'e+_near', 'e-_near', 'max_strength_mmr_near', 'e+_far', 'e-_far', 'max_strength_mmr_far', 'megno']

assert len(labels) == 41

In [4]:
def convert_to_latex(label):
    # 1. if it ends in a number, add an underscore
    if label[-1].isdigit():
        label = label[:-1] + '_' + label[-1]
    # 2. replace sin/cos with \sin/\cos
    label = label.replace('sin', '\\sin')
    label = label.replace('cos', '\\cos')
    label = label.replace('_Omega', '\\Omega')
    label = label.replace('_pomega', '\\omega')
    label = label.replace('_theta', '\\theta')
    return label

latex_labels = [convert_to_latex(label) for label in labels]
print(latex_labels)

['time', 'e+_near', 'e-_near', 'max_strength_mmr_near', 'e+_far', 'e-_far', 'max_strength_mmr_far', 'megno', 'a_1', 'e_1', 'i_1', '\\cos\\Omega_1', '\\sin\\Omega_1', '\\cos\\omega_1', '\\sin\\omega_1', '\\cos\\theta_1', '\\sin\\theta_1', 'a_2', 'e_2', 'i_2', '\\cos\\Omega_2', '\\sin\\Omega_2', '\\cos\\omega_2', '\\sin\\omega_2', '\\cos\\theta_2', '\\sin\\theta_2', 'a_3', 'e_3', 'i_3', '\\cos\\Omega_3', '\\sin\\Omega_3', '\\cos\\omega_3', '\\sin\\omega_3', '\\cos\\theta_3', '\\sin\\theta_3', 'm_1', 'm_2', 'm_3', 'nan_mmr_near', 'nan_mmr_far', 'nan_megno']


In [5]:
from sklearn.preprocessing import StandardScaler
import numpy as np
ssX = StandardScaler()
ssX.scale_ = np.array([2.88976974e+03, 6.10019661e-02, 4.03849732e-02, 4.81638693e+01,
           6.72583662e-02, 4.17939679e-02, 8.15995339e+00, 2.26871589e+01,
           4.73612029e-03, 7.09223721e-02, 3.06455099e-02, 7.10726478e-01,
           7.03392022e-01, 7.07873597e-01, 7.06030923e-01, 7.04728204e-01,
           7.09420909e-01, 1.90740659e-01, 4.75502285e-02, 2.77188320e-02,
           7.08891412e-01, 7.05214134e-01, 7.09786887e-01, 7.04371833e-01,
           7.04371110e-01, 7.09828420e-01, 3.33589977e-01, 5.20857790e-02,
           2.84763136e-02, 7.02210626e-01, 7.11815232e-01, 7.10512240e-01,
           7.03646004e-01, 7.08017286e-01, 7.06162814e-01, 2.12569430e-05,
           2.35019125e-05, 2.04211110e-05, 7.51048890e-02, 3.94254400e-01,
           7.11351099e-02])
ssX.mean_ = np.array([ 4.95458585e+03,  5.67411891e-02,  3.83176945e-02,  2.97223474e+00,
           6.29733979e-02,  3.50074471e-02,  6.72845676e-01,  9.92794768e+00,
           9.99628430e-01,  5.39591547e-02,  2.92795061e-02,  2.12480714e-03,
          -1.01500319e-02,  1.82667162e-02,  1.00813201e-02,  5.74404197e-03,
           6.86570242e-03,  1.25316320e+00,  4.76946516e-02,  2.71326280e-02,
           7.02054326e-03,  9.83378673e-03, -5.70616748e-03,  5.50782881e-03,
          -8.44213953e-04,  2.05958338e-03,  1.57866569e+00,  4.31476211e-02,
           2.73316392e-02,  1.05505555e-02,  1.03922250e-02,  7.36865006e-03,
          -6.00523246e-04,  6.53016990e-03, -1.72038113e-03,  1.24807860e-05,
           1.60314173e-05,  1.21732696e-05,  5.67292645e-03,  1.92488263e-01,
           5.08607199e-03])
ssX.var_ = ssX.scale_**2

In [6]:
# load when on the cluster
model = spock_reg_model.load(version)
feature_nn = model.feature_nn

In [7]:
# load when local
# feature_nn = torch.load(f'models/{version}_feature_nn.pt')

In [8]:
input_linear = feature_nn.linear.weight * feature_nn.mask
input_linear = input_linear.detach().numpy()
if feature_nn.linear.bias is not None:
    input_bias = feature_nn.linear.bias.detach().numpy()
else:
    input_bias = np.zeros(input_linear.shape[0])

In [9]:
# m_i is the mean of the i'th feature, s_i is the standard deviation
# get the linear transformation that creates feature i
def linear_transformation(i):
    return input_linear[i]

# let's make the linear transformation a bit easier to read
def format_num(x, latex=False):
    if abs(x) > 1000:
        x2 = 100 * (x // 100)
        return str(x2)
    # if abs(x) > 10:
        # return f'{x:.0f}'
    # if abs(x) > 1:
        # return f'{x:.2f}'
    # if abs(x) > 0.1:
        # return f'{x:.2f}'
    # if abs(x) > 0.01:
        # return f'{x:.3f}'
    # elif abs(x) > 0.001:
        # return f'{x:.4f}'
    else:
        return f'{x:.3g}'

format_vec = np.vectorize(format_num)

In [10]:
import sympy as sp
sym_vars = {lbl: sp.Symbol(lbl, real=True) for lbl in labels}

def simplify_scaled_feature(transformation, bias=0, include_ssx_bias=True):
    # Create symbolic variables for each feature

    expr = bias

    # Add each transformed feature (unscaled)
    for f_idx in range(len(labels)):
        c = transformation[f_idx]
        if c != 0:
            label = labels[f_idx]
            mean_j = ssX.mean_[f_idx] if include_ssx_bias else 0.0
            scale_j = ssX.scale_[f_idx]
            expr += c * (sym_vars[label] - mean_j) / scale_j

    expr = sp.simplify(expr)
    return expr

def format_sympy_expr(expr, latex=False):
    # replace labels with latex labels (change character from labels[i] to latex_labels[i])
    if latex:
        for lbl, sym in sym_vars.items():
            i = labels.index(lbl)
            new_lbl = latex_labels[i]
            expr = expr.subs(sym, sp.Symbol(new_lbl, real=True))

    coeffs = expr.as_coefficients_dict()

    terms_str = []
    const_str = None
    for var, coef in coeffs.items():
        if var == 1:
            const_str = format_num(coef, latex)
        else:
            times = '' if latex else '*'
            terms_str.append(f'{format_num(coef, latex)} {times} {var}')

    if const_str is not None:
        terms_str.append(const_str)

    return ' + '.join(terms_str)

def format_transformation(transformation, bias, latex):
    sorted_ixs = np.argsort(np.abs(transformation))[::-1]
    times = '' if latex else '*'
    used_labels = latex_labels if latex else labels
    features = [f'{format_num(transformation[i], latex)} {times} {used_labels[i]}' for i in sorted_ixs if transformation[i] != 0]
    if bias != 0:
        features = [format_num(bias, latex)] + features
    return ' + '.join(features)

In [11]:
def get_scaled_feature_bias(i):
    transformation = linear_transformation(i)
    bias = input_bias[i]
    expr = simplify_scaled_feature(transformation, bias)
    return expr.as_coefficients_dict().get(1, 0)

In [12]:
# load pysr f2 equations
import pickle
reg = pickle.load(open('sr_results/11003.pkl', 'rb'))
results = reg.equations_[0]

In [13]:
entry = results.iloc[8]

In [14]:
expr = entry.sympy_format

In [15]:
def add_bias_to_mean_terms(expr):
    replacements = {}
    for symbol in expr.free_symbols:
        if symbol.name.startswith('m') and symbol.name[1:].isdigit():
            i = symbol.name[1:]  # get the number after 'm'
            replacements[symbol] = symbol + get_scaled_feature_bias(int(i))
    return expr.xreplace(replacements)

In [16]:
results['sympy_format'] = results['sympy_format'].apply(add_bias_to_mean_terms)

In [17]:
from pysr.export_latex import sympy2latextable

In [18]:
def get_important_complexities(results, loss_gap = 0.25):
    complexities = list(results['complexity'])
    losses = list(results['loss'])
    assert sorted(losses) == losses[::-1]

    # important complexities are those that decrease the loss by more than loss_gap since the previous important complexity.
    important_complexities = [complexities[0]]
    current_loss = losses[0]

    for i in range(1, len(complexities)):
        if current_loss - losses[i] > loss_gap:
            important_complexities.append(complexities[i])
            current_loss = losses[i]

    # automatically include the highest complexity too
    if complexities[-1] != important_complexities[-1]:
        important_complexities.append(complexities[-1])

    return important_complexities

In [19]:
important_complexities = get_important_complexities(results, loss_gap=0.25)

In [20]:
# important_complexities = [3, 5, 14, 30]
important_complexities = [1, 3, 4, 7, 11, 14, 22, 26]

In [21]:
important_indices = [i for i, c in enumerate(results['complexity']) if c in important_complexities]

In [22]:
important_indices

[0, 1, 2, 4, 6, 8, 15, 19]

In [23]:
s = sympy2latextable(reg.equations_[0], precision=2, columns=['equation', 'complexity', 'loss'], indices=important_indices)

In [24]:
mapping_dict = dict(zip([2, 4, 1, 6, 7, 8, 0, 3, 5, 9], range(9)))

In [25]:
import re
def remap_latex_str(s):
    # use regex so that we don't replace multiple times. from o3-mini-high
    mapping_str = {str(old): str(new) for old, new in mapping_dict.items()}
    pattern = re.compile(r'([ms])_\{([^}]+)\}')
    def repl(match):
        prefix, key = match.groups()
        return f"{prefix}_{{{mapping_str[key]}}}" if key in mapping_str else match.group(0)
    return pattern.sub(repl, s)

In [26]:
# replace m_{ with \mu_{ and s_{ with \sigma_{ and y = with \log_{10} T_{\text{inst}} =
s = remap_latex_str(s)
s = s.replace('m_{', '\\mu_{')
s = s.replace('s_{', '\\sigma_{')
s = s.replace('y =', '\\log_{10} T_{\\text{inst}} =')

In [27]:
print(s)

\begin{table}[h]
\begin{center}
\begin{tabular}{@{}ccc@{}}
\toprule
Equation & Complexity & Loss \\
\midrule
$\log_{10} T_{\text{inst}} = 7.0$ & $1$ & $5.7$ \\
$\log_{10} T_{\text{inst}} = 0.98 - \mu_{0}$ & $3$ & $5.0$ \\
$\log_{10} T_{\text{inst}} = 7.2 - \sin{\left(\mu_{0} + 6.0 \right)}$ & $4$ & $4.9$ \\
$\log_{10} T_{\text{inst}} = - \mu_{0} + \frac{3.6}{\sigma_{1}^{0.16}} - 6.0$ & $7$ & $3.3$ \\
\begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} \log_{10} T_{\text{inst}} = 0.14^{\sigma_{2}} \left(- \mu_{0} + \sigma_{1}^{-0.31} - 6.0\right) + 3.7 \end{dmath*} \end{minipage} & $11$ & $2.7$ \\
\begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} \log_{10} T_{\text{inst}} = 0.059^{\sigma_{2}} \left(\mu_{4} + \sigma_{1}^{-0.33} - \sin{\left(\mu_{0} + 6.0 \right)} - 0.72\right) + 3.7 \end{dmath*} \end{minipage} & $14$ & $2.5$ \\
\begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} \log_{10} T_{\text{inst}} = 0.067^{\sigma_{2}} \left(\mu_{4} - \sigma_{5} + \left

In [28]:
# now we can write it as a combination of the input features
# we'll sort the features by their absolute value to make it a bit easier to read
def feature_string(i, include_ssx=False, latex=False, include_ssx_bias=True):
    transformation = linear_transformation(i)
    bias = input_bias[i]

    if include_ssx:
        expr = simplify_scaled_feature(transformation, bias, include_ssx_bias=include_ssx_bias)
        s = format_sympy_expr(expr, latex)
    else:
        s = format_transformation(transformation, bias, latex)

    # change + -'s to -'s
    s = s.replace(' + -', ' - ')
    return s


In [29]:
for i in range(input_linear.shape[0]):
    print(simplify_scaled_feature(linear_transformation(i), include_ssx_bias=True))

14.0481961363884*e1 - 34.9908344619683*e2 + 0.910846870277527
209.781235963898*a1 + 85414.8542688018*m1 - 210.769332067401
-4.01863051350855*a3 + 27051.5748154689*m1 + 6.00644919622817
-1.10101459088907*sin_Omega2 + 0.763710665324014*sin_Omega3 + 0.00289048960447445
-27.545810485898*a2 + 13.9602160286755*a3 + 12.4808819456433
1.46652031263173*e1 + 29.1184195269148*e3 - 1.33552272919815
37.5296375943142*i2 + 13.1045896905103*i3 - 1.37644761310641
-4.91574016561418*e3 + 58430.5224859651*m2 - 0.724621594937567
21.1436690517468*e1 - 4.47539705411287*e3 - 0.94779177292589
0.957226289341714*e1 + 64678.2447691033*m3 - 0.83899683225858


In [30]:
input_linear.shape

(10, 41)

In [31]:
def feature_coeffs(i, include_ssx=False, latex=False, include_ssx_bias=True):
    transformation = linear_transformation(i)
    bias = input_bias[i]

    if include_ssx:
        expr = simplify_scaled_feature(transformation, bias, include_ssx_bias=include_ssx_bias)
        s = format_sympy_expr(expr, latex)
    else:
        s = format_transformation(transformation, bias, latex)

    # change + -'s to -'s
    s = s.replace(' + -', ' - ')
    return s

In [32]:
input_linear = feature_nn.linear.weight * feature_nn.mask
input_linear = input_linear.detach().numpy()

In [33]:
for i in range(input_linear.shape[0]):
    print(str(i) + ": " + feature_string(i, include_ssx=True, latex=False, include_ssx_bias=False))

0: 14.0 * e1 - 35.0 * e2
1: 210 * a1 + 85400 * m1
2: 27000 * m1 - 4.02 * a3
3: 0.764 * sin_Omega3 - 1.10 * sin_Omega2
4: 14.0 * a3 - 27.5 * a2
5: 29.1 * e3 + 1.47 * e1
6: 13.1 * i3 + 37.5 * i2
7: 58400 * m2 - 4.92 * e3
8: 21.1 * e1 - 4.48 * e3
9: 0.957 * e1 + 64600 * m3


In [34]:
def latex_string(include_ssx=False, include_ssx_bias=True):
    s = ('\\begin{align*}\n'
        + f'ver&sion={version}\\\\'
        + 'f_1& \\text{ features:} \\\\ \n')

    for i in range(input_linear.shape[0]):
        s += f'    &{i}: {feature_string(i, include_ssx, latex=True, include_ssx_bias=include_ssx_bias)} \\\\ \n'

    s += '''\end{align*}'''
    return s


In [35]:
print(latex_string(include_ssx=True, include_ssx_bias=False))

\begin{align*}
ver&sion=24880\\f_1& \text{ features:} \\ 
    &0: 14.0  e_1 - 35.0  e_2 \\ 
    &1: 210  a_1 + 85400  m_1 \\ 
    &2: 27000  m_1 - 4.02  a_3 \\ 
    &3: 0.764  \sin\Omega_3 - 1.10  \sin\Omega_2 \\ 
    &4: 14.0  a_3 - 27.5  a_2 \\ 
    &5: 29.1  e_3 + 1.47  e_1 \\ 
    &6: 13.1  i_3 + 37.5  i_2 \\ 
    &7: 58400  m_2 - 4.92  e_3 \\ 
    &8: 21.1  e_1 - 4.48  e_3 \\ 
    &9: 0.957  e_1 + 64600  m_3 \\ 
\end{align*}


In [36]:
from IPython.display import display, Markdown
display(Markdown('$$\n' + latex_string(include_ssx=True) + '\n$$'))

$$
\begin{align*}
ver&sion=24880\\f_1& \text{ features:} \\ 
    &0: 14.0  e_1 - 35.0  e_2 + 0.911 \\ 
    &1: 210  a_1 + 85400  m_1 - 211 \\ 
    &2: 27000  m_1 - 4.02  a_3 + 6.01 \\ 
    &3: 0.764  \sin\Omega_3 - 1.10  \sin\Omega_2 + 0.00289 \\ 
    &4: 14.0  a_3 - 27.5  a_2 + 12.5 \\ 
    &5: 29.1  e_3 + 1.47  e_1 - 1.34 \\ 
    &6: 13.1  i_3 + 37.5  i_2 - 1.38 \\ 
    &7: 58400  m_2 - 4.92  e_3 - 0.725 \\ 
    &8: 21.1  e_1 - 4.48  e_3 - 0.948 \\ 
    &9: 0.957  e_1 + 64600  m_3 - 0.839 \\ 
\end{align*}
$$