# LTL Trace Generation with TensorFlow Transformer and Keras Transformer Trainer


In [None]:
import tensorflow as tf

from ml2.datasets import CSVDataset, SplitDataset
from ml2.ltl import LTLFormula
from ml2.tokenizers import ExprToSeqTokenizer
from ml2.trace import SymbolicTrace, SymTraceToSeqTokenizer
from ml2.train import KerasTransformerTrainer
from ml2.pipelines import TFTransformerPipeline

## Create Pipeline


### Create Input Tokenizer


In [None]:
input_tokenizer = ExprToSeqTokenizer(dtype=LTLFormula, pad=128)

In [None]:
formula_data = CSVDataset.load("rft-0/val", "ltl-strace", dtype=LTLFormula)
input_tokenizer.build_vocabulary(formula_data.generator())

### Create Target Tokenizer


In [None]:
target_tokenizer = SymTraceToSeqTokenizer(notation="infix", eos=True, pad=64)

In [None]:
target_data = CSVDataset.load("rft-0/val", "ltl-strace", dtype=SymbolicTrace)
target_tokenizer.build_vocabulary(target_data.generator(), add_start=True)

### Create Pipeline


In [None]:
model_config = {
    "alpha": 0.5,
    "beam_size": 2,
    "custom_pos_enc": True,
    "d_embed_dec": 128,
    "d_embed_enc": 128,
    "d_ff": 512,
    "dropout": 0.0,
    "dtype_float": tf.float32,
    "dtype_int": tf.int32,
    "ff_activation": "relu",
    "num_heads": 4,
    "num_layers_dec": 4,
    "num_layers_enc": 4,
}

In [None]:
pipeline = TFTransformerPipeline(
    name="t-0",
    project="ltl-strace",
    model_config=model_config,
    input_tokenizer=input_tokenizer,
    target_tokenizer=target_tokenizer,
    max_input_length=128,
    max_target_length=64,
)

## Create Trainer


In [None]:
data = SplitDataset.load("rft-0", "ltl-strace")

In [None]:
trainer = KerasTransformerTrainer(
    pipeline=pipeline,
    train_dataset=data["train"],
    val_dataset=data["val"],
    steps=500,
    val_freq=500,
)

In [None]:
trainer.train()

In [None]:
formula = LTLFormula.from_str("! X ( a & 1 U b )")

In [None]:
preds = pipeline(formula)

In [None]:
for pred in preds:
    if pred is None:
        print('Decoding Error')
    else:
        print(pred.to_str())

In [None]:
#pipeline.eval_attn_weights(formula)