In [1]:
import os
import pandas as pd
import torch
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModelForSequenceClassification

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from nlinec import get_positive_data, get_all_types, get_results_dir, get_models_dir
from nlinec.predict import predict_probabilities

## Setup

In [3]:
# Specify the dataset to predict and a file to save the predictions to
DATASET = 'g_test.json'
MODEL = "nlinec-2-logging"

SAVE_MODEL_TO = os.path.join(get_models_dir(), MODEL)
SAVE_PREDICTIONS_TO = os.path.join(get_results_dir(), "predictions", MODEL, "test_predictions.csv")

# Specify the parameters for the prediction
HYPOTHESIS_ONLY = True
SAVE_EVERY = 100_000

# Use the GPU if available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Make sure the directory to save the predictions to exists
os.makedirs(os.path.dirname(SAVE_PREDICTIONS_TO), exist_ok=True)

## Load models & data

In [4]:
tokenizer = AutoTokenizer.from_pretrained('roberta-large-mnli')
# model = AutoModelForSequenceClassification.from_pretrained(MODEL).to(DEVICE)
model = AutoModelForSequenceClassification.from_pretrained(SAVE_MODEL_TO).to(DEVICE)

In [5]:
# Make entailment predictions for all types and filter out the relevant ones later in the analysis
all_types = get_all_types(granularity=-1)
all_types

Unnamed: 0,full_type,type
0,/other,other
1,/location/country,country
2,/location,location
3,/other/scientific,scientific
4,/location/city,city
...,...,...
84,/location/geography/body_of_water,body_of_water
85,/location/geograpy/island,island
86,/location/geograpy,geograpy
87,/other/legal,legal


In [6]:
data = get_positive_data(DATASET)
data

Loading g_test.json: 8963it [00:00, 85365.18it/s]


Unnamed: 0,mention_span,full_type,sentence,label
0,Valley Federal Savings & Loan Association,"[/organization, /organization/company]",Valley Federal Savings & Loan Association sai...,2
1,Valley Federal,"[/organization, /organization/company]","Terms weren't disclosed, but Valley Federal ha...",2
2,Valley Federal,"[/organization, /organization/company]",Valley Federal said Friday that it is conside...,2
3,"Valley Federal , with assets of $ 3.3 billion ,","[/organization, /organization/company]","Valley Federal , with assets of $ 3.3 billion...",2
4,Imperial Corp. of America,"[/organization, /organization/company]",Valley Federal Savings & Loan Association said...,2
...,...,...,...,...
8958,"Fridays in general , which tend to be strong d...",[/other],Another study found that the 82 Fridays the 13...,2
8959,stocks,[/other],"But the date tends to be a plus, not a minus, ...",2
8960,the 1962 - 85 period,[/other],"But their study, which spanned the 1962 - 85 p...",2
8961,professors,[/person],"Robert Kolb and Ricardo Rodriguez, professors ...",2


In [7]:
# If some predictions already exist, load them
if os.path.exists(SAVE_PREDICTIONS_TO):
    # Load the predictions from file
    print("Loading predictions from file")
    predictions_df = pd.read_csv(SAVE_PREDICTIONS_TO, index_col=0)
else:
    # Create a dataframe with the same index as the data
    predictions_df = pd.DataFrame(columns=list(all_types['full_type']), index=data.index)

In [8]:
predictions_df

Unnamed: 0,/other,/location/country,/location,/other/scientific,/location/city,/other/product,/other/event/sports_event,/other/event,/other/art,/other/art/broadcast,...,/organization/stock_exchange,/location/transit/bridge,/organization/company/broadcast,/organization/transit,/location/structure/theater,/location/geography/body_of_water,/location/geograpy/island,/location/geograpy,/other/legal,/other/product/mobile_phone
0,,,,,,,,,,,...,,,,,,,,,,
1,,,,,,,,,,,...,,,,,,,,,,
2,,,,,,,,,,,...,,,,,,,,,,
3,,,,,,,,,,,...,,,,,,,,,,
4,,,,,,,,,,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8958,,,,,,,,,,,...,,,,,,,,,,
8959,,,,,,,,,,,...,,,,,,,,,,
8960,,,,,,,,,,,...,,,,,,,,,,
8961,,,,,,,,,,,...,,,,,,,,,,


In [9]:
# Find out which predictions still need to be made
todo = predictions_df.isna().any(axis=1)
print(f'Progress: {(~todo).mean() * 100:.2%}')

Progress: 0.00%


## Predict

In [10]:
# Make predictions for the remaining rows
with torch.no_grad():  # Disable gradient calculation for speed
    # Keep track of how many predictions have been made since the last save
    new_predictions_counter = 0

    # Iterate over all rows in the dev data
    for row in tqdm(data.loc[todo, :].itertuples(), total=todo.sum()):

        # Predict the type of the mention and store the prediction
        entailment_probabilities = predict_probabilities(
            model,
            tokenizer,
            row.sentence,
            row.mention_span,
            all_types['type'],
            hypothesis_only=HYPOTHESIS_ONLY)[0, :, -1]  # -1 is the entailment class

        # Store the prediction
        predictions_df.loc[row.Index, :] = entailment_probabilities

        # Save the predictions to file every SAVE_EVERY predictions
        new_predictions_counter += 1
        if new_predictions_counter >= SAVE_EVERY:
            
            # Save the predictions to file
            predictions_df.to_csv(SAVE_PREDICTIONS_TO)
            new_predictions_counter = 0

# Save the remaining predictions to file
predictions_df.to_csv(SAVE_PREDICTIONS_TO)

100%|██████████| 8963/8963 [05:19<00:00, 28.02it/s]
