In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from flash_ansr import FlashANSR, get_path, FlashANSRTransformer, ExpressionSpace, GenerationConfig
from flash_ansr.refine import ConvergenceError
from flash_ansr.expressions.utils import codify, num_to_constants

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = 'cpu'
print(f"Using device: {device}")

Using device: cuda


In [2]:
MODEL = 'v7.0'
CHECKPOINT = ''

In [3]:
nsr = FlashANSR(
    expression_space=ExpressionSpace.from_config(get_path('models', 'ansr-models', MODEL, CHECKPOINT, 'expression_space.yaml')),
    flash_ansr_transformer=FlashANSRTransformer.load(get_path('models', 'ansr-models', MODEL, CHECKPOINT))[1].to(device).eval(),
    generation_config=GenerationConfig(method='softmax_sampling'),
    n_restarts=4,
    verbose=True,
).to(device)

print(f'{nsr.flash_ansr_transformer.n_params:,} parameters')

27,137,058 parameters


In [4]:
demo_expression = [
    ('x1**2 + 2*x1 + 1', (2, 1), (1, 5)),
    ('-x + log(x + x**4)', None, (1, 5)),
    ('0.1 * ((1.1 / x)**(12) - (1.2 / x)**6)', (0.1, 1.1, 1.2), (0.8, 2.5)),
    ('5.3 / (1.0 + exp(0.72 * (x - 2.85)))', (5.3, 1, 0.72, 2.85), (-10, 10)),
    ('5.3 / (1.0 + exp(0.72 * (x - 2.85))) + sin(1.5 * x)', (5.3, 1, 0.72, 2.85, 1.5), (-10, 10)),
][0]

In [5]:
expression, constants, xlim = demo_expression

In [6]:
prefix_expression = nsr.expression_space.parse_expression(expression, mask_numbers=True)
prefix_expression_w_num = nsr.expression_space.operators_to_realizations(prefix_expression)
prefix_expression_w_constants, constants_names = num_to_constants(prefix_expression_w_num)
code_string = nsr.expression_space.prefix_to_infix(prefix_expression_w_constants, realization=True)
code = codify(code_string, nsr.expression_space.variables + constants_names)

if constants is None:
    demo_function = lambda x: nsr.expression_space.code_to_lambda(code)(x, 0, 0)
else:
    demo_function = lambda x: nsr.expression_space.code_to_lambda(code)(x, 0, 0, *constants)

In [7]:
x = np.random.uniform(*xlim, 100)
y = demo_function(x)
if isinstance(y, float):
    y = np.full_like(x, y)

x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(-1).to(device)
y_tensor = torch.tensor(y, dtype=torch.float32).unsqueeze(-1).to(device)

# Pad the x_tensor with zeros to match the expected maximum input dimension of the set transformer
pad_length = nsr.flash_ansr_transformer.encoder_max_n_variables - x_tensor.shape[-1] - y_tensor.shape[-1]

if pad_length > 0:
    x_tensor = nn.functional.pad(x_tensor, (0, pad_length, 0, 0), value=0)

data_tensor = torch.cat([x_tensor, y_tensor], dim=-1)
print(data_tensor.shape)

torch.Size([100, 4])


In [8]:
beams, scores, is_valid = nsr.generate(data=data_tensor, verbose=True)

Generating 32 sequences (max length: 32): 100%|██████████| 32/32 [00:00<00:00, 68.26it/s]


In [9]:
print(len(beams))

19


In [10]:
for beam, score, valid in zip(beams, scores, is_valid):
    print(beam)
    print(nsr.expression_space.tokenizer.decode(beam, special_tokens='<num>'))
    print(score)
    print(valid)
    print()
    break

[1, 14, 7, 6, 30, 2]
['pow2', '+', '<num>', 'x1']
-1.2111598998287718
True



In [11]:
nsr.fit(x_tensor, y_tensor)

In [12]:
nsr.results

