In [3]:
import torch
import numpy as np
import spock_reg_model

In [4]:
labels = ['time', 'e+_near', 'e-_near', 'max_strength_mmr_near', 'e+_far', 'e-_far', 'max_strength_mmr_far', 'megno', 'a1', 'e1', 'i1', 'cos_Omega1', 'sin_Omega1', 'cos_pomega1', 'sin_pomega1', 'cos_theta1', 'sin_theta1', 'a2', 'e2', 'i2', 'cos_Omega2', 'sin_Omega2', 'cos_pomega2', 'sin_pomega2', 'cos_theta2', 'sin_theta2', 'a3', 'e3', 'i3', 'cos_Omega3', 'sin_Omega3', 'cos_pomega3', 'sin_pomega3', 'cos_theta3', 'sin_theta3', 'm1', 'm2', 'm3', 'nan_mmr_near', 'nan_mmr_far', 'nan_megno']

# not all of these labels are actually used. for training, these inputs are zeroed out, but still passed in as zeroes.
# ideally, the linear layer ignores them, which does happen if i do l1 regularization to it
skipped = ['nan_mmr_near', 'nan_mmr_far', 'nan_megno', 'e+_near', 'e-_near', 'max_strength_mmr_near', 'e+_far', 'e-_far', 'max_strength_mmr_far', 'megno']

In [51]:
version = 10290  # eps = 0.001
# version = 9259  # eps = 0.1

In [52]:
# load when on the cluster
model = spock_reg_model.load(version)
feature_nn = model.feature_nn

In [53]:
# load when local
# feature_nn = torch.load(f'models/{version}_feature_nn.pt')

In [54]:
input_linear = feature_nn.linear.weight * feature_nn.mask
input_linear = input_linear.detach().numpy()
if feature_nn.linear.bias is not None:
    input_bias = feature_nn.linear.bias.detach().numpy()
else:
    input_bias = np.zeros(input_linear.shape[0])

In [55]:
# m_i is the mean of the i'th feature, s_i is the standard deviation
# get the linear transformation that creates feature i
def linear_transformation(i):
    return input_linear[i]

In [56]:
# let's make the linear transformation a bit easier to read
def format_num(x):
    if abs(x) > 0.1:
        return f'{x:.2f}'
    if abs(x) > 0.01:
        return f'{x:.3f}'
    elif abs(x) > 0.001:
        return f'{x:.4f}'
    else:
        return f'{x:.2e}'

format_vec = np.vectorize(format_num)

In [57]:
# now we can write it as a combination of the input features
# we'll sort the features by their absolute value to make it a bit easier to read
def feature_equation(i):
    transformation = linear_transformation(i)
    bias = input_bias[i]
    sorted_ixs = np.argsort(np.abs(transformation))[::-1]
    features = [format_num(transformation[i]) + ' * ' + labels[i] for i in sorted_ixs if transformation[i] != 0]
    if bias != 0:
        features = [format_num(bias)] + features

    return features

In [58]:
for i in range(input_linear.shape[0]):
    print(str(i) + ": " + ' + '.join(feature_equation(i)))

0: -0.63 * e2 + 0.38 * e1
1: 1.82 * m1 + 0.99 * a1
2: -0.70 * a3 + 0.30 * m1
3: -0.59 * sin_Omega2 + 0.41 * sin_Omega3
4: -5.25 * a2 + 4.66 * a3
5: 1.52 * e3 + 0.10 * e1
6: 1.04 * i2 + 0.37 * i3
7: 1.37 * m2 + -0.26 * e3
8: 1.50 * e1 + -0.23 * e3
9: 1.32 * m3 + 0.068 * e1


In [59]:
def convert_to_latex(label):
    # 1. if it ends in a number, add an underscore
    if label[-1].isdigit():
        label = label[:-1] + '_' + label[-1]
    # 2. replace sin/cos with \sin/\cos
    label = label.replace('sin', '\\sin')
    label = label.replace('cos', '\\cos')
    label = label.replace('_Omega', '\\Omega')
    label = label.replace('_pomega', '\\omega')
    return label

latex_labels = [convert_to_latex(label) for label in labels]
print(latex_labels)

['time', 'e+_near', 'e-_near', 'max_strength_mmr_near', 'e+_far', 'e-_far', 'max_strength_mmr_far', 'megno', 'a_1', 'e_1', 'i_1', '\\cos\\Omega_1', '\\sin\\Omega_1', '\\cos\\omega_1', '\\sin\\omega_1', '\\cos_theta_1', '\\sin_theta_1', 'a_2', 'e_2', 'i_2', '\\cos\\Omega_2', '\\sin\\Omega_2', '\\cos\\omega_2', '\\sin\\omega_2', '\\cos_theta_2', '\\sin_theta_2', 'a_3', 'e_3', 'i_3', '\\cos\\Omega_3', '\\sin\\Omega_3', '\\cos\\omega_3', '\\sin\\omega_3', '\\cos_theta_3', '\\sin_theta_3', 'm_1', 'm_2', 'm_3', 'nan_mmr_near', 'nan_mmr_far', 'nan_megno']


In [60]:
def latex_line(i):
    transformation = linear_transformation(i)
    bias = input_bias[i]
    sorted_ixs = np.argsort(np.abs(transformation))[::-1]
    features = [format_num(transformation[i]) + ' ' + latex_labels[i] for i in sorted_ixs if transformation[i] != 0]
    if bias != 0:
        features = [format_num(bias)] + features

    line = ' + '.join(features)
    line = line.replace('+ -', '- ')
    return line

def latex_string():
    s = '''\\begin{align*}
    f_1& \\text{ features:} \\\\ \n'''

    for i in range(input_linear.shape[0]):
        s += f'    &{i}: {latex_line(i)} \\\\ \n'

    s += '''\end{align*}'''
    return s


