Working on task https://www.notion.so/msalnikov/b0b68b3db11b4c40a4bada127bfde310?v=635216a0f3d646d58fde31f60cc9e4c9&p=82caba2f68c94f4ea320134e855e7bb4&pm=c

In [19]:
!wget -nc https://raw.githubusercontent.com/amazon-science/mintaka/main/evaluate/evaluate.py
!wget -nc https://github.com/amazon-science/mintaka/raw/main/data/mintaka_test.json

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
File ‘evaluate.py’ already there; not retrieving.

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
--2023-04-20 18:57:14--  https://github.com/amazon-science/mintaka/raw/main/data/mintaka_test.json
Resolving github.com (github.com)... 140.82.121.4
Connecting to github.com (github.com)|140.82.121.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/amazon-science/mintaka/main/data/mintaka_test.json [following]
--2023-

In [2]:
import transformers
import datasets
import evaluate
import requests
import torch
import random
import json
import numpy as np
from pywikidata import Entity
from joblib import Memory
from tqdm.auto import tqdm
from torch.utils.data import DataLoader

from kbqa.seq2seq.utils import convert_to_features

In [3]:
torch.manual_seed(8)
random.seed(8)
np.random.seed(0)

In [4]:
memory = Memory('/tmp/cache', verbose=0)

@memory.cache
def get_wd_search_results(
    search_string: str,
    max_results: int = 500,
    language: str = 'en',
    mediawiki_api_url: str = "https://www.wikidata.org/w/api.php",
    user_agent: str = None,
) -> list:
    params = {
        'action': 'wbsearchentities',
        'language': language,
        'search': search_string,
        'format': 'json',
        'limit': 50
    }

    user_agent = "pywikidata" if user_agent is None else user_agent
    headers = {
        'User-Agent': user_agent
    }

    cont_count = 1
    results = []
    while cont_count > 0:
        params.update({'continue': 0 if cont_count == 1 else cont_count})

        reply = requests.get(mediawiki_api_url, params=params, headers=headers)
        reply.raise_for_status()
        search_results = reply.json()

        if 'success' not in search_results or search_results['success'] != 1:
            raise Exception('WD search failed')
        else:
            for i in search_results['search']:
                results.append(i['id'])

        if 'search-continue' not in search_results:
            cont_count = 0
        else:
            cont_count = search_results['search-continue']

        if cont_count > max_results:
            break

    return results

### Dataset: MINTAKA

In [5]:
dataset = datasets.load_dataset('AmazonScience/mintaka')
dataset

No config specified, defaulting to: mintaka/en
Found cached dataset mintaka (/root/.cache/huggingface/datasets/AmazonScience___mintaka/en/1.0.0/bb35d95f07aed78fa590601245009c5f585efe909dbd4a8f2a4025ccf65bb11d)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'lang', 'question', 'answerText', 'category', 'complexityType', 'questionEntity', 'answerEntity'],
        num_rows: 14000
    })
    validation: Dataset({
        features: ['id', 'lang', 'question', 'answerText', 'category', 'complexityType', 'questionEntity', 'answerEntity'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['id', 'lang', 'question', 'answerText', 'category', 'complexityType', 'questionEntity', 'answerEntity'],
        num_rows: 4000
    })
})

In [6]:
dataset['validation'].to_pandas().head()

Unnamed: 0,id,lang,question,answerText,category,complexityType,questionEntity,answerEntity
0,9ace9041,en,What is the fourth book in the Twilight series?,Breaking Dawn,books,ordinal,"[{'name': 'Q44523', 'entityType': 'entity', 'l...","[{'name': 'Q53945', 'label': 'Breaking Dawn'}]"
1,88bdb808,en,How many games are in the Uncharted series?,6,videogames,count,"[{'name': 'Q1064135', 'entityType': 'entity', ...","[{'name': 'Q17150', 'label': 'Uncharted: Drake..."
2,ecfd471d,en,"As of 2015, which group held the record for th...",U2,music,generic,"[{'name': 'Q41254', 'entityType': 'entity', 'l...","[{'name': 'Q396', 'label': 'U2'}]"
3,5d8dc3ff,en,Who is the oldest person to ever win an Academ...,James Ivory,movies,superlative,"[{'name': 'Q19020', 'entityType': 'entity', 'l...","[{'name': 'Q51577', 'label': 'James Ivory'}]"
4,118daa85,en,Which Mario Kart games do not feature Link as ...,"Super Mario Kart, Mario Kart 64, Mario Kart: S...",videogames,difference,"[{'name': 'Q188196', 'entityType': 'entity', '...","[{'name': 'Q1061560', 'label': 'Super Mario Ka..."


In [7]:
dataset['train'].to_pandas()['complexityType'].unique()

array(['ordinal', 'intersection', 'generic', 'superlative', 'yesno',
       'comparative', 'multihop', 'difference', 'count'], dtype=object)

### Zero-short learning

In [8]:
# model_checkpoint = 'google/t5-3b-ssm'
# model_checkpoint = 'google/t5-large-ssm'
model_checkpoint = '/mnt/storage/QA_System_Project/seq2seq_runs/mintaka_only_experiments_mintaka_tunned/model_t5_large_ssm_nq/models/checkpoint-7000/'
device = torch.device('cuda:2')
batch_size = 8
dataset_split = 'test'

