In [2]:
from labelling.labels import LABELS
from chat_helper import Chat
import pandas as pd

label_string = ""
for label, desc in LABELS.items():
    label_string += f"{label}: {desc}\n"

prompts = [
    """You are going to be provided a series of interactions from a user regarding questions about finite state automatons.
    Each message has to be labelled, according to the following labels: 
    
    {labels}
    
    You only need to answer with the corresponding label you've identified.
    Do not explain the reasoning, do not use different terms from the labels you've received now.
    Interaction: 
    {text}
    Label: 
    """,
    """You are an AI assistant trained to classify questions into the following categories:
    
    {labels}
    
    Please classify the following question:
    {text}
    Category: 
    """
]

for r_i, prompt in enumerate(prompts):
    a = prompt.replace("{labels}", label_string)

    print(a)

    prompts[r_i] = a

df = pd.read_json("../filtered_data.json")

You are going to be provided a series of interactions from a user regarding questions about finite state automatons.
    Each message has to be labelled, according to the following labels: 
    
    START: Initial greetings or meta-questions, such as 'hi' or 'hello'.
GEN_INFO: General questions about the automaton that don't focus on specific components or functionalities.
STATE_COUNT: Questions asking about the number of states in the automaton.
FINAL_STATE: Questions about final states of the automaton.
STATE_ID: Questions about the identity of a particular state.
TRANS_DETAIL: General questions about the transitions within the automaton.
SPEC_TRANS: Specific questions about particular transitions or arcs between states.
TRANS_BETWEEN: Specific question about a transition between two states
LOOPS: Questions about loops or self-referencing transitions within the automaton.
GRAMMAR: Questions about the language or grammar recognized by the automaton.
INPUT_QUERY: Questions about the in

In [3]:
from tqdm import tqdm

# ollama_models = ["llama3.1:8b", "gemma:7b", "qwen:7b"]
ollama_models = ["gemma2:9b", "llama3.1:8b"]

# We are initializing a new dataframe with the same index as the original one
res_df = pd.DataFrame(index=df.index)

for model in ollama_models:
    chat = Chat(model=model)

    dataset_size = len(df)

    for p_i, prompt_version in enumerate(prompts):
        progress_bar = tqdm(total=dataset_size, desc=f"Asking {model} with prompt {p_i}", unit="rows")

        for r_i, row in df.iterrows():
            text = row["Text"]

            prompt = prompts[0].replace("{text}", text)

            inferred_label = chat.interact(prompt, stream=True, print_output=False, use_in_context=False)
            inferred_label = inferred_label.strip().replace("'", "")

            res_df.at[r_i, f"{model} {p_i}"] = inferred_label
            progress_bar.update()

        print(progress_bar.format_dict["elapsed"])
        progress_bar.close()

Asking gemma2:9b with prompt 0: 100%|██████████| 290/290 [06:17<00:00,  1.30s/rows]


377.2468931674957


Asking gemma2:9b with prompt 1:  21%|██        | 61/290 [01:17<04:31,  1.19s/rows]

KeyboardInterrupt: 

# Cleanup

In [18]:
valid_labels = list(LABELS.keys())
default_label = "INVALID"


def find_keyword(possible_label: str) -> str:
    for l in valid_labels:
        if l in possible_label:
            return l

    return default_label


def clean_row(row: pd.Series):
    for i in range(1, len(row)):
        row.iloc[i] = find_keyword(row.iloc[i])
        
    return row


c_df = res_df.copy()
c_df.apply(clean_row, axis=1)

Unnamed: 0,gemma2:9b 0,gemma2:9b 1,llama3.1:8b 0,llama3.1:8b 1
0,START,START,START,START
1,GEN_INFO,GEN_INFO,GEN_INFO,GEN_INFO
2,SPEC_TRANS,SPEC_TRANS,TRANS_BETWEEN,TRANS_BETWEEN
3,SPEC_TRANS,SPEC_TRANS,TRANS_BETWEEN,TRANS_BETWEEN
4,Please provide the interaction. \nLABEL: START,START,START,START
...,...,...,...,...
285,OPT_REP,OPT_REP,OPT_REP,OPT_REP
286,GRAMMAR,GRAMMAR,GRAMMAR,GRAMMAR
287,REPETITIVE_PAT,REPETITIVE_PAT,REPETITIVE_PAT,REPETITIVE_PAT
288,TRANS_DETAIL,TRANS_DETAIL,TRANS_DETAIL,GEN_INFO


# Evaluation

In [1]:
from collections import Counter


def majority_vote(row: pd.Series):
    label_counts = Counter(row)
    majority_label = label_counts.most_common(1)[0][0]
    return majority_label


m_df = c_df.apply(majority_vote, axis=1)

a_df = pd.concat([df, m_df], axis=1)
a_df

NameError: name 'pd' is not defined

In [25]:
a_df.to_json(f"labelled_data_aggregated_{pd.Timestamp.now()}.json")