In [None]:
from pytorch_pretrained_bert import BertModel, BertTokenizer 
import torch
import sys
from allennlp.models.model import Model
from allennlp.training.metrics.drop_em_and_f1 import DropEmAndF1
from allennlp.data.token_indexers.wordpiece_indexer import WordpieceIndexer
from allennlp.data.iterators.bucket_iterator import BucketIterator
from allennlp.data.iterators.basic_iterator import BasicIterator
from allennlp.data.vocabulary import Vocabulary
from allennlp.training.trainer import Trainer
from allennlp.nn.util import move_to_device
from typing import Sequence, Union
from tqdm import tqdm

import logging
logging.basicConfig(level=logging.ERROR)
from tqdm import tqdm_notebook

from drop_bert.data_processing import BertDropTokenizer, BertDropReader, BertDropTokenIndexer
import drop_bert.nhelpers
from drop_bert.augmented_bert_templated_old import NumericallyAugmentedBERTT
from drop_bert.augmented_bert_plus_old import NumericallyAugmentedBERTPlus
from drop_bert.augmented_bert_old import NumericallyAugmentedBERT

In [None]:
device_num = 0
device = torch.device('cuda:%d' % device_num)

In [None]:
from tqdm import tqdm
from collections import defaultdict

def filter_by_answer_choice(abert, model, reader, answer_type):
    abert.load_state_dict(torch.load(model, map_location=device))
    abert.to(device)
    abert.eval()
    
    reader.answer_type = answer_type
    dev = reader.read('data/drop_dataset_dev.json')
    iterator = BasicIterator(batch_size = 1)
    iterator.index_with(Vocabulary())

    dev_iter = iterator(dev, num_epochs=1)
    dev_batches = [batch for batch in dev_iter]
    dev_batches = move_to_device(dev_batches, device_num)
    
    filtered = defaultdict(list)
    with torch.no_grad():
        for batch in tqdm(dev_batches):
            out = abert(**batch)
            for i, answer in enumerate(out["answer"]):
                filtered[answer['answer_type']].append(batch)
    torch.cuda.empty_cache()
    return filtered

In [None]:
from tqdm import tqdm
from collections import defaultdict

def get_metrics(abert, model, batches, answer_type, answer_choice):
    abert.load_state_dict(torch.load(model, map_location=device))
    abert.to(device)
    abert.eval()

    with torch.no_grad():
        for batch in tqdm(batches):
            abert(**batch)
        print(answer_type, answer_choice, len(batches), abert._drop_metrics.get_metric())
    torch.cuda.empty_cache()

In [None]:
filtered = {}
tokenizer = BertDropTokenizer('bert-base-uncased')
token_indexer = BertDropTokenIndexer('bert-base-uncased')
abert = NumericallyAugmentedBERTT(Vocabulary(), 'bert-base-uncased', special_numbers=[100, 1])
reader = BertDropReader(tokenizer, {'tokens': token_indexer}, 
                        extra_numbers=[100, 1], exp_search='template')
model = '/home/ubuntu/storage/nabert_t_full_attn/best.th'
filtered['all'] = filter_by_answer_choice(abert, model, reader, None)
filtered['date'] = filter_by_answer_choice(abert, model, reader, ['date'])
filtered['multiple_span'] = filter_by_answer_choice(abert, model, reader, ['multiple_span'])
filtered['single_span'] = filter_by_answer_choice(abert, model, reader, ['single_span'])
filtered['number'] = filter_by_answer_choice(abert, model, reader, ['number'])

In [None]:
for answer_type in filtered:
    full = []
    for answer_choice in filtered[answer_type]:
        full += filtered[answer_type][answer_choice]
        print(answer_type, answer_choice)
        get_metrics(model, filtered[answer_type][answer_choice], answer_type, answer_choice)
    get_metrics(model, full, answer_type, 'all')