In [None]:
#######################################################################################################
##                                                                                                   ##
##    VARIABLES "dataset_name", "exp_prefix", "exp_version", "model_name", "code_retry_limit" AND    ##
##           "max_concurency" WILL BE CREATED BY THE nbconvert.ExecutePreprocessor ABOVE             ##
##                                                                                                   ##
#######################################################################################################
# dataset_name='Taxi + LEHD evaluation - 280 examples'
# exp_prefix='baseline-29-01_'
# exp_version='1.0.0'
# model_name='gpt-4o-mini'
# code_retry_limit=2
# max_concurrency=1

# Import libraries and initialize variables

In [None]:
from sttn.nli.analyst import STTNAnalyst

from langchain.globals import set_debug
from langchain_openai import ChatOpenAI
from langsmith import Client, traceable
from langsmith.evaluation import evaluate

import backoff
import openai
import traceback
import gc
import re
import numpy as np
import sttn.algorithms.community.detection

from sttn.eval.evaluators import (Evaluators, SummaryEvaluators, max_tries,
                                                       base,
                                                       factor,
                                                       max_value)

set_debug(False)

# Langsmith client
client = Client()

# Evaluating LLM
eval_llm = ChatOpenAI(temperature=0.0, model="gpt-4o", max_retries=3)

### Upload `langsmith dataset`

In [None]:
test_dataset = client.read_dataset(dataset_name=dataset_name)
print(f"\n\n------------Info about {test_dataset.name} dataset------------\n")
for key, value in test_dataset.dict().items():
    print(f'{key}: {value}')

# Wrap `STTNAnalyst.chat` results in a compatible function with backoff

In [None]:
@backoff.on_exception(backoff.expo, (openai.RateLimitError), max_tries=max_tries, base=base, factor=factor, max_value=max_value)
def get_context_with_backoff(inputs: dict, model_name: str, code_retry_limit: int):
    analyst = STTNAnalyst(model_name=model_name, code_retry_limit=code_retry_limit)
    #print(f"\nQuery_ID: {inputs['id']}, DEBUG:\n\tLaunched analyst...")
    context = analyst.chat(user_query=inputs["question"])
    context.analysis_code = str(context.analysis_code) if context.analysis_code else ''
    #print(f"Query_ID: {inputs['id']}, DEBUG:\n\tContext returned!")
    return context

def delete_generated_temp_vars(analysis_code: str, id: int):
    temp_vars = set(re.findall(r'\b([a-zA-Z_]\w*)\b\s?=(?!=)', analysis_code))
    for var in temp_vars:
        try:
            # Dereference objects
            globals()[var] = None
            del globals()[var]
            #print(f"\nQuery_ID: {id}, DEBUG:\n\tDeleted variable `{var}`")
        except Exception as e:
            pass
            #print(f"\nQuery_ID: {id}, WARNING:\n\tCan't delete variable `{var}`\n\tError message: {e}\n")
    #gc.collect()

In [None]:
@traceable
def analyst_results(inputs: dict, model_name=model_name, code_retry_limit=code_retry_limit) -> dict:
    """
    Wrapper function to get the results from the Analyst and return them in a dictionary
    Args:
        inputs: dict, the inputs to the Analyst
        model_name: str, the name of the model to use
        code_retry_limit: int, the number of times to retry the code
    Returns:
        dict, the results from the Analyst
    """
    
    # Initialize an empty result dictionary
    empty_result_dict = {"data_provider_id": "",
                         "data_provider_args": {},
                         "result": None,
                         "executable": False,
                         "analysis_code": "NO CODE FROM ANALYST, RETURN 0"}
    
    print(f"\nQuery_ID: {inputs['id']}, INFO:\n\tQuery:', {inputs['question']}\n")
    # Get the context from the Analyst
    try:
        context = get_context_with_backoff(inputs=inputs, model_name=model_name, code_retry_limit=code_retry_limit)
    except Exception as e:
        print(f"\n\nQuery_ID: {inputs['id']}, ERROR:\n\tAn error happened while launching Analyst (return empty dict instead)\n\tError message:")
        traceback.print_exc()
        print("\n\t|| END OF ERROR ||\n\n")
        return empty_result_dict
    
    try:
        #print(f"\nQuery_ID: {inputs['id']}, DEBUG:\n\tChecking context for query ID {inputs['id']}:")
        # Get the data provider from the context if it exists
        if not context.data_provider or not context.feasible:
            return empty_result_dict
        else:
            #print(f'\tData provider: {context.data_provider}')
            data_provider = context.data_provider

        # Get the data provider id and args from the context if they both exist
        if context.data_provider_id and context.data_provider_args:
            data_provider_id = context.data_provider_id
            # lowercase the values of context.data_provider_args dictionary
            data_provider_args = context.data_provider_args
            data_provider_args = {k: v.lower().strip() if isinstance(v, str) else v for k, v in data_provider_args.items()}
            #print(f"\tData provider ID: {data_provider_id}\n\tData provider args: {data_provider_args}")
        else:
            return empty_result_dict
        
        # Turn the result into a float if possible else None
        try:
            #print(f"\tResult: {context.result}\n")
            context.result = float(context.result)
            result = round(context.result, 5)
            if np.isnan(result):
                result = None
        except:
            print(f"\nQuery_ID: {inputs['id']}, WARNING:\n\tError while converting result to float (return None instead)\n\tError message:{e}\n\t|| END OF ERROR ||\n")
            result = None
        
        # Add an evaluation query to the analysis code for the geospatial/temporal awareness
        if context.analysis_code:
            analysis_code = f"We have the following data structure: {data_provider.__doc__} \
                            \n{data_provider.get_data.__doc__}\
                            \nWe retrieved the SpatioTemporalNetwork with the following arguments {context.data_provider_args}\
                            \nThe code looks like this:\n" + str(context.analysis_code) 
        else:
            analysis_code= "NO CODE FROM ANALYST, RETURN 0"
        
        # extract temporary names from the analysis code and delete them
        #print(f"Query_ID: {inputs['id']}, DEBUG:\n\tDelete temp vars...")
        delete_generated_temp_vars(context.analysis_code, inputs['id'])

        output = {"data_provider_id": data_provider_id,
                "data_provider_args": data_provider_args,
                "result": result,
                "executable": True,
                "analysis_code": analysis_code,
                }
        
        # print(f"Query_ID: {inputs['id']}, DEBUG:\n\tFinal output:", output)
        return output

    except Exception as e:
        print(f"\nQuery_ID: {inputs['id']}, ERROR:\n\tUnexpected error appeared while transforming Analyst's output\n\tError message:{e}\n\t|| END OF ERROR ||\n")
        return empty_result_dict

