In [62]:
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch

In [63]:
from nlinec import get_positive_data, get_all_types, get_results_dir, get_type, color_palette, get_models_dir, get_granularity, combine_premise_hypothesis, construct_hypothesis
from nlinec.predict import predict_heads

In [64]:
# Import lognorm
from matplotlib.colors import LogNorm

dark_viridis = plt.get_cmap('viridis')(0)
dark_viridis = (dark_viridis[0] * 0.8, dark_viridis[1] * 0.8, dark_viridis[2] * 0.8, 1)

## Setup

In [65]:
# Specify the dataset to predict and a file to load the predictions from
SPLIT = "test"
DATASET = f'g_{SPLIT}.json'
MODEL = "nlinec-D-1"
# MODEL = "roberta-large-mnli"
HYPOTHESIS_ONLY = False

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SAVE_MODEL_TO = os.path.join(get_models_dir(), MODEL)
SAVE_PREDICTIONS_TO = os.path.join(get_results_dir(), MODEL, f"{SPLIT}_predictions" + ("_ho" if HYPOTHESIS_ONLY else "") + ".csv")

## Setup 

In [66]:
positive_data = get_positive_data(DATASET)
positive_data

Loading g_test.json: 8963it [00:00, 42437.73it/s]


Unnamed: 0,mention_span,full_type,sentence,label
0,Valley Federal Savings & Loan Association,"[/organization, /organization/company]",Valley Federal Savings & Loan Association sai...,2
1,Valley Federal,"[/organization, /organization/company]","Terms weren't disclosed, but Valley Federal ha...",2
2,Valley Federal,"[/organization, /organization/company]",Valley Federal said Friday that it is conside...,2
3,"Valley Federal , with assets of $ 3.3 billion ,","[/organization, /organization/company]","Valley Federal , with assets of $ 3.3 billion...",2
4,Imperial Corp. of America,"[/organization, /organization/company]",Valley Federal Savings & Loan Association said...,2
...,...,...,...,...
8958,"Fridays in general , which tend to be strong d...",[/other],Another study found that the 82 Fridays the 13...,2
8959,stocks,[/other],"But the date tends to be a plus, not a minus, ...",2
8960,the 1962 - 85 period,[/other],"But their study, which spanned the 1962 - 85 p...",2
8961,professors,[/person],"Robert Kolb and Ricardo Rodriguez, professors ...",2


In [67]:
gran_types = []
for i in [1, 2, 3]:
    all_types = get_all_types(granularity=i)
    all_types['granularity'] = all_types['full_type'].apply(lambda x: get_granularity(x))
    gran_types.append(all_types[all_types['granularity'] == i])

In [68]:
# Get the predictions
predictions_df = pd.read_csv(SAVE_PREDICTIONS_TO, index_col=0)

# Combine the predictions with the original data based on the index
data_with_predictions = positive_data.join(predictions_df).explode('full_type')

# Add the granularity of the type
data_with_predictions['granularity'] = data_with_predictions['full_type'].apply(get_granularity)

# Reset the index
data_with_predictions.reset_index(drop=True, inplace=True)

In [69]:
# data_with_predictions.drop_duplicates(subset=['full_type', 'mention_span'])

In [70]:
# Compute the predicted type for each granularity
for i in [1, 2, 3]:
    # Get the rows where the correct type has the correct granularity
    granularity_mask = data_with_predictions['granularity'] == i

    # Add a new column which stores the predicted type (i.e. the the name of the column with the maximum probability)
    data_with_predictions.loc[granularity_mask, 'predicted_type'] = data_with_predictions.loc[granularity_mask, list(gran_types[i - 1]['full_type'])].idxmax(axis=1)
    
    # Add a column that stores if the prediction was correct
    data_with_predictions.loc[granularity_mask, 'correct'] = data_with_predictions.loc[granularity_mask, 'predicted_type'] == data_with_predictions.loc[granularity_mask, 'full_type']

In [71]:
short_mask = data_with_predictions['sentence'].apply(len) <= 50

In [72]:
index = 138

In [73]:
data_with_predictions.loc[index, ['sentence', 'mention_span', 'full_type', 'predicted_type', 'correct', '/other', '/person', '/location', '/organization']]

sentence           A new president wasn't named.
mention_span                     A new president
full_type                                /person
predicted_type                           /person
correct                                     True
/other                                  0.036021
/person                                 0.920342
/location                               0.002709
/organization                           0.027321
Name: 138, dtype: object

In [74]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [75]:
tokenizer = AutoTokenizer.from_pretrained('roberta-large-mnli')

if MODEL == 'roberta-large-mnli':
    model = AutoModelForSequenceClassification.from_pretrained(MODEL).to(DEVICE)
else:
    model = AutoModelForSequenceClassification.from_pretrained(SAVE_MODEL_TO).to(DEVICE)

In [80]:
combine_premise_hypothesis(
    premise=data_with_predictions.loc[index, 'sentence'],
    hypothesis=construct_hypothesis(
        entity=data_with_predictions.loc[index, 'mention_span'],
        type='[type]',
    ),
    hypothesis_only=HYPOTHESIS_ONLY
)

" A new president wasn't named.</s><s>A new president is a [type]."

In [76]:
probabilities = {type_:
    predict_heads(
        model,
        tokenizer,
        combine_premise_hypothesis(
            premise=data_with_predictions.loc[index, 'sentence'],
            hypothesis=construct_hypothesis(
                entity=data_with_predictions.loc[index, 'mention_span'],
                type=type_,
            ),
            hypothesis_only=HYPOTHESIS_ONLY
        ))[0]
for type_ in ['other', 'person', 'location', 'organization']}

In [77]:
probabilities

{'other': array([2.7951819e-04, 9.6369970e-01, 3.6020804e-02], dtype=float32),
 'person': array([3.2144846e-04, 7.9337031e-02, 9.2034149e-01], dtype=float32),
 'location': array([0.00191595, 0.9953752 , 0.0027089 ], dtype=float32),
 'organization': array([3.1682427e-04, 9.7236246e-01, 2.7320687e-02], dtype=float32)}

In [78]:
pd.DataFrame(probabilities, index=model.config.id2label.values())

Unnamed: 0,other,person,location,organization
CONTRADICTION,0.00028,0.000321,0.001916,0.000317
NEUTRAL,0.9637,0.079337,0.995375,0.972362
ENTAILMENT,0.036021,0.920341,0.002709,0.027321
