In [None]:
############################################################################
## VARIABLES "dataset_name", "exp_prefix", "exp_version" AND "model_name" ##
## WILL BE CREATED BY THE nbconvert preprocessor ABOVE                    ##
############################################################################


# Upload libraries and initialize env variables

In [None]:
from sttn.nli.analyst import STTNAnalyst
from langchain.globals import set_debug
from langsmith import Client
from langsmith.schemas import Example, Run
from langsmith.evaluation import evaluate, LangChainStringEvaluator
import backoff
import openai

from typing import List

set_debug(False)

# Langsmith client
client = Client()

### Upload `langsmith dataset`

In [None]:
test_dataset = client.read_dataset(dataset_name=dataset_name)
print(test_dataset.dict())

# Create evaluators and make `STTNAnalyst` results compatible to `langsmith.evaluator`

## Wrap `STTNAnalyst.chat` results in a compatible function with backoff

In [70]:
@backoff.on_exception(backoff.expo, (openai.RateLimitError), max_tries=6, base=8, factor=1, max_value=60)
def get_context_with_backoff(model_name, inputs):
    return STTNAnalyst(model_name=model_name).chat(user_query=inputs["question"])

In [71]:
def analyst_results(inputs: dict, model_name="gpt-4o") -> dict:
    try:
        context = get_context_with_backoff(inputs=inputs, model_name=model_name)
    except Exception as e:
        print(f'An error happened while launching Analyst on query: \n"{inputs["question"]}"\n', str(e))
        return {"data_provider_id": "",
                "data_provider_args": {},
                "result": None,
                "executable": False}
    
    # Get the data provider id and args from the context
    data_provider_id = context.data_provider_id if context.data_provider_id else ''
    data_provider_args = context.data_provider_args if context.data_provider_args else {}
    result = context.result
    analysis_code = context.analysis_code
    
    return {"data_provider_id": data_provider_id,
            "data_provider_args": data_provider_args,
            "result": result,
            "executable": True,
            #"analysis_code": analysis_code
            }

## Create evaluators for `data_provider_id`,  `data_provider_args` and `executable`

In [36]:
def data_provider_id_match(run: Run, example: Example) -> dict:
    ref_provider_id = example.outputs["data_provider_id"]
    pred_provider_id = run.outputs["data_provider_id"]
    score = pred_provider_id == ref_provider_id
    return {"key": "data_provider_match",
            "score": int(score)}

def data_provider_args_match(run: Run, example: Example) -> dict:
    ref_provider_args = example.outputs["data_provider_args"]
    pred_provider_args = run.outputs["data_provider_args"]
    score = pred_provider_args == ref_provider_args
    return {"key": "data_provider_args_match",
            "score": int(score)}

def result_match(run: Run, example: Example) -> dict:
    try:
        ref_result = example.outputs["result"]
        pred_result = run.outputs["result"]
        score = pred_result == ref_result
        return {"key": "result_match",
                "score": int(score)}
    except Exception as e:
        print(f'An error happened while comparing the results: {str(e)}')
        print("Result ATTRIBUTE MIGHT BE MISSING IN THE TEST DATASET")
        return {"key": "result_match",
                "score": -1}

def executable_match(run: Run, example: Example) -> dict:
    ref_provider_args = example.outputs["executable"]
    pred_provider_args = run.outputs["executable"]
    score = pred_provider_args == ref_provider_args
    return {"key": "executable",
            "score": bool(score)}

analysis_code_score_eval = LangChainStringEvaluator(
    "labeled_score_string", 
    config={
        "criteria": {
            "accuracy": "How accurate is this code compared to the reference on a scale of 1-10?",
            "efficiency": "How efficient is this code compared to the reference on a scale of 1-10?",  # can change to bool
        },
        "normalize_by": 10,
    },
    prepare_data=lambda run, example: {
        "prediction": run.outputs["analysis_code"], 
        "reference": example.outputs["analysis_code"],
        "input": example.inputs["question"],
    }  
)


## Create summary evaluators for:

### Specific data_provider_ids

In [37]:
# Taxi
def taxi_dp_id_accuracy_summary_eval(runs: List[Run], examples: List[Example]) -> dict:
    """
    Evaluator for accuracy score only for NycTaxiDataProvider.
    ### Parameters:
    - runs: List[Run] - list of LangChain Run objects
    - examples: List[Example] - list of Langsmith Example objects (from the test dataset)

    ### Returns:
    - dict: {"key": "taxi_dp_id_accuracy",
             "score": float} - the accuracy score of the correct matches of NycTaxiDataProvider
    """
    try:
        sum_id_match = 0
        taxi_examples = 0
        for run, example in zip(runs, examples):
            if example.outputs["data_provider_id"] == "NycTaxiDataProvider":
                sum_id_match += data_provider_id_match(run, example)['score']
                taxi_examples += 1
        
        return {"key": "taxi_dp_id_accuracy",
                "score": sum_id_match/taxi_examples}
    except Exception as e:
        print(f"An error occurred while evaluating NycTaxiDataProvider accuracy: {str(e)}")
        print("NycTaxiDataProvider MIGHT BE MISSING IN THE TEST DATASET")
        return {"key": "taxi_dp_id_accuracy",
                "score": -1.0}


