In [1]:
import torch
import pickle
from matplotlib import pyplot as plt
import numpy as np
import pysr
import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
labels = ['time', 'e+_near', 'e-_near', 'max_strength_mmr_near', 'e+_far', 'e-_far', 'max_strength_mmr_far', 'megno', 'a1', 'e1', 'i1', 'cos_Omega1', 'sin_Omega1', 'cos_pomega1', 'sin_pomega1', 'cos_theta1', 'sin_theta1', 'a2', 'e2', 'i2', 'cos_Omega2', 'sin_Omega2', 'cos_pomega2', 'sin_pomega2', 'cos_theta2', 'sin_theta2', 'a3', 'e3', 'i3', 'cos_Omega3', 'sin_Omega3', 'cos_pomega3', 'sin_pomega3', 'cos_theta3', 'sin_theta3', 'm1', 'm2', 'm3', 'nan_mmr_near', 'nan_mmr_far', 'nan_megno']

# not all of these labels are actually used. for training, these inputs are zeroed out, but still passed in as zeroes.
# ideally, the linear layer ignores them, which does happen if i do l1 regularization to it
skipped = ['nan_mmr_near', 'nan_mmr_far', 'nan_megno', 'e+_near', 'e-_near', 'max_strength_mmr_near', 'e+_far', 'e-_far', 'max_strength_mmr_far', 'megno']

In [3]:
# l1 reg = 2: 95944
# feature_nn = torch.load('models/95944_feature_nn.pt')

# l1 reg = 0.2: 92122
# feature_nn = torch.load('models/92122_feature_nn.pt')

# l1 nonabs reg: 63524
# feature_nn = torch.load('models/63524_feature_nn.pt')

# this gives the (n_features, n_inputs) matrix of the linear transformation used as f1
# input_linear = feature_nn.weight.detach().numpy()
# input_bias = feature_nn.bias.detach().numpy()
# input_linear.shape

In [4]:
# topk 2 pruned masked linear
# feature_nn = torch.load('models/52410_feature_nn.pt')

# topk2 pruned masked linear 3750
# feature_nn = torch.load('models/57964_feature_nn.pt')

# topk2 pruned masked linear 7500 again
# feature_nn = torch.load('models/2762_feature_nn.pt')

# another one
feature_nn = torch.load('models/55549_feature_nn.pt')


In [25]:
input_linear.shape
input_linear = feature_nn.linear.weight * feature_nn.mask
l = input_linear.abs().sum(dim=-1)
l[l < 1] = 0
print(l[l > 1])
_, ixs = torch.topk(input_linear.abs().sum(dim=-1), 10)
l[~ixs] = 0
print(l)


tensor([ 2.0622,  1.0224,  2.5198,  1.1172,  2.0941,  1.5463, 10.0008,  1.6823,
         1.3679,  1.3427,  2.2789,  1.5476,  1.6634,  1.0618,  1.1225,  1.5119,
         1.2386], grad_fn=<IndexBackward0>)
tensor([ 2.0622,  0.0000,  2.5198,  0.0000,  1.1172,  0.0000,  0.0000, 10.0008,
         0.0000,  1.3679,  1.3427,  0.0000,  0.0000,  0.0000,  0.0000,  1.0618,
         0.0000,  0.0000,  1.5119,  0.0000], grad_fn=<IndexPutBackward0>)


In [7]:
input_linear = feature_nn.linear.weight * feature_nn.mask
input_linear = input_linear.detach().numpy()
if feature_nn.linear.bias is not None:
    input_bias = feature_nn.linear.bias.detach().numpy()
else:
    input_bias = np.zeros(input_linear.shape[0])

In [8]:
# m_i is the mean of the i'th feature, s_i is the standard deviation
# get the linear transformation that creates feature i
def linear_transformation(i):
    return input_linear[i]

In [7]:

# let's make the linear transformation a bit easier to read
def format_num(x):
    if abs(x) > 0.1:
        return f'{x:.2f}'
    if abs(x) > 0.01:
        return f'{x:.3f}'
    elif abs(x) > 0.001:
        return f'{x:.4f}'
    else:
        return f'{x:.2e}'

format_vec = np.vectorize(format_num)

In [11]:
# now we can write it as a combination of the input features
# we'll sort the features by their absolute value to make it a bit easier to read
def feature_equation(i):
    transformation = linear_transformation(i)
    bias = input_bias[i]
    sorted_ixs = np.argsort(np.abs(transformation))[::-1]
    features = [format_num(transformation[i]) + ' * ' + labels[i] for i in sorted_ixs if transformation[i] != 0]
    if bias != 0:
        features = [format_num(bias)] + features

    return features

