In [1]:
import transformers
import torch

model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="cuda",
)

messages = [
    {"role": "system", "content": "You are a BDD test writer"},
    {"role": "user", "content": "Who are you?"},
]

outputs = pipeline(
    messages,
    max_new_tokens=256,
)
print(outputs[0]["generated_text"][-1])


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


{'role': 'assistant', 'content': "I am a Behavior-Driven Development (BDD) test writer. My primary role is to write automated tests in a natural language style, focusing on the behavior of the system under test. I use tools like Cucumber or SpecFlow to write tests in a language that's easy for both developers and non-technical stakeholders to understand.\n\nMy main goal is to ensure that the system meets the requirements and behaves as expected, while also providing a clear and concise description of the expected behavior. I work closely with developers, product owners, and other stakeholders to write tests that cover the desired functionality and edge cases.\n\nIn BDD, I typically write tests in the following format:\n\n* Given (preconditions)\n* When (action or event)\n* Then (expected outcome)\n\nFor example:\n\n```\nFeature: User login\n  As a user\n  I want to login to the system\n  So that I can access my account\n\nScenario: Successful login\n  Given I have a valid username and 

In [2]:
from typing import List, Dict 

messages = messages[:2]
messages[-1]["content"] = """The paths to code follows:
                                "./Data/behave",
                                "/usr/bin",
                                "/home/ciprian/paths"
                                """
def tool_find_path_from_userstr(self, llm_pipe, conv_hist: List[Dict]):

    # Temorarly add the prompt request to extract path
    conv_hist.append({"role": "user", "content": """Group the files given by the last user message in a JSON file as bellow. 
    Do not write any code or variables, just extract the paths and fill the JSON below.

        {
            "paths" : ["path1", "path2", "path3", ...]
        }

        If you do not find any folder path from the user message, provide a list, as below, do not invent one.
        {
            "paths" : []
        }
        """})

    outputs = llm_pipe(
        messages,
        max_new_tokens=512,
    )

    # Remove the last message from the conv hist
    conv_hist = conv_hist[:-1]

    last_gen_msg = outputs[0]["generated_text"][-1]
    res = last_gen_msg["content"] 
    return res

self = None
tool_find_path_from_userstr(self, pipeline, messages)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


'Here\'s a Python script that groups the paths into a JSON file as requested:\n\n```python\nimport json\nimport os\n\ndef group_paths(paths):\n    """\n    Groups the paths into a JSON file.\n    \n    Args:\n        paths (list): A list of paths.\n    \n    Returns:\n        dict: A dictionary containing the grouped paths.\n    """\n    grouped_paths = {\n        "paths": paths\n    }\n    return grouped_paths\n\ndef main():\n    # Define the paths\n    paths = [\n        "./Data/behave",\n        "/usr/bin",\n        "/home/ciprian/paths"\n    ]\n    \n    # Group the paths\n    grouped_paths = group_paths(paths)\n    \n    # Write the grouped paths to a JSON file\n    with open(\'paths.json\', \'w\') as f:\n        json.dump(grouped_paths, f, indent=4)\n\nif __name__ == "__main__":\n    main()\n```\n\nWhen you run this script, it will create a JSON file named `paths.json` in the same directory as the script with the following content:\n\n```json\n{\n    "paths": [\n        "./Data/b

In [19]:
import os 
from langchain_openai import ChatOpenAI
LLAMA_API = "LA-a2ce4f869d2d48099a6135d60330a4c216e9504626e2492c9da30d762f16af1c"
LANGSMITH_API = "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"

os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"] = LANGSMITH_API
os.environ["LANGCHAIN_PROJECT"] = "SOME NAME"


# model = ChatOpenAI(
#     openai_api_key=LLAMA_API,
#     openai_api_base="https://api.llama-api.com",
#     model="llama3-70b"
# )

model = pipeline.model 

In [34]:
import pandas as pd
import sqlite3


df = pd.read_csv("telco.csv")

df.columns = df.columns.str.replace(' ', '_')

customer_df = df[['Customer_ID', 'Gender', 'Age', 'Under_30', 'Senior_Citizen', 'Married', 'Dependents', 'Number_of_Dependents', 'Country', 'State', 'City', 'Zip_Code', 'Latitude', 'Longitude', 'Population']]
service_df = df[['Customer_ID', 'Phone_Service', 'Multiple_Lines', 'Internet_Service', 'Internet_Type', 'Online_Security', 'Online_Backup', 'Device_Protection_Plan', 'Premium_Tech_Support', 'Streaming_TV', 'Streaming_Movies', 'Streaming_Music', 'Unlimited_Data']]
billing_df = df[['Customer_ID', 'Tenure_in_Months', 'Offer', 'Avg_Monthly_Long_Distance_Charges', 'Avg_Monthly_GB_Download', 'Contract', 'Paperless_Billing', 'Payment_Method', 'Monthly_Charge', 'Total_Charges', 'Total_Refunds', 'Total_Extra_Data_Charges', 'Total_Long_Distance_Charges', 'Total_Revenue']]
referral_df = df[['Customer_ID', 'Referred_a_Friend', 'Number_of_Referrals']]
churn_df = df[['Customer_ID', 'Quarter', 'Satisfaction_Score', 'Customer_Status', 'Churn_Label', 'Churn_Score', 'CLTV', 'Churn_Category', 'Churn_Reason']]

conn = sqlite3.connect('telco.db')

customer_df.to_sql('Customer', conn, if_exists='replace', index=False)
service_df.to_sql('Service', conn, if_exists='replace', index=False)
billing_df.to_sql('Billing', conn, if_exists='replace', index=False)
referral_df.to_sql('Referral', conn, if_exists='replace', index=False)
churn_df.to_sql('Churn', conn, if_exists='replace', index=False)

conn.close()

In [35]:
def query_db(query):
  conn = sqlite3.connect('telco.db')
  try:
    return pd.read_sql_query(query, conn)
  finally:
    conn.close()

In [36]:
DB_DESCRIPTION = """You have access to the following tables and columns in a sqllite3 database:

Customer Table
Customer_ID: A unique ID that identifies each customer.
Gender: The customer’s gender: Male, Female.
Age: The customer’s current age, in years, at the time the fiscal quarter ended.
Under_30: Indicates if the customer is under 30: Yes, No.
Senior_Citizen: Indicates if the customer is 65 or older: Yes, No.
Married: Indicates if the customer is married: Yes, No.
Dependents: Indicates if the customer lives with any dependents: Yes, No.
Number_of_Dependents: Indicates the number of dependents that live with the customer.
Country: The country of the customer’s primary residence. Example: United States.
State: The state of the customer’s primary residence.
City: The city of the customer’s primary residence.
Zip_Code: The zip code of the customer’s primary residence.
Latitude: The latitude of the customer’s primary residence.
Longitude: The longitude of the customer’s primary residence.
Population: A current population estimate for the entire Zip Code area.

Service Table
Customer_ID: A unique ID that identifies each customer (Foreign Key).
Phone_Service: Indicates if the customer subscribes to home phone service with the company: Yes, No.
Multiple_Lines: Indicates if the customer subscribes to multiple telephone lines with the company: Yes, No.
Internet_Service: Indicates if the customer subscribes to Internet service with the company: Yes, No.
Internet_Type: Indicates the type of Internet service: DSL, Fiber Optic, Cable, None.
Online_Security: Indicates if the customer subscribes to an additional online security service provided by the company: Yes, No.
Online_Backup: Indicates if the customer subscribes to an additional online backup service provided by the company: Yes, No.
Device_Protection Plan: Indicates if the customer subscribes to an additional device protection plan for their Internet equipment provided by the company: Yes, No.
Premium_Tech_Support: Indicates if the customer subscribes to an additional technical support plan from the company with reduced wait times: Yes, No.
Streaming_TV: Indicates if the customer uses their Internet service to stream television programming from a third party provider: Yes, No.
Streaming_Movies: Indicates if the customer uses their Internet service to stream movies from a third party provider: Yes, No.
Streaming_Music: Indicates if the customer uses their Internet service to stream music from a third party provider: Yes, No.
Unlimited_Data: Indicates if the customer has paid an additional monthly fee to have unlimited data downloads/uploads: Yes, No.

Billing Table
Customer_ID: A unique ID that identifies each customer (Foreign Key).
Tenure_in_Months: Indicates the total amount of months that the customer has been with the company by the end of the quarter specified above.
Offer: Identifies the last marketing offer that the customer accepted, if applicable. Values include None, Offer A, Offer B, Offer C, Offer D, and Offer E.
Avg_Monthly_Long_Distance_Charges: Indicates the customer’s average long distance charges, calculated to the end of the quarter specified above.
Avg_Monthly_GB_Download: Indicates the customer’s average download volume in gigabytes, calculated to the end of the quarter specified above.
Contract: Indicates the customer’s current contract type: Month-to-Month, One Year, Two Year.
Paperless_Billing: Indicates if the customer has chosen paperless billing: Yes, No.
Payment_Method: Indicates how the customer pays their bill: Bank Withdrawal, Credit Card, Mailed Check.
Monthly_Charge: Indicates the customer’s current total monthly charge for all their services from the company.
Total_Charges: Indicates the customer’s total charges, calculated to the end of the quarter specified above.
Total_Refunds: Indicates the customer’s total refunds, calculated to the end of the quarter specified above.
Total_Extra_Data_Charges: Indicates the customer’s total charges for extra data downloads above those specified in their plan, by the end of the quarter specified above.
Total_Long_Distance_Charges: Indicates the customer’s total charges for long distance above those specified in their plan, by the end of the quarter specified above.
Total_Revenue: The total revenue generated from the customer.

Referral Table
Customer_ID: A unique ID that identifies each customer (Foreign Key).
Referred_a_Friend: Indicates if the customer has ever referred a friend or family member to this company: Yes, No.
Number_of_Referrals: Indicates the number of referrals to date that the customer has made.

Churn Table
Customer_ID: A unique ID that identifies each customer (Foreign Key).
Quarter: The fiscal quarter that the data has been derived from (e.g. Q3).
Satisfaction_Score: A customer’s overall satisfaction rating of the company from 1 (Very Unsatisfied) to 5 (Very Satisfied).
Customer_Status: Indicates the status of the customer at the end of the quarter: Churned, Stayed, Joined.
Churn_Label: Yes = the customer left the company this quarter. No = the customer remained with the company.
Churn_Score: A value from 0-100 that is calculated using the predictive tool IBM SPSS Modeler. The model incorporates multiple factors known to cause churn. The higher the score, the more likely the customer will churn.
CLTV: Customer Lifetime Value. A predicted CLTV is calculated using corporate formulas and existing data. The higher the value, the more valuable the customer. High value customers should be monitored for churn.
Churn_Category: A high-level category for the customer’s reason for churning: Attitude, Competitor, Dissatisfaction, Other, Price.
Churn_Reason: A customer’s specific reason for leaving the company. Directly related to Churn Category.
"""

In [47]:
from langchain_core.prompts import ChatPromptTemplate
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import JsonOutputParser

can_answer_router_prompt = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
    You are a database reading bot that can answer users' questions using information from a database. \n

    {data_description} \n\n

    Given the user's question, decide whether the question can be answered using the information in the database. \n\n

    Return a JSON with two keys, 'reasoning' and 'can_answer', and no preamble or explanation.
    Return one of the following JSON:
    
    {{"reasoning": "I can find the average revenue of customers with tenure over 24 months by averaging the Total Revenue column in the Billing table filtered by Tenure in Months > 24", "can_answer":true}}
    {{"reasoning": "I can find customers who signed up during the last 12 month using the Tenure in Months column in the Billing table", "can_answer":true}}
    {{"reasoning": "I can't answer how many customers churned last year because the Churn table doesn't contain a year", "can_answer":false}}
    

    <|eot_id|><|start_header_id|>user<|end_header_id|>
    Question: {question} \n
    <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
    input_variables=["data_description", "question"],
)

