# Example of how to load model and make method prediction given an integrand

In [1]:
import pandas as pd
import random
import torch
import torch.nn as nn

from src.utils.config import load_config
from src.utils.io import load_vocab, load_precomputed_positions
from src.utils.tree_utils import get_prefix_data_with_paths, path_to_index

from src.models.tree_transformer import TreeTransformer

from scripts.inference import maybe_strip_module_prefix
from src.data.dataset import PrefixExpressionDataset


## Load model

In [2]:
cfg = load_config("configs/train_config.yaml")
vocab = load_vocab(cfg) # all possible tokens in the dataset

In [3]:
# model definition
model = TreeTransformer(
        vocab_size=len(vocab),
        d_model=cfg.model.d_model,
        nhead=cfg.model.heads,
        num_layers=cfg.model.layers,
        dim_feedforward=cfg.model.dim_feedforward,
        num_labels=12, # number of possible methods in the dataset based on maple's int methods (not including lookup)
        n=cfg.tree.branching_factor,
        k=cfg.tree.depth
    )

In [4]:
# load the saved model
device = torch.device("cpu") # force to use cpu, normally gpu for training and the inference script
checkpoint_path = 'models/ranking/ranking_best.pth'
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)

In [5]:
# If the checkpoint was saved with DataParallel, keys may have 'module.' prefix
state_dict = checkpoint.get("model_state_dict", checkpoint)
state_dict = maybe_strip_module_prefix(state_dict) # function to remove 'module.' if running on CPU
model.load_state_dict(state_dict)

model.to(device)
model.eval()

