In [1]:
import pandas as pd
import pickle
from tqdm.auto import tqdm
import time

from joblib import Parallel, delayed

%load_ext autoreload
%autoreload 2
from pywikidata import Entity
from kbqa.candidate_selection import QuestionToRankInstanceOf
from kbqa.logger import get_logger

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
logger = get_logger()

dataset_version = 'mintaka' # wdsq, rubq

if dataset_version == 'mintaka':
    answers_df = pd.read_pickle('./mintaka_onehop_answers_no_tree.pkl')
    mgenre_df = pd.read_pickle('./mintaka_onehop_mgenre_no_tree.pkl')
elif dataset_version == 'wdsq':
    answers_df = pd.read_csv('./filtered_test_with_answers.csv')
    mgenre_df = pd.read_pickle('./filtered_test_with_mgenre_no_prefix_tree.pkl')
elif dataset_version == 'rubq':
    answers_df = pd.read_csv('rubq_test_with_answers_no_prefix.csv')
    mgenre_df = pd.read_pickle('./test_rubq2_mgenre.pkl')
else:
    raise ValueError('Wrong dataset_version')

mgenre_df = mgenre_df.loc[:,~mgenre_df.columns.duplicated()]
answers_df = answers_df.loc[:,~answers_df.columns.duplicated()]

mgenre_df.head()