messages = [
    {"role": "system", "content": 
        """You are a database reading bot that can answer users' questions using information from a database. \n {data_description} \n\n
    
        Given the user's question, decide whether the question can be answered using the information in the database. \n\n
    
        Return a JSON with two keys, 'reasoning' and 'can_answer', and no preamble or explanation.
        Return one of the following JSON:
        
        {{"reasoning": "I can find the average revenue of customers with tenure over 24 months by averaging the Total Revenue column in the Billing table filtered by Tenure in Months > 24", "can_answer":true}}
        {{"reasoning": "I can find customers who signed up during the last 12 month using the Tenure in Months column in the Billing table", "can_answer":true}}
        {{"reasoning": "I can't answer how many customers churned last year because the Churn table doesn't contain a year", "can_answer":false}}
    
        """},
    {"role": "user", "content": "Question: {question} \n"},
]

##### 
temp = 0
if temp:
    messages[0]["content"] = messages[0]["content"].format(data_description=DB_DESCRIPTION) 
    messages[1]["content"] = messages[1]["content"].format(question="Count customers by zip code. Return the 5 most common zip codes")
    
    outputs = pipeline(messages, max_new_tokens=1024)
    print(outputs[0]["generated_text"][-1])
    
    
else:
    #messages[0]["content"] = messages[0]["content"].format(data_description=DB_DESCRIPTION) 
    #messages[1]["content"] = messages[1]["content"].format(question="Count customers by zip code. Return the 5 most common zip codes")
    
    can_answer_router = can_answer_router_prompt | pipeline | JsonOutputParser()
    can_answer_router.invoke({"question": "Count customers by zip code. Return the 5 most common zip codes", "data_description": DB_DESCRIPTION})