TreeTransformer(
  (embedding): Embedding(100, 40)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=40, out_features=40, bias=True)
        )
        (linear1): Linear(in_features=40, out_features=128, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=128, out_features=40, bias=True)
        (norm1): LayerNorm((40,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((40,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (classifier): Linear(in_features=40, out_features=12, bias=True)
  (cls_dropout): Dropout(p=0.1, inplace=False)
  (cls_norm): LayerNorm((40,), eps=1e-05, elementwise_affine=True)
)

## Load Dataset

In [6]:
# The integrand column is in string form (Maple readable format), but we use the prefix expression form for the model
df = pd.read_parquet('data/processed/train_data.parquet')
df.head()

Unnamed: 0,integrand,prefix,integral,label_original,source,label
0,2*2^(1/2)*x^(1/2)*tan(1)^(1/2),"[[CLS], mul, INT+, 2, mul, pow, INT+, 2, div, ...",4/3*x^(3/2)*2^(1/2)*tan(1)^(1/2),"[72, 72, -1, 115, -1, -1, -1, -1, -1, -1, 72, 72]",elementary,"[0.0, 0.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, ..."
1,2*x^2*arctanh(x)+ln(x),"[[CLS], add, mul, INT+, 2, mul, pow, x, INT+, ...",2/3*x^3*arctanh(x)+1/3+1/3*x^2+x*ln(x)+2/3*ln(...,"[139, -1, 139, 143, -1, -1, 130, -1, -1, -1, -...",elementary,"[0.6923076923076923, -1.0, 0.6923076923076923,..."
2,1/x*sin(cosh(5)),"[[CLS], mul, pow, x, INT-, 1, sin, cosh, INT+,...",sin(cosh(5))*ln(x),"[83, -1, -1, 83, 83, -1, 83, -1, -1, -1, -1, -1]",elementary,"[0.0, -1.0, -1.0, 0.0, 0.0, -1.0, 0.0, -1.0, -..."
3,-3+Pi-x,"[[CLS], add, INT-, CONST1, add, Pi, mul, INT-,...",1/2*x*(-x+2*Pi-6),"[54, -1, 54, 54, 58, -1, 58, -1, -1, -1, 54, 54]",elementary,"[0.0, -1.0, 0.0, 0.0, 1.0, -1.0, 1.0, -1.0, -1..."
4,16*x^2/sin(2)^2,"[[CLS], mul, INT+, CONST2, mul, pow, x, INT+, ...",16/3*x^3/sin(2)^2,"[62, -1, -1, 62, 62, -1, 62, -1, -1, -1, 62, 62]",elementary,"[0.0, -1.0, -1.0, 0.0, 0.0, -1.0, 0.0, -1.0, -..."


In [7]:
# Dictionary of methods to integer labels
method_dict = {'default': 0, 
            'ddivides': 1, 
            'parts': 2,
            'risch': 3,
            'norman': 4,
            'trager': 5,
            'parallelrisch': 6,
            'meijerg': 7, 
            'elliptic': 8,
            'pseudoelliptic':9,
            'gosper': 10,
            'orering': 11}

In [214]:
# Here is an example from the dataset. You can change the sample_n to try different examples.
sample_n = 793966

print(df['integrand'].iloc[sample_n])
print(df['prefix'].iloc[sample_n])
print(df['label'].iloc[sample_n])

x^2*(3*x^3*cosh(2*x^3)+sinh(2*x^3))
['[CLS]' 'mul' 'pow' 'x' 'INT+' '2' 'add' 'mul' 'INT+' 'CONST1' 'mul'
 'pow' 'x' 'INT+' 'CONST1' 'cosh' 'mul' 'INT+' '2' 'pow' 'x' 'INT+'
 'CONST1' 'sinh' 'mul' 'INT+' '2' 'pow' 'x' 'INT+' 'CONST1']
[-1.         -1.         -1.          0.         -1.         -1.
 -1.          0.37755102 -1.         -1.         -1.          1.        ]


For the given integrand, we see that three methods produce an answer, and the rest fail (denoted by -1). From our method dictionary, we see that risch produced the answer with the shortest DAG size, meijerg produces an answer a bit longer than risch, and orering produces the longest answer. Lets take the integrand and integrate it in Maple with each of the following methods.

#### risch result

![risch result](notebook_images/f1_risch.png)

#### meijerg result

![meijerg result](notebook_images/f1_meijerg.png)

#### orering result

![orering result](notebook_images/f1_orering.png)

## Prepare Input for model

In [210]:
expr = df['prefix'].iloc[sample_n]
tokens, path_list = get_prefix_data_with_paths(expr) # path list is the one-hot encoding of path to node from root
token_ids = torch.tensor([vocab.get(tok, 1) for tok in tokens], dtype=torch.long) # converts tokens to ids, unknown token is 1
token_ids

tensor([ 2,  3,  6,  9, 10,  8,  3,  6, 19,  8, 16,  8,  3,  6,  8,  6,  3,  9,
        17,  7, 16,  8,  9, 10, 19,  8,  6,  4, 17,  6,  8,  6, 18,  8, 16,  8])

In [None]:
# tree positional encoding for each token
token_mask = torch.zeros(token_ids.shape, dtype=torch.bool) # only used during training to denote pad tokens, hence all zeros here
positions = load_precomputed_positions(cfg)
k = cfg.tree.depth
pos_tensor = torch.stack([
                positions[min(len(path), k)][path_to_index(path[:min(len(path), k)])]
                for path in path_list
            ])
pos_tensor

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [1., 0., 1.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]])

In [215]:
# Make all the same shape of batch size 1
token_ids   = token_ids.unsqueeze(0).to(device)         # [1, L], long
pos_tensor  = pos_tensor.unsqueeze(0).to(device)        # [1, L, d_model], float
token_mask  = token_mask.unsqueeze(0).to(device)        # [1, L], bool

In [None]:
res = model(token_ids, pos_tensor, token_mask).squeeze()  # [1, L, num_labels]
torch.round(torch.sigmoid(res), decimals=4)  # ordering of methods (ascending)

tensor([0.4194, 0.2343, 0.3913, 0.6850, 0.7282, 0.6633, 0.0702, 0.9550, 0.9121,
        0.4445, 0.0958, 0.9919], grad_fn=<RoundBackward1>)

Note that the model is predicting scaled DAG size, so a smaller value is better/ 

We see from the output, the model first tries to predict methods that are going to fail. This is where the classifier stage comes in to filter out the methods that are likely to fail (this will be included in the examples later). For the methods that do succeed, we predict to try risch first, then meijerg, then orering. We are able to correctly predict the best method for DAG size, risch. Although the first method is the only one that actually matters, we note that the model is able to get the complete ordering correctly.  