In [1]:
# Load symbolicregression model

import torch
import os, sys
import symbolicregression
import sympytorch
import requests
from sympy.core.rules import Transform
import sympy as sp

model_path = "ckpt/model.pt" 
try:
    if not os.path.isfile(model_path): 
        print("Downloading model...")
        url = "https://dl.fbaipublicfiles.com/symbolicregression/model1.pt"
        r = requests.get(url, allow_redirects=True)
        open(model_path, 'wb').write(r.content)
    if not torch.cuda.is_available():
        sr_model = torch.load(model_path, map_location=torch.device('cpu'))
    else:
        sr_model = torch.load(model_path)
        sr_model = sr_model.cuda()
    print(sr_model.device)
    print("Model successfully loaded!")

except Exception as e:
    print("ERROR: model not loaded! path was: {}".format(model_path))
    print(e)    
    
est = symbolicregression.model.SymbolicTransformerRegressor(
                        model=sr_model,
                        max_input_points=10001,
                        n_trees_to_refine=5,
                        rescale=True
                        )

cuda:0
Model successfully loaded!


In [4]:
import numpy as np
import sympy as sp
from sympy import sympify, lambdify, symbols, integrate, Interval, Symbol, I, S, oo, plot
from IPython.display import display

# Given an expr f (of variable t), returns its integral, together with t's and y's for regression
def integrate_expr(f, min_x=-4.0, max_x=4.0, increment=0.002, verbose=False):
    if verbose:
        print("Running integration on")
        display(f)
    # Compute integration
    x, t = symbols(['x','t'])
    fi = integrate(f, (t, 0, x))
    fi = fi.subs(x, t)
    if verbose:
        display(fi)
        #plot(fi, (t, min_x, max_x))
    # Generate data for symbolic regression
    fl = lambdify((t), fi, "numpy")
    ts = np.arange(min_x, max_x, increment)
    ys = fl(ts)
    return fi, ts, ys
    
integrate_expr(sympify("sin(t)+2.5"), verbose=True)

Running integration on


sin(t) + 2.5

2.5*t - cos(t) + 1

(2.5*t - cos(t) + 1,
 array([-4.   , -3.998, -3.996, ...,  3.994,  3.996,  3.998]),
 array([-8.34635638, -8.33984408, -8.33333441, ..., 11.64317264,
        11.64666559, 11.65015592]))

In [7]:
from utils.utils import *

def round_expr(expr, num_digits=2):
    return expr.xreplace(Transform(lambda x: x.round(num_digits), lambda x: isinstance(x, sp.Float)))

# Run symbolic regression on given data
# Returns: (raw regressed expr, rounded expr, model refined expr)
@timeout(15)
def symbolic_regress(sr_model, xs, ys, generate_refinement=False, verbose=False):
    if verbose:
        print("Running Symbolic Regression...")
    ##Example of data
    xs = np.reshape(xs, (len(xs),1))
    ys = np.reshape(ys, (len(ys),1))
    sr_model.fit(xs,ys)
    #
    replace_ops = {"add": "+", "mul": "*", "sub": "-", "pow": "**", "inv": "1/"}
    model_str = sr_model.retrieve_tree(with_infos=True)["relabed_predicted_tree"].infix()
    for op,replace_op in replace_ops.items():
        model_str = model_str.replace(op,replace_op)
    #
    raw_expr = sp.parse_expr(model_str)
    x_0, t = symbols(['x_0', 't'])
    raw_expr = raw_expr.subs(x_0, t)
    if verbose:
        display(raw_expr)
    #
    expr = sp.expand(raw_expr)
    rounded_expr = round_expr(expr)
    if verbose:
        display(rounded_expr)
    #
    # Encode some input text
    if generate_refinement:
        prompt = str(rounded_expr)
        input_ids = tokenizer.encode(prompt, return_tensors='pt')
        #
        # Generate text
        output = model.generate(input_ids, max_length=50, num_return_sequences=1, temperature=0.1)
        #
        # Decode and print the output
        generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
        generated_expr = sympify(generated_text)
        if verbose:
            display(generated_expr)
        #
    else:
        generated_expr = None
    return raw_expr, rounded_expr, generated_expr


fi, ts, ys = integrate_expr(sympify("sin(t)+2.5"), verbose=False)
rounded_fi = round_expr(fi)
display(rounded_fi)
raw_expr, rounded_expr, generated_expr = symbolic_regress(est, ts, ys, verbose=False)
display(generated_expr)

print("Diff1:")
display(rounded_fi-rounded_expr)
# print("Diff2:")
# display(rounded_fi-generated_expr)

2.5*t - cos(t) + 1

  warn_deprecated('grad')


None

Diff1:


-0.01*t**3 - 0.24*t**2 - 0.13*t - cos(t) + 1

In [4]:
import json
import random
import numpy as np

def load_expressions(filepaths):
    lines = []
    for filepath in filepaths:
        fin = open(filepath, 'r')
        lines.extend(fin.readlines())
        fin.close()
    #
    random.shuffle(lines)
    exprs = set()
    for line in lines:
        data = json.loads(line)
        for k,v in data.items():
            if k in ('f_t', 'g_t'):
                try:
                    if 'sqrt' not in v[1]:
                        expr = sympify(v[1])
                        exprs.add(expr)
                except:
                    continue
            elif k in ('original'):
                try:
                    if 'sqrt' not in v:
                        expr = sympify(v)
                        exprs.add(expr)
                except:
                    continue
    return exprs

exprs = load_expressions(['datasets/parametric_equations_pairs.json'])  #'datasets/parametric_equations.json'  'datasets/function_evaluation.json', 

In [5]:
len(exprs)

81647

In [None]:
# Run symbolic regression on each case