TypeError: can only concatenate str (not "StringPromptValue") to str

In [38]:
can_answer_router_prompt = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
    You are a database reading bot that can answer users' questions using information from a database. \n

    {data_description} \n\n

    Given the user's question, decide whether the question can be answered using the information in the database. \n\n

    Return a JSON with two keys, 'reasoning' and 'can_answer', and no preamble or explanation.
    Return one of the following JSON:
    
    {{"reasoning": "I can find the average revenue of customers with tenure over 24 months by averaging the Total Revenue column in the Billing table filtered by Tenure in Months > 24", "can_answer":true}}
    {{"reasoning": "I can find customers who signed up during the last 12 month using the Tenure in Months column in the Billing table", "can_answer":true}}
    {{"reasoning": "I can't answer how many customers churned last year because the Churn table doesn't contain a year", "can_answer":false}}
    

    <|eot_id|><|start_header_id|>user<|end_header_id|>
    Question: {question} \n
    <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
    input_variables=["data_description", "question"],
)

can_answer_router = can_answer_router_prompt | model | JsonOutputParser()

def check_if_can_answer_question(state):
  result = can_answer_router.invoke({"question": state["question"], "data_description": DB_DESCRIPTION})

  return {"plan": result["reasoning"], "can_answer": result["can_answer"]}