Unnamed: 0,id,lang,question,answerText,category,complexityType,questionEntity,answerEntity,Q,mgenre_results,selected_mgenre_results,selected_entities
0,15aae099,en,Who is George Clooney's current wife?,Amal Clooney,movies,generic,"[{'name': 'Q23844', 'entityType': 'entity', 'l...","[{'name': 'Q16769592', 'label': 'Amal Clooney'}]",Who is George Clooney's current wife?,"[George Clooney >> en, George Clooney >> id, G...",[George Clooney >> en],[Q23844]
1,da14605c,en,Who got the horse head in The Godfather?,Luca Brasi,movies,generic,"[{'name': 'Q47703', 'entityType': 'entity', 'l...","[{'name': 'Q1159541', 'label': 'Luca Brasi'}]",Who got the horse head in The Godfather?,"[The Horse Head in the Godfather >> en, The Ho...","[The Horse Head in the Godfather >> en, The Ho...",[]
2,4962bc3f,en,What movie won the Oscar for Best Picture in 1...,Gandhi,movies,generic,"[{'name': 'Q102427', 'entityType': 'entity', '...","[{'name': 'Q202211', 'label': 'Gandhi'}]",What movie won the Oscar for Best Picture in 1...,[What Movie Won the Oscar for Best Picture in ...,[],[]
3,13eaddb4,en,Who won Best Actress at the Oscars in 2005?,Hilary Swank,movies,generic,"[{'name': 'Q19020', 'entityType': 'entity', 'l...","[{'name': 'Q93187', 'label': 'Hilary Swank'}]",Who won Best Actress at the Oscars in 2005?,"[Anexo:Óscar a la mejor actriz >> es, Anexo:Pr...",[],[]
4,32cd4730,en,Who is the actor that played John Kramer in Saw?,Tobin Bell,movies,generic,"[{'name': 'Q12320195', 'entityType': 'entity',...","[{'name': 'Q310190', 'label': 'Tobin Bell'}]",Who is the actor that played John Kramer in Saw?,"[Actor >> en, John Kramer >> en, John Kramer (...",[Actor >> en],"[Q291840, Q421957, Q557214, Q2473937, Q2823758..."


In [6]:
answer_cols = [c for c in answers_df.columns if 'answer_' in c]

if dataset_version == 'wdsq':
    df = pd.merge(left=answers_df[['Q']+answer_cols], right=mgenre_df[['O', 'Q', 'selected_entities']], left_on='Q', right_on='Q', how='left')[['O', 'Q', 'selected_entities']+answer_cols]
elif dataset_version == 'rubq':
    df = pd.merge(left=answers_df, right=mgenre_df, left_on='Q', right_on='Q', how='left')[['O', 'Q', 'selected_entities']+answer_cols]
elif dataset_version == 'mintaka':
    cols_to_use = ['id'] + answers_df.columns.difference(mgenre_df.columns).tolist()
    df = pd.merge(left=answers_df[cols_to_use], right=mgenre_df, left_on='id', right_on='id', how='left', )
    df['O'] = df['answerEntity'].apply(lambda e: [_e.get('name') for _e in e])

df.head()

Unnamed: 0,id,answer_0,answer_1,answer_10,answer_100,answer_101,answer_102,answer_103,answer_104,answer_105,...,answerText,category,complexityType,questionEntity,answerEntity,Q,mgenre_results,selected_mgenre_results,selected_entities,O
0,15aae099,Cate Blanchett,Angelina Jolie,Sofia Coppola,Denise Richards,Kristin Clooney,Alison Lohman,Jennifer Aniston,Helena Clooney,Jennifer Aniston Clooney,...,Amal Clooney,movies,generic,"[{'name': 'Q23844', 'entityType': 'entity', 'l...","[{'name': 'Q16769592', 'label': 'Amal Clooney'}]",Who is George Clooney's current wife?,"[George Clooney >> en, George Clooney >> id, G...",[George Clooney >> en],[Q23844],[Q16769592]
1,da14605c,Francis Ford Coppola,Francis Ford Coppola,Frank Sinatra,Joe DiMaggio,Charles Bronson,Robert De Niro,Vito Corleone,George C. Scott,Michael Cimino,...,Luca Brasi,movies,generic,"[{'name': 'Q47703', 'entityType': 'entity', 'l...","[{'name': 'Q1159541', 'label': 'Luca Brasi'}]",Who got the horse head in The Godfather?,"[The Horse Head in the Godfather >> en, The Ho...","[The Horse Head in the Godfather >> en, The Ho...",[],[Q1159541]
2,4962bc3f,The Best Years of Our Lives,To Kill a Mockingbird,Night of the Living Dead,The Great Gatsby,The Adventures of Tom Sawyer,The Greatest Showman,Dracula II: The Final Cut,Underworld: Awakening,Goodfellas,...,Gandhi,movies,generic,"[{'name': 'Q102427', 'entityType': 'entity', '...","[{'name': 'Q202211', 'label': 'Gandhi'}]",What movie won the Oscar for Best Picture in 1...,[What Movie Won the Oscar for Best Picture in ...,[],[],[Q202211]
3,13eaddb4,Cate Blanchett,Meryl Streep,Angelina Jolie,Jennifer Aniston,Sandra Bullock,Jessica Lange,Sarah Jessica Parker,Jodie Foster,Kate Winslet,...,Hilary Swank,movies,generic,"[{'name': 'Q19020', 'entityType': 'entity', 'l...","[{'name': 'Q93187', 'label': 'Hilary Swank'}]",Who won Best Actress at the Oscars in 2005?,"[Anexo:Óscar a la mejor actriz >> es, Anexo:Pr...",[],[],[Q93187]
4,32cd4730,John Kramer,John Kramer,John Kramer,Paul Rudd,Bill Cosby,James Earl Jones,Joseph Cotten,Robert De Niro,Paul Newman,...,Tobin Bell,movies,generic,"[{'name': 'Q12320195', 'entityType': 'entity',...","[{'name': 'Q310190', 'label': 'Tobin Bell'}]",Who is the actor that played John Kramer in Saw?,"[Actor >> en, John Kramer >> en, John Kramer (...",[Actor >> en],"[Q291840, Q421957, Q557214, Q2473937, Q2823758...",[Q310190]


In [7]:
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

In [8]:
def _row_proc(row):
    answers_candidates = []
    for lbl in row[answer_cols].dropna().unique():
        try:
            answers_candidates.extend(Entity.from_label(lbl)[:2])
        except ValueError:
            pass

    question_entities = [Entity(e) for e in row['selected_entities']]

    qtr = QuestionToRankInstanceOf(
        row['Q'],
        question_entities,
        answers_candidates,
        only_forward_one_hop=True,
    )

    answers = qtr.final_answers()
    if len(answers) > 0:
        return answers[0][1].idx
    else:
        return None

# filtered_answers = Parallel(n_jobs=6)(
#     delayed(_row_proc)(row)
#     for _, row in tqdm(df.iterrows(), total=df.index.size)
# )
# filtered_answers = [_row_proc(row) for _, row in tqdm(df.iterrows(), total=df.index.size)]
filtered_answers = []
for _, row in tqdm(df.iterrows(), total=df.index.size):
    filtered_answers.append(_row_proc(row))

100%|██████████| 340/340 [00:46<00:00,  7.27it/s]


In [9]:
import re

df['filtered_answer'] = filtered_answers
df['filtered_answer'] = df['filtered_answer'].apply(lambda x: Entity(x) if x is not None else None)

def _parse(s, pattern=re.compile(r'Q[0-9]+')):
    if isinstance(s, str):
        try:
            return [Entity(e) for e in re.findall(pattern, s)]
        except:
            return []
    else:
        return [Entity(e) for e in s]


df['target'] = df['O'].apply(_parse)

df['is_correct'] = df.apply(
    lambda row: row['filtered_answer'] in row['target'],
    axis=1
)

df[df['is_correct']].index.size / df.index.size

0.18235294117647058

In [28]:
class QuestionToRankInstanceOfHtml(QuestionToRankInstanceOf):
    def _repr_html_(self) -> str:
        html = ["""
        <style>
        .flex-row-container {
            display: flex;
            flex-wrap: wrap;
        }
        .flex-row-container > .flex-row-item {
            flex: 1 0 29%; /*grow | shrink | basis */
        }

        .flex-row-item {
            margin: 10px;
        }

        th {
            word-wrap: break-word;
        }

        td {
            word-wrap: break-word;
        }
        </style>
        """]
        html.append(f'<b>Question:</b> {self.question}')
        if self.target is not None:
            for te in self.target:
                html.append(
                    f'<br><b>Target:</b> <span style="color: green">Entity: {te.idx} ({te.label})</span> (InstanceOf: {"; ".join([f"{e.idx} ({e.label})" for e in te.instance_of])})'
                )
        html.append('<div class="flex-row-container">')

        # FINAL ANSWERS
        html_final_answers = ["<h4>Final answers</h4>"]
        html_final_answers.extend([
            '<table style="border: 1px solid #A9ED2B; width: 900px;">',
            '<tr style="font-size:1rem; font-weight: bold; background-color: #A9ED2B">',
            "<th>Property</th>",
            "<th>P Label</th>",
            "<th>Entity</th>",
            "<th>E Label</th>",
            "<th>InstanceOf</th>",
            "<th>instance of score</th>",
            "<th>forward one hop neighbors score</th>",
            "<th>answers candidates score</th>",
            "<th>property question intersection score</th>",
            "</tr>"
        ])
        for property, entity, instance_of_score, forward_one_hop_neighbors_score, answers_candidates_score, property_question_intersection_score in self.final_answers():
            if self.target is not None and entity in self.target:
                html_final_answers.append('<tr style="background-color: #3CF30F">')
            else:
                html_final_answers.append('<tr>')
            html_final_answers.extend([
                f'<td>{property.idx if property is not None else ""}</td>',
                f'<td>{property.label if property is not None else ""}</td>',
                f'<td>{entity.idx}</td>',
                f'<td>{entity.label}</td>',
                f'<td>{"<br>".join([f"{io.idx} ({io.label})" for io in entity.instance_of])}</td>',
                f'<td>{instance_of_score:.5f}</td>',
                f'<td>{forward_one_hop_neighbors_score:.5f}</td>',
                f'<td>{answers_candidates_score:.5f}</td>',
                f'<td>{property_question_intersection_score:.5f}</td>',
                '</tr>',
            ])
        html_final_answers.append("</table>")
        html_final_answers = "".join(html_final_answers)
        
        # QUESTION
        html_question_entities = []
        for qentity in self.question_entities:
            html_question_entities.append(f"<h4>One hop neighbors for Entity: {qentity.idx} ({qentity.label})</h4>")
            html_question_entities.extend([
                '<table style="width: 700px;">',
                '<tr style="font-size:1rem; font-weight: bold; background-color: #50ADFF">',
                '<th>Dir</th>',
                "<th>Property</th>",
                "<th>P Label</th>",
                "<th>Entity</th>",
                "<th>E Label</th>",
                "<th>InstanceOf</th>",
                "</tr>"
            ])

            if self.only_forward_one_hop:
                neighbors = qentity.forward_one_hop_neighbors
            else:
                neighbors = qentity.forward_one_hop_neighbors + qentity.backward_one_hop_neighbors

            for property, entity in neighbors:
                if self.target is not None and entity in self.target:
                    html_question_entities.append('<tr style="background-color: #3CF30F">')
                elif entity in self.final_answers():
                    html_question_entities.append('<tr style="background-color: #A9ED2B">')
                else:
                    html_question_entities.append('<tr>')
                html_question_entities.extend([
                    f'<td>{"->" if (property, entity) in qentity.forward_one_hop_neighbors else "<-"}</td>'
                    f'<td>{property.idx}</td>',
                    f'<td>{property.label}</td>',
                    f'<td>{entity.idx}</td>',
                    f'<td>{entity.label}</td>',
                    f'<td>{"<br>".join([f"{io.idx} ({io.label})" for io in entity.instance_of])}</td>',
                    '</tr>',
                ])

            html_question_entities.append('</table>')
        
        html_question_entities = '<div class="flex-row-item">' + html_final_answers + "".join(html_question_entities) + '</div>'
        html.append(html_question_entities)

        # ANSWERS_INSTANCE_OF_COUNT
        html_answer_instance_of = ""
        html_answer_instance_of = [
            '<h4>Answers instanceOf count (<b style="color: green;">selected</b>)</h4>',
            "<table>",
            '<tr style="font-size:1rem; font-weight: bold; background-color: #50ADFF">',
            "<th>InstanceOf</th>",
            "<th>Label</th>",
            "<th>Count</th>",
            "</tr>"
        ]
        for instance_of_entity, count in self.answer_instance_of_count:
            if instance_of_entity in self._answer_instance_of:
                html_answer_instance_of.append('<tr style="background-color: #7AE2BC">')
            else:
                html_answer_instance_of.append('<tr>')
            html_answer_instance_of.append(f'<td>{instance_of_entity.idx}</td>')
            html_answer_instance_of.append(f'<td>{instance_of_entity.label}</td>')
            html_answer_instance_of.append(f'<td>{count}</td>')
            html_answer_instance_of.append('</tr>')  
        html_answer_instance_of.append("</table>")

        html_answer_instance_of = "".join(html_answer_instance_of)

        # ANSWERS candidates
        html_answers_candidates = [f"<h4>Seq2Seq answers candidates</h4>"]
        html_answers_candidates.extend([
            "<table>",
            '<tr style="font-size:1rem; font-weight: bold; background-color: #50ADFF">',
            "<th>Entity</th>",
            "<th>E Label</th>",
            "<th>InstanceOf</th>",
            "</tr>"
        ])
        for entity in self.answers_candidates:
            if self.target is not None and entity in self.target:
                html_answers_candidates.append('<tr style="background-color: #3CF30F">')
            else:
                html_answers_candidates.append('<tr>')
            html_answers_candidates.extend([
                f'<td>{entity.idx}</td>',
                f'<td>{entity.label}</td>',
                f'<td>{"<br>".join([f"{io.idx} ({io.label})" for io in entity.instance_of])}</td>',
                '</tr>',
            ])

        html_answers_candidates = '<div class="flex-row-item">' + html_answer_instance_of + "".join(html_answers_candidates) + '</div>'
        html.append(html_answers_candidates)

        return "".join(html) + '</div>'



row = df[df['is_correct']].iloc[11]

# row = df.iloc[5]
answers_candidates = []
for lbl in row[answer_cols].dropna().unique():
    try:
        answers_candidates.extend(Entity.from_label(lbl)[:1])
    except ValueError:
        pass
question_entities = [Entity(e) for e in row['selected_entities']]

qtr = QuestionToRankInstanceOfHtml(
    row['Q'],
    question_entities,
    answers_candidates,
    target_entity=row['target'],
    only_forward_one_hop=True,
)

qtr.final_answers()
qtr

Property,P Label,Entity,E Label,InstanceOf,instance of score,forward one hop neighbors score,answers candidates score,property question intersection score
,,Q4952029,Boy's Town,Q482994 (album),0.83333,0.0,0.83333,0.0
,,Q651231,Boyz II Men,Q120544 (vocal group),0.83333,0.0,0.75,0.0
,,Q101442732,Baekhyun,Q169930 (extended play),0.83333,0.0,0.58333,0.0
,,Q4952711,Boys and Girls Together,Q47461344 (written work),0.83333,0.0,0.5,0.0
,,Q598633,Back to School,Q11424 (film),0.83333,0.0,0.25,0.0
,,Q298548,BTS,Q4167410 (Wikimedia disambiguation page),0.0,0.0,1.0,0.0
,,Q752419,Big Time Rush,Q215380 (musical group) Q216337 (boy band),0.83333,0.0,0.16667,0.0
,,Q61446123,BTSM,Q4167410 (Wikimedia disambiguation page),0.0,0.0,0.91667,0.0
,,Q419756,Boy Meets Girl,Q4167410 (Wikimedia disambiguation page),0.0,0.0,0.66667,0.0
,,Q4189075,Boy Meets Boy,Q24862 (short film),0.0,0.0,0.41667,0.0

InstanceOf,Label,Count
Q482994,album,1.0
Q120544,vocal group,1.0
Q169930,extended play,1.0
Q47461344,written work,1.0
Q24862,short film,1.0
Q11424,film,1.0
Q215380,musical group,1.0
Q216337,boy band,1.0
Q17317604,professional wrestling event,1.0

Entity,E Label,InstanceOf
Q298548,BTS,Q4167410 (Wikimedia disambiguation page)
Q61446123,BTSM,Q4167410 (Wikimedia disambiguation page)
Q4952029,Boy's Town,Q482994 (album)
Q651231,Boyz II Men,Q120544 (vocal group)
Q419756,Boy Meets Girl,Q4167410 (Wikimedia disambiguation page)
Q101442732,Baekhyun,Q169930 (extended play)
Q4952711,Boys and Girls Together,Q47461344 (written work)
Q4189075,Boy Meets Boy,Q24862 (short film)
Q344413,Boys and Girls,Q4167410 (Wikimedia disambiguation page)
Q598633,Back to School,Q11424 (film)


In [91]:
import os

os.environ['SEQ2SEQ_RUN_NAME'] = 'wdsq_tunned_t5_large_ssm_nq'
os.environ['CUDA_VISIBLE_DEVICES'] = '5'
os.environ['SEQ2SEQ_DATASET'] = 'rubq'
os.environ['SEQ2SEQ_MODEL_NAME'] = 'google/t5-large-ssm-nq'

In [22]:
from pathlib import Path
import pandas as pd
from joblib import Parallel, delayed
from tqdm.auto import tqdm
import numpy as np

tqdm.pandas()

from pywikidata import Entity
from kbqa.candidate_selection import QuestionToRankInstanceOf
from seq2seq_dbs_answers_generation import load_params, load_datasets


params, run_name = load_params()
train_dataset, valid_dataset, test_dataset, question_col_name = load_datasets(
    params
)
datasets = {
    'train': train_dataset,
    'valid': valid_dataset,
    'test': test_dataset,
}

run_name = 'wdsq_tunned'
params, run_name

({'seq2seq': {'dataset': 'wdsq',
   'model': {'name': 'google/t5-large-ssm',
    'path': '/mnt/raid/data/kbqa/seq2seq_runs/wdsq_tunned/google_t5-large-ssm/models/',
    'num_return_sequences': 200,
    'num_beams': 200,
    'num_beam_groups': 20,
    'diversity_penalty': 0.1,
    'batch_size': 2}},
  'entity_linking': {'ner': {'path': '/mnt/raid/data/kbqa/ner/spacy_models/wdsq_tuned/model-best'}}},
 'wdsq_tunned')

In [3]:
candidates_main_path = Path(f"/mnt/raid/data/kbqa/datasets/candidates/{run_name}/{params['seq2seq']['dataset']}/")
linked_entities_main_path = Path(f"/mnt/raid/data/kbqa/datasets/linked_entities/{run_name}/{params['seq2seq']['dataset']}/")

split_name = 'valid'
candidates_df = pd.read_pickle(candidates_main_path / f"{split_name}.pkl")
answer_cols = candidates_df.columns
linked_entities_df = pd.read_pickle(linked_entities_main_path / f"{split_name}.pkl")
df = pd.concat([datasets[split_name].to_pandas(), linked_entities_df[['selected_entities']], candidates_df], axis=1)
df.head()

Unnamed: 0,S,P,O,Q,selected_entities,answer_0,answer_1,answer_2,answer_3,answer_4,...,answer_190,answer_191,answer_192,answer_193,answer_194,answer_195,answer_196,answer_197,answer_198,answer_199
0,Q318926,P19,Q1010,where was sasha vujačić born,[Q318926],Zagreb,Belgrade,Dubrovnik,Belgrade,Zagreb,...,Vrbice,ijek,Split,Zagreb,Dubrovnik,Hvar,Skopje,Bihaj,Kos,Split
1,Q2568216,R57,Q14949730,What is a film directed by wiebke von carolsfeld?,[],Idiocracy,The Night of the Living Dead,It's a Wonderful Life,Idiocracy,The Man Who Wasn't There,...,Leopold's Tale,Goethe's Son,Anatomy of a Surgeon,The Story of a Girl,The Swordsman,Leopold's Garden,Leopold's Dream,Star Wars: Episode I,The Man Who Won,Leopold II of Germany
2,Q2275923,P106,Q40348,What was Seymour Parker Gilbert's profession?,[],politician,politician,actor,politician,actor,...,entrepreneur,lawyer,journalist,musician,photographer,singer,singer-songwriter,poet,author,baseball manager
3,Q2856873,P20,Q160927,in what french city did antoine de févin die,[Q2856873],Paris,Paris,Villefranche-Billancourt,Paris,Villefranche-sur-Mer,...,Aix-sur-Oise,Lyon,Neuilly-sur-Mer,La Roche-Bresse,Saint-Cloud,Villefranche-Billancon,Villefranche-sur-Mer,Saint-Lazare,Neuilly en Provence,Périgord
4,Q522966,P106,Q2526255,What job does jamie hewlett have,[Q522966],screenwriter,screenwriter,actor,screenwriter,actor,...,television personality,director of communications,radio DJ,producer,writer,television producer,journalist,politician,film director,entrepreneur


In [4]:
# row = df.iloc[0]

# answers_candidates = []
# for lbl in row[answer_cols].dropna().unique():
#     try:
#         answers_candidates.extend(Entity.from_label(lbl)[:1])
#     except ValueError:
#         pass
# question_entities = [Entity(e) for e in row['selected_entities']]

# qtr = QuestionToRankInstanceOf(
#     row['Q'],
#     question_entities,
#     answers_candidates,
#     # target_entity=row['target'],
#     only_forward_one_hop=True,
# )

# results = qtr.final_answers()
# results_df = pd.DataFrame(
#     results,
#     columns=['property', 'entity', 'instance_of_score', 'forward_one_hop_neighbors_score', 'answers_candidates_score', 'property_question_intersection_score']
# )
# results_df['property'] = results_df['property'].apply(lambda p: p.idx if p is not None else None)
# results_df['entity'] = results_df['entity'].apply(lambda p: p.idx if p is not None else None)
# results_df.head()

In [15]:
only_forward_one_hop = False 

scores_cols = [
    'instance_of_score',
    'forward_one_hop_neighbors_score',
    'answers_candidates_score',
    'property_question_intersection_score',
]

X = []
Y = []
results_is_correct = np.array([None] * df.index.size)
for path in tqdm(
    Path(f"/mnt/raid/data/kbqa/datasets/selected_candidates/{run_name}/{params['seq2seq']['dataset']}/{split_name}/fw_only_{only_forward_one_hop}/").glob("*.json")
):
    index = int(path.name.split('.')[0])
    result_df = pd.read_json(path)
    target_entity = Entity(df.loc[index]['O'])
    is_correct = target_entity == Entity(result_df.iloc[0]['entity'])
    results_is_correct[index] = is_correct
    
    result_df['Y'] = result_df['entity'].apply(lambda e: Entity(e) == target_entity).astype(int).values.tolist()

    result_df = pd.concat([result_df[result_df['Y'] == 1], result_df[result_df['Y'] == 0].iloc[-5:], result_df[result_df['Y'] == 0].iloc[:5]], axis=0)

    X.extend(result_df[scores_cols].values.tolist())
    Y.extend(result_df['Y'].values.tolist())

2821it [01:32, 30.56it/s]


In [16]:
Y = np.array(Y)
X = np.array(X)
cnt = len(Y[Y == 1])
out_Y = np.empty((cnt*2,))
out_X = np.empty((cnt*2, 4))

indices = np.where(Y == 1)[0]
out_Y[:cnt] = Y[indices]
out_X[:cnt] = X[indices]

indices = np.random.choice(np.where(Y == 0)[0], cnt)
out_Y[cnt:] = Y[indices]
out_X[cnt:] = X[indices]

In [17]:
from sklearn.linear_model import LogisticRegression

model = LogisticRegression(penalty='l1', solver='saga')
model.fit(out_X, out_Y)
coefs = np.array(model.coef_[0])
print(coefs)

[ 1.62802934 -0.13551901  2.10342577  5.7296816 ]


In [18]:
results_is_correct[results_is_correct != None].sum() / len(results_is_correct)

0.4551577454803261

In [27]:
results_is_correct = np.array([None] * df.index.size)
for path in tqdm(
    Path(f"/mnt/raid/data/kbqa/datasets/selected_candidates/{run_name}/{params['seq2seq']['dataset']}/{split_name}/fw_only_{only_forward_one_hop}/").glob("*.json")
):
    index = int(path.name.split('.')[0])
    result_df = pd.read_json(path).iloc[:100]

    result_df['score'] = result_df.apply(
        lambda row: row[scores_cols].values @ coefs,
        axis=1
    )
    result_df = result_df.sort_values(by='score', ascending=False)

    is_correct = Entity(df.loc[index]['O']) == Entity(result_df.iloc[0]['entity'])
    results_is_correct[index] = is_correct

2821it [02:04, 22.69it/s]


In [28]:
results_is_correct[results_is_correct != None].sum() / len(results_is_correct)

0.4587025877348458

In [103]:
df = pd.concat([datasets[split_name].to_pandas(), candidates_df], axis=1)
df.head()

Unnamed: 0,object,question,answer_0,answer_1,answer_2,answer_3,answer_4,answer_5,answer_6,answer_7,...,answer_190,answer_191,answer_192,answer_193,answer_194,answer_195,answer_196,answer_197,answer_198,answer_199
0,"[Q7944, Q167903, Q5975740, Q60186, Q2580904, Q...",What can cause a tsunami?,a tsunami,a tsunami,earthquakes,a tsunami,earthquakes,a tsunami,earthquake,tidal waves,...,tectonic earthquakes,torn offshore,tsunami-related flooding,a tsunami amplitude,earthquake,tectonic plate explosion,tectonic plate movement,surfacing of the seafloor,tsunami physics,tectonic instability
1,[Q102513],"Who wrote the novel ""uncle Tom's Cabin""?",Ernest Hemingway,William Henry Draper,Ernest Hemingway,William Henry Draper,A. C. Benson,Philip Van Doren Stern,Robert Louis Stevenson,C. S. Lewis,...,A. Philip Randall,Irwin Allen,James A. McNeill,James A. Blandick,Tom Clancy,John Green,James A. Blanding,James A. Garner,James A. Garfield,Charles Bukowski
2,[Q692],"Who is the author of the play ""Romeo and Juliet""?",William Shakespeare,William Shakespeare,William Shakespeare,William Shakespeare,William Shakespeare,William Shakespeare,William Shakespeare,William Shakespeare,...,Arthur Brooke,Edwin Bolland,Claude Aventine,Claude Aventiaux,Charles Shakespeare,Charles William Bingley,Charles William Connolly,Claude Aventin,Edwina Moore,Isobel Ricci
3,[Q19660],What is the name of the capital of Romania?,Bucharest,Bucharest,Bucharest,Bucharest,Bucharest,Bucharest,Bucharest,Bucharest,...,Romania,Craiova,București,Constanța,Bucharém,Bucharida,Iasi,Timisoara,Iaşi,Bucureşti
4,"[Q6607, Q17172850, Q483994, Q626035, Q2643890]",What instrument did Jimi Hendrix play?,guitar,guitar,guitar,guitar,guitar,guitar,guitar,guitar,...,bass guitar,piano,synthesized bass,harmonica,synthesized guitar,synthesized music,tuba,harp,synthesized instruments,tin whistle


In [104]:
def l2e(l):
    try:
        return Entity.from_label(l)[0]
    except:
        pass

df['t5results'] = df['answer_0'].apply(l2e)

In [105]:
df['targets'] = df['object'].apply(lambda lst: [Entity(o) for o in lst])

In [106]:
# (df['O'].apply(Entity) == t5results).sum() / df.index.size

res = df.apply(
    lambda row: row['t5results'] in row['targets'],
    axis=1
)

res.sum() / res.index.size

0.1779842744817727

In [97]:
params

{'seq2seq': {'dataset': 'rubq',
  'model': {'name': 'google/t5-large-ssm-nq',
   'path': '/mnt/raid/data/kbqa/seq2seq_runs/wdsq_tunned/google_t5-large-ssm/models/',
   'num_return_sequences': 200,
   'num_beams': 200,
   'num_beam_groups': 20,
   'diversity_penalty': 0.1,
   'batch_size': 2}},
 'entity_linking': {'ner': {'path': '/mnt/raid/data/kbqa/ner/spacy_models/wdsq_tuned/model-best'}}}