seen_exprs = set()
fin = open("datasets/parametric_equations_randomized_polynomial_integral_results.json", "r")
lines = fin.readlines()
for line in lines:
    result = json.loads(line)
    expr = result["original"]
    seen_exprs.add(expr)
fin.close()
print(f"{len(seen_exprs)} exprs loaded")

fout = open("datasets/parametric_equations_randomized_polynomial_integral_results.json", "a")
num_seen = 0
num_seen_changed = False

for f in exprs:
    if str(f) in seen_exprs:
        num_seen+=1
        num_seen_changed = True
        continue
    else:
        num_seen_changed = False
    if num_seen_changed:
        print(f"{num_seen} exprs ignored")
    #print("Original expr and its integral:")
    #display(f)
    #print(f)
    try:
        fi, xs, ys = integrate_expr(f, verbose=False)
        x, t = symbols(['x','t'])
        fi = fi.subs({x:t})
        rounded_fi = round_expr(fi)
        #display(rounded_fi)
        raw_expr, rounded_expr, generated_expr = symbolic_regress(est, xs, ys, generate_refinement=False, verbose=False)
#         print("Generated expr:")
#         display(generated_expr)
        results = {"original":str(f),
                   "integral":str(fi),
                   "rounded_integral":str(rounded_fi),
                   "regressed":str(raw_expr),
                   "rounded_regressed":str(rounded_expr),
                   #"generated_regressed":str(generated_expr),
                   "diff_rounded": str(rounded_fi-rounded_expr),
                   #"diff_generated": str(rounded_fi-generated_expr)
                  }
        fout.write(json.dumps(results))
        fout.write('\n')
        fout.flush()
    except:
        print("Failed to run symbolic regression")
        continue
    #     print("Diff1:")
    #     display(rounded_fi-rounded_expr)
    #     print("Diff2:")
    #     display(rounded_fi-generated_expr)

fout.close()

4097 exprs loaded


  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')


In [8]:
# Check the accuracy of symbolic regression

import json
from sympy import evalf, N
from utils.utils import *


# Check if f1 and f2 are almost equal.
# Note: Relative error is defined based on f1. Please use the original expression as f1.
def almost_equal(f1, f2, max_abs_error=0.011, max_relative_error=0.011, verbose=False):
    expr = f1-f2
    coeff_pairs = None
    try:
        coeff_pairs = get_coefficients_and_exponents(expr)
    except:
        print("Cannot get_coefficients_and_exponents")
        print(str(expr))
    if coeff_pairs is None:
        constants = get_all_constants(expr)
    else:
        constants = [p[0] for p in coeff_pairs]
    # Check if all diffs are within max_abs_error
    violators = [c for c in constants if c == sp.nan or abs(c) > max_abs_error]
    if verbose:
        print("Violating constants:", violators)
    if len(violators) == 0:
        return True
    # Check if all violating diffs are within max_relative_error
    try:
        coeffs1 = get_polynomial_coeffs(f1)
        coeffs2 = get_polynomial_coeffs(f2)
    except:
        return False
    for i in range(len(coeffs1)):
        if abs(coeffs1[i] - coeffs2[i]) > max_abs_error and \
           abs(coeffs1[i] - coeffs2[i]) > max_relative_error*abs(coeffs1[i]):
            return False
    return True
    

# fin = open("datasets/parametric_equations_polynomial_integral_results.json", "r")
# lines = fin.readlines()

# num_total, qualified_rounded, qualified_generated = 0, 0, 0
# for line in lines:
#     result = json.loads(line)
#     if "diff_rounded" not in result or "diff_generated" not in result:
#         continue
#     rounded_integral = N(sympify(result["rounded_integral"]))
#     rounded_regressed = N(sympify(result["rounded_regressed"]))
# #     display(rounded_integral)
# #     display(rounded_regressed)
#     try:
#         rounded_regressed = filter_non_polynomial(rounded_regressed)
#     except:
#         print("Cannot filter non-polynomials on", str(rounded_regressed))
#     #generated_regressed = N(sympify(result["generated_regressed"]))
#     diff_rounded = rounded_integral - rounded_regressed
#     #diff_generated = sympify(result["diff_generated"])
#     num_total += 1
#     if almost_equal(rounded_integral, rounded_regressed, verbose=False):
#         qualified_rounded += 1
#     else:
#         display(rounded_integral)
#         display(rounded_regressed)
#         print(rounded_regressed)
#         display(diff_rounded)
# #     if is_close_to_zero(diff_generated, True):
# #         qualified_generated += 1
    
# fin.close()

# print(num_total, qualified_rounded, qualified_generated)

In [10]:
# for i in range(1000):
#     f = originals[i]
#     x, t = symbols(['x','t'])
#     fi = integrate(f, t)
#     g = integrals[i]
#     #fs, gs, diff = get_diff(fi, g, t)
#     # print(np.mean(np.abs(fs)))
#     # print(np.mean(np.abs(gs)))
#     # print(np.mean(np.abs(diff)))
#     relative_diff = get_avg_diff(fi, g, t)
#     if relative_diff > 0.01:
#         print("CASE", i)
#         display(fi)
#         display(g)
#         print(relative_diff)

In [11]:
# Generate data for regression to infer the rules for integral
# The data is in the following format:
# Let us suppose there are N regression problems, and MAX_POWER=6, which is 
#   the highest power allowed in our polynomials.
#   data_series is a list of 7 items, each corresponding to a particular power
#   Each item is a tuple of 2 elements, with the first being x's and the second
#   being y's for the regression problem for the corresponding power.
#   The y's is a 1-D array of size N, which is the coefficient of that particular power for
#   each of N regression problems.
#   The each of the x's is an array of size 14, which contains 
#   [0, coeff-power-0, 1, coeff-power-1, 2, coeff-power-2, ...]
# If we want to find the connection between x's and y's, we should find that
#   y_1 = x_1 (coeff-power-0)
#   y_2 = x_3 (coeff-power-1)
#   y_3 = x_5 (coeff-power-2)
#   y_4 = x_7 (coeff-power-3)
#   y_5 = x_9 (coeff-power-4)

