In [1]:
import os
import pandas as pd
import torch
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModelForSequenceClassification

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from nlinec.data.load import get_positive_data, get_all_types
from nlinec.predict import predict_probabilities
from nlinec.utils import get_results_dir

In [3]:
all_types = get_all_types(granularity=-1)
all_types

Unnamed: 0,full_type,type
0,/other,other
1,/other/body_part,body_part
2,/person/title,title
3,/person,person
4,/person/athlete,athlete
...,...,...
84,/organization/stock_exchange,stock_exchange
85,/location/structure/hotel,hotel
86,/location/transit/bridge,bridge
87,/location/transit/railway,railway


In [4]:
save_results_file = os.path.join(get_results_dir(), "predictions", "zero-shot", "dev_predictions.csv")

In [5]:
dev_data = get_positive_data("g_dev.json")
dev_data

2202it [00:00, 204369.30it/s]


Unnamed: 0,mention_span,full_type,sentence
0,Friday,[/other],Japan's wholesale prices in September rose 3.3...
1,September,[/other],Japan's wholesale prices in September rose 3.3...
2,Japan,"[/location, /location/country]",Japan's wholesale prices in September rose 3.3...
3,the Bank of Japan,"[/location, /location/structure, /organization...",Japan's wholesale prices in September rose 3.3...
4,3.3 %,[/other],Japan's wholesale prices in September rose 3.3...
...,...,...,...
2197,the Treasury 's,"[/organization, /organization/government]","The non-callable issue, which can be put back ..."
2198,$ 500 million of Remic mortgage securities,[/other],$ 500 million of Remic mortgage securities of...
2199,the Treasury 's,"[/organization, /organization/government]","The issue, which is puttable back to the compa..."
2200,200 basis points,[/other],Among classes for which details were available...


In [6]:
tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")
model = AutoModelForSequenceClassification.from_pretrained("roberta-large-mnli").to("cuda")

Some weights of the model checkpoint at roberta-large-mnli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
# If some predictions already exist, load them
if os.path.exists(save_results_file):
    # Load the predictions from file
    print("Loading predictions from file")
    predictions_df = pd.read_csv(save_results_file, index_col=0)
else:
    # Create a dataframe with the same index as the data
    predictions_df = pd.DataFrame(columns=list(all_types['full_type']), index=dev_data.index)

In [8]:
predictions_df

Unnamed: 0,/other,/other/body_part,/person/title,/person,/person/athlete,/other/art,/other/art/music,/other/event,/other/event/holiday,/other/religion,...,/other/award,/person/coach,/other/language/programming_language,/other/product/computer,/other/event/sports_event,/organization/stock_exchange,/location/structure/hotel,/location/transit/bridge,/location/transit/railway,/other/product/mobile_phone
0,,,,,,,,,,,...,,,,,,,,,,
1,,,,,,,,,,,...,,,,,,,,,,
2,,,,,,,,,,,...,,,,,,,,,,
3,,,,,,,,,,,...,,,,,,,,,,
4,,,,,,,,,,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2197,,,,,,,,,,,...,,,,,,,,,,
2198,,,,,,,,,,,...,,,,,,,,,,
2199,,,,,,,,,,,...,,,,,,,,,,
2200,,,,,,,,,,,...,,,,,,,,,,


In [9]:
SAVE_EVERY = 1000
todo = predictions_df.isna().any(axis=1)
(~todo).mean()

0.0

In [10]:
with torch.no_grad():
    new_predictions_counter = 0
    # Iterate over all rows in the dev data
    for row in tqdm(dev_data.loc[todo, :].itertuples(), total=todo.sum()):
        # Predict the type of the mention and store the prediction
        entailment_probabilities = predict_probabilities(model, tokenizer, row.sentence, row.mention_span, all_types['type'])[0, :, -1]

        # Store the prediction
        predictions_df.loc[row.Index, :] = entailment_probabilities

        # Save the predictions to file every SAVE_EVERY predictions
        new_predictions_counter += 1
        if new_predictions_counter >= SAVE_EVERY:
            # Save the predictions to file
            predictions_df.to_csv(save_results_file)
            new_predictions_counter = 0

# Save the remaining predictions to file
predictions_df.to_csv(save_results_file)

100%|██████████| 2202/2202 [03:40<00:00,  9.97it/s]


In [11]:
# TODO: Evaluate whether at least one of the correct types per sentence is predicted