# NL to FOL Translation with `LiteLLM`

Use `LiteLLM` to translate natural language sentences to first-order logic (FOL) formulas, which offers more flexibility to swap the underlying LLM.

## 0. Setup

In [29]:
import os
import json
import pandas as pd
from litellm import completion
from rich import print as rprint

from prompt import PromptTemplate, TRANSLATE_TEMPLATE_V1, TRANSLATE_EXAMPLES_V1, CORRECT_TEMPLATE_V1
from metrics import is_syntactically_valid_with_timeout

In [3]:
# Load environment variables
from dotenv import load_dotenv
load_dotenv()

True

In [6]:
trans_prompt_template = PromptTemplate(TRANSLATE_TEMPLATE_V1)
correct_prompt_template = PromptTemplate(CORRECT_TEMPLATE_V1)

In [8]:
df = pd.read_csv('data/quy_che.tsv', sep='\t')

all_quy_che = df['quy_che'].tolist()
all_quy_che_fols = []

In [48]:
for nl_str in all_quy_che:
    messages = []
    
    trans_prompt = trans_prompt_template(
        nl_str=nl_str,
        existing_predicates_str=[],
        examples_str=TRANSLATE_EXAMPLES_V1
    )
    messages.append({
        "role": "user",
        "content": trans_prompt
    })
    rprint(messages[-1]) # FIXME: delete this line
    invalid_premises = []
    
    while True:
        response = completion(
            model=os.environ['LITELLM_MODEL'],
            messages=messages,
            api_key=os.environ['LITELLM_API_KEY'], 
            base_url=os.environ['LITELLM_BASE_URL'],
            # temperature=0.0,
            # top_p=0.7,
            # max_tokens=8192,
            stream=False
        )

        result = response.choices[0].message.content
        messages.append({
            "role": "assistant",
            "content": result
        })
        rprint(messages[-1]) # FIXME: delete this line
        
        json_str = result.strip('```').strip()
        if json_str.startswith('json'):
            json_str = json_str[4:].strip()
        parsed_json = json.loads(json_str)
        
        invalid_premises = [] # Reset invalid_premises
        for premise in parsed_json['premises']:
            if not is_syntactically_valid_with_timeout(premise):
                invalid_premises.append(premise)
        
        if len(invalid_premises) > 0:
            messages.append({
                "role": "user",
                "content": correct_prompt_template(invalid_premises_str='- ' + '\n- '.join(invalid_premises))
            })
            rprint(messages[-1]) # FIXME: delete this line
        else:
            all_quy_che_fols.append(parsed_json)
            break
    

In [None]:
for i, quy_che_fol in enumerate(all_quy_che_fols):
    rprint(f"Quy chế {i+1}:")
    rprint(quy_che_fol)