from sympy import evalf, N
from utils.utils import *

fin = open("datasets/parametric_equations_polynomial_integral_results.json", "r")
lines = fin.readlines()

MAX_POWER = 6
MAX_AVG_DIFF = 0.01

data_series = [([],[])] * (MAX_POWER+1)
data_series = []
originals = []
integrals = []
t = Symbol('t')

for i in range(MAX_POWER+1):
    data_series.append((list(),list()))

for line in lines:
    result = json.loads(line)
    if "rounded_regressed" not in result:
        continue
    #original = N(sympify(result["original"]))
    original_integral = N(sympify(result["rounded_integral"]))
    integral = N(sympify(result["rounded_regressed"]))
    try:
        original_integral = filter_non_polynomial(original_integral)
        integral = filter_non_polynomial(integral)
        #original_integral = integrate(original, t)
        avg_diff = get_avg_diff(original_integral, integral, t)
        if avg_diff > MAX_AVG_DIFF:
            #print("Skipping due to diff=", avg_diff)
            #display(original_integral)
            #display(integral)
            continue
        else:
            
    except:
        print("Cannot filter non-polynomials on", str(integral))
        continue
    try:
        coeffs_original = get_polynomial_coeffs(original)
        coeffs_integral = get_polynomial_coeffs(integral)
    except:
        print("Cannot get_coefficients_and_exponents")
        display(integral)
        continue
    if original.is_constant():
        print("Skipping", line)
        continue
    xs = list()
    for i in range(MAX_POWER+1):
        xs.append(i)
        xs.append(coeffs_original[i])
    for i in range(MAX_POWER+1):
        data_series[i][0].append(xs.copy())
        data_series[i][1].append(coeffs_integral[i])
    if len(data_series[0][1]) % 100 == 0:
        print(len(data_series[0][1]), "processed")
#     if len(data_series[0][1]) == 3844:
#         display(original)
#         print(original.is_constant())
#         display(sympify(result["original"]))
#         display(integral)
    originals.append(original)
    integrals.append(integral)
    
fin.close()


100 processed
Skipping due to diff= 0.01992431786496669


0.0833333333333333*t**3 - 0.5*t**2 + 1.0*t

0.08*t**3 - 0.5*t**2 + 1.0*t

200 processed
Skipping due to diff= 0.010420435344398904


0.0833333333333333*t**3 - 0.875*t**2 + 3.0625*t

0.08*t**3 - 0.88*t**2 + 3.06*t

Skipping due to diff= 0.01244764075765293


0.375*t**2 + 1.0*t

0.38*t**2 + 1.0*t

Skipping due to diff= 0.015052592500927274


0.0208333333333333*t**3 - 0.1875*t**2 + 0.5625*t

0.02*t**3 - 0.19*t**2 + 0.56*t

300 processed
400 processed
Skipping due to diff= 0.012956870871648679


0.0833333333333333*t**3 + 0.75*t**2 + 2.25*t

0.08*t**3 + 0.75*t**2 + 2.25*t

500 processed
600 processed
Skipping due to diff= 0.011008578433168069


0.108843537414966*t**3 + 0.244897959183673*t**2 + 0.183673469387755*t

0.11*t**3 + 0.24*t**2 + 0.18*t

Skipping due to diff= 0.01733481217157357


0.244897959183673*t**3 + 0.489795918367347*t**2 + 0.326530612244898*t

0.24*t**3 + 0.49*t**2 + 0.33*t

Skipping due to diff= 0.016706664499452455


0.037037037037037*t**3 - 0.444444444444444*t**2 + 1.77777777777778*t

0.04*t**3 - 0.44*t**2 + 1.78*t

700 processed
Skipping due to diff= 0.032897073331140954


0.0833333333333333*t**3 + 0.25*t**2 + 0.25*t

0.08*t**3 + 0.25*t**2 + 0.25*t

Skipping due to diff= 0.02396960508643259


0.037037037037037*t**3 - 0.333333333333333*t**2 + 1.0*t

0.04*t**3 - 0.33*t**2 + 1.0*t

Skipping due to diff= 0.015349585447345281


0.0833333333333333*t**3 - 0.75*t**2 + 2.25*t

0.08*t**3 - 0.75*t**2 + 2.24*t

800 processed
Skipping due to diff= 0.01188636918430256


0.1875*t**3 + 0.1875*t**2 + 0.0625*t

0.19*t**3 + 0.19*t**2 + 0.06*t

900 processed
1000 processed
1100 processed
Skipping due to diff= 0.010658980220653829


0.0833333333333333*t**3 - 0.875*t**2 + 3.0625*t

0.08*t**3 - 0.87*t**2 + 3.06*t

1200 processed
Skipping due to diff= 0.013367843137255286


-0.142857142857143*t**2 + 0.714285714285714*t

-0.14*t**2 + 0.71*t

Skipping due to diff= 0.012043182690164276


0.037037037037037*t**3 + 0.555555555555556*t**2 + 2.77777777777778*t

0.04*t**3 + 0.56*t**2 + 2.78*t

1300 processed
1400 processed
1500 processed
Skipping due to diff= 0.01999999999999871


0.214285714285714*t**2

0.21*t**2

1600 processed
Skipping due to diff= 0.012022935865913421


0.1875*t**3 - 0.375*t**2 + 0.25*t

0.19*t**3 - 0.38*t**2 + 0.25*t

Skipping due to diff= 0.03999999999999875


0.0208333333333333*t**3 + 0.0625*t**2 + 0.0625*t

