In [None]:
import json
import random
import re
import copy

with open('relations.json', 'r') as f:
    relations = json.load(f)
assert isinstance(relations, dict)    

mapping = {1: 'Jan', 2: 'Feb', 3: "Mar", 4: "Apr", 5: "May",
           6: 'Jun', 7: 'Jul', 8: "Aug", 9: "Sep", 10: 'Oct',
           11: "Nov", 12: 'Dec'}
    
class Time(object):
    def __init__(self, time_str):
        splits = [int(_) for _ in time_str.split('-')]
        
        self.year = max(splits[0], 1)
        self.month = splits[1]
        self.date = splits[2]

        if self.month == 1 and self.date == 1:
            self.month = 0
            self.date = 0
        elif self.month == 0 or self.date == 0:
            self.month = 0
            self.date = 0
        
        assert self.year > 0
            
    def __gt__(self, other):
        assert isinstance(other, Time)
        if self.year > other.year:
            return True
        elif self.year < other.year:
            return False
        else:
            if self.month > other.month:
                return True
            elif self.month < other.month:
                return False
            else:
                if self.date > other.date:
                    return True
                else:
                    return False
    
    def __eq__(self, other):
        assert isinstance(other, Time), other
        return self.year == other.year and self.month == other.month and self.date == other.date
    
    def __lt__(self, other):
        assert isinstance(other, Time)
        if self.year < other.year:
            return True
        elif self.year > other.year:
            return False
        else:
            if self.month < other.month:
                return True
            elif self.month > other.month:
                return False
            else:
                if self.date < other.date:
                    return True
                else:
                    return False
    
    def __repr__(self):
        if self.month == 0:
            return '{}'.format(self.year)
        else:
            return '{} {}'.format(mapping[self.month], str(self.year))
    
    def __str__(self):
        return self.__repr__()
    
    @classmethod
    def parse(cls, time):
        assert isinstance(time, str)
        if ' ' not in time:
            return cls(f'{time}-0-0')
        else:
            month, year = time.split(' ')
            for k in mapping:
                if mapping[k] == month:
                    month = k
                    break
            return cls(f'{year}-{month}-1')
    
    @classmethod
    def minus_one_year(cls, time):
        return cls('{}-{}-{}'.format(time.year - 1, time.month, time.date))

    @classmethod
    def minus_k_year(cls, time, k):
        return cls('{}-{}-{}'.format(max(time.year - k, 2), time.month, time.date))
    
    @classmethod
    def add_one_year(cls, time):
        return cls('{}-{}-{}'.format(time.year + 1, time.month, time.date))

    @classmethod
    def add_k_year(cls, time, k):
        return cls('{}-{}-{}'.format(time.year + k, time.month, time.date))      
    
    @classmethod
    def add_one_month(cls, time):
        new_time = copy.deepcopy(time)
        if new_time.month < 12:
            new_time.month += 1
            return new_time
        else:
            new_time.month = 1
            new_time.year += 1
            return new_time

def random_pop(time_range):
    cur = time_range[0]
    end = time_range[1]
    candidates = []
    cur = Time.add_one_month(cur)
    while cur < end or cur == end:
        candidates.append(cur)
        cur = Time.add_one_month(cur)

    if candidates:
        return random.choice(candidates)
    else:
        return random.choice(time_range)

def too_close(time1, time2):
    delta = (time2.year - time1.year) * 12
    delta += time2.month - time1.month
    return delta <= 2

