In [3]:
import torch
import pickle
from matplotlib import pyplot as plt
import numpy as np
import pysr

In [4]:
labels = ['time', 'e+_near', 'e-_near', 'max_strength_mmr_near', 'e+_far', 'e-_far', 'max_strength_mmr_far', 'megno', 'a1', 'e1', 'i1', 'cos_Omega1', 'sin_Omega1', 'cos_pomega1', 'sin_pomega1', 'cos_theta1', 'sin_theta1', 'a2', 'e2', 'i2', 'cos_Omega2', 'sin_Omega2', 'cos_pomega2', 'sin_pomega2', 'cos_theta2', 'sin_theta2', 'a3', 'e3', 'i3', 'cos_Omega3', 'sin_Omega3', 'cos_pomega3', 'sin_pomega3', 'cos_theta3', 'sin_theta3', 'm1', 'm2', 'm3', 'nan_mmr_near', 'nan_mmr_far', 'nan_megno']

# not all of these labels are actually used. for training, these inputs are zeroed out, but still passed in as zeroes.
# ideally, the linear layer ignores them, which does happen if i do l1 regularization to it
skipped = ['nan_mmr_near', 'nan_mmr_far', 'nan_megno', 'e+_near', 'e-_near', 'max_strength_mmr_near', 'e+_far', 'e-_far', 'max_strength_mmr_far', 'megno']

In [5]:
# l1 reg = 2: 95944
# feature_nn = torch.load('models/95944_feature_nn.pt')

# l1 reg = 0.2: 92122
# feature_nn = torch.load('models/92122_feature_nn.pt')

# l1 nonabs reg: 63524
# feature_nn = torch.load('models/63524_feature_nn.pt')

# topk 2 pruned masked linear
feature_nn = torch.load('models/52410_feature_nn.pt')

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
# this gives the (n_features, n_inputs) matrix of the linear transformation used as f1
input_linear = feature_nn.weight.detach().numpy()
input_bias = feature_nn.bias.detach().numpy()
input_linear.shape

AttributeError: 'MaskedLinear' object has no attribute 'weight'

In [7]:
input_linear = feature_nn.mask * feature_nn.linear.weight
input_linear = input_linear.detach().numpy()
input_bias = feature_nn.linear.bias.detach().numpy()
torch.save(input_linear, 'models/52410_input_linear.pt')
torch.save(input_bias, 'models/52410_input_bias.pt')

In [8]:

# m_i is the mean of the i'th feature, s_i is the standard deviation
# get the linear transformation that creates feature i
def linear_transformation(i):
    return input_linear[i]

In [9]:

# let's make the linear transformation a bit easier to read
def format_num(x):
    if abs(x) > 0.1:
        return f'{x:.2f}'
    if abs(x) > 0.01:
        return f'{x:.3f}'
    elif abs(x) > 0.001:
        return f'{x:.4f}'
    else:
        return f'{x:.2e}'

format_vec = np.vectorize(format_num)

In [13]:

# now we can write it as a combination of the input features
# we'll sort the features by their absolute value to make it a bit easier to read
def feature_equation(i):
    transformation = linear_transformation(i)
    bias = input_bias[i]
    sorted_ixs = np.argsort(np.abs(transformation))[::-1]
    return [format_num(bias)] + [format_num(transformation[i]) + ' * ' + labels[i] for i in sorted_ixs if transformation[i] != 0]

In [16]:

for i in range(input_linear.shape[0]):
    print(f'feature {i}:')
    # print(' +\n'.join(feature_equation(i)[:10]))
    print(' + '.join(feature_equation(i)))
    # print('+ ... (smaller terms omitted)')
    print()

feature 0:
-0.0055 + -1.53 * m1 + 0.35 * a3

feature 1:
0.0080 + 0.78 * e3 + -0.22 * e2

feature 2:
0.0083 + 0.85 * e2 + -0.0058 * i1

feature 3:
0.0015 + -0.28 * sin_Omega2 + 0.11 * sin_Omega1

feature 4:
-0.0014 + 1.81 * a1 + -0.011 * e1

feature 5:
0.0084 + -3.93 * a2 + 3.71 * a3

feature 6:
-3.17e-05 + -1.11e-04 * i2 + 1.71e-05 * e2

feature 7:
-2.67e-04 + -5.35e-05 * m1 + 1.34e-05 * a2

feature 8:
-0.0017 + -0.42 * i3 + -0.0025 * i1

feature 9:
-0.011 + -1.11 * m2 + 0.011 * i1

feature 10:
-5.91e-04 + -0.29 * sin_Omega3 + 0.053 * sin_Omega2

feature 11:
-0.0014 + 0.21 * cos_Omega2 + -0.16 * cos_Omega3

feature 12:
2.27e-05 + 2.78e-05 * cos_Omega2 + -1.71e-05 * a2

feature 13:
0.013 + 1.32 * m3 + -0.26 * a3

feature 14:
8.41e-05 + 6.71e-05 * sin_Omega2 + -1.99e-06 * a2

feature 15:
-2.57e-05 + 0.0010 * sin_pomega3 + 4.77e-05 * sin_Omega1

feature 16:
-4.55e-04 + 0.86 * e1 + -0.35 * e2

feature 17:
0.0070 + 0.86 * a3 + 0.36 * a2

feature 18:
0.0015 + 0.44 * i2 + 0.054 * i1

feature 