0.02*t**3 + 0.06*t**2 + 0.06*t

1700 processed
Skipping due to diff= 0.18323427621227703


0.00680272108843537*t**3 - 0.0408163265306122*t**2 + 0.0816326530612245*t

0.01*t**3 - 0.04*t**2 + 0.08*t

1800 processed
Skipping due to diff= 0.015215175369217203


0.244897959183673*t**3 - 0.857142857142857*t**2 + 1.0*t

0.25*t**3 - 0.86*t**2 + 1.0*t

1900 processed
Skipping due to diff= 0.03164821259095167


0.0833333333333333*t**3 - 0.25*t**2 + 0.25*t

0.08*t**3 - 0.25*t**2 + 0.25*t

2000 processed
2100 processed
Skipping due to diff= 0.02817749808059334


0.0272108843537415*t**3 - 0.244897959183673*t**2 + 0.73469387755102*t

0.03*t**3 - 0.25*t**2 + 0.73*t

2200 processed
2300 processed
Skipping due to diff= 0.020000000000000552


0.0612244897959184*t**3 - 0.122448979591837*t**2 + 0.0816326530612245*t

0.06*t**3 - 0.12*t**2 + 0.08*t

2400 processed
Skipping due to diff= 0.025502225453370938


0.037037037037037*t**3 + 0.333333333333333*t**2 + 1.0*t

0.04*t**3 + 0.33*t**2 + 1.0*t

Skipping due to diff= 0.010737257215813669


0.037037037037037*t**3 + 0.666666666666667*t**2 + 4.0*t

0.04*t**3 + 0.66*t**2 + 4.01*t

Skipping due to diff= 0.011906051039736052


0.1875*t**3 - 0.1875*t**2 + 0.0625*t

0.19*t**3 - 0.19*t**2 + 0.06*t

2500 processed
2600 processed
2700 processed
2800 processed
2900 processed
3000 processed
Skipping due to diff= 0.03615993287995145


0.037037037037037*t**3 + 0.222222222222222*t**2 + 0.444444444444444*t

0.04*t**3 + 0.22*t**2 + 0.44*t

Skipping due to diff= 0.03125527327270995


-0.125*t**2 + 0.5*t

-0.12*t**2 + 0.5*t

Skipping due to diff= 0.03831743477761161


-0.125*t**2 + 0.25*t

-0.12*t**2 + 0.25*t

3100 processed
Skipping due to diff= 0.015536477487326955


0.0533333333333333*t**3 - 0.56*t**2 + 1.96*t

0.05*t**3 - 0.56*t**2 + 1.96*t

3200 processed
Skipping due to diff= 0.013740319307972315


0.285714285714286*t**2 - 0.714285714285714*t

0.29*t**2 - 0.71*t

Skipping due to diff= 0.02000000000000051


0.0612244897959184*t**3 + 0.0612244897959184*t**2 + 0.0204081632653061*t

0.06*t**3 + 0.06*t**2 + 0.02*t

3300 processed
3400 processed
3500 processed
3600 processed
Skipping due to diff= 0.06020546902042363


0.037037037037037*t**3 + 0.111111111111111*t**2 + 0.111111111111111*t

0.04*t**3 + 0.11*t**2 + 0.11*t

3700 processed
Skipping due to diff= 0.014671812790908095


0.037037037037037*t**3 - 0.555555555555556*t**2 + 2.77777777777778*t

0.04*t**3 - 0.55*t**2 + 2.79*t + 0.01

Skipping due to diff= 0.010295552994719528


0.108843537414966*t**3 + 0.0816326530612245*t**2 + 0.0204081632653061*t

0.11*t**3 + 0.08*t**2 + 0.02*t

3800 processed
Skipping due to diff= 0.037970314237388614


0.0272108843537415*t**3 - 0.204081632653061*t**2 + 0.510204081632653*t

0.03*t**3 - 0.2*t**2 + 0.51*t

Skipping due to diff= 0.011371833839919455


-0.142857142857143*t**2 - 0.857142857142857*t

-0.14*t**2 - 0.86*t

3900 processed
4000 processed
Skipping due to diff= 0.015624999999998423


0.213333333333333*t**3

0.21*t**3

4100 processed
4200 processed
4300 processed
Skipping due to diff= 0.010171690194581199


0.444444444444444*t**3 + 0.666666666666667*t**2 + 0.333333333333333*t

0.44*t**3 + 0.67*t**2 + 0.33*t

4400 processed
Skipping due to diff= 0.010513296227581839


0.108843537414966*t**3 + 0.326530612244898*t**2 + 0.326530612244898*t

0.11*t**3 + 0.33*t**2 + 0.33*t

Skipping due to diff= 0.016496008005847773


0.00680272108843537*t**3 - 0.183673469387755*t**2 + 1.6530612244898*t

0.01*t**3 - 0.18*t**2 + 1.64*t

Skipping due to diff= 0.010353921118106632


0.166666666666667*t**3 + 1.0*t**2 + 2.0*t

0.17*t**3 + 1.0*t**2 + 2.0*t

4500 processed
Skipping due to diff= 0.01561239421573437


0.166666666666667*t**2 - 0.666666666666667*t

0.17*t**2 - 0.67*t

Skipping due to diff= 0.20560670831963118


0.0133333333333333*t**3 + 0.04*t**2 + 0.04*t

0.01*t**3 + 0.04*t**2 + 0.04*t

Skipping due to diff= 0.11694581989683726


0.00680272108843537*t**3 - 0.0612244897959184*t**2 + 0.183673469387755*t

0.01*t**3 - 0.06*t**2 + 0.18*t

4600 processed
Skipping due to diff= 0.06397783861644858


0.00680272108843537*t**3 + 0.102040816326531*t**2 + 0.510204081632653*t