In [12]:
for i in range(input_linear.shape[0]):
    print(str(i) + ": " + ' + '.join(feature_equation(i)))

0: -1.55 * m1 + 0.51 * a3
1: 0.60 * i2 + 0.42 * i1
2: 1.36 * e2 + -1.16 * e3
3: -0.0072 * sin_Omega3 + -0.0025 * i3
4: 1.08 * i3 + -0.041 * cos_Omega1
5: -1.41 * m3 + 0.68 * a3
6: 0.80 * sin_pomega2 + -0.74 * sin_pomega1
7: -5.02 * a3 + 4.98 * a2
8: 1.37 * a2 + -0.32 * m2
9: -0.74 * cos_Omega3 + 0.63 * cos_Omega2
10: -0.91 * sin_Omega3 + 0.43 * sin_Omega2
11: -1.58 * e1 + 0.70 * e2
12: -0.020 * i2 + -0.0043 * e1
13: -1.06 * e3 + -0.48 * e2
14: 1.43 * a1 + -0.23 * m3
15: 0.91 * sin_Omega1 + -0.16 * sin_Omega3
16: 5.36e-04 * sin_Omega2 + -5.04e-04 * m1
17: 0.91 * sin_Omega2 + -0.22 * cos_Omega2
18: 1.23 * m2 + -0.28 * e1
19: -0.98 * sin_pomega3 + -0.26 * cos_pomega2


In [52]:
def features_used(i):
    transformation = linear_transformation(i)
    sorted_ixs = np.argsort(np.abs(transformation))[::-1]
    return ', '.join([labels[i] for i in sorted_ixs if transformation[i] != 0])

def feature_equation2(i):
    transformation = linear_transformation(i)
    bias = input_bias[i]
    sorted_ixs = np.argsort(np.abs(transformation))[::-1]
    v = [format_num(bias)]
    for i in sorted_ixs:
        if transformation[i] != 0:
            v.append(format_num(transformation[i]))
            v.append(labels[i])

    return v

In [60]:
ixs = [(i, features_used(i)) for i in range(input_linear.shape[0])]
sorted_ixs = sorted(ixs, key=lambda x: x[1])
values = []
for i, _ in sorted_ixs:
    print(str(i) + ", " + ' , '.join(feature_equation(i)))
    values.append([i] + feature_equation2(i))

4, -0.0014 , 1.81 , a1 , -0.011 , e1
5, 0.0084 , -3.93 , a2 , 3.71 , a3
17, 0.0070 , 0.86 , a3 , 0.36 , a2
12, 2.27e-05 , 2.78e-05 , cos_Omega2 , -1.71e-05 , a2
11, -0.0014 , 0.21 , cos_Omega2 , -0.16 , cos_Omega3
16, -4.55e-04 , 0.86 , e1 , -0.35 , e2
2, 0.0083 , 0.85 , e2 , -0.0058 , i1
1, 0.0080 , 0.78 , e3 , -0.22 , e2
19, -0.0016 , -0.44 , i1 , -0.017 , cos_Omega1
6, -3.17e-05 , -1.11e-04 , i2 , 1.71e-05 , e2
18, 0.0015 , 0.44 , i2 , 0.054 , i1
8, -0.0017 , -0.42 , i3 , -0.0025 , i1
7, -2.67e-04 , -5.35e-05 , m1 , 1.34e-05 , a2
0, -0.0055 , -1.53 , m1 , 0.35 , a3
9, -0.011 , -1.11 , m2 , 0.011 , i1
13, 0.013 , 1.32 , m3 , -0.26 , a3
14, 8.41e-05 , 6.71e-05 , sin_Omega2 , -1.99e-06 , a2
3, 0.0015 , -0.28 , sin_Omega2 , 0.11 , sin_Omega1
10, -5.91e-04 , -0.29 , sin_Omega3 , 0.053 , sin_Omega2
15, -2.57e-05 , 0.0010 , sin_pomega3 , 4.77e-05 , sin_Omega1


In [61]:
# values3 = values
# values2 = values
# values1 = values

In [74]:
feature_dict = {}
all_values = values3 + values2 + values1
for value_list, i in (values3, 3), (values2, 2), (values1, 1):
    for l in value_list:
        f = l[-3] + ', ' + l[-1]
        if f not in feature_dict:
            feature_dict[f] = [(1, 0), (2, 0), (3, 0)]
        feature_dict[f][i-1] = (i, [l[1], l[2], l[4]])


