In [6]:
import json
import random
import numpy as np
import sympy as sp
from sympy import sympify, Function, dsolve, Eq, Derivative, sin, cos, symbols, Symbol,lambdify
from sympy.abc import x

def load_expressions(filepaths):
    lines = []
    for filepath in filepaths:
        fin = open(filepath, 'r')
        lines.extend(fin.readlines())
        fin.close()
    #
    random.shuffle(lines)
    exprs = []
    for line in lines:
        data = json.loads(line)
        for k,v in data.items():
            if k in ('f_t', 'g_t'):
                expr = sympify(v[1])
                exprs.append(expr)
    return exprs

exprs = load_expressions(['datasets/parametric_equations.json'])  #'datasets/function_evaluation.json', 

In [7]:
f = sympify('t**2/3+2*t+3.3333')
f

t**2/3 + 2*t + 3.3333

In [2]:
from sympy.calculus.util import continuous_domain, function_range
from sympy import Interval, S

DEFAULT_LEFT = -4.0
DEFAULT_RIGHT = 4.0

def get_data_points(f):
    t = Symbol('t')
    #display(f)
    fl = lambdify((t), f, 'numpy')
    dom = continuous_domain(f, t, S.Reals)
    intv = dom.intersect(Interval(DEFAULT_LEFT, DEFAULT_RIGHT))
    #print("interval:", intv)
    ts = np.arange(float(intv.left), float(intv.right), 0.001)
    ys = fl(ts)
    return ts, ys

In [3]:
import torch
import os, sys
import symbolicregression
from symbolicregression import model as symbolicregression_model
import sympytorch
import requests

model_path = "ckpt/model.pt" 
try:
    if not os.path.isfile(model_path): 
        print("Downloading model...")
        url = "https://dl.fbaipublicfiles.com/symbolicregression/model1.pt"
        r = requests.get(url, allow_redirects=True)
        open(model_path, 'wb').write(r.content)
    if not torch.cuda.is_available():
        model = torch.load(model_path, map_location=torch.device('cpu'))
    else:
        model = torch.load(model_path)
        model = model.cuda()
    print(model.device)
    print("Model successfully loaded!")

except Exception as e:
    print("ERROR: model not loaded! path was: {}".format(model_path))
    print(e)    
    
est = symbolicregression_model.SymbolicTransformerRegressor(
                        model=model,
                        max_input_points=10001,
                        n_trees_to_refine=5,
                        rescale=True
                        )

cuda:0
Model successfully loaded!


In [4]:
# import errno
# import os
# import signal
# import functools

# class TimeoutError(Exception):
#     pass

# def timeout(seconds=15, error_message=os.strerror(errno.ETIME)):
#     def decorator(func):
#         def _handle_timeout(signum, frame):
#             raise TimeoutError(error_message)

#         @functools.wraps(func)
#         def wrapper(*args, **kwargs):
#             signal.signal(signal.SIGALRM, _handle_timeout)
#             signal.alarm(seconds)
#             try:
#                 result = func(*args, **kwargs)
#             finally:
#                 signal.alarm(0)
#             return result

#         return wrapper

#     return decorator

In [5]:
from utils.utils import 
from sympy.core.rules import Transform

@timeout(15)
def symbolicregression(f, verbose=True, num_digits=2):
    display(f)
    ts, ys = get_data_points(f)
    if np.isnan(ys).any():
        print("Has NaN")
    ts = np.reshape(ts, (len(ts),1))
    ys = np.reshape(ys, (len(ts),1))
    #
    est.fit(ts,ys)
    replace_ops = {"add": "+", "mul": "*", "sub": "-", "pow": "**", "inv": "1/"}
    model_str = est.retrieve_tree(with_infos=True)["relabed_predicted_tree"].infix()
    for op,replace_op in replace_ops.items():
        model_str = model_str.replace(op,replace_op)
    #
    x_0,t = symbols(['x_0', 't'])
    raw_expr = sp.parse_expr(model_str)
    expr = sp.expand(raw_expr)
    rounded_expr = expr.xreplace(Transform(lambda x: x.round(num_digits), lambda x: isinstance(x, sp.Float)))
    rounded_expr = rounded_expr.subs(x_0, t)
    display(rounded_expr)
    return expr, rounded_expr

seen_exprs = set()
fin = open("datasets/parametric_equations_sbr.json", 'r')
lines = fin.readlines()
for line in lines:
    data = json.loads(line)
    seen_exprs.add(data['original'])
fin.close()

fout = open("datasets/parametric_equations_sbr.json", 'a')
for i in range(len(exprs)):
    f = exprs[i]
    if str(f) in seen_exprs:
        continue
    try:
        expr, rounded_expr = symbolicregression(f)
        result = {'original':str(f), 'regressed':str(expr), 'rounded_regressed':str(rounded_expr)}
        fout.write(json.dumps(result)+'\n')
        fout.flush()
    except:
        continue

fout.close()

13*t/3 + 26/3

  warn_deprecated('grad')


4.33*t + 8.67

8*t**2 - 240*t + 1797

  warn_deprecated('grad')


8.0*t**2 - 240.0*t + 1797.0

(399 - 176*t)**2/2401

  warn_deprecated('grad')


12.9*t**2 - 58.5*t + 66.31

1600*t**2 - 16*t*(-1560 + 20*sqrt(3))/3 - 832*sqrt(3)/3 + 32464/3

  warn_deprecated('grad')


1600.0*t**2 + 8135.25*t + 10340.98

sqrt(2)*(-11*t**2 - 462*t + 4853)/4

  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')


-3.89*t**2 - 163.34*t + 1715.79

-45*t**2 - 450*t - 1134

  warn_deprecated('grad')


-45.0*t**2 - 450.0*t - 1134.0

0.25*(237 - 56*t)**2

  warn_deprecated('grad')


784.0*t**2 - 6636.0*t + 14042.25

392*t**2 - 1680*t + 1801

  warn_deprecated('grad')


392.0*t**2 - 1680.0*t + 1801.0

144*(t - 8)**2

  warn_deprecated('grad')


144.0*t**2 - 2304.0*t + 9216.0

-14.2886297376093*t**2 - 230.816326530612*t - 929.857142857143

  warn_deprecated('grad')


-14.29*t**2 - 230.82*t - 929.86

-98*t**2 - 420*t - 443

  warn_deprecated('grad')


-98.0*t**2 - 420.0*t - 443.0

0.5625*(42 - 5*t)**2

  warn_deprecated('grad')


14.06*t**2 - 236.25*t + 992.25

(64*t**2 - 480*t + 897)**2

  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')


4096.0*t**4 - 61440.0*t**3 + 345216.0*t**2 - 861120.0*t + 804609.01

125*t**2 - 750*t + 1131

  warn_deprecated('grad')


125.0*t**2 - 750.0*t + 1131.0

(45*t**2 + 450*t + 1124)**2

  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')


2025.0*t**4 + 40500.0*t**3 + 303660.0*t**2 + 1011600.0*t + 1263376.0

4*(5*t**2 + 450*t + 10197)**2/729

  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')
  warn_deprecated('grad')


0.14*t**4 + 24.61*t**3 + 1669.85*t**2 + 50352.98*t + 570525.54

0.015625*(133*t + 368)**2

  warn_deprecated('grad')


276.39*t**2 + 1529.5*t + 2116.0

2116*(23*t**2 - 210*t + 476)**2/2401

64*t/3 + 112

(32*t**2 + 240*t + 447)**2


KeyboardInterrupt