model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_checkpoint)

seq2seq_pipeline = transformers.pipeline(
    task='text2text-generation',
    model=model,
    tokenizer=tokenizer,
    device=device,
)

In [9]:
dataset = dataset.map(
    lambda batch: convert_to_features(
        batch, tokenizer, label_feature_name="answerText"
    ),
    batched=True,
)

columns = [
    "input_ids",
    "labels",
    "attention_mask",
]
dataset.set_format(type="torch", columns=columns)

Loading cached processed dataset at /root/.cache/huggingface/datasets/AmazonScience___mintaka/en/1.0.0/bb35d95f07aed78fa590601245009c5f585efe909dbd4a8f2a4025ccf65bb11d/cache-d82ee0eb58e1802d.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/AmazonScience___mintaka/en/1.0.0/bb35d95f07aed78fa590601245009c5f585efe909dbd4a8f2a4025ccf65bb11d/cache-c8955b1d3baff169.arrow


Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

In [10]:
dataloader = DataLoader(dataset[dataset_split], batch_size=batch_size)

generated_text = []
for batch in tqdm(dataloader):
    outputs = model.generate(
        batch['input_ids'].to(device),
        max_new_tokens=64,
        return_dict_in_generate=True,
        # output_scores=True,
    )

    _generated_text = tokenizer.batch_decode(
        outputs.sequences,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )

    generated_text.extend(_generated_text)

  0%|          | 0/500 [00:00<?, ?it/s]

In [11]:
generated_entities_idxs = []
for text in tqdm(generated_text):
    results = get_wd_search_results(text, max_results=1)
    if len(results) > 0:
        entity = results[0]
    else:
        entity = None

    generated_entities_idxs.append(entity)

  0%|          | 0/4000 [00:00<?, ?it/s]

# ALARM: Edit path_to_test_file and path_to_predictions to correct 

In [20]:
with open(f'preds_{dataset_split}.json', 'w') as f:
    json.dump(
        dict(zip(dataset[dataset_split]['id'], generated_entities_idxs)),
        f
    )

## ALARM: Edit path_to_test_file and path_to_predictions to correct
!python evaluate.py --mode kg --path_to_test_set ./mintaka_test.json --path_to_predictions ./preds_test.json --lang en

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Exact Match: 0.1705
F1: 0.1722
Hits@1: 0.1732


In [34]:
from evaluate import format_predictions, calculate_h1

is_hit = []
is_type_hit = []
for generated_entity_idx, target_answer_entities in tqdm(
    zip(generated_entities_idxs, dataset[dataset_split]['answerEntity']),
    total=len(generated_entities_idxs)
):
    answer = [e['name'] for e in target_answer_entities]
    pred = format_predictions(generated_entity_idx, 'kg')
    is_hit.append(
        bool(calculate_h1(pred, answer, 'kg'))
    )

    if pred is not None and len(pred) > 0:
        answer_types = set()
        for ans in answer:
            answer_types = answer_types.union(Entity(ans).instance_of)
        
        pred_types = set(Entity(pred[0]).instance_of)

        is_type_hit.append(
            len(answer_types.intersection(pred_types)) > 0
        )
    else:
        is_type_hit.append(False)


hits_1 = sum(is_hit) / len(generated_entities_idxs)
print("Hits@1: ", hits_1)

  0%|          | 0/4000 [00:00<?, ?it/s]

Hits@1:  0.16925


In [35]:
df = dataset[dataset_split].to_pandas()
df['is_hit'] = is_hit
df['is_type_hit'] = is_type_hit
df['generated_entity'] = generated_entities_idxs

for complexity_type, group in df.groupby('complexityType'):
    print(f"Hits@1 {complexity_type:12s} = {group['is_hit'].sum() / group.index.size}")

Hits@1 comparative  = 0.27
Hits@1 count        = 0.0
Hits@1 difference   = 0.1425
Hits@1 generic      = 0.19
Hits@1 intersection = 0.285
Hits@1 multihop     = 0.09
Hits@1 ordinal      = 0.15
Hits@1 superlative  = 0.375
Hits@1 yesno        = 0.0


In [44]:
print(
    'Proportion of errors with incorrect type                =',
    df[(df['is_type_hit']) & (~df['is_hit'])].index.size / df[~df['is_hit']].index.size
)

for complexity_type, group in df.groupby('complexityType'):
    print(
        f'Proportion of errors with incorrect type ({complexity_type:12s}) =',
        group[(group['is_type_hit']) & (~group['is_hit'])].index.size / group[~group['is_hit']].index.size
    )

Proportion of errors with incorrect type                = 0.31266927475173034
Proportion of errors with incorrect type (comparative ) = 0.22602739726027396
Proportion of errors with incorrect type (count       ) = 0.0
Proportion of errors with incorrect type (difference  ) = 0.3848396501457726
Proportion of errors with incorrect type (generic     ) = 0.3395061728395062
Proportion of errors with incorrect type (intersection) = 0.7027972027972028
Proportion of errors with incorrect type (multihop    ) = 0.3516483516483517
Proportion of errors with incorrect type (ordinal     ) = 0.43529411764705883
Proportion of errors with incorrect type (superlative ) = 0.576
Proportion of errors with incorrect type (yesno       ) = 0.0
