In [None]:
#!pip install tldextract

In [1]:
import json
import tldextract
from pprint import pp
from retrieval_importance import learn_importance, encode_retrievals, encode_groups, v_grouped, \
    most_important_groups, least_important_groups

In [2]:
retrievals = []
with open('test_data/wikifact_author.jsonl') as f:
    for line in f:
        retrievals.append(json.loads(line))

In [3]:
pp(retrievals[0], width=100, compact=True)

{'question': 'The author of Nimmer on Copyright is',
 'correct_answers': ['David Nimmer', 'David R. Nimmer', 'David Richard Nimmer', 'Melville Nimmer',
                     'Melville Bernard Nimmer', 'Melville B. Nimmer', 'Melville Bernard Nimmer'],
 'retrieved_websites': ['en.wikipedia.org', 'books.google.com', 'www.goodreads.com',
                        'library.law.yale.edu', 'store.lexisnexis.com', 'www.copyright-protect.net',
                        'books.google.com', 'www.top-law-schools.com',
                        'scholarship.law.vanderbilt.edu', 'www.goodreads.com',
                        'lawcat.berkeley.edu', 'openlibrary.org', 'wikimili.com',
                        'lawcat.berkeley.edu', 'www.pbookshop.com', 'openlibrary.org',
                        'www.wipo.int', 'www.worldcat.org', 'store.lexisnexis.com',
                        'youbookinc.com', 'link.law.upenn.edu', 'www.copyright.gov',
                        'www.worldcat.org', 'www.casemine.com', 'www.overdri

In [4]:
def utility(retrieval, prediction):
    if prediction in retrieval["correct_answers"]:
        return 1.0
    else:
        return 0.0

In [5]:
def group(retrieved):    
    url_parts = tldextract.extract(retrieved)
    return f'{url_parts.domain}.{url_parts.suffix}'

In [6]:
encoded_retrievals, mapping = encode_retrievals(retrievals, "retrieved_websites", "retrieved_answers", utility)
grouping, group_mapping = encode_groups(mapping, group)

In [7]:
%%time
v = learn_importance(encoded_retrievals, k=10, learning_rate=0.1, num_steps=50, n_jobs=4, grouping=grouping)

CPU times: user 779 ms, sys: 30.1 ms, total: 809 ms
Wall time: 227 ms


In [8]:
v_per_group = v_grouped(v, grouping, group_mapping)

In [9]:
most_important_groups(v_per_group, 5)

[('wikiwand.com', 0.5508036219874283),
 ('goodreads.com', 0.5401960523473347),
 ('amazon.com', 0.5277891005927581),
 ('wikimili.com', 0.522149887005473),
 ('openlibrary.org', 0.5211981876261937)]

In [10]:
least_important_groups(v_per_group, 5)

[('imdb.com', 0.4877848228428647),
 ('researchgate.net', 0.48954101212472395),
 ('litcharts.com', 0.4909785382466099),
 ('reddit.com', 0.4933304393088931),
 ('youbookinc.com', 0.4945495060356605)]