# LEHD
def lehd_dp_accuracy_summary_eval(runs: List[Run], examples: List[Example]) -> dict:
    """
    Evaluator for id accuracy only for OriginDestinationEmploymentDataProvider data provider.
    ### Parameters:
    - runs: List[Run] - list of LangChain Run objects
    - examples: List[Example] - list of Langsmith Example objects (from the test dataset)

    ### Returns:
    - dict: {"key": "lehd_dp_id_accuracy",
             "score": float} - the accuracy score of the correct matches of OriginDestinationEmploymentDataProvider
    """
    try:
        sum_id_match = 0
        lehd_examples = 0
        for run, example in zip(runs, examples):
            if example.outputs["data_provider_id"] == "OriginDestinationEmploymentDataProvider":
                sum_id_match += data_provider_id_match(run, example)['score']
                lehd_examples += 1
        
        return {"key": "lehd_dp_id_accuracy",
                "score": sum_id_match/lehd_examples}
    
    except Exception as e:
        print(f"An error occurred while evaluating LEHD data provider accuracy: {str(e)}")
        print("OriginDestinationEmploymentDataProvider MIGHT BE MISSING IN THE TEST DATASET")
        return {"key": "lehd_dp_id_accuracy",
                "score": -1.0}

### Specific data_provider_args (when id predicted correctly)

In [38]:
# Taxi
def taxi_dp_args_accuracy_summary_eval(runs: List[Run], examples: List[Example]) -> dict:
    """
    Evaluator for args accuracy only when NycTaxiDataProvider was correctly picked.
    ### Parameters:
    - runs: List[Run] - list of LangChain Run objects
    - examples: List[Example] - list of Langsmith Example objects (from the test dataset)

    ### Returns:
    - dict: {"key": "taxi_dp_args_accuracy",
             "score": float} - the accuracy score for the args of NycTaxiDataProvider
    """
    try:
        sum_id_match = 0
        sum_args_match = 0
        for run, example in zip(runs, examples):
            if example.outputs["data_provider_id"] == "NycTaxiDataProvider":
                id_match = data_provider_id_match(run, example)['score']
                sum_id_match += id_match

                # check the args only when id was predicted correctly            
                if id_match:
                    args_match = data_provider_args_match(run, example)['score']
                    sum_args_match += args_match
    
        return {"key": "taxi_dp_args_accuracy",
                "score": sum_args_match/sum_id_match}
    except Exception as e:
        print(f"An error occurred while evaluating NycTaxiDataProvider args accuracy: {str(e)}")
        print("NycTaxiDataProvider MIGHT BE MISSING IN THE TEST DATASET")
        return {"key": "taxi_dp_args_accuracy",
                "score": -1.0}


# LEHD
def lehd_dp_args_accuracy_summary_eval(runs: List[Run], examples: List[Example]) -> dict:
    """ 
    Evaluator for args accuracy only when OriginDestinationEmploymentDataProvider was correctly picked.
    ### Parameters:
    - runs: List[Run] - list of LangChain Run objects
    - examples: List[Example] - list of Langsmith Example objects (from the test dataset)

    ### Returns:
    - dict: {"key": "lehd_dp_args_accuracy",
             "score": float} - the accuracy score for the args of OriginDestinationEmploymentDataProvider
    """
    try:
        sum_id_match = 0
        sum_args_match = 0
        for run, example in zip(runs, examples):
            if example.outputs["data_provider_id"] == "OriginDestinationEmploymentDataProvider":
                id_match = data_provider_id_match(run, example)['score']
                sum_id_match += id_match

                # check the args only when id was predicted correctly
                if id_match:
                    args_match = data_provider_args_match(run, example)['score']
                    sum_args_match += args_match
        
        return {"key": "lehd_dp_args_accuracy",
                "score": sum_args_match/sum_id_match}
    except Exception as e:
        print(f"An error occurred while evaluating LEHD data provider args accuracy: {str(e)}")
        print("OriginDestinationEmploymentDataProvider MIGHT BE MISSING IN THE TEST DATASET")
        return {"key": "lehd_dp_args_accuracy",
                "score": -1.0}

### Complexities