0.01*t**3 + 0.1*t**2 + 0.51*t

4700 processed
Skipping due to diff= 0.013195038981960427


0.375*t**2 + 0.25*t

0.38*t**2 + 0.25*t

Skipping due to diff= 0.019345007889606972


0.0133333333333333*t**3 - 0.28*t**2 + 1.96*t

0.01*t**3 - 0.28*t**2 + 1.96*t

Skipping due to diff= 0.013203960396039544


-0.125*t**2 + 1.25*t

-0.12*t**2 + 1.25*t

4800 processed
Skipping due to diff= 0.02258807588075889


-0.125*t**2 - 0.75*t

-0.12*t**2 - 0.75*t

4900 processed
Skipping due to diff= 0.06249999999999938


0.0533333333333333*t**3

0.05*t**3

Skipping due to diff= 0.01879939700900003


-0.166666666666667*t**2 + 0.333333333333333*t

-0.17*t**2 + 0.33*t

5000 processed
5100 processed
5200 processed
5300 processed
5400 processed
Skipping due to diff= 0.010876757110165856


-0.214285714285714*t**2 + 1.28571428571429*t

-0.21*t**2 + 1.29*t

Skipping due to diff= 0.031131746664010477


0.0533333333333333*t**3 - 0.32*t**2 + 0.64*t

0.05*t**3 - 0.32*t**2 + 0.64*t

5500 processed
Skipping due to diff= 0.01960784313725307


0.166666666666667*t**3

0.17*t**3

5600 processed
5700 processed
Skipping due to diff= 0.010513296227581861


0.108843537414966*t**3 - 0.326530612244898*t**2 + 0.326530612244898*t

0.11*t**3 - 0.33*t**2 + 0.33*t

5800 processed
5900 processed
Skipping due to diff= 0.016182360515297735


0.166666666666667*t**3 + 0.5*t**2 + 0.5*t

0.17*t**3 + 0.5*t**2 + 0.5*t

6000 processed
6100 processed
Skipping due to diff= 0.013104521728193192


-0.375*t**2 + 0.5*t

-0.37*t**2 + 0.5*t

6200 processed
6300 processed
Skipping due to diff= 0.015577604624074301


0.166666666666667*t**3 - 0.5*t**2 + 0.5*t

0.17*t**3 - 0.5*t**2 + 0.5*t

Skipping due to diff= 0.0392189154217153


-0.125*t**2 - 0.25*t

-0.12*t**2 - 0.25*t

Skipping due to diff= 0.016195578065647923


-0.142857142857143*t**2 - 0.571428571428571*t

-0.14*t**2 - 0.57*t

6400 processed
Skipping due to diff= 0.011721738445614574


0.375*t**2 + 1.25*t

0.38*t**2 + 1.25*t

Skipping due to diff= 1.0


0

13.44*t**3 - 44.0*t**2 + 48.0*t

6500 processed
Skipping due to diff= 0.012772478259203883


0.375*t**2 - 0.75*t

0.37*t**2 - 0.75*t

6600 processed
Skipping due to diff= 0.011194099092742552


0.213333333333333*t**3 + 4.8*t**2 + 36.0*t

0.21*t**3 + 4.76*t**2 + 36.47*t + 0.01

6700 processed
Skipping due to diff= 0.010874390623379188


0.0833333333333333*t**3 + 0.875*t**2 + 3.0625*t

0.08*t**3 + 0.87*t**2 + 3.06*t

6800 processed
6900 processed
Skipping due to diff= 0.019343212473703325


0.0533333333333333*t**3 - 0.48*t**2 + 1.44*t

0.05*t**3 - 0.48*t**2 + 1.44*t

Skipping due to diff= 0.020000000000000337


0.0612244897959184*t**3 + 0.183673469387755*t**2 + 0.183673469387755*t

0.06*t**3 + 0.18*t**2 + 0.18*t

7000 processed
Skipping due to diff= 0.010149051198805631


0.1875*t**3 + 0.5625*t**2 + 0.5625*t

0.19*t**3 + 0.56*t**2 + 0.56*t

7100 processed
7200 processed
Skipping due to diff= 0.016753632244961448


0.0833333333333333*t**3 - 0.625*t**2 + 1.5625*t

0.08*t**3 - 0.62*t**2 + 1.56*t

7300 processed
Skipping due to diff= 0.01561239421573437


-0.166666666666667*t**2 + 0.666666666666667*t

-0.17*t**2 + 0.67*t

7400 processed
7500 processed
7600 processed
7700 processed
Skipping due to diff= 0.0773728498948133


0.0133333333333333*t**3 - 0.12*t**2 + 0.36*t

0.01*t**3 - 0.12*t**2 + 0.36*t

Skipping due to diff= 0.05514758900121737


0.0272108843537415*t**3 - 0.122448979591837*t**2 + 0.183673469387755*t

0.03*t**3 - 0.12*t**2 + 0.18*t

7800 processed
Skipping due to diff= 0.016182360515297735


0.166666666666667*t**3 + 0.5*t**2 + 0.5*t

0.17*t**3 + 0.5*t**2 + 0.5*t

7900 processed
Skipping due to diff= 0.012914151492990607


-0.375*t**2 - 0.75*t

-0.38*t**2 - 0.75*t

8000 processed
Skipping due to diff= 0.01879939700900003


0.166666666666667*t**2 - 0.333333333333333*t

0.17*t**2 - 0.33*t

8100 processed
8200 processed
Skipping due to diff= 0.24999999999999814


0.0133333333333333*t**3

0.01*t**3

8300 processed
8400 processed
Skipping due to diff= 0.039999999999999626


0.0833333333333333*t**3

0.08*t**3

8500 processed
Skipping due to diff= 0.03430914034720834


0.0208333333333333*t**3 + 0.125*t**2 + 0.25*t

