In [None]:
import json
from tqdm import tqdm
from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer
from pytorch_pretrained_bert import BertTokenizer
from allennlp.data.dataset_readers.reading_comprehension.util import split_tokens_by_hyphen
from collections import defaultdict

def get_answer_type(answers):
    if answers['number']:
        return 'number'
    elif answers['spans']:
        if len(answers['spans']) == 1:
            return 'single_span'
        return 'multiple_span'
    elif any(answers['date'].values()):
        return 'date'

def get_specs(file_path, tokenizer, wordpiece_tokenizer):
    with open(file_path) as dataset_file:
        dataset = json.load(dataset_file)
        
    num_passage_words = 0
    num_passage_tokens = 0
    passage_count = 0
    num_question_words = 0
    num_question_tokens = 0
    question_count = 0
    num_answers = 0
    passage_vocab = set()
    passage_vocab_token = set()
    question_vocab = set()
    question_vocab_token = set()
    answer_counts = defaultdict(int)
    
    
    for passage_id, passage_info in tqdm(dataset.items()):
        passage_text = passage_info["passage"].strip()

                
        passage_words = [x.text for x in split_tokens_by_hyphen(tokenizer.tokenize(passage_text))]
        passage_tokens = wordpiece_tokenizer.tokenize(passage_text)
        
        num_passage_words += len(passage_words)
        num_passage_tokens += len(passage_tokens)
        passage_vocab.update(passage_words)
        passage_vocab_token.update(passage_tokens)
        passage_count += 1
        
        # Process questions from this passage
        for question_answer in passage_info["qa_pairs"]:
            question_id = question_answer["query_id"]
            question_text = question_answer["question"].strip()
            
            question_words = [x.text for x in split_tokens_by_hyphen(tokenizer.tokenize(question_text))]
            question_tokens = wordpiece_tokenizer.tokenize(question_text)
        
            num_question_words += len(question_words)
            num_question_tokens += len(question_tokens)
            question_vocab.update(question_words)
            question_vocab_token.update(question_tokens)
            question_count += 1
            
            answer_annotations = []
            if "answer" in question_answer:
                answer_annotations.append(question_answer["answer"])
                answer_type = get_answer_type(question_answer["answer"])
                answer_counts[get_answer_type(question_answer["answer"])] += 1
            if "validated_answers" in question_answer:
                answer_annotations += question_answer["validated_answers"]
            
            num_answers += len(answer_annotations)
            
    print(float(num_passage_words) / passage_count) 
    print(float(num_passage_tokens) / passage_count)
    print(float(num_question_words) / question_count)
    print(float(num_question_tokens) / question_count)
    print(float(num_answers) / question_count)
    print(float(question_count) / passage_count)
    print(passage_count)
    print(question_count)
    print(len(passage_vocab))
    print(len(passage_vocab_token))
    print(len(question_vocab))
    print(len(question_vocab_token))
    print([(key, float(answer_counts[key]) / question_count) for key in answer_counts])
           

In [None]:
get_specs('data/drop_dataset_dev.json', WordTokenizer(), BertTokenizer.from_pretrained('bert-base-uncased'))

In [None]:
get_specs('data/drop_dataset_train.json', WordTokenizer(), BertTokenizer.from_pretrained('bert-base-uncased'))