def skip_question(state):
  if state["can_answer"]:
    return "no"
  else:
    return "yes"
  
write_query_prompt = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
    You are a database reading bot that can answer users' questions using information from a database. \n

    {data_description} \n\n

    In the previous step, you have prepared the following plan: {plan}

    Return an SQL query with no preamble or explanation. Don't include any markdown characters or quotation marks around the query.
    <|eot_id|><|start_header_id|>user<|end_header_id|>
    Question: {question} \n
    <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
    input_variables=["data_description", "question", "plan"],
)

write_query_chain = write_query_prompt | model | StrOutputParser()

def write_query(state):
  result = write_query_chain.invoke({
      "data_description": DB_DESCRIPTION,
      "question": state["question"],
      "plan": state["plan"]
  })

  return {"sql_query": result}

def execute_query(state):
  query = state["sql_query"]

  try:
    return {"sql_result": query_db(query).to_markdown()}
  except Exception as e:
    return {"sql_result", str(e)}
  
  
write_answer_prompt = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
    You are a database reading bot that can answer users' questions using information from a database. \n

    In the previous step, you have planned the query as follows: {plan},
    generated the query {sql_query}
    and retrieved the following data:
    {sql_result}

    Return a text answering the user's question using the provided data.
    <|eot_id|><|start_header_id|>user<|end_header_id|>
    Question: {question} \n
    <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
    input_variables=["question", "plan", "sql_query", "sql_result"],
)

write_answer_chain = write_answer_prompt | model | StrOutputParser()