def prop(time, first_last=None, difficulty='easy'):
    if isinstance(time, tuple) or isinstance(time, list):
        assert len(time) == 2, time
        assert isinstance(time[0], Time) and isinstance(time[1], Time)
        if too_close(time[0], time[1]):
            return 'in {}'.format(str(time[0]))
        else:
            if difficulty == 'easy':
                option = random.choice(['between'])
            elif difficulty == 'hard':
                if first_last == 'first':
                    option = random.choice(['in', 'between-subset', 'before'])
                elif first_last == 'last':
                    option = random.choice(['in', 'between-subset', 'after'])
                elif first_last is None:
                    option = random.choice(['in', 'between-subset'])
                else:
                    raise ValueError()
            else:
                raise ValueError()

            if option == 'in':
                options = ['in {}'.format(str(random_pop(time)))]
                if time[1].year // 10 > time[0].year // 10:
                    if time[1].year % 10 >= 3:
                        options.append('in early {}s'.format(time[1].year // 10 * 10))
                    if time[0].year % 10 <= 7:
                        options.append('in late {}s'.format(time[0].year // 10 * 10))
                return random.choice(options)
            elif option == 'between':
                return 'from {} to {}'.format(str(time[0]), str(time[1]))
            elif option == 'between-subset':
                x1 = random_pop(time)
                x2 = random_pop((x1, time[1]))
                return 'between {} and {}'.format(str(x1), str(x2))
            elif option == 'before':
                x = random_pop(time)
                return 'before {}'.format(str(x))
            elif option == 'after':
                x = random_pop(time)
                return 'after {}'.format(str(x))
            else:
                raise ValueError('Not Existing')
    else:
        return 'in {}'.format(str(time))

def link_2_name(string):
    string = string.replace('/wiki/', '')
    string = string.replace('_', ' ')
    return string

## ETC Train/Test Data

In [None]:
from tqdm import tqdm
import gzip
import json

splits = ['train', 'dev', 'test']
difficulties = ['hard']

def enc(string, split):
    string = json.dumps(string)
    if split == 'train':
        return (string + '\n').encode()
    else:
        return string + '\n'

def split_paragraphs(paras):
    # Process the data
    ctxs = []
    buffer = {"title": paras[0], "text": ""}
    for para in paras[1:]:
        if para[0].isupper() and len(para.split(' ')) <= 4:
            if len(buffer["text"].split(' ')) > 15:
                ctxs.append(buffer)
            buffer = {"title": para.strip(' .'), "text": ""}
        else:
            if len(buffer['text'].split(' ')) + len(para.split(' ')) > 100:
                if len(buffer['text'].split(' ')) > 15:
                    ctxs.append(buffer)
                    buffer = {"title": ctxs[-1]['title'], "text": ""}
                tokens = para.split(' ')
                for j in range(0, len(tokens), 100):
                    buffer['text'] = ' '.join(tokens[j: j + 100])
                    ctxs.append(buffer)
                    buffer = {"title": ctxs[-1]['title'], "text": ""}
            else:
                buffer['text'] += ' ' + para

    if buffer['text']:
        ctxs.append(buffer)
    ctxs = ctxs[:100]
    return ctxs
    
for split in splits:
    with open(f'dataset/annotated_{split}.json', 'r') as f:
        data = json.load(f)
    for difficulty in difficulties:
        if split == 'train':
            file = gzip.open(f'dataset/{split}.{difficulty}.json.gzip', 'wb')
        else:
            file = open(f'dataset/{split}.{difficulty}.json', 'w')
        
        for d in tqdm(data, desc=f'{split}-{difficulty}'):
            assert isinstance(d['type'], str)
            
            paragraphs = split_paragraphs(d['paras'])
            assert isinstance(paragraphs, list)
            
            templates = relations[d['type']]['template']
            template = random.choice(templates)
            template = template.replace('$1', link_2_name(d['link']))

            qas = []
            for i, entry in enumerate(d['questions']):
                assert len(re.findall('\?$', template)) == 1, template
                time_step = [Time.parse(entry[0][0]), Time.parse(entry[0][1])]

                assert isinstance(entry[1], list), entry[1]
                assert isinstance(entry[1][0], dict), entry[1]
                
                if i == 0:
                    specifier = prop(time_step, 'first', difficulty)
                elif i == len(d['questions']) - 1:
                    specifier = prop(time_step, 'last', difficulty)
                else:
                    specifier = prop(time_step, None, difficulty)

                if '$4' in template:
                    question = template.replace('$4', specifier)
                elif '$2' in template:
                    question = template.replace('$2', specifier)
                else:
                    raise "It's not a template"
                qas.append((question, entry[1]))
            
            while len(qas) < 3 and difficulty == 'hard' and relations[d['type']]['mode'] == 'accumulate':                
                start_ = Time.parse(d['questions'][0][0][0])
                end_ = Time.parse(d['questions'][-1][0][1])
                
                options = [(Time.minus_k_year(start_, 10), start_)]
                
                recent = Time('2020-0-0')
                if end_ < recent:
                    options.append((end_, min(recent, Time.add_k_year(end_, 10))))

                choice = random.choice(range(len(options)))
                if choice == 0:
                    specifier = prop(options[0], 'first', difficulty)
                else:
                    specifier = prop(options[1], 'last', difficulty)
                
                assert '$4' in template
                question = template.replace('$4', specifier)
                qas.append((question, [{'para': 0, 'from': 0, 'end': 0, 'answer': ''}]))
            
            for q_index, qs in enumerate(qas):
                answers = [_['answer'] for _ in qs[1]]
                if split in ['dev', 'test']:
                    tmp = {'idx': d['index'] + '#' + str(q_index), 'question': qs[0], 'context': ' '.join(d['paras']),
                           'targets': answers, 'paragraphs': paragraphs}
                    q_index += 1
                    file.write(enc(tmp, split))
                else:
                    from_ = []
                    end_ = []
                    offset = [0]
                    for para in d['paras'][:-1]:
                        offset.append(len(para) + 1 + offset[-1])
                    for ans in qs[1]:
                        from_.append(offset[ans['para']] + ans['from'])
                        end_.append(offset[ans['para']] + ans['end'])
                    
                    passage = ' '.join(d['paras'])
                    assert passage[from_[0]: end_[0]] == qs[1][0]['answer'], passage[from_[0]: end_[0]] + ' # ' + qs[1][0]['answer']
                    
                    tmp = {'idx': d['index'] + '#' + str(q_index), 'question': qs[0], 'context': ' '.join(d['paras']), 
                           'targets': answers, 'from': from_, 'end': end_, 'paragraphs': paragraphs}
                    file.write(enc(tmp, split))
        
        file.close()

In [None]:
import json

data = []
with open('dataset/dev.hard.json', 'r') as f:
    for line in f:
        data.append(json.loads(line))

answerable, na = 0, 0
for d in data:
    if d['targets'] == ['']:
        na += 1
    else:
        answerable += 1
        
print(na, answerable)

In [None]:
data = []
with open('dataset/dev.hard.json', 'r') as f:
    for line in f:
        data.append(json.loads(line))
print(len(data))

In [None]:
import datasets

dataset = datasets.load_dataset('json', data_files={'dev': 'dataset/dev.easy.json'})
dataset = dataset['dev']
references = {}
for entry in iter(dataset):
    references[entry['idx']] = entry['targets']

import json

with open('outputs/2021-07-27/11-25-35/output.json') as f:
    pred = json.load(f)

prediction = {}
for row in pred:
    output = '' if row['output'] == '[unanswerable]' else row['output']
    prediction[row['idx']] = output
    
from utils import get_raw_scores

get_raw_scores(prediction, references)