In [None]:
%%capture
!pip -q install syncode transformers accelerate sentencepiece

In [None]:
# Load LLM
LM = None
GRAMMAR = r"""
start: obj

obj: "{" ws? "\"operands\"" ws? ":" ws? "[" ws? number ws? "," ws? number ws? "]" ws? "," ws? "\"operator\"" ws? ":" ws? op ws? "}"

op: "\"+\"" | "\"-\"" | "\"*\"" | "\"/\""

# supports:
#  - integers: 12
#  - decimals: 12.34, 12., .34
#  - optional exponent: 1.2e-3
number: /-?(?:\d+\.\d*|\d*\.\d+|\d+)(?:[eE][+-]?\d+)?/

ws: /[ \t\r\n]+/
"""

def get_lm():
    import torch
    from syncode import Syncode
    global LM
    if LM is not None:
        return LM # so no repeated loading
    device = "cuda" if torch.cuda.is_available() else "cpu"
    LM = Syncode(
        model="microsoft/phi-2",
        grammar=GRAMMAR,
        mode="grammar_strict",
        parse_output_only=True,
        device=device,
        max_new_tokens=128,
    )
    return LM


In [None]:
# Function to call

FEW_SHOT = """You are an automated system to extract information from the verbal description of arithmetics.
You should extract exactly two floating-point operands and exactly one operator from the query.
You should output a one-line JSON, in the exact format as demonstrated by the following examples!

Query: What is 327. multiplied by 11.0?
Output: {"operands":[327.0,11.0],"operator":"*"}

Query: What is 45.1 plus 23.54?
Output: {"operands":[45.1,23.54],"operator":"+"}

Query: What is 120.4 divided by 4.0?
Output: {"operands":[120.4,4.0],"operator":"/"}
"""

def run_syncode(textual_query):
    import json

    llm = get_lm()

    prompt = FEW_SHOT + "\nQuery: " + textual_query.strip() + "\nOutput:"
    out = llm.infer(prompt)[0]

    try:
        obj = json.loads(out)
        operands = obj["operands"]; operator = obj["operator"]
        a = float(operands[0]); b = float(operands[1])
        if operator not in {"+", "-", "*", "/"}:
            raise ValueError(f"Bad operator: {operator}")
    except Exception as e:
        raise ValueError(f"Failed to parse Syncode output: {out!r}\nError: {e}")

    if operator == "+":
        ans = a + b
    elif operator == "-":
        ans = a - b
    elif operator == "*":
        ans = a * b
    elif operator == "/":
        ans = a / b if b != 0.0 else float("inf")
    else:
        raise AssertionError("unreachable")

    return ([a, b], operator, float(ans))


In [None]:
# Test cases
def local_test():
    tests = [
        "What is 1919.8 multiplied by 81.0?",
        "What is 11.4 plus 51.4?",
        "What is 6.0 divided by 7.0?",
        "What is -8.9 minus 3?",
    ]

    for t in tests:
        print(t, "=>", run_syncode(t))
# local_test()

What is 1919.8 multiplied by 81.0? => ([1919.8, 81.0], '*', 155503.8)
What is 11.4 plus 51.4? => ([11.4, 51.4], '+', 62.8)
What is 6.0 divided by 7.0? => ([6.0, 7.0], '/', 0.8571428571428571)
What is -8.9 minus 3? => ([-8.9, -3.0], '-', -5.9)