def write_answer(state):
  result = write_answer_chain.invoke({
      "question": state["question"],
      "plan": state["plan"],
      "sql_result": state["sql_result"],
      "sql_query": state["sql_query"]
  })

  return {"answer": result}

cannot_answer_prompt = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
    You are a database reading bot that can answer users' questions using information from a database. \n

    You cannot answer the user's questions because of the following problem: {problem}.

    Explain the issue to the user and apologize for the inconvenience.
    <|eot_id|><|start_header_id|>user<|end_header_id|>
    Question: {question} \n
    <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
    input_variables=["question", "problem"],
)

cannot_answer_chain = cannot_answer_prompt | model | StrOutputParser()

def explain_no_answer(state):
  result = cannot_answer_chain.invoke({
      "problem": state["plan"], # the plan contains an explanation of why we can't answer the question
      "question": state["question"]
  })

  return {"answer": result}

In [39]:
from typing_extensions import TypedDict


class WorkflowState(TypedDict):
  question: str
  plan: str
  can_answer: bool
  sql_query: str
  sql_result: str
  answer: str

from langgraph.graph import END, StateGraph


workflow = StateGraph(WorkflowState)

workflow.add_node("check_if_can_answer_question", check_if_can_answer_question)
workflow.add_node("write_query", write_query)
workflow.add_node("execute_query", execute_query)
workflow.add_node("write_answer", write_answer)
workflow.add_node("explain_no_answer", explain_no_answer)

workflow.set_entry_point("check_if_can_answer_question")

workflow.add_conditional_edges(
    "check_if_can_answer_question",
    skip_question, # given the text response from this function,
    { # we choose which node to go to
        "yes": "explain_no_answer",
        "no": "write_query",
    },
)

workflow.add_edge("write_query", "execute_query")
workflow.add_edge("execute_query", "write_answer")

workflow.add_edge("explain_no_answer", END)
workflow.add_edge("write_answer", END)

app = workflow.compile()


In [40]:
inputs = {"question": "Count customers by zip code. Return the 5 most common zip codes"}
app.invoke(inputs)

TypeError: embedding(): argument 'indices' (position 2) must be Tensor, not StringPromptValue

In [7]:
import spacy

# Load the language model
nlp = spacy.load('en_core_web_md')

# Define the words
word1 = "weight"
word2 = "mass"

# Process the words
token1 = nlp(word1)
token2 = nlp(word2)

# Calculate similarity
similarity = token1.similarity(token2)
print(f"Similarity between '{word1}' and '{word2}': {similarity}")


Similarity between 'weight' and 'mass': 0.45311458655115255


In [8]:
import Levenshtein

# Example strings
str1 = "kitten"
str2 = "sitting"

# Calculate Levenshtein distance
distance = Levenshtein.distance(str1, str2)
print(f"Levenshtein Distance: {distance}")

# Calculate similarity ratio
similarity = Levenshtein.ratio(str1, str2)
print(f"Levenshtein Similarity Ratio: {similarity}")

Levenshtein Distance: 3
Levenshtein Similarity Ratio: 0.6153846153846154


In [1]:
import os
import re

def extract_clauses(directory):
    clauses = {'given': [], 'when': [], 'then': []}
    pattern = re.compile(r'@(given|when|then)\(\'(.*?)\'\)')
    
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith('.py'):
                with open(os.path.join(root, file), 'r') as f:
                    content = f.read()
                    matches = pattern.findall(content)
                    for match in matches:
                        clauses[match[0]].append(match[1])
    return clauses

# Example usage
directory = './car-behave-master/features/steps'
clauses = extract_clauses(directory)
print(clauses)

{'given': [], 'when': [], 'then': []}


In [3]:
import re

# Define the pattern with named groups
pattern = r"(?P<param1>\w+)\s(?P<param2>\w+)"

# Dictionary with replacement values
replacements = {
    "param1": "Hello",
    "param2": "World"
}

# Function to replace named groups
def replace_named_groups(match):
    group_dict = match.groupdict()
    for key, value in group_dict.items():
        if key in replacements:
            group_dict[key] = replacements[key]
    return " ".join(group_dict.values())

# Text to be processed
text = "foo bar"

# Perform the replacement
result = re.sub(pattern, replace_named_groups, text)
print(result)  # Output: Hello World