0.02*t**3 + 0.13*t**2 + 0.25*t

8600 processed
8700 processed
8800 processed
Skipping due to diff= 0.010053873864444594


0.108843537414966*t**3 - 0.244897959183673*t**2 + 0.183673469387755*t

0.11*t**3 - 0.24*t**2 + 0.18*t

Skipping due to diff= 0.03164821259095167


0.0833333333333333*t**3 - 0.25*t**2 + 0.25*t

0.08*t**3 - 0.25*t**2 + 0.25*t

Skipping due to diff= 0.013563747611938794


0.285714285714286*t**2 + 0.857142857142857*t

0.29*t**2 + 0.86*t

8900 processed


In [12]:
import pickle

# fout = open('/tmp/test_data_series.pkl', 'wb')
# pickle.dump(data_series, fout)
# fout.close()

fin = open('/tmp/test_data_series.pkl', 'rb')
data_series = pickle.load(fin)
fin.close()

# Check if the results are as expected
n = len(data_series[1][1])
for y_idx in range(3, 4):
    xs = np.asarray(data_series[y_idx][0])
    x1s = xs[:,(y_idx-1)*2+1]
    ys = data_series[y_idx][1]
    for i in range(n):
        if abs(ys[i] - x1s[i]/y_idx) > 0.02 and abs(ys[i]/(x1s[i]/y_idx)-1) > 0.02:
            print(f"y_idx={y_idx}, i={i}, x={x1s[i]}, y={ys[i]}")
            display(originals[i])
            display(integrals[i])

y_idx=3, i=10, x=0.3333333333333333, y=0.08


-2.57142857142857*t - 5.28571428571429

-1.29*t**2 - 5.29*t

y_idx=3, i=62, x=0.3333333333333333, y=0.0


65.3333333333333*t**2 - 121.333333333333*t + 56.3333333333333

21.78*t**3 - 60.67*t**2 + 56.33*t

y_idx=3, i=402, x=0.4444444444444444, y=0.17


75.0*t**2 - 120.0*t + 48.0

25.0*t**3 - 60.0*t**2 + 48.0*t

y_idx=3, i=513, x=0.0625, y=0.0


32.1111111111111*t**2 - 98.2222222222222*t + 75.1111111111111

10.7*t**3 - 49.11*t**2 + 75.11*t

y_idx=3, i=679, x=0.16, y=0.0


-3.42857142857143*t - 0.142857142857143

-1.71*t**2 - 0.14*t

y_idx=3, i=854, x=0.16, y=0.03


2.6*t - 3.0

1.3*t**2 - 3.0*t

y_idx=3, i=942, x=0.5, y=0.05


2.56*t**2 + 12.16*t + 14.44

0.85*t**3 + 6.08*t**2 + 14.44*t

y_idx=3, i=1028, x=0.7346938775510204, y=0.0


7.8*t + 5.2

3.9*t**2 + 5.2*t

y_idx=3, i=1103, x=0.1111111111111111, y=0.0


7.2*t + 6.2

3.6*t**2 + 6.2*t

y_idx=3, i=1121, x=0.16, y=0.01


8.8*t - 6.6

4.4*t**2 - 6.6*t

y_idx=3, i=1138, x=0.0625, y=0.0


72.25*t**2 - 17.0*t + 1.0

24.08*t**3 - 8.5*t**2 + 1.0*t

y_idx=3, i=1165, x=0.1111111111111111, y=0.0


7.8*t - 15.0

3.9*t**2 - 15.0*t

y_idx=3, i=1205, x=0.25, y=0.0


-3.4*t - 8.8

-1.7*t**2 - 8.8*t

y_idx=3, i=1215, x=0.08163265306122448, y=0.0


-8.6*t - 4.2

-4.3*t**2 - 4.2*t

y_idx=3, i=1230, x=0.5, y=0.0


6.85714285714286*t - 1.57142857142857

3.43*t**2 - 1.57*t

y_idx=3, i=1237, x=1.3333333333333333, y=0.28


-2.57142857142857*t - 6.14285714285714

-1.29*t**2 - 6.14*t

y_idx=3, i=1243, x=0.3333333333333333, y=0.0


0.510204081632653*t**2 + 3.26530612244898*t + 5.22448979591837

0.17*t**3 + 1.63*t**2 + 5.22*t

y_idx=3, i=1402, x=0.16, y=0.01


1.96*t**2 - 12.88*t + 21.16

0.65*t**3 - 6.44*t**2 + 21.16*t

y_idx=3, i=1572, x=0.64, y=0.0


9.0*t**2 - 43.7142857142857*t + 53.0816326530612

3.0*t**3 - 21.86*t**2 + 53.08*t

y_idx=3, i=1647, x=0.08163265306122448, y=0.0


8.33333333333333*t + 3.0

4.17*t**2 + 3.0*t

y_idx=3, i=1856, x=0.16, y=0.02


1.0 - 4.4*t

-2.2*t**2 + 1.0*t

y_idx=3, i=1944, x=0.16, y=0.01


39.5102040816327*t**2 + 8.97959183673469*t + 0.510204081632653

13.17*t**3 + 4.49*t**2 + 0.51*t

y_idx=3, i=1949, x=0.0625, y=0.0


23.04*t**2 - 28.8*t + 9.0

7.68*t**3 - 14.4*t**2 + 9.0*t

y_idx=3, i=2173, x=0.1111111111111111, y=0.0


0.111111111111111*t**2 + 3.11111111111111*t + 21.7777777777778

0.03*t**3 + 1.54*t**2 + 21.88*t

y_idx=3, i=2282, x=0.3333333333333333, y=0.0


17.1632653061224*t**2 + 16.5714285714286*t + 4.0

5.72*t**3 + 8.29*t**2 + 4.0*t

