In [1]:
import pandas as pd
from tldextract import extract
from statistics import mode
from tabulate import tabulate
from retrieval_importance import learn_importance, encode_retrievals, encode_groups, v_grouped

In [2]:
def create_retrievals(trainset, target, test_or_valid_set, test_or_valid_answers):
    
    retrievals = []

    observed_labels = [label.lower() for label in trainset[target].unique()]

    for sample_index, row in test_or_valid_set.iterrows():

        retrieved_websites = []
        generated_answers = []

        raw_answers_for_sample = test_or_valid_answers[test_or_valid_answers.sample_index==sample_index]\
            .sort_values(by=['position'])

        for _, raw_answer in raw_answers_for_sample.iterrows():

            prediction = raw_answer.answer.lower()
            # use training label if predicted label is a substring
            for observed_label in observed_labels:
                if prediction in observed_label:
                    prediction = observed_label
                    break        

            retrieved_websites.append(raw_answer.url)
            generated_answers.append(prediction)

        retrievals.append({
            "sample_index": sample_index,
            "correct_answers": [row[target].lower()],
            "retrieved_websites": retrieved_websites,
            "generated_answers": generated_answers,
        })    
        
    return retrievals        

In [3]:
def utility(retrieval, prediction):
    if prediction in retrieval["correct_answers"]:
        return 1.0
    else:
        return 0.0

In [4]:
def group(retrieved):    
    _, domain_name, ending = extract(retrieved)
    return f'{domain_name}.{ending}'

In [5]:
def eval_accuracy(retrievals, k):
    correct = 0

    for retrieval in retrievals:
        prediction = mode(retrieval['generated_answers'][:k])
        if prediction in retrieval['correct_answers']:
            correct += 1

    accuracy = correct / len(retrievals)
    return accuracy


def eval_accuracy_thresholded(retrievals, group_importances, k, threshold, keep_unknown_sources):
    correct = 0

    for retrieval in retrievals:

        if keep_unknown_sources:
            filtered_answers = [
                answer
                    for website, answer 
                    in zip(retrieval['retrieved_websites'], retrieval['generated_answers'])
                    if not group(website) in group_importances or group_importances[group(website)] > threshold
                ]
        else:
            filtered_answers = [
                answer
                    for website, answer 
                    in zip(retrieval['retrieved_websites'], retrieval['generated_answers'])
                    if group(website) in group_importances and group_importances[group(website)] > threshold            
                ]            
            
        if len(filtered_answers) > 0:
            prediction = mode(filtered_answers[:k])
            if prediction in retrieval['correct_answers']:
                correct += 1

    accuracy = correct / len(retrievals)
    return accuracy

In [6]:
def run_experiment(dataset, target, k, keep_unknown_sources):

    trainset = pd.read_csv(f'applications/imputation/data/{dataset}/train.csv')
    validset = pd.read_csv(f'applications/imputation/data/{dataset}/valid.csv')
    testset = pd.read_csv(f'applications/imputation/data/{dataset}/test.csv')

    valid_answers = pd.read_csv(f'applications/imputation/answers/{dataset}_minilm/valid.csv')
    test_answers = pd.read_csv(f'applications/imputation/answers/{dataset}_minilm/test.csv')
    
    validset_retrievals = create_retrievals(trainset, target, validset, valid_answers)
    encoded_retrievals, mapping = encode_retrievals(validset_retrievals, "retrieved_websites", 
                                                    "generated_answers", utility)
    grouping, group_mapping = encode_groups(mapping, group)

    v = learn_importance(encoded_retrievals, k=k, learning_rate=0.1, num_steps=500, n_jobs=-1, grouping=grouping)

    group_importances = v_grouped(v, grouping, group_mapping)    
    
    validation_accuracy = eval_accuracy(validset_retrievals, k=k)
    validation_accuracy_thresholded = eval_accuracy_thresholded(validset_retrievals, group_importances, 
        k=k, threshold=0.5, keep_unknown_sources=keep_unknown_sources)

    testset_retrievals = create_retrievals(trainset, target, testset, test_answers)    
    
    test_accuracy = eval_accuracy(testset_retrievals, k=k)
    test_accuracy_thresholded = eval_accuracy_thresholded(testset_retrievals, group_importances, 
        k=k, threshold=0.5, keep_unknown_sources=keep_unknown_sources)    
    
    return validation_accuracy, validation_accuracy_thresholded, test_accuracy, test_accuracy_thresholded

In [7]:
val_acc_buy, val_acc_clean_buy, test_acc_buy, test_acc_clean_buy = run_experiment(
    dataset = 'buy',
    target = 'manufacturer',
    k=10,
    keep_unknown_sources=False,    
)

val_acc_rest, val_acc_clean_rest, test_acc_rest, test_acc_clean_rest = run_experiment(
    dataset = 'restaurant',
    target = 'city',
    k=10,
    keep_unknown_sources=True,    
)

In [8]:
print(tabulate([
 ('buy', val_acc_buy, val_acc_clean_buy, 0.846, test_acc_buy, test_acc_clean_buy),
 ('restaurant', val_acc_rest, val_acc_clean_rest, 0.709, test_acc_rest, test_acc_clean_rest),       
], headers=['task', 'retr (val)', 'retr+clean (val)', 'GPT-3 0-shot (test)', 
            'retr (test)', 'retr+clean (test)']))


task          retr (val)    retr+clean (val)    GPT-3 0-shot (test)    retr (test)    retr+clean (test)
----------  ------------  ------------------  ---------------------  -------------  -------------------
buy             0.888889            0.923077                  0.846       0.876923             0.892308
restaurant      0.717949            0.762821                  0.709       0.77907              0.790698
