In [1]:
import pandas as pd
from tldextract import extract

#dataset = 'restaurant'
#target = 'city'

dataset = 'buy'
target = 'manufacturer'


trainset = pd.read_csv(f'applications/imputation/data/{dataset}/train.csv')
validset = pd.read_csv(f'applications/imputation/data/{dataset}/valid.csv')
testset = pd.read_csv(f'applications/imputation/data/{dataset}/test.csv')

valid_answers = pd.read_csv(f'applications/imputation/answers/{dataset}_minilm/valid.csv')
test_answers = pd.read_csv(f'applications/imputation/answers/{dataset}_minilm/test.csv')

In [2]:
def create_retrievals(trainset, target, test_or_valid_set, test_or_valid_answers):
    
    retrievals = []

    observed_labels = [label.lower() for label in trainset[target].unique()]

    for sample_index, row in test_or_valid_set.iterrows():

        retrieved_websites = []
        generated_answers = []

        raw_answers_for_sample = test_or_valid_answers[test_or_valid_answers.sample_index==sample_index].sort_values(by=['position'])

        for _, raw_answer in raw_answers_for_sample.iterrows():

            prediction = raw_answer.answer.lower()
            # use training label if predicted label is a substring
            for observed_label in observed_labels:
                if prediction in observed_label:
                    prediction = observed_label
                    break        

            retrieved_websites.append(raw_answer.url)
            generated_answers.append(prediction)

        retrievals.append({
            "sample_index": sample_index,
            "correct_answers": [row[target].lower()],
            "retrieved_websites": retrieved_websites,
            "generated_answers": generated_answers,
        })    
        
    return retrievals        

In [3]:
def utility(retrieval, prediction):
    if prediction in retrieval["correct_answers"]:
        return 1.0
    else:
        return 0.0

In [4]:
def group(retrieved):    
    _, domain_name, ending = extract(retrieved)
    return f'{domain_name}.{ending}'

In [5]:
from retrieval_importance import learn_importance, encode_retrievals, encode_groups, v_grouped

validset_retrievals = create_retrievals(trainset, target, validset, valid_answers)
encoded_retrievals, mapping = encode_retrievals(validset_retrievals, "retrieved_websites", "generated_answers", 
                                                utility)
grouping, group_mapping = encode_groups(mapping, group)

v = learn_importance(encoded_retrievals, k=10, learning_rate=0.1, num_steps=500, n_jobs=-1, grouping=grouping)

group_importances = v_grouped(v, grouping, group_mapping)

In [6]:
from operator import itemgetter
sorted(group_importances.items(), key = itemgetter(1), reverse = True)[:10]

[('kenwood.com', 0.542688825629156),
 ('tacomaworld.com', 0.542688825629156),
 ('appleiphonereview.com', 0.5363816321662923),
 ('internetrader.com', 0.5363816321662923),
 ('huntoffice.ie', 0.5363378455875544),
 ('thinkcomputers.org', 0.5363378455875544),
 ('omnimount.com', 0.5356277998146017),
 ('mooz.reviews', 0.5329131550708474),
 ('qualitymobilevideo.com', 0.5317792209735135),
 ('ldproducts.com', 0.5312337329329281)]

In [7]:
sorted(group_importances.items(), key = itemgetter(1))[:10]

[('nikonimgsupport.com', 0.4316519523395776),
 ('bobvila.com', 0.44547142554403796),
 ('homechit.com', 0.4488742787499718),
 ('dvdfab.org', 0.45874378505397057),
 ('superuser.com', 0.4626051883264587),
 ('quora.com', 0.46492755983875594),
 ('askanydifference.com', 0.46547241617616625),
 ('cheapassgamer.com', 0.46547241617616625),
 ('manuals.plus', 0.4694615495112489),
 ('columbia.edu', 0.46971805159528907)]

In [8]:
from statistics import mode

def eval_accuracy(retrievals, k):
    correct = 0

    for retrieval in retrievals:
        prediction = mode(retrieval['generated_answers'][:k])
        if prediction in retrieval['correct_answers']:
            correct += 1

    accuracy = correct / len(retrievals)
    return accuracy


def eval_accuracy_thresholded(retrievals, group_importances, k, threshold):
    correct = 0

    for retrieval in retrievals:

        
        filtered_answers = [
            answer
                for website, answer 
                in zip(retrieval['retrieved_websites'], retrieval['generated_answers'])
                if group(website) in group_importances and group_importances[group(website)] > threshold
            ]

        if len(filtered_answers) > 0:
            prediction = mode(filtered_answers[:k])
            if prediction in retrieval['correct_answers']:
                correct += 1

    accuracy = correct / len(retrievals)
    return accuracy

In [9]:
eval_accuracy(validset_retrievals, k=10)

0.8888888888888888

In [10]:
eval_accuracy_thresholded(validset_retrievals, group_importances, k=10, threshold=0.5)

0.9230769230769231

In [11]:
testset_retrievals = create_retrievals(trainset, target, testset, test_answers)

In [12]:
eval_accuracy(testset_retrievals, k=10)

0.8769230769230769

In [13]:
eval_accuracy_thresholded(testset_retrievals, group_importances, k=10, threshold=0.5)

0.8923076923076924