# Evaluate the system

## Initialize evaluators

In [None]:
evaluators = Evaluators(eval_llm)
summary_evaluators = SummaryEvaluators(evaluators)

## Run the evaluation

In [None]:
analyst = STTNAnalyst(model_name=model_name, verbose=False)  # model_name is initialized by runner script at the begginning of the notebook

chain_results = evaluate(
    analyst_results,  #  AI system (wrapped function with outputs as dict),
    data=test_dataset.name,  # The dataset name to predict and grade over
    evaluators=[
        evaluators.data_provider_id_match,
        evaluators.data_provider_args_match,
        evaluators.executable_match,
        evaluators.result_match,
        evaluators.get_geosp_aware_eval_score,
        evaluators.get_temp_aware_eval_score,
        ],  # The evaluators to score the results
    summary_evaluators=[
        summary_evaluators.taxi_dp_id_accuracy_summary_eval,
        summary_evaluators.taxi_dp_args_accuracy_summary_eval,
        summary_evaluators.taxi_dp_result_accuracy_summary_eval,
        summary_evaluators.lehd_dp_accuracy_summary_eval,
        summary_evaluators.lehd_dp_args_accuracy_summary_eval,
        summary_evaluators.lehd_dp_result_accuracy_summary_eval,
        summary_evaluators.geospatial_awr_result_accuracy_summary_eval,
        summary_evaluators.temporal_awr_result_accuracy_summary_eval,
        summary_evaluators.comm_det_result_accuracy_summary_eval,
        summary_evaluators.pagerank_result_accuracy_summary_eval,
        summary_evaluators.net_dens_result_accuracy_summary_eval,
        summary_evaluators.cen_deg_result_accuracy_summary_eval,
        summary_evaluators.clust_coef_result_accuracy_summary_eval,
        summary_evaluators.poorly_written_args_accuracy_summary_eval,
        summary_evaluators.poorly_written_result_accuracy_summary_eval,
        ],  # summary evluators to score the overall results
    experiment_prefix=exp_prefix,  # A prefix for your experiment names to easily identify them
    metadata={
      "version": f"{exp_version}",
    },
    max_concurrency=max_concurrency,  # The maximum number of concurrent evaluations
)


# Get evaluation metrics

In [None]:
print(f"\n------------------Evaluation metrics------------------")
results_dict = {}
counts_dict = {}
try:
    for example_res in chain_results._results:
        for eval_metric in example_res['evaluation_results']['results']:
            if eval_metric.key == '__ignore__': continue
            elif eval_metric.key in results_dict:
                results_dict[eval_metric.key] += eval_metric.score
                counts_dict[eval_metric.key] += 1
            else:
                results_dict[eval_metric.key] = eval_metric.score
                counts_dict[eval_metric.key] = 1


    # Fancy print
    for metric, score in results_dict.items():
        if score >= 0:
            print(f"'{metric}': {score/counts_dict[metric]}")
    print('\n')
    for sum_eval_metric in chain_results._summary_results['results']:
        results_dict[sum_eval_metric.key] = sum_eval_metric.score
        if sum_eval_metric.score >= 0:
            print(f"'{sum_eval_metric.key}': {sum_eval_metric.score}")

    print('\nEmpty/ignored metrics below:')
    for metric, score in results_dict.items():
        if metric.startswith('__ignore'):
            continue
        elif score < 0:
            print(f"'{metric}': {score}")

        
except Exception as e:
    print(f'\nERROR:\n\tAn error occurred while processing the results: {str(e)}')
    print("Please check the results by the link below.")

# End