In [1]:
import torch
import pickle
from matplotlib import pyplot as plt
import numpy as np
import pysr

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
labels = ['time', 'e+_near', 'e-_near', 'max_strength_mmr_near', 'e+_far', 'e-_far', 'max_strength_mmr_far', 'megno', 'a1', 'e1', 'i1', 'cos_Omega1', 'sin_Omega1', 'cos_pomega1', 'sin_pomega1', 'cos_theta1', 'sin_theta1', 'a2', 'e2', 'i2', 'cos_Omega2', 'sin_Omega2', 'cos_pomega2', 'sin_pomega2', 'cos_theta2', 'sin_theta2', 'a3', 'e3', 'i3', 'cos_Omega3', 'sin_Omega3', 'cos_pomega3', 'sin_pomega3', 'cos_theta3', 'sin_theta3', 'm1', 'm2', 'm3', 'nan_mmr_near', 'nan_mmr_far', 'nan_megno']

# not all of these labels are actually used. for training, these inputs are zeroed out, but still passed in as zeroes.
# ideally, the linear layer ignores them, which does happen if i do l1 regularization to it
skipped = ['nan_mmr_near', 'nan_mmr_far', 'nan_megno', 'e+_near', 'e-_near', 'max_strength_mmr_near', 'e+_far', 'e-_far', 'max_strength_mmr_far', 'megno']

In [10]:
# l1 reg = 2: 95944
# feature_nn = torch.load('models/95944_feature_nn.pt')

# l1 reg = 0.2: 92122
# feature_nn = torch.load('models/92122_feature_nn.pt')

# l1 nonabs reg: 63524
# feature_nn = torch.load('models/63524_feature_nn.pt')

# topk 2 pruned masked linear
# feature_nn = torch.load('models/52410_feature_nn.pt')

feature_nn = torch.load('models/250_feature_nn.pt')

In [11]:
# this gives the (n_features, n_inputs) matrix of the linear transformation used as f1
input_linear = feature_nn.weight.detach().numpy()
input_bias = feature_nn.bias.detach().numpy()
input_linear.shape

(20, 41)

In [7]:
input_linear = feature_nn.mask * feature_nn.linear.weight
input_linear = input_linear.detach().numpy()
input_bias = feature_nn.linear.bias.detach().numpy()
torch.save(input_linear, 'models/52410_input_linear.pt')
torch.save(input_bias, 'models/52410_input_bias.pt')

In [6]:

# m_i is the mean of the i'th feature, s_i is the standard deviation
# get the linear transformation that creates feature i
def linear_transformation(i):
    return input_linear[i]

In [7]:

# let's make the linear transformation a bit easier to read
def format_num(x):
    if abs(x) > 0.1:
        return f'{x:.2f}'
    if abs(x) > 0.01:
        return f'{x:.3f}'
    elif abs(x) > 0.001:
        return f'{x:.4f}'
    else:
        return f'{x:.2e}'

format_vec = np.vectorize(format_num)

In [8]:
# now we can write it as a combination of the input features
# we'll sort the features by their absolute value to make it a bit easier to read
def feature_equation(i):
    transformation = linear_transformation(i)
    bias = input_bias[i]
    sorted_ixs = np.argsort(np.abs(transformation))[::-1]
    return [format_num(bias)] + [format_num(transformation[i]) + ' * ' + labels[i] for i in sorted_ixs if transformation[i] != 0]

In [12]:
for i in range(input_linear.shape[0]):
    print(f'feature {i}:')
    # print(' +\n'.join(feature_equation(i)[:10]))
    print(' + '.join(feature_equation(i)))
    # print('+ ... (smaller terms omitted)')
    print()

feature 0:
8.66e-05 + 0.023 * a3 + 0.017 * megno + 0.014 * e+_near + -0.013 * nan_mmr_near + -0.012 * a2 + -0.0096 * m2 + -0.0093 * nan_megno + 0.0092 * e+_far + 0.0090 * max_strength_mmr_far + 0.0086 * max_strength_mmr_near + -0.0063 * m3 + -0.0050 * e-_far + -0.0025 * m1 + 0.0024 * i1 + -0.0019 * e1 + 0.0016 * i3 + 0.0016 * e-_near + -0.0016 * i2 + -0.0014 * nan_mmr_far + -0.0012 * sin_Omega2 + -0.0012 * sin_pomega1 + -0.0012 * e2 + -8.62e-04 * cos_pomega3 + 8.52e-04 * sin_pomega3 + -7.53e-04 * cos_pomega2 + -7.28e-04 * cos_Omega1 + 7.05e-04 * cos_Omega2 + 6.06e-04 * e3 + 5.49e-04 * cos_Omega3 + 4.21e-04 * a1 + -4.05e-04 * sin_pomega2 + -3.77e-04 * sin_Omega1 + 3.19e-04 * sin_theta3 + -1.94e-04 * sin_theta2 + -1.45e-04 * cos_theta2 + 1.20e-04 * time + -8.99e-05 * sin_theta1 + -6.72e-05 * cos_theta1 + -6.05e-05 * cos_theta3 + -3.81e-05 * sin_Omega3 + -1.05e-05 * cos_pomega1

feature 1:
0.030 + -1.21 * i2 + -0.12 * e2 + 0.12 * i3 + 0.093 * i1 + -0.071 * e3 + 0.070 * a3 + -0.055 * m2 + 