y_idx=3, i=2471, x=0.4444444444444444, y=0.0


12.0*t**2 + 24.0*t + 12.0

4.0*t**3 + 12.0*t**2 + 12.0*t

y_idx=3, i=2629, x=0.1836734693877551, y=0.0


8.33333333333333 - t

-0.5*t**2 + 8.33*t

y_idx=3, i=2699, x=0.0625, y=0.0


6.25*t**2 - 25.0*t + 25.0

2.08*t**3 - 12.5*t**2 + 25.0*t

y_idx=3, i=2717, x=0.1836734693877551, y=0.0


-t - 6.5

-0.5*t**2 - 6.5*t

y_idx=3, i=2958, x=0.1111111111111111, y=0.01


3.0 - 6.5*t

-3.25*t**2 + 3.0*t

y_idx=3, i=3139, x=0.16, y=0.02


9.0*t**2 - 29.1428571428571*t + 23.5918367346939

3.0*t**3 - 14.57*t**2 + 23.59*t

y_idx=3, i=3231, x=0.25, y=0.04


4.5*t**2 + 63.0*t + 220.5

1.5*t**3 + 31.5*t**2 + 220.5*t

y_idx=3, i=3260, x=0.1111111111111111, y=0.0


4.0*t**2 - 12.0*t + 9.0

1.33*t**3 - 6.0*t**2 + 9.0*t

y_idx=3, i=3363, x=0.08163265306122448, y=0.0


53.7777777777778*t**2 - 4.88888888888889*t + 0.111111111111111

17.93*t**3 - 2.44*t**2 + 0.11*t

y_idx=3, i=3697, x=0.08163265306122448, y=0.0


21.3333333333333*t**2 + 10.6666666666667*t + 1.33333333333333

7.11*t**3 + 5.33*t**2 + 1.33*t

y_idx=3, i=4300, x=0.25, y=0.05


1.57142857142857*t - 5.42857142857143

0.79*t**2 - 5.43*t

y_idx=3, i=4400, x=0.16, y=0.02


10.5625*t**2 - 50.375*t + 60.0625

3.52*t**3 - 25.19*t**2 + 60.06*t

y_idx=3, i=4420, x=0.0625, y=0.0


-3.0*t - 1.8

-1.5*t**2 - 1.8*t

y_idx=3, i=4504, x=0.0625, y=0.0


56.3333333333333*t**2

18.78*t**3

y_idx=3, i=4561, x=0.1111111111111111, y=0.0


1.71428571428571*t + 1.85714285714286

0.86*t**2 + 1.86*t

y_idx=3, i=5176, x=0.32653061224489793, y=0.0


1.0 - 5.42857142857143*t

-2.71*t**2 + 1.0*t

y_idx=3, i=5183, x=0.5102040816326531, y=0.11


3.0625*t**2 - 21.0*t + 36.0

1.02*t**3 - 10.5*t**2 + 36.0*t

y_idx=3, i=5594, x=0.16, y=0.0


8.0*t**2 + 84.0*t + 220.5

2.67*t**3 + 42.0*t**2 + 220.5*t

y_idx=3, i=5840, x=0.08163265306122448, y=0.0


43.56*t**2 - 105.6*t + 64.0

14.52*t**3 - 52.8*t**2 + 64.0*t

y_idx=3, i=5916, x=0.5625, y=0.0


3.8*t + 5.6

1.9*t**2 + 5.6*t

y_idx=3, i=6036, x=0.5102040816326531, y=0.0


13.4444444444444*t**2 + 26.8888888888889*t + 13.4444444444444

4.48*t**3 + 13.44*t**2 + 13.44*t

y_idx=3, i=6248, x=0.25, y=0.0


7.28571428571429*t + 0.142857142857143

3.64*t**2 + 0.14*t

y_idx=3, i=6472, x=0.16, y=0.03


0.333333333333333*t**2 + 8.66666666666667*t + 56.3333333333333

0.11*t**3 + 4.33*t**2 + 56.33*t

y_idx=3, i=6536, x=0.08163265306122448, y=0.0


8.4*t + 2.8

4.2*t**2 + 2.8*t

y_idx=3, i=6571, x=0.36, y=0.0


4.0 - 4.42857142857143*t

-2.21*t**2 + 4.0*t

y_idx=3, i=7230, x=0.5, y=0.0


33.0625*t**2 + 34.5*t + 9.0

11.02*t**3 + 17.25*t**2 + 9.0*t

y_idx=3, i=7425, x=0.7346938775510204, y=0.06


56.25*t**2 - 225.0*t + 225.0

18.75*t**3 - 112.5*t**2 + 225.0*t

y_idx=3, i=7936, x=1.3061224489795917, y=0.33


29.16*t**2 - 30.24*t + 7.84

9.72*t**3 - 15.12*t**2 + 7.84*t

y_idx=3, i=8023, x=0.1836734693877551, y=0.04


t**2 - 12.6666666666667*t + 40.1111111111111

0.33*t**3 - 6.33*t**2 + 40.11*t

y_idx=3, i=8086, x=0.1111111111111111, y=0.0


5.5*t - 4.25

2.75*t**2 - 4.25*t

y_idx=3, i=8311, x=0.0625, y=0.0


72.25*t**2 + 34.0*t + 4.0

24.08*t**3 + 17.0*t**2 + 4.0*t

y_idx=3, i=8330, x=0.16, y=0.0


-8.8*t - 1.6

-4.4*t**2 - 1.6*t

In [18]:
y_idx=1

n = len(data_series[1][1])

est_simple = symbolicregression2.model.SymbolicTransformerRegressor(
                        model=sr_model,
                        max_input_points=10001,
                        n_trees_to_refine=100,
                        rescale=True
                        )