In [78]:
for f in feature_dict:
    print(f)
    for i in range(3):
        print('\t' + str(feature_dict[f][i]))





a1, e1
	(1, ['-0.0014', '1.81', '-0.011'])
	(2, 0)
	(3, ['-0.026', '1.38', '-0.0086'])
a2, cos_Omega3
	(1, 0)
	(2, 0)
	(3, ['-0.0030', '0.0044', '0.0038'])
a3, a2
	(1, ['0.0070', '0.86', '0.36'])
	(2, 0)
	(3, ['-0.49', '-0.83', '-0.67'])
cos_Omega3, cos_Omega2
	(1, 0)
	(2, ['-0.014', '-0.80', '0.66'])
	(3, ['-0.0056', '-0.87', '0.40'])
e1, e2
	(1, ['-4.55e-04', '0.86', '-0.35'])
	(2, ['-0.16', '-0.86', '-0.45'])
	(3, ['0.085', '1.53', '-0.79'])
e2, e1
	(1, 0)
	(2, ['0.065', '1.46', '-1.00'])
	(3, ['0.21', '1.19', '0.33'])
e3, e2
	(1, ['0.0080', '0.78', '-0.22'])
	(2, ['0.088', '1.53', '-0.50'])
	(3, ['-0.11', '-1.49', '0.72'])
i1, i2
	(1, 0)
	(2, 0)
	(3, ['0.030', '0.79', '0.048'])
i2, i3
	(1, 0)
	(2, 0)
	(3, ['0.0046', '0.61', '0.29'])
i3, e3
	(1, 0)
	(2, 0)
	(3, ['0.016', '0.80', '0.014'])
i3, m1
	(1, 0)
	(2, 0)
	(3, ['0.0041', '-0.0100', '0.0045'])
i3, m3
	(1, 0)
	(2, 0)
	(3, ['0.0047', '0.0042', '1.54e-04'])
m1, a3
	(1, ['-0.0055', '-1.53', '0.35'])
	(2, 0)
	(3, ['-0.15', '-1.81', 

In [87]:
s = [(f, sum([0 if not l[1] else sum([abs(float(f1)) for f1 in l[1]]) for l in feature_dict[f]])) for f in feature_dict]

In [88]:
s = sorted(s, key=lambda x: -x[1])

In [89]:
for t in s:
    f = t[0]
    print(f)
    for i in range(3):
        print('\t' + str(feature_dict[f][i]))


a2, a3
	(1, ['0.0084', '-3.93', '3.71'])
	(2, ['0.54', '-5.01', '4.71'])
	(3, 0)
e3, e2
	(1, ['0.0080', '0.78', '-0.22'])
	(2, ['0.088', '1.53', '-0.50'])
	(3, ['-0.11', '-1.49', '0.72'])
e1, e2
	(1, ['-4.55e-04', '0.86', '-0.35'])
	(2, ['-0.16', '-0.86', '-0.45'])
	(3, ['0.085', '1.53', '-0.79'])
m1, a3
	(1, ['-0.0055', '-1.53', '0.35'])
	(2, 0)
	(3, ['-0.15', '-1.81', '0.59'])
e2, e1
	(1, 0)
	(2, ['0.065', '1.46', '-1.00'])
	(3, ['0.21', '1.19', '0.33'])
m3, a3
	(1, ['0.013', '1.32', '-0.26'])
	(2, 0)
	(3, ['0.045', '1.59', '-0.58'])
sin_Omega2, sin_Omega1
	(1, ['0.0015', '-0.28', '0.11'])
	(2, ['-0.0025', '-0.71', '0.71'])
	(3, ['-0.022', '-0.87', '0.59'])
a1, e1
	(1, ['-0.0014', '1.81', '-0.011'])
	(2, 0)
	(3, ['-0.026', '1.38', '-0.0086'])
sin_Omega3, sin_Omega2
	(1, ['-5.91e-04', '-0.29', '0.053'])
	(2, ['-0.025', '-0.93', '0.47'])
	(3, ['-0.0034', '-0.96', '0.50'])
a3, a2
	(1, ['0.0070', '0.86', '0.36'])
	(2, 0)
	(3, ['-0.49', '-0.83', '-0.67'])
cos_Omega3, cos_Omega2
	(1, 0)
	(