In [39]:
# Low complexity
def low_complexity_accuracy_summary_eval(runs: List[Run], examples: List[Example]) -> dict:
    """
    Accuracy evaluator for the low complexity queries.
    ### Parameters:
    - runs: List[Run] - list of LangChain Run objects
    - examples: List[Example] - list of Langsmith Example objects (from the test dataset)

    ### Returns:
    - dict: {"key": "low_compl_accuracy",
             "score": float} - the accuracy score for the low complexity queries
    """
    try:
        low_comp_correct_id_and_args = 0
        low_comp_examples = 0

        for run, example in zip(runs, examples):
            if example.outputs["complexity"] == "low":
                id_match = data_provider_id_match(run, example)['score']
                low_comp_examples += 1
                if id_match:
                    args_match = data_provider_args_match(run, example)['score']
                    low_comp_correct_id_and_args += args_match
        
        return {"key": "low_compl_accuracy",
                "score": low_comp_correct_id_and_args/low_comp_examples}
    except Exception as e:
        print(f"An error occurred while evaluating low complexity accuracy: {str(e)}")
        print("LOW COMPLEXITY FEATURE MIGHT BE MISSING IN THE TEST DATASET")
        return {"key": "low_compl_accuracy",
                "score": -1.0}


# Medium complexity
def medium_complexity_accuracy_summary_eval(runs: List[Run], examples: List[Example]) -> dict:
    """
    Accuracy evaluator for the medium complexity queries.
    ### Parameters:
    - runs: List[Run] - list of LangChain Run objects
    - examples: List[Example] - list of Langsmith Example objects (from the test dataset)

    ### Returns:
    - dict: {"key": "medium_compl_accuracy",
             "score": float} - the accuracy score for the medium complexity queries
    """
    try:
        med_comp_correct_id_and_args = 0
        med_comp_examples = 0

        for run, example in zip(runs, examples):
            if example.outputs["complexity"] == "medium":
                id_match = data_provider_id_match(run, example)['score']
                med_comp_examples += 1
                if id_match:
                    args_match = data_provider_args_match(run, example)['score']
                    med_comp_correct_id_and_args += args_match
        
        return {"key": "medium_compl_accuracy",
                "score": med_comp_correct_id_and_args/med_comp_examples}
    except Exception as e:
        print(f"An error occurred while evaluating medium complexity accuracy: {str(e)}")
        print("MEDIUM COMPLEXITY FEATURE MIGHT BE MISSING IN THE TEST DATASET")
        return {"key": "medium_compl_accuracy",
                "score": -1.0}


### Bad queries (bad grammar)

In [40]:
def bad_grammar_accuracy_summary_eval(runs: List[Run], examples: List[Example]) -> dict:
    """
    Accuracy evaluator for the bad grammar queries.
    ### Parameters:
    - runs: List[Run] - list of LangChain Run objects
    - examples: List[Example] - list of Langsmith Example objects (from the test dataset)

    ### Returns:
    - dict: {"key": "bad_grammar_accuracy",
             "score": float} - the accuracy score for the bad grammar queries
    """
    try:
        bad_grammar_correct_id_and_args = 0
        bad_grammar_examples = 0

        for run, example in zip(runs, examples):
            if example.outputs["bad_grammar"]:
                id_match = data_provider_id_match(run, example)['score']
                bad_grammar_examples += 1
                if id_match:
                    args_match = data_provider_args_match(run, example)['score']
                    bad_grammar_correct_id_and_args += args_match
        
        return {"key": "bad_grammar_accuracy",
                "score": bad_grammar_correct_id_and_args/bad_grammar_examples}
    except Exception as e:
        print(f"An error occurred while evaluating bad grammar accuracy: {str(e)}")
        print("BAD_GRAMMAR FEATURE MIGHT BE MISSING IN THE TEST DATASET")
        return {"key": "bad_grammar_accuracy",
                "score": -1.0}


# Evaluate the system

In [None]:
analyst = STTNAnalyst(model_name=model_name, verbose=False)  # model_name is initialized by runner script at the begginning of the notebook

chain_results = evaluate(
    analyst_results,  #  AI system (wrapped function with outputs as dict),
    data=test_dataset.name,  # The dataset name to predict and grade over
    evaluators=[
        data_provider_id_match,
        data_provider_args_match,
        executable_match,
        result_match,
        #analysis_code_score_eval
        ],  # The evaluators to score the results
    summary_evaluators=[
        taxi_dp_id_accuracy_summary_eval,
        taxi_dp_args_accuracy_summary_eval,
        lehd_dp_accuracy_summary_eval,
        lehd_dp_args_accuracy_summary_eval,
        low_complexity_accuracy_summary_eval,
        medium_complexity_accuracy_summary_eval,
        bad_grammar_accuracy_summary_eval,
        ],  # summary evluators to score the overall results
    experiment_prefix=exp_prefix,  # A prefix for your experiment names to easily identify them
    metadata={
      "version": f"{exp_version}",
    }
)


# End