In [30]:
version = 30886 # k = 20
version =68491 #k=15
version =70470 # k = 10
version =74535#k=5
version = 44991 #k=2

In [31]:
import torch
import numpy as np
import spock_reg_model

In [None]:
labels = [f'm{i}' for i in range(10)

In [33]:
# load when on the cluster
model = spock_reg_model.load(version)
nn = model.regress_nn


In [41]:
n = model.hparams['latent']
labels = [f'm{i}' for i in range(n)] + [f's{i}' for i in range(n)]

In [42]:
input_linear = nn.linear.weight * nn.mask
input_linear = input_linear.detach().numpy()
if nn.linear.bias is not None:
    input_bias = nn.linear.bias.detach().numpy()
else:
    input_bias = np.zeros(input_linear.shape[0])

In [43]:
# m_i is the mean of the i'th feature, s_i is the standard deviation
# get the linear transformation that creates feature i
def linear_transformation(i):
    return input_linear[i]

In [44]:
# let's make the linear transformation a bit easier to read
def format_num(x):
    if abs(x) > 0.1:
        return f'{x:.2f}'
    if abs(x) > 0.01:
        return f'{x:.3f}'
    elif abs(x) > 0.001:
        return f'{x:.4f}'
    else:
        return f'{x:.2e}'

format_vec = np.vectorize(format_num)

In [45]:
# now we can write it as a combination of the input features
# we'll sort the features by their absolute value to make it a bit easier to read
def feature_equation(i):
    transformation = linear_transformation(i)
    bias = input_bias[i]
    sorted_ixs = np.argsort(np.abs(transformation))[::-1]
    features = [format_num(transformation[i]) + ' * ' + labels[i] for i in sorted_ixs if transformation[i] != 0]
    if bias != 0:
        features = [format_num(bias)] + features

    return features

In [46]:
for i in range(input_linear.shape[0]):
    print(str(i) + ": " + ' + '.join(feature_equation(i)))

0: 0.46 + -31.19 * s6 + -0.82 * m1
1: -0.36 + -0.21 * m1 + -0.0048 * m4


In [26]:
def convert_to_latex(label):
    # 1. if it ends in a number, add an underscore
    if label[-1].isdigit():
        label = label[:-1] + '_' + label[-1]
    # 2. replace sin/cos with \sin/\cos
    label = label.replace('sin', '\\sin')
    label = label.replace('cos', '\\cos')
    label = label.replace('_Omega', '\\Omega')
    label = label.replace('_pomega', '\\omega')
    label = label.replace('_theta', '\\theta')
    # label = label.replace('max_strength_mmr_near', '\\text{max_strength_mmr_near}')
    return label

latex_labels = [convert_to_latex(label) for label in labels]
print(latex_labels)

['time', 'e+_near', 'e-_near', 'max_strength_mmr_near', 'e+_far', 'e-_far', 'max_strength_mmr_far', 'megno', 'a_1', 'e_1', 'i_1', '\\cos\\Omega_1', '\\sin\\Omega_1', '\\cos\\omega_1', '\\sin\\omega_1', '\\cos\\theta_1', '\\sin\\theta_1', 'a_2', 'e_2', 'i_2', '\\cos\\Omega_2', '\\sin\\Omega_2', '\\cos\\omega_2', '\\sin\\omega_2', '\\cos\\theta_2', '\\sin\\theta_2', 'a_3', 'e_3', 'i_3', '\\cos\\Omega_3', '\\sin\\Omega_3', '\\cos\\omega_3', '\\sin\\omega_3', '\\cos\\theta_3', '\\sin\\theta_3', 'm_1', 'm_2', 'm_3', 'nan_mmr_near', 'nan_mmr_far', 'nan_megno']


In [27]:
def latex_line(i):
    transformation = linear_transformation(i)
    bias = input_bias[i]
    sorted_ixs = np.argsort(np.abs(transformation))[::-1]
    features = [format_num(transformation[i]) + ' ' + latex_labels[i] for i in sorted_ixs if transformation[i] != 0]
    if bias != 0:
        features = [format_num(bias)] + features

    line = ' + '.join(features)
    line = line.replace('+ -', '- ')
    return line

def latex_string():
    s = ('\\begin{align*}\n'
        + f've&rsion={version}\\\\'
        + 'f_1& \\text{ features:} \\\\ \n')

    for i in range(input_linear.shape[0]):
        s += f'    &{i}: {latex_line(i)} \\\\ \n'

    s += '''\end{align*}'''
    return s


In [28]:
from IPython.display import display, Markdown
display(Markdown('$$\n' + latex_string() + '\n$$'))

$$
\begin{align*}
ve&rsion=44991\\f_1& \text{ features:} \\ 
    &0: 1.33e-05 \sin\Omega_2 - 1.24e-05 \sin\Omega_3 \\ 
    &1: -0.82 m_2 - 0.34 a_3 \\ 
    &2: -7.20e-11 m_3 + 5.77e-11 \sin\omega_3 \\ 
    &3: 2.61e-06 max_strength_mmr_near - 1.39e-10 e_3 \\ 
    &4: -0.21 m_3 + 0.11 m_1 \\ 
    &5: 1.48e-05 \cos\Omega_2 - 1.22e-05 \cos\Omega_3 \\ 
    &6: 6.47 a_3 - 0.28 a_1 \\ 
    &7: -1.69e-06 megno + 7.18e-11 e_1 \\ 
    &8: -7.50e-11 i_3 - 1.22e-12 a_1 \\ 
    &9: -8.44e-06 \cos\Omega_3 + 8.39e-06 \cos\Omega_1 \\ 
\end{align*}
$$