In [61]:
from IPython.display import display, Markdown
display(Markdown('$$\n' + latex_string() + '\n$$'))

$$
\begin{align*}
    f_1& \text{ features:} \\ 
    &0: -0.63 e_2 + 0.38 e_1 \\ 
    &1: 1.82 m_1 + 0.99 a_1 \\ 
    &2: -0.70 a_3 + 0.30 m_1 \\ 
    &3: -0.59 \sin\Omega_2 + 0.41 \sin\Omega_3 \\ 
    &4: -5.25 a_2 + 4.66 a_3 \\ 
    &5: 1.52 e_3 + 0.10 e_1 \\ 
    &6: 1.04 i_2 + 0.37 i_3 \\ 
    &7: 1.37 m_2 - 0.26 e_3 \\ 
    &8: 1.50 e_1 - 0.23 e_3 \\ 
    &9: 1.32 m_3 + 0.068 e_1 \\ 
\end{align*}
$$

In [62]:
version = 24880
model = spock_reg_model.load(version)
feature_nn = model.feature_nn

input_linear = feature_nn.linear.weight * feature_nn.mask
input_linear = input_linear.detach().numpy()
if feature_nn.linear.bias is not None:
    input_bias = feature_nn.linear.bias.detach().numpy()
else:
    input_bias = np.zeros(input_linear.shape[0])

In [63]:
def get_nonzero(arr):
    return arr[arr.nonzero()], arr.nonzero()

def set_nonzero(arr, values, indices):
    arr[indices] = values

In [64]:
def simplify(x, y, epsilon=0.1, zeroing_allowed=True):
    if x == 0 and y == 0:
        return (0, 0, 1), 0
    if x == 0:
        return (0, 1, 1/y), 0
    if y == 0:
        return (1, 0, 1/x), 0

    best_simplification, best_magnitude, best_error = None, None, None
    possible_values = list(range(-10, 11))
    candidate_ratios = [(a, c) for a in possible_values for c in possible_values]

    for a, b in candidate_ratios:
        if not zeroing_allowed and (a == 0 or b == 0):
            continue

        k = 1
        if a != 0:
            k = x / a
        if b != 0 and (a == 0 or abs(y) > abs(x)):
            k = y / b

        if k < 0: continue
        x2, y2 = k * a, k * b

        error = abs(x - x2) + abs(y - y2)
        # should be measured with the normalized values, not the original.
        if error >= epsilon: continue

        magnitude = abs(a) + abs(b)

        if best_error is None or magnitude < best_magnitude or magnitude == best_magnitude and error < best_error:
            best_simplification, best_error, best_magnitude = (a, b, x2, y2), error, magnitude

    return best_simplification, best_error

In [69]:
def simplify_stuff(input_linear, epsilon=0.1, normalize=True, zeroing_allowed=True):
    input_linear2 = input_linear.copy()
    for i in range(input_linear.shape[0]):
        nonzero, indices = get_nonzero(input_linear[i])
        x, y = nonzero
        print('original:\t', f'{x:.3f} {y:.3f}')

        if normalize:
            l1 = abs(x) + abs(y)
            x, y = x / l1, y / l1
            print('normalized:\t', f'{x:.3f} {y:.3f}')

        simplification, error = simplify(x, y, epsilon=epsilon, zeroing_allowed=zeroing_allowed)
        if simplification is None:
            print("no simplification found")
            nonzero = [x, y]
        else:
            a, b, x2, y2 = simplification
            print("new values:\t", f"{x2:.3f} {y2:.3f}", "with error", f"{error:.3f}")
            print("ratio:\t\t", f"{a} {b}")
            nonzero = [x2, y2]

            if normalize:
                l1 = abs(x2) + abs(y2)
                x2, y2 = x2 / l1, y2 / l1
                print("final normed:\t", f"{x2:.3f} {y2:.3f}")
                nonzero = [x2, y2]


        print()
        set_nonzero(input_linear2[i], nonzero, indices)

    input_linear2 = torch.tensor(input_linear2)
    feature_nn.linear.weight = torch.nn.Parameter(input_linear2)

    s = '24880_feature_nn_simplified_v3_'
    # if normalize:
        # s += 'norm_'
    # if not zeroing_allowed:
        # s += 'nozero_'
    s += f'eps={epsilon}.pt'

    torch.save(feature_nn, s)
    print(f'saved to', s)

In [71]:
simplify_stuff(input_linear, epsilon=0.001, normalize=True, zeroing_allowed=False)

original:	 0.996 -1.664
normalized:	 0.375 -0.625
new values:	 0.375 -0.625 with error 0.001
ratio:		 3 -5
final normed:	 0.375 -0.625

original:	 0.994 1.816
normalized:	 0.354 0.646
no simplification found

original:	 -1.341 0.575
normalized:	 -0.700 0.300
new values:	 -0.700 0.300 with error 0.000
ratio:		 -7 3
final normed:	 -0.700 0.300

original:	 -0.776 0.544
normalized:	 -0.588 0.412
new values:	 -0.588 0.412 with error 0.000
ratio:		 -10 7
final normed:	 -0.588 0.412

original:	 -5.254 4.657
normalized:	 -0.530 0.470
no simplification found

original:	 0.104 1.517
normalized:	 0.064 0.936
no simplification found

original:	 1.040 0.373
normalized:	 0.736 0.264
no simplification found

original:	 -0.256 1.373
normalized:	 -0.157 0.843
no simplification found

original:	 1.500 -0.233
normalized:	 0.865 -0.135
no simplification found

original:	 0.068 1.321
normalized:	 0.049 0.951
no simplification found

saved to 24880_feature_nn_simplified_v3_eps=0.001.pt