Hello World


In [8]:
# Given a string, find the closing index of a matching parenthesis starting as open_pos
def find_matching_parenthesis(s: str, open_pos: int) -> int:
    stack = []
    for i, char in enumerate(s):
        if char == '(':
            stack.append(i)
        elif char == ')':
            if stack:
                start = stack.pop()
                if start == open_pos:
                    return i
            else:
                return -1 # Incorrect, no matching opening parenthesis
    return -1  # Return -1 if no matching parenthesis is found

# Example usage
gherkin_step = "@given(\"the car has (?P<engine_power>\\d+) kw, weighs (?P<weight>\\d+) kg, has a drag coefficient of (?P<drag>[\\.\\d]+)\")"


def fill_regex_with_dictvalues(step_regex: str, replacements: dict) -> str:
    stack = []
    slen = len(step_regex)
    
    filled_str = "" 
    last_unmatched = 0
    
    for i, char in enumerate(step_regex):
        if char == '(':
            remain_sz = slen - i
            if remain_sz > 3 and step_regex[i+1] == '?' and step_regex[i+2] == 'P' and step_regex[i+3] == '<':
                closing_index = find_matching_parenthesis(step_regex[i:], 0)            
                assert closing_index != -1, "Incorrect, no matching closing parenthesis"
                group = step_regex[i:i+closing_index+1]
                
                key_to_replace = group[4:group.index('>')]
                assert key_to_replace in replacements, f"Key {key_to_replace} not found in replacements"
                filled_str += step_regex[last_unmatched:i] + replacements[key_to_replace]
                
                last_unmatched = i+closing_index+1
    
    filled_str += step_regex[last_unmatched:]
    return filled_str
                
                
replacements = {
    "engine_power": "123",
    "weight": "45",
    "drag": "0.3"
}
print(fill_regex_with_dictvalues(gherkin_step, replacements))  # Output: @given("the car has 123 kw, weighs 45 kg, has a drag coefficient of 0.3")        



@given("the car has 123 kw, weighs 45 kg, has a drag coefficient of 0.3")


In [1]:
import Levenshtein
# Example strings
str1 = "kitten"
str2 = "sitting"

# Calculate Levenshtein distance
distance = Levenshtein.distance(str1, str2)
print(f"The Levenshtein distance between '{str1}' and '{str2}' is {distance}")

The Levenshtein distance between 'kitten' and 'sitting' is 3


In [1]:
import transformers
import torch
from typing import List, Dict
from pprint import pprint
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
)

messages = [
    {"role": "system", "content": "You are a BDD testing expert that can write scenarios using Gherkin language, Python language and behave library"},
    {"role": "user", "content": "Show me a scenario for testing Mario like games"},
]

terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

def inference(prompt: str, max_tokens: int ) -> str:
    messages[-1]["content"] = prompt
    outputs = pipeline(
        messages,
        max_new_tokens=max_tokens,
        eos_token_id=terminators,
        do_sample=False,
        #temperature=0.1,
        #top_p=0.9,
    )
    
    return outputs[0]["generated_text"][-1]


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [15]:
USER_INPUT_STEP = "A drag of 123, a mass of 12345 kg, and an engine of 124kw the Yoda's vehicle has!"

match_rewrite_prompt_template= """Given the below Gherkin available steps, check if any is close the input step. If yes, return the closest step in Gherkin syntax, including the @given, @when, or @then before.
    
    Use Json syntax for response. Use the following format if any step can be matched:
    {{
      "found": true,
      "step_found":  the step you found closest
    }}
    
    If no available option is OK, then use:
    {{
        "found": false,
    }}
    
    Do not provide any other information or examples.
    
    ### Input step: 
    {user_input_step_to_match}
    
    ### Available steps:
    {user_input_available_steps}"""