Unnamed: 0,log_prob,fvu,score,expression,complexity,target_complexity,numeric_prediction,raw_beam,beam,raw_beam_decoded,function,refiner,beam_id,fit_constants,fit_covariances,fit_loss
0,-1.211160,1.134832e-14,0.0004,"[pow2, +, <num>, x1]",4,,,"[1, 14, 7, 6, 30, 2]","[14, 7, 6, 30]","[pow2, +, <num>, x1]",<function <lambda> at 0x7f2f00c0d260>,"Refiner(expression=['pow2', '+', '<num>', 'x1'...",0,[0.999999981338544],[[1.4164727827735667e-16]],9.658664e-13
1,-1.211160,1.134832e-14,0.0004,"[pow2, +, <num>, x1]",4,,,"[1, 14, 7, 6, 30, 2]","[14, 7, 6, 30]","[pow2, +, <num>, x1]",<function <lambda> at 0x7f2f00c0d260>,"Refiner(expression=['pow2', '+', '<num>', 'x1'...",0,[0.9999999813370153],[[1.4164729321730692e-16]],9.658664e-13
2,-1.211160,1.134832e-14,0.0004,"[pow2, +, <num>, x1]",4,,,"[1, 14, 7, 6, 30, 2]","[14, 7, 6, 30]","[pow2, +, <num>, x1]",<function <lambda> at 0x7f2f00c0d260>,"Refiner(expression=['pow2', '+', '<num>', 'x1'...",0,[0.9999999813370158],[[1.4164727976504408e-16]],9.658664e-13
3,-1.211160,1.134832e-14,0.0004,"[pow2, +, <num>, x1]",4,,,"[1, 14, 7, 6, 30, 2]","[14, 7, 6, 30]","[pow2, +, <num>, x1]",<function <lambda> at 0x7f2f00c0d260>,"Refiner(expression=['pow2', '+', '<num>', 'x1'...",0,[0.9999999813370158],[[1.416472906535827e-16]],9.658664e-13
4,-2.377273,1.134832e-14,0.0004,"[pow2, -, <num>, x1]",4,,,"[1, 14, 8, 6, 30, 2]","[14, 8, 6, 30]","[pow2, -, <num>, x1]",<function <lambda> at 0x7f2f00c0f380>,"Refiner(expression=['pow2', '-', '<num>', 'x1'...",1,[-0.9999999813370151],[[1.4164727822310249e-16]],9.658664e-13
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
60,-17.038419,inf,inf,"[+, *, x1, +, <num>, x1, abs, sin, pow1_3, +, ...",17,,,"[1, 7, 10, 30, 7, 6, 30, 12, 22, 19, 7, 14, 30...","[7, 10, 30, 7, 6, 30, 12, 22, 19, 7, 14, 30, 2...","[+, *, x1, +, <num>, x1, abs, sin, pow1_3, +, ...",<function <lambda> at 0x7f2f00cb2160>,"Refiner(expression=['+', '*', 'x1', '+', '<num...",15,[2.1017350632001635],[[9.659248864034937e-05]],
61,-15.624030,,,"[pow3, pow1_2, +, <num>, +, <num>, -, x1, *, <...",11,,,"[1, 15, 18, 7, 6, 7, 6, 8, 30, 10, 6, 30, 2]","[15, 18, 7, 6, 7, 6, 8, 30, 10, 6, 30]","[pow3, pow1_2, +, <num>, +, <num>, -, x1, *, <...",<function <lambda> at 0x7f2f00cb28e0>,"Refiner(expression=['pow3', 'pow1_2', '+', '<n...",16,"[-1.4374481021564485, -1.500309383001398, 1.26...","[[inf, inf, inf], [inf, inf, inf], [inf, inf, ...",
62,-15.624030,,,"[pow3, pow1_2, +, <num>, +, <num>, -, x1, *, <...",11,,,"[1, 15, 18, 7, 6, 7, 6, 8, 30, 10, 6, 30, 2]","[15, 18, 7, 6, 7, 6, 8, 30, 10, 6, 30]","[pow3, pow1_2, +, <num>, +, <num>, -, x1, *, <...",<function <lambda> at 0x7f2f00cb28e0>,"Refiner(expression=['pow3', 'pow1_2', '+', '<n...",16,"[295.27240052075945, -295.2096405879673, -1.13...","[[16981078559566.346, -16981078549761.629, -22...",1.534844e-01
63,-15.624030,,,"[pow3, pow1_2, +, <num>, +, <num>, -, x1, *, <...",11,,,"[1, 15, 18, 7, 6, 7, 6, 8, 30, 10, 6, 30, 2]","[15, 18, 7, 6, 7, 6, 8, 30, 10, 6, 30]","[pow3, pow1_2, +, <num>, +, <num>, -, x1, *, <...",<function <lambda> at 0x7f2f00cb28e0>,"Refiner(expression=['pow3', 'pow1_2', '+', '<n...",16,"[22.72171128753537, -22.658951410326324, -1.13...","[[76759592052722.94, -76759591945368.47, -4594...",1.534844e-01