xs = np.asarray(data_series[y_idx][0])
x1s = np.reshape(xs[:,y_idx], (n,1))
ys = np.reshape(data_series[y_idx][1], (n,1))
ys = ys + 2*x1s*x1s + 3.5

print(xs[0:10, y_idx])
print(ys[0:10])

est_simple.fit(x1s,ys)

replace_ops = {"add": "+", "mul": "*", "sub": "-", "pow": "**", "inv": "1/"}
model_str = est_simple.retrieve_tree(with_infos=True)["relabed_predicted_tree"].infix()
for op,replace_op in replace_ops.items():
    model_str = model_str.replace(op,replace_op)
    
print(model_str)

raw_expr = sp.parse_expr(model_str)
sp.simplify(sp.expand(round_expr(raw_expr)))

[16.      1.      5.75   45.5625 64.      2.25   -2.5    -1.25   -5.
 -3.25  ]
[[5.31610000e+02]
 [6.50000000e+00]
 [7.53750000e+01]
 [4.20094281e+03]
 [8.25950000e+03]
 [1.58750000e+01]
 [1.35000000e+01]
 [5.37500000e+00]
 [4.85000000e+01]
 [2.13750000e+01]]


  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')


(((1.554575343737481 + (5.640700005269877 * (-0.4280303776973979 + (0.019889758680590723 * x_0)))) + (1.696035343737901 + (62.61638890529857 * (((0.06962404147058532 + (0.16090522472991203 * (-0.4280303776973979 + (0.019889758680590723 * x_0)))) * (-55.84349567032667 + (2.9520097485743936e-05 * (-0.4280303776973979 + (0.019889758680590723 * x_0))))))**2))) - (-1.4310353437381038 + (2.6206999947306824 * (-0.4280303776973979 + (0.019889758680590723 * x_0)))))


  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')


2.0*x_0**2 + 0.97998046875*x_0 + 3.4959716796875

In [19]:
y_idx=2
n = len(data_series[1][1])

xs = np.asarray(data_series[y_idx][0])
x3s = np.reshape(xs[:,2*y_idx+1], (n,1))
ys = np.reshape(data_series[y_idx][1], (n,1))
ys = ys + 2*x3s*x3s + 3.5

# print(xs[0:10, 1])
# print(ys[0:10])

est_simple.fit(x3s,ys)

replace_ops = {"add": "+", "mul": "*", "sub": "-", "pow": "**", "inv": "1/"}
model_str = est_simple.retrieve_tree(with_infos=True)["relabed_predicted_tree"].infix()
for op,replace_op in replace_ops.items():
    model_str = model_str.replace(op,replace_op)
    
print(model_str)

raw_expr = sp.parse_expr(model_str)
sp.simplify(sp.expand(round_expr(raw_expr)))

  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')


(3.3392761051514017 + (0.21293595651491945 * ((((-1.301679949865295 * (-0.6085491829353974 + (0.04605635689950794 * x_0))) - (0.8040539237955435 + (0.5666800290262324 * (-0.6085491829353974 + (0.04605635689950794 * x_0))))) + (((2.1713199788039668 * (-0.6085491829353974 + (0.04605635689950794 * x_0))) - (-43.995946222924566 + (8.948680207527886 * (-0.6085491829353974 + (0.04605635689950794 * x_0))))) + ((0.28953627372418606 + (76.37131988767095 * (-0.6085491829353974 + (0.04605635689950794 * x_0)))) + (1.3659462479783977 + (0.32960227394313424 * ((10.33464441154438 + (-0.03083960010993443 * ((-4.059956923118578 + (-0.001071832130220887 * ((2.4162290705595773 + (-4.504784126501213 * ((-3.316814557241166 + (-0.33719902144061215 * ((2.412266701884242 + (0.006376601025858294 * (-0.6085491829353974 + (0.04605635689950794 * x_0)))))**2)))**2)))**2)))**2)))**3))))))**2))


  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  alpha1 = min(1.0, 1.01*2*(phi0 - old_phi0)/derphi0)
  alpha1 = min(1.0, 1.01*2*(phi0 - old_phi0)/derphi0)


2.8222274780273438*x_0**2 - 1.9871347770094872*x_0 + 3.689521593041718

In [23]:
import scipy

y_idx=1
ys = np.asarray(data_series[y_idx][1])
xs = np.asarray(data_series[y_idx][0])[:,1]

print(scipy.stats.linregress(xs, ys))

diff = (ys-xs).flatten()
for i in range(len(diff)):
    if abs(diff[i]) > 0.02:
        print(i, diff[i])

LinregressResult(slope=1.0002448476609136, intercept=-0.0014890801122113828, rvalue=0.9999992237781083, pvalue=0.0, stderr=1.3262767772178022e-05, intercept_stderr=0.0007253298787952417)
0 0.10999999999999943
10 0.7066666666667061
20 0.05408163265309973
31 -0.14333333333332998
46 -0.06000000000000005
62 1.7966666666669937
72 0.03999999999999915
183 0.020000000000010232
185 0.04750000000000032
186 0.02999999999999936
203 -0.04306122448979721
210 0.06265306122449843
286 -0.07428571428571029
295 0.05918367346940201
402 -0.38444444444439796
411 0.08750000000000036
431 -0.060000000000002274
513 0.3099999999999987
517 0.03666666666670082
590 0.1038775510203962
612 0.10999999999999943
621 -0.0611111111111029
642 0.03999999999999915
679 0.5999999999999943
680 -0.028979591836729934
687 -0.05999999999999961
724 -0.03249999999999886
757 -0.020546769020477473
854 0.1699999999999946
861 0.040000000000000036
873 -0.029999999999999805
888 0.0799999999999983
921 -0.03999999999999915
942 0.75
945 -0.02

In [28]:
originals[4]

72.25*t**2 + 106.25*t + 39.0625