# TODO: take from file with source code 
USER_INPUT_AVAILABLE_STEPS = """@given("the car has (?P<engine_power>\d+) kw, weighs (?P<weight>\d+) kg, has a drag coefficient of (?P<drag>[\.\d]+)")
    
    @given("a frontal area of (?P<area>.+) m\^2")
    
    @when("I accelerate to (?P<speed>\d+) km/h")
    
    @then("the time should be within (?P<precision>[\d\.]+)s of (?P<time>[\d\.]+)s")
    
    @given("that the car is moving at (?P<speed>\d+) m/s")
    
    @when("I brake at (?P<brake_force>\d+)% force")
    
    @step("(?P<seconds>\d+) seconds? pass(?:es)?")
    
    @then("I should have traveled less than (?P<distance>\d+) meters")
    
    @given("that the car's heading is (?P<heading>\d+) deg")
    
    @when("I turn (?P<direction>left|right) at a yaw rate of (?P<rate>\d+) deg/sec for (?P<duration>\d+) seconds")
    
    @then("the car's heading should be (?P<heading>\d+) deg")"""
    
match_rewrite_prompt = match_rewrite_prompt_template.format(user_input_step_to_match=USER_INPUT_STEP, 
                                         user_input_available_steps=USER_INPUT_AVAILABLE_STEPS)

res0 = inference(match_rewrite_prompt, 1024)["content"]

pprint(res0)

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


('{\n'
 '  "found": true,\n'
 '  "step_found": "@given(\\"the car has (?P<engine_power>\\\\d+) kw, weighs '
 '(?P<weight>\\\\d+) kg, has a drag coefficient of '
 '(?P<drag>[\\\\.\\\\d]+)\\")"\n'
 '}')


In [16]:

prompt_matching_params_template = """Can you match the parameters in the input text step with the target step ?
    
    Example:
    ### Input: The plane has a travel speed of 123 km/h and a lenght of 500 m
    ### Target: @given(A plane that has a (?P<speed>\d+) km/h, length (?P<size>\d+) m
    Response:
    {{
     "speed" : "123 km/h",
     "size" : "500 m"
    }}
    
    Your task:
    ### Input: {user_input_step}
    ### Target: {step_str}
    
    Response:    your response 
    
    Do not write anything else.
    """
    

step_found_str = '@given("the car has (?P<engine_power>\d+) kw, weighs (?P<weight>\d+) kg, has a drag coefficient of (?P<drag>[\.\d]+)")'
prompt_matching_params = prompt_matching_params_template.format(user_input_step=USER_INPUT_STEP, step_str=step_found_str)


res = inference(prompt_matching_params, 1024)["content"]
pprint(res)
#pprint(inference("tell me a scenario for testing Mario like games", 100))

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


'{\n  "engine_power": "124 kw",\n  "weight": "12345 kg",\n  "drag": "123"\n}'


In [2]:
USE_DEBUG = False 
from typing import Union, Tuple

# Takes an input step in natural language and tries to match against a set of available input steps in a Gherkin file.
# Two steps are used in the process:
    # Step 1: try to find the closest in terms of matching
    # Step 2: try to match parameters. Report the error if not succeeded
def _match_input_step_to_set(self, 
                            USER_INPUT_STEP: str, 
                            USER_INPUT_AVAILABLE_STEPS: str,
                            max_generated_tokens : int = 1024) -> Tuple[bool, str]:
    
    match_rewrite_prompt_template= """Given the below Gherkin available steps, check if any can match the input step. 
    
    Use Json syntax for response. Use the following format if any step can be matched:
    {{
      "found": true,
      "step_found":  the step you found closest
    }}
    
    If no available option is OK, then use:
    {{
        "found": false,
    }}
    
    Do not provide any other information or examples.
    
    ### Input step: 
    {user_input_step_to_match}
    
    ### Available steps:
    {user_input_available_steps}"""
    
    prompt_matching_params_template = """Can you match the parameters in the input text step with the target step ?
    
    Example:
    ### Input: The plane has a travel speed of 123 km/h and a lenght of 500 m
    ### Target: @given(A plane that has a (?P<speed>\d+) km/h, length (?P<size>\d+) m
    Response:
    {{
     "speed" : "123 km/h",
     "size" : "500 m"
    }}
    
    Your task:
    ### Input: {user_input_step}
    ### Target: {step_str}
    
    Response:    your response 
    
    Do not write anything else.
    """
    
    match_rewrite_prompt = match_rewrite_prompt_template.format(user_input_step_to_match=USER_INPUT_STEP, 
                                         user_input_available_steps=USER_INPUT_AVAILABLE_STEPS)
    
    import json 
    res = inference(match_rewrite_prompt, max_generated_tokens)["content"]
    res = res.replace("\\", "\\\\") # escape the backslashes
    pprint(f"Plain result: {res}")
    
    
    # Loading in json
    res_json = None
    try:
        dir = json.loads(res)
        res_json = dir
    except Exception as e:
        pprint(f"exception occured while reading the output: {e}")
    
    
    step_found_str = res_json.get("step_found", None)
    
    if step_found_str is not None:
        prompt_matching_params = prompt_matching_params_template.format(user_input_step=USER_INPUT_STEP, step_str=step_found_str)
        res22 = inference(prompt_matching_params, max_generated_tokens)["content"]
        res22 = res22.replace("\\", "\\\\") # escape the backslashes
        
        #RESPONSE_TAG = "Response:"
        #resp_json_begin_index = res22.find(RESPONSE_TAG) + len(RESPONSE_TAG)
        resp_json_begin_index = res22.find("{")
        resp_json_end_index = res22.rfind("}")
        
        if resp_json_begin_index !=-1 and resp_json_end_index!= -1:
            resp_json_str = res22[resp_json_begin_index : resp_json_end_index + 1]
            
            try:
                resp_json = json.loads(resp_json_str)
                pprint(resp_json)

                return (True, resp_json)
            except Exception as e:
                msg = f"Error {e} when parsing for parameters:\n{resp_json_str}"
                pprint(msg)
                
                return (False, msg)
    
    
    msg = "The model didn't find any match. The raw output is {temp_resp}".format(temp_resp=res22)
    pprint(msg)
    return (False, msg)

def match_input_step_to_set(self, 
                            USER_INPUT_STEP: str, 
                            max_generated_tokens : int = 1024) -> Tuple[bool, str]:
    
    # TODO: take from file with source code 
    USER_INPUT_AVAILABLE_STEPS = """@given("the car has (?P<engine_power>\d+) kw, weighs (?P<weight>\d+) kg, has a drag coefficient of (?P<drag>[\.\d]+)")
    
    @given("a frontal area of (?P<area>.+) m\^2")
    
    @when("I accelerate to (?P<speed>\d+) km/h")
    
    @then("the time should be within (?P<precision>[\d\.]+)s of (?P<time>[\d\.]+)s")
    
    @given("that the car is moving at (?P<speed>\d+) m/s")
    
    @when("I brake at (?P<brake_force>\d+)% force")
    
    @step("(?P<seconds>\d+) seconds? pass(?:es)?")
    
    @then("I should have traveled less than (?P<distance>\d+) meters")
    
    @given("that the car's heading is (?P<heading>\d+) deg")
    
    @when("I turn (?P<direction>left|right) at a yaw rate of (?P<rate>\d+) deg/sec for (?P<duration>\d+) seconds")
    
    @then("the car's heading should be (?P<heading>\d+) deg")"""
    
    is_matched, resp = _match_input_step_to_set(self, USER_INPUT_STEP, USER_INPUT_AVAILABLE_STEPS, max_generated_tokens=1024)

    return is_matched, resp 


is_matched, resp = match_input_step_to_set(None, USER_INPUT_STEP = "A drag of 123, a mass of 12345 kg, and an engine of 124kw the Yoda's vehicle has!")
pprint(f"The model matched the step: {is_matched}\n \
            Response: {resp}\n")



Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


'Plain result: {\n  "found": false,\n}'
('exception occured while reading the output: Expecting property name enclosed '
 'in double quotes: line 3 column 1 (char 20)')


AttributeError: 'NoneType' object has no attribute 'get'

In [23]:
inp = '{\n  "found": true,\n  "step_found":  "the car has (?P<engine_power>\\d+) kw, weighs (?P<weight>\\d+) kg, has a drag coefficient of (?P<drag>[\\.\\d]+)"\n}'

inp = inp.replace("\\", "\\\\")

#r = inp[66:]
#print(r)

import json
json.loads(inp, strict=False)

{'found': True,
 'step_found': 'the car has (?P<engine_power>\\d+) kw, weighs (?P<weight>\\d+) kg, has a drag coefficient of (?P<drag>[\\.\\d]+)'}