## 随机shuffle扩充数据

In [1]:
import json
import numpy as np
import random

In [2]:
with open("reclor_data/train_qtype.json","r") as file:
    train_data = json.load(file)

In [3]:
train_data[0]

{'context': "In rheumatoid arthritis, the body' s immune system misfunctions by attacking healthy cells in the joints causing the release of a hormone that in turn causes pain and swelling. This hormone is normally activated only in reaction to injury or infection. A new arthritis medication will contain a protein that inhibits the functioning of the hormone that causes pain and swelling in the joints.",
 'question': 'The statements above, if true, most strongly support which one of the following conclusions?',
 'answers': ['Unlike aspirin and other medications that reduce pain and swelling and that are currently available, the new medication would repair existing cell damage that had been caused by rheumatoid arthritis.',
  'A patient treated with the new medication for rheumatoid arthritis could sustain a joint injury without becoming aware of it.',
  'Joint diseases other than rheumatoid arthritis would not be affected by the new medication.',
  "The benefits to rheumatoid arthritis

In [16]:
new_train_data = []
for d in train_data:
    new_d = {}
    c_sent_list = d['context'].split('.')[:-1]
    c_sent_list_copy = c_sent_list.copy()
    random.shuffle(c_sent_list_copy)
    new_context = [s.strip()+'.' for s in c_sent_list_copy]
    new_d['context'] = ' '.join(new_context)
    new_d['question'] = d['question']
    
    gt_answer = d['answers'][d['label']]
    new_answers = d['answers'].copy()
    random.shuffle(new_answers)
    new_d['answers'] = new_answers
    new_d['label'] = new_d['answers'].index(gt_answer)
    new_d['qtype'] = d['qtype']
    new_d['id_string'] = ''
    new_train_data.append(new_d)

In [17]:
with open("reclor_data/train_qtype_shuffle.json",'w') as file:
    json.dump(new_train_data, file, indent=4)

## 数据增强

In [4]:
def is_whitespace(c):
    if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
        return True
    return False

def _is_stopwords(word, stopwords):
    if word in stopwords:
        is_stopwords_flag = True
    else:
        is_stopwords_flag = False
    return is_stopwords_flag

def _head_tail_is_stopwords(span, stopwords):
    if span[0] in stopwords or span[-1] in stopwords:
        return True
    else:
        return False

def _with_septoken(ngram, tokenizer):
    if tokenizer.bos_token in ngram or tokenizer.sep_token in ngram or tokenizer.eos_token in ngram:
        flag = True
    else: flag = False
    return flag

def _is_argument_words(seq, argument_words):
    pattern = None
    arg_words = list(argument_words.keys())
    if seq.strip() in arg_words:
        pattern = argument_words[seq.strip()]
    return pattern

def _is_exist(exists:list, start:int, end:int):
    flag = False
    for estart, eend in exists:
        if estart <= start and eend >= end:
            flag = True
            break
    return flag

def _find_punct(tokens, punctuations):
    punct_ids = [0] * len(tokens)
    for i, token in enumerate(tokens):
        if token in punctuations:
            punct_ids[i] = 1
    return punct_ids

def _find_arg_ngrams(tokens, max_gram):
    n_tokens = len(tokens)
    global_arg_start_end = []
    argument_words = {}
    argument_ids = [0] * n_tokens
    for n in range(max_gram, 0, -1):  # loop over n-gram.
        for i in range(n_tokens - n):  # n-gram window sliding.
            window_start, window_end = i, i + n
            ngram = " ".join(tokens[window_start:window_end])
            pattern = _is_argument_words(ngram, relations)
            if pattern:
                if not _is_exist(global_arg_start_end, window_start, window_end):
                    global_arg_start_end.append((window_start, window_end))
                    argument_ids[window_start:window_end] = [pattern] * (window_end - window_start)
                    argument_words[ngram] = (window_start, window_end)

    return argument_words, argument_ids

def _find_dom_ngrams_2(tokens, max_gram):
    '''
    1. 判断 stopwords 和 sep token
    2. 先遍历一遍，记录 n-gram 的重复次数和出现位置
    3. 遍历记录的 n-gram, 过滤掉 n-gram 子序列（直接比较 str）
    4. 赋值 domain_ids.

    '''

    stemmed_tokens = [token_stem(token) for token in tokens]

    ''' 1 & 2'''
    n_tokens = len(tokens)
    d_ngram = {}
    domain_words_stemmed = {}
    domain_words_orin = {}
    domain_ids = [0] * n_tokens
    for n in range(max_gram, 0, -1):  # loop over n-gram.
        for i in range(n_tokens - n):  # n-gram window sliding.

            window_start, window_end = i, i+n
            stemmed_span = stemmed_tokens[window_start:window_end]
            stemmed_ngram = " ".join(stemmed_span)
            orin_span = tokens[window_start:window_end]
            orin_ngram = " ".join(orin_span)

            if _is_stopwords(orin_ngram, stopwords): continue
            if _head_tail_is_stopwords(orin_span, stopwords): continue
            if _with_septoken(orin_ngram, tokenizer): continue

            if not stemmed_ngram in d_ngram:
                d_ngram[stemmed_ngram] = []
            d_ngram[stemmed_ngram].append((window_start, window_end))

    ''' 3 '''
    d_ngram = dict(filter(lambda e: len(e[1]) > 1, d_ngram.items()))
    raw_domain_words = list(d_ngram.keys())
    raw_domain_words.sort(key=lambda s: len(s), reverse=True)  # sort by len(str).
    domain_words_to_remove = []
    for i in range(0, len(d_ngram)):
        for j in range(i+1, len(d_ngram)):
            if raw_domain_words[i] in raw_domain_words[j]:
                domain_words_to_remove.append(raw_domain_words[i])
            if raw_domain_words[j] in raw_domain_words[i]:
                domain_words_to_remove.append(raw_domain_words[j])
    for r in domain_words_to_remove:
        try:
            del d_ngram[r]
        except:
            pass

    ''' 4 '''
    d_id = 0
    for stemmed_ngram, start_end_list in d_ngram.items():
        d_id += 1
        for start, end in start_end_list:
            domain_ids[start:end] = [d_id] * (end - start)
            rebuilt_orin_ngram = " ".join(tokens[start: end])
            if not stemmed_ngram in domain_words_stemmed:
                domain_words_stemmed[stemmed_ngram] = []
            if not rebuilt_orin_ngram in domain_words_orin:
                domain_words_orin[rebuilt_orin_ngram] = []
            domain_words_stemmed[stemmed_ngram] +=  [(start, end)]
            domain_words_orin[rebuilt_orin_ngram] += [(start, end)]


    return domain_words_stemmed, domain_words_orin, domain_ids

In [249]:
import copy
import json
import numpy as np
import re
from transformers import AutoTokenizer
import gensim
import random
from itertools import combinations
from NLP_Data_Augmentation_main.back_translate import back_translate
tokenizer = AutoTokenizer.from_pretrained("roberta-large")
np.random.seed(42)
random.seed(42)
import nltk
nltk.download('averaged_perceptron_tagger')
from nltk import pos_tag,word_tokenize
from collections import defaultdict

[nltk_data] Error loading averaged_perceptron_tagger: <urlopen error
[nltk_data]     [Errno 11004] getaddrinfo failed>


In [250]:
def is_whitespace(c):
    if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
        return True
    return False

def _is_stopwords(word, stopwords):
    if word in stopwords:
        is_stopwords_flag = True
    else:
        is_stopwords_flag = False
    return is_stopwords_flag

def _head_tail_is_stopwords(span, stopwords):
    if span[0] in stopwords or span[-1] in stopwords:
        return True
    else:
        return False

def _with_septoken(ngram, tokenizer):
    if tokenizer.bos_token in ngram or tokenizer.sep_token in ngram or tokenizer.eos_token in ngram:
        flag = True
    else: flag = False
    return flag

def _is_argument_words(seq, argument_words):
    pattern = None
    arg_words = list(argument_words.keys())
    if seq.strip() in arg_words:
        pattern = argument_words[seq.strip()]
    return pattern

def _is_exist(exists:list, start:int, end:int):
    flag = False
    for estart, eend in exists:
        if estart <= start and eend >= end:
            flag = True
            break
    return flag

def _find_punct(tokens, punctuations):
    punct_ids = [0] * len(tokens)
    for i, token in enumerate(tokens):
        if token in punctuations:
            punct_ids[i] = 8
        if token and token[-1] in [".","!","?"]:
            punct_ids[i] = 9
    return punct_ids

def _find_arg_ngrams(tokens, max_gram, relations):
    n_tokens = len(tokens)
    global_arg_start_end = []
    argument_words = {}
    argument_ids = [0] * n_tokens
    for n in range(max_gram, 0, -1):  # loop over n-gram.
        for i in range(n_tokens - n):  # n-gram window sliding.
            window_start, window_end = i, i + n
            ngram = " ".join(tokens[window_start:window_end]).strip()
            pattern = _is_argument_words(ngram, relations)
            if pattern:
                if not _is_exist(global_arg_start_end, window_start, window_end):
                    global_arg_start_end.append((window_start, window_end))
                    argument_ids[window_start:window_end] = [pattern] * (window_end - window_start)
                    argument_words[ngram] = (window_start, window_end)

    return argument_words, argument_ids

def _find_dom_ngrams_2(tokens, max_gram):
    '''
    1. 判断 stopwords 和 sep token
    2. 先遍历一遍，记录 n-gram 的重复次数和出现位置
    3. 遍历记录的 n-gram, 过滤掉 n-gram 子序列（直接比较 str）
    4. 赋值 domain_ids.

    '''

    stemmed_tokens = [token_stem(token) for token in tokens]

    ''' 1 & 2'''
    n_tokens = len(tokens)
    d_ngram = {}
    domain_words_stemmed = {}
    domain_words_orin = {}
    domain_ids = [0] * n_tokens
    for n in range(max_gram, 0, -1):  # loop over n-gram.
        for i in range(n_tokens - n):  # n-gram window sliding.

            window_start, window_end = i, i+n
            stemmed_span = stemmed_tokens[window_start:window_end]
            stemmed_ngram = " ".join(stemmed_span)
            orin_span = tokens[window_start:window_end]
            orin_ngram = " ".join(orin_span)

            if _is_stopwords(orin_ngram, stopwords): continue
            if _head_tail_is_stopwords(orin_span, stopwords): continue
            if _with_septoken(orin_ngram, tokenizer): continue

            if not stemmed_ngram in d_ngram:
                d_ngram[stemmed_ngram] = []
            d_ngram[stemmed_ngram].append((window_start, window_end))

    ''' 3 '''
    d_ngram = dict(filter(lambda e: len(e[1]) > 1, d_ngram.items()))
    raw_domain_words = list(d_ngram.keys())
    raw_domain_words.sort(key=lambda s: len(s), reverse=True)  # sort by len(str).
    domain_words_to_remove = []
    for i in range(0, len(d_ngram)):
        for j in range(i+1, len(d_ngram)):
            if raw_domain_words[i] in raw_domain_words[j]:
                domain_words_to_remove.append(raw_domain_words[i])
            if raw_domain_words[j] in raw_domain_words[i]:
                domain_words_to_remove.append(raw_domain_words[j])
    for r in domain_words_to_remove:
        try:
            del d_ngram[r]
        except:
            pass

    ''' 4 '''
    d_id = 0
    for stemmed_ngram, start_end_list in d_ngram.items():
        d_id += 1
        for start, end in start_end_list:
            domain_ids[start:end] = [d_id] * (end - start)
            rebuilt_orin_ngram = " ".join(tokens[start: end])
            if not stemmed_ngram in domain_words_stemmed:
                domain_words_stemmed[stemmed_ngram] = []
            if not rebuilt_orin_ngram in domain_words_orin:
                domain_words_orin[rebuilt_orin_ngram] = []
            domain_words_stemmed[stemmed_ngram] +=  [(start, end)]
            domain_words_orin[rebuilt_orin_ngram] += [(start, end)]
    return domain_words_stemmed, domain_words_orin, domain_ids

### Step 1: Read Data

In [281]:
# load reclor data
with open("reclor_data/train_qtype.json","r") as file:
    train_data = json.load(file)

# load logiqa data
# with open("logiqa_data/Train_v1.txt","r",encoding='utf-8') as file:
#     lines = file.readlines()
# n_examples = int(len(lines)//8)
# train_data = []
# for i in range(n_examples):
#     dataDict = {}
#     dataDict['context'] = lines[i*8+2].strip()
#     dataDict['question'] = lines[i*8+3].strip()
#     dataDict['answers'] = [lines[i*8+j].strip()[2:].strip() for j in range(4,8)]
#     dataDict['label'] = lines[i*8+1].strip()
#     dataDict['qtype'] = ''
#     train_data.append(dataDict)

In [282]:
with open("reclor_data/relation_words_set.json","r") as file:
    relations = json.load(file)
punctuations = [',', '.', ';', ':']
stopwords = list(gensim.parsing.preprocessing.STOPWORDS) + punctuations

### Step 2: Get Atoms

In [283]:
"""
 1: Causal   2: onlyif   3: if
 4: Opposite   5: Follow   7: Fact
"""
train_data_da = train_data.copy()
def split_into_sentences(token_list, id_list):
    """
        function: split the context by the punctuations into sentences
        input: token_list, id_list
        output: sentence_list   List[List]
    """
    split_id_indices = np.where(np.array(id_list) == 9)[0].tolist()
    sentence_list = []
    for i in range(len(split_id_indices)):
        if i==0:
            sentence_list.append(token_list[0:split_id_indices[i]])
        else:
            if split_id_indices[i-1]+1 != split_id_indices[i]:   # address the '. . .' at the end of the sentence
                sentence_list.append(token_list[split_id_indices[i-1]+1:split_id_indices[i]])
    return sentence_list

def check_variables(var_list):
    """
        function: check variables from the sentence
        input: sentence, space_id
        output: adjusted variable_list
    """
    variable_list = []
    for var in var_list:
        variable = var
        if len(var) != 0:
            if var[0] in ['.',',','?','!',':',';']:
                variable = var[1:]
            if var[-1] in ['.',',','?','!',':',';']:
                variable = var[:-1]
        variable_list.append(variable)
    return variable_list

def get_atom(trigger_id, trigger_argument, sentence, space_id, space_ids_list):
    """
        function: convert the sentence into the atom form
            [ ( <id>, <trigger> ), [ [<variable>], [<variable>], [<variable>] ] ]
        input: trigger_id, trigger_argument, sentence, space_id, space_ids_list
        output: atom form
    """
    global num_1
    global num_2
    global num_3
    atom = [(trigger_id, trigger_argument),]
    if trigger_id == 1:   # cause -->  result
        if 0 in space_ids_list[:space_id[0]] and 0 in space_ids_list[space_id[1]:]:   # trigger is in the middle
            if trigger_argument in ["because", "since", "due to", "because of"]:    # result + trigger + cause
                var_list = [sentence[space_id[1]:], sentence[space_id[0]:space_id[1]], sentence[0:space_id[0]]]
            else:       # cause + trigger + result
                var_list = [sentence[0:space_id[0]], sentence[space_id[0]:space_id[1]], sentence[space_id[1]:]]

        elif 0 in space_ids_list[space_id[1]:]:  # trigger in the front
            if trigger_argument in ["there fore", "t there fore", "thus", "so", "hence", "as a result", "consequently"]:  
                # only result, no cause
                var_list = [[], [trigger_argument], sentence[space_id[1]:]]
            else:    # trigger + cause + , + result
                comma_id = None
                for i in range(len(sentence)):
                    if i>=space_id[1] and sentence[i]==',':   # find the closest ',' in the sentence
                        comma_id = i
                if comma_id:
                    var_list = [sentence[space_id[1]:comma_id], [trigger_argument], sentence[comma_id+1:]]
                else:
                    var_list = [[], [trigger_argument], sentence[space_id[1]:]]
        else:   # trigger in the back
            var_list = [sentence]

        atom.append(check_variables(var_list))

    elif trigger_id == 2:  # premise --->  hypothesis  (only if)
        if 0 in space_ids_list[:space_id[0]] and 0 in space_ids_list[space_id[1]:]:   # trigger is not in the front
            if trigger_argument == "unless" and "not" in sentence[0:space_id[0]]: #process "unless" and negation
                adjust_hypo = sentence[0:space_id[0]].copy()
                adjust_hypo.remove("not")
                var_list = [sentence[space_id[1]:], sentence[space_id[0]:space_id[1]], adjust_hypo]
            else:
                var_list = [sentence[space_id[1]:], sentence[space_id[0]:space_id[1]], sentence[0:space_id[0]]]

        elif 0 in space_ids_list[space_id[1]:]:   # trigger in the front
            comma_id = None
            for i in range(len(sentence)):
                if i>=space_id[1] and sentence[i]==',':   # find the closest ',' in the sentence
                    comma_id = i
            if comma_id:
                if trigger_argument == "unless" and "not" in sentence[0:space_id[0]]:#process "unless" and negation
                    adjust_hypo = sentence[comma_id+1:].copy()
                    adjust_hypo.remove("not")
                    var_list = [sentence[space_id[1]:comma_id], [trigger_argument], adjust_hypo]
                else:
                    var_list = [sentence[space_id[1]:comma_id], [trigger_argument], sentence[comma_id+1:]]
            else:
                var_list = [[], [trigger_argument], sentence[space_id[1]:]]
        else:
            var_list = [sentence]
        atom.append(check_variables(var_list))

    elif trigger_id == 3:  # premise --->  hypothesis   (if)
        if 0 in space_ids_list[:space_id[0]] and 0 in space_ids_list[space_id[1]:]:   # trigger is not in the front
            if trigger_argument in ['if', 'once', 'as long as', 'as soon as']:
                var_list = [sentence[space_id[1]:], sentence[space_id[0]:space_id[1]], sentence[0:space_id[0]]]
            else:
                var_list = [sentence[0:space_id[0]], sentence[space_id[0]:space_id[1]], sentence[space_id[1]:]]
        elif 0 in space_ids_list[space_id[1]:]:   # trigger in the front or in the back
            if trigger_argument in ['if', 'once', 'as long as', 'as soon as']:
                comma_id = None
                for i in range(len(sentence)):
                    if i>=space_id[1] and sentence[i]==',':   # find the closest ',' in the sentence
                        comma_id = i
                if comma_id:
                    var_list = [sentence[space_id[1]:comma_id], [trigger_argument], sentence[comma_id+1:]]
                else:
                    var_list = [[], [trigger_argument], sentence[space_id[1]:]]
            else:
                print(trigger_argument)
        else:
            var_list = [sentence]
        atom.append(check_variables(var_list))


    else:
        var_list = [sentence]
        atom.append(check_variables(var_list))
    return atom


""" get the atom forms for each instance """
atoms = []
atoms_text_dict = []
for d in train_data_da:
    bpe_tokens = tokenizer.tokenize(d['context'])
    bare_tokens = [token[1:].lower() if "Ġ" in token else token.lower() for token in bpe_tokens]
    bare_tokens = bare_tokens[1:] if bare_tokens[0]=='.' else bare_tokens
    argument_words, argument_space_ids = _find_arg_ngrams(bare_tokens, max_gram=5, relations=relations)
    punct_space_ids = _find_punct(bare_tokens, punctuations)
    space_ids = [a+b for a,b in zip(argument_space_ids,punct_space_ids)]
    sentence_list = split_into_sentences(bare_tokens, space_ids)
    atom_sent = []
    atom_text_dict = {}
    for index, s in enumerate(sentence_list):
        argument_words, argument_space_ids = _find_arg_ngrams(s, 5, relations)
        punct_space_ids = _find_punct(s, punctuations)
        space_ids = [a+b for a,b in zip(argument_space_ids,punct_space_ids)]
        if argument_words:
            argument_words_list = list(argument_words.keys())
            argument_words_list = [w.strip() for w in argument_words_list]
            argument_ids_list = [relations[argu] for argu in argument_words_list]
            trigger_index = argument_ids_list.index(min(argument_ids_list))
            trigger_argument = argument_words_list[trigger_index]
            trigger_id = relations[trigger_argument]  # the id number of the trigger in the dictionary
            atom = get_atom(trigger_id, trigger_argument, s, argument_words[trigger_argument], space_ids)
            atom_sent.append(atom)  # get atom forms
            atom_text_dict[index] = s
        else:
            atom_sent.append([(7,"fact"), [s]])   # get atom forms
            atom_text_dict[index] = s
    atoms.append(atom_sent)
    atoms_text_dict.append(atom_text_dict)
atoms

[[[(1, 'causes'),
   [['in',
     'r',
     'he',
     'umat',
     'oid',
     'arthritis',
     ',',
     'the',
     'body',
     "'",
     's',
     'immune',
     'system',
     'mis',
     'fun',
     'ctions',
     'by',
     'attacking',
     'healthy',
     'cells',
     'in',
     'the',
     'joints',
     'causing',
     'the',
     'release',
     'of',
     'a',
     'hormone',
     'that',
     'in',
     'turn'],
    ['causes'],
    ['pain', 'and', 'swelling']]],
  [(7, 'fact'),
   [['this',
     'hormone',
     'is',
     'normally',
     'activated',
     'only',
     'in',
     'reaction',
     'to',
     'injury',
     'or',
     'infection']]],
  [(1, 'causes'),
   [['a',
     'new',
     'arthritis',
     'medication',
     'will',
     'contain',
     'a',
     'protein',
     'that',
     'inhibits',
     'the',
     'functioning',
     'of',
     'the',
     'hormone',
     'that'],
    ['causes'],
    ['pain', 'and', 'swelling', 'in', 'the', 'joints']]]],
 [[(

### Step 3: Get Variables and Tags

In [284]:
def has_same_logical_component(set1, set2):
    has_same = False
    overlap = -1
    if len(set1) > 1 and len(set2) > 1:
        overlap = len(set1 & set2)/max(min(len(set1), len(set2)), 1)
        if overlap >= 0.6:  # hyper-parameter:0.5
            has_same = True
    return has_same, overlap

def tag_variables(tags, i, j, sents_list, max_tag, map_dict):
    """
        function: tag variables
        input: tag_list, current position i, j, sent_set, max_tag, map_dict
        output: tagged list
    """
    flag = False   # has same or not
    max_tag += 1
    current_sent = set(sents_list[i][j])-set(stopwords)
    for m in range(i+1, len(tags)):
        for n in range(len(tags[m])):
            comp_sent = set(sents_list[m][n])-set(stopwords)
            has_same,_ = has_same_logical_component(current_sent, comp_sent)
            if has_same:   # same variable
                if tags[m][n] == -1:
                    tags[i][j] = max_tag
                    tags[m][n] = max_tag
                    flag = True
                    break
    if not flag:
        tags[i][j] = max_tag
    
#     if tags[i][j] in map_dict.keys():     # write into the variable_text_map_dict
#         map_dict[tags[i][j]].append(sents_list[i][j])
#     else:
#         map_dict[tags[i][j]] = sents_list[i][j]
    
    return tags,max_tag,map_dict

""" tag the same variables """
variable_tags, negation_tags, variable_text_dict = [], [], []
for atom in atoms:
    tags, sent_list, negation_tag = [], [], []
    for atom_sent in atom:
        if len(atom_sent[1])==1:
            tags.append([-1])
            sent_list.append(atom_sent[1])
                
        elif len(atom_sent[1])==3:
            if len(atom_sent[1][0]) != 0:
                tags.append([-1,-1])
                sent_list.append([atom_sent[1][0], atom_sent[1][2]])
            else:
                tags.append([-1])
                sent_list.append([atom_sent[1][2]])
        
        ''' identify the negation word '''
        neg_list = []
        for index, s in enumerate(atom_sent[1]):
            if len(s) != 0 and index != 1:
                if "not" in s:
                    neg_list.append(1)
                else:
                    neg_list.append(0)
        negation_tag.append(neg_list)
    max_tag = -1
    map_dict = {}
    for i in range(len(tags)):
        for j in range(len(tags[i])):
            if tags[i][j] == -1:
                tags, max_tag, map_dict = tag_variables(tags,i,j,sent_list,max_tag,map_dict)   # return the tagged list
    map_dict = defaultdict(list)
    for i in range(len(tags)):
        for j in range(len(tags[i])):
            map_dict[tags[i][j]].append(sent_list[i][j])
    variable_tags.append(tags)
    negation_tags.append(negation_tag)
    variable_text_dict.append(map_dict)
# variable_tags
# variable_text_dict

### Step 4: Get Extended Variables and Tags

In [285]:
""" logic equivalence """

""" get atoms in a variable forms"""
atom_variable = []
for atom, var in zip(atoms, variable_tags):
    atom_var_instance = []
    for a, b in zip(atom, var):
        atom_var_instance.append([a[0],b])
    atom_variable.append(atom_var_instance)
    
def reverse_neg(num):
    '''
        function: reverse the negation tags
    '''
    if num == 0:
        return 1
    else:
        return 0

def is_duplicated_atom(cur_atom, cur_list):
    '''
        function: judge the duplicate of the current atom
    '''
    for a in cur_list:
        if a[0][0]==cur_atom[0][0] and a[1]==cur_atom[1]:
            return True
    return False

def get_extended_atoms(atom_input, var_input, neg_input):
    """
        function: get extended atoms of one instance
        input: atom, var_tags, neg_tags
        output: extend_atom, extend_var_tags, extend_neg_tags
    """
    atom = atom_input.copy()
    var = var_input.copy()
    neg = neg_input.copy()
    assert len(atom) == len(var)
    assert len(var) == len(neg)
    
    extend_atom, extend_var_tags, extend_neg_tags = [], [], []
    extend_atom_acc, extend_var_tags_acc, extend_neg_tags_acc = [], [], []
    iter_num = 5
    atom_acc = atom.copy()   # real-time atom for tracing the updated atom list to avoid duplicated
    for iteration in range(1,iter_num+1):
        for j in range(len(atom)):
            if atom[j][0][0] == 1 and len(var[j])==2:   # single atom equalvalence (causal)
                cur_atom = [atom[j][0],[var[j][1], var[j][0]]]
                if not is_duplicated_atom(cur_atom, atom+extend_atom):
                    extend_atom.append([atom[j][0],[var[j][1], var[j][0]]])
                    extend_var_tags.append([var[j][1], var[j][0]])
                    extend_neg_tags.append([reverse_neg(neg[j][0]), reverse_neg(neg[j][1])])
                    atom_acc.append([atom[j][0],[var[j][1], var[j][0]]])

            if atom[j][0][0] == 2 and len(var[j])==2:   # single atom equalvalence (only if)
                cur_atom = [(3,'if'), var[j]]
                if not is_duplicated_atom(cur_atom, atom+extend_atom):   # avoid duplicated
                    extend_atom.append([(3,'if'), var[j]])
                    extend_var_tags.append(var[j])
                    extend_neg_tags.append([reverse_neg(neg[j][0]), reverse_neg(neg[j][1])])
                    atom_acc.append([(3,'if'), var[j]])

            if atom[j][0][0] <= 3 and len(var[j])==2 and j!=len(atom)-1:   # two atom equalvalence 
                for k in range(j+1, len(atom)):
                    if len(var[k]) == 2 and atom[k][0][0]<=3 and var[j][1]==var[k][0] \
                        and neg[j][1]==neg[k][0] and var[j][0]!=var[k][1]:   # j(A,B) conj k(B,C)
                        cur_atom = [(atom[j][0][0], atom[j][0][1]), [var[j][0], var[k][1]]]
                        if not is_duplicated_atom(cur_atom, atom+extend_atom):# avoid duplicated
                            extend_atom.append([(atom[j][0][0], atom[j][0][1]), [var[j][0], var[k][1]]])
                            extend_var_tags.append([var[j][0], var[k][1]])
                            extend_neg_tags.append([neg[j][0], neg[k][1]])
                            atom_acc.append([(atom[j][0][0], atom[j][0][1]), [var[j][0], var[k][1]]])

                    if len(var[k]) == 2 and atom[k][0][0]<=3 and var[j][0]==var[k][1] \
                        and neg[j][0]==neg[k][1] and var[j][1]!=var[k][0]:   # k(A,B) conj j(B,C)
                        cur_atom = [(atom[j][0][0], atom[j][0][1]), [var[k][0], var[j][1]]]
                        if not is_duplicated_atom(cur_atom, atom+extend_atom):# avoid duplicated
                            extend_atom.append([(atom[j][0][0], atom[j][0][1]), [var[k][0], var[j][1]]])
                            extend_var_tags.append([var[k][0], var[j][1]])
                            extend_neg_tags.append([neg[k][0], neg[j][1]])
                            atom_acc.append([(atom[j][0][0], atom[j][0][1]), [var[k][0], var[j][1]]])

            if len(var[j])==1 :   # fact equalvalence
                for k in range(len(atom)):
                    if j!=k and len(var[k])==2 and atom[k][0][0] in [2,3] \
                        and var[j][0]==var[k][0] and neg[j][0]==neg[k][0]:
                        cur_atom = [(7, atom[j][0][1]), [var[k][1]]]
                        if not is_duplicated_atom(cur_atom, atom+extend_atom):  # avoid duplicated
                            extend_atom.append([(7, atom[j][0][1]), [var[k][1]]])
                            extend_var_tags.append([var[k][1]])
                            extend_neg_tags.append([neg[k][1]])
                            atom_acc.append([(7, atom[j][0][1]), [var[k][1]]])
                            
        ''' after each iteration, append extended atoms into the list. And search within the new atom list 
            in the next round '''
#         print(extend_atom)
        atom += extend_atom
        var += extend_var_tags
        neg += extend_neg_tags
        extend_atom_acc += extend_atom
        extend_var_tags_acc += extend_var_tags
        extend_neg_tags_acc += extend_neg_tags
        extend_atom, extend_var_tags, extend_neg_tags = [], [], []
    return extend_atom_acc, extend_var_tags_acc, extend_neg_tags_acc
    
""" predefined rules """
extend_atom_variable, extend_variable_tags, extend_negation_tags = [], [], []
for i, (atom, var, neg) in enumerate(zip(atom_variable, variable_tags, negation_tags)):
    extend_atom, extend_var_tags, extend_neg_tags = get_extended_atoms(atom, var, neg)
    extend_atom_variable.append(extend_atom)
    extend_variable_tags.append(extend_var_tags)
    extend_negation_tags.append(extend_neg_tags)

In [286]:
# obtain atoms text dict
for i in range(len(atom_variable)):
    assert len(atoms_text_dict[i]) == len(atom_variable[i]) ,"not match in number"
    for j in range(len(atom_variable[i])):
        atoms_text_dict[i][(atom_variable[i][j][0],tuple(atom_variable[i][j][1]))] = atoms_text_dict[i].pop(j)

### Step 5: Get Equivalent Logics

In [287]:
# get the extended atoms (original + extended)
full_atom_variable = [atom+extend_atom for atom,extend_atom in zip(atom_variable,extend_atom_variable)]
full_variable_tags = [v_tag+extend_v_tag for v_tag,extend_v_tag in zip(variable_tags,extend_variable_tags)]
full_negation_tags = [n_tag+extend_n_tag for n_tag,extend_n_tag in zip(negation_tags, extend_negation_tags)]

def not_duplicate(new_atom, new_atom_list):
    for a in new_atom_list:
        if sorted(new_atom)==sorted(a):
            return False
    return True

def is_same_logic(new_atom, origin_atom):
    if sorted(new_atom)==sorted(origin_atom):
        return True
    else:
        return False

def sample_possible_logics(atom, var, neg, min_len, max_len, origin_num):
    """
        function: sample all the possible logics from the full atom list
        input: atom_var, var_tag, neg_tag
        output: sample_atoms(list), sample_var_tags(list), sample_neg_tags(list)
    """
    atom_num = len(atom)
    sample_atoms, sample_var_tags, sample_neg_tags = [], [], []
    for n in range(min_len, max_len+1):
        sample_index = list(combinations(range(atom_num), n))
        for index in sample_index:
            if index != tuple(np.arange(origin_num)):   # avoid duplicate with the original
                sample_atom = [atom[k] for k in index]
                sample_var_tag = [var[k] for k in index]
                sample_neg_tag = [neg[k] for k in index]
                sample_atoms.append(sample_atom)
                sample_var_tags.append(sample_var_tag)
                sample_neg_tags.append(sample_neg_tag)
    return sample_atoms, sample_var_tags, sample_neg_tags

a = 0
equal_atom_variable, equal_variable_tags, equal_negation_tags = [], [], []
for i, (atom, var, neg) in enumerate(zip(full_atom_variable, full_variable_tags, full_negation_tags)):
#     print(atom_variable[2])
    min_len = len(atom_variable[0])  # original atoms
#     min_len = len(atom)
#     max_len = len(atom)
    max_len = min_len + 2
    origin_num = len(atom_variable[0])
    sample_atoms, sample_var_tags, sample_neg_tags = sample_possible_logics(atom, var, neg, min_len, max_len, origin_num)
    new_atom_list, new_var_tag_list, new_neg_tag_list = [], [], []
    for j, (sp_atom, sp_var_tag, sp_neg_tag) in enumerate(zip(sample_atoms, sample_var_tags, sample_neg_tags)):
        extend_atom, extend_var_tags, extend_neg_tags = get_extended_atoms(sp_atom, sp_var_tag, sp_neg_tag)
        new_atom = sp_atom + extend_atom
        new_var_tags = sp_var_tag + extend_var_tags
        new_neg_tags = sp_neg_tag + extend_neg_tags
#         print(sp_atom)
#         print(sp_neg_tag)
#         print(extend_atom)
#         print(extend_neg_tags)
        if is_same_logic(new_atom, full_atom_variable[i]) and not_duplicate(sp_atom, new_atom_list):  # transform to dict
            new_atom_list.append(sp_atom)
            new_var_tag_list.append(sp_var_tag)
            new_neg_tag_list.append(sp_neg_tag)
#     if sample_atoms != new_atom_list:
#         print("==========")
#         a += 1
#         print(i)
#         print(sample_atoms)
#         print(new_atom_list)
#         print("========")
    equal_atom_variable.append(new_atom_list)
    equal_variable_tags.append(new_var_tag_list)
    equal_negation_tags.append(new_neg_tag_list)
    if i % 100 == 0:
        print(i)

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600


In [289]:
equal_atom_variable[2]

[[[(2, 'only if'), [0, 1]],
  [(2, 'unless'), [2, 0]],
  [(4, 'however'), [2]],
  [(1, 't there fore'), [1]]],
 [[(2, 'only if'), [0, 1]],
  [(2, 'unless'), [2, 0]],
  [(4, 'however'), [2]],
  [(1, 't there fore'), [1]],
  [(3, 'if'), [0, 1]]],
 [[(2, 'only if'), [0, 1]],
  [(2, 'unless'), [2, 0]],
  [(4, 'however'), [2]],
  [(1, 't there fore'), [1]],
  [(2, 'only if'), [2, 1]]],
 [[(2, 'only if'), [0, 1]],
  [(2, 'unless'), [2, 0]],
  [(4, 'however'), [2]],
  [(1, 't there fore'), [1]],
  [(3, 'if'), [2, 0]]],
 [[(2, 'only if'), [0, 1]],
  [(2, 'unless'), [2, 0]],
  [(4, 'however'), [2]],
  [(1, 't there fore'), [1]],
  [(7, 'however'), [0]]],
 [[(2, 'only if'), [0, 1]],
  [(2, 'unless'), [2, 0]],
  [(4, 'however'), [2]],
  [(1, 't there fore'), [1]],
  [(3, 'if'), [2, 1]]],
 [[(2, 'only if'), [0, 1]],
  [(2, 'unless'), [2, 0]],
  [(4, 'however'), [2]],
  [(1, 't there fore'), [1]],
  [(7, 'however'), [1]]]]

### Step 6-1: Generate Text From Extend Atoms (Optional) 

In [238]:
""" generate text from extended atoms """

""" define template """
causal_template = ["Because [cause], [result].", "Since [cause], [result].", "[result], because [result].", \
                  "Due to [cause], [result]."]

onlyif_template = ["Only if [premise], [hypothesis].", "[hypothesis], only if [premise].", "Only when [premise], [hypothesis]."]

if_template = ["If [premise], then [hypothesis].", "[hypothesis], if [premise].", "If [premise], [hypothesis].", \
              "As long as [premise], then [hypothesis]."]

fact_template = ["In fact, [fact].", "Therefore, [fact].", "Actually, [fact].", "So [fact]."]

def add_negation(s):
    text = word_tokenize(s) #分词
    for word,tag in pos_tag(text):
        if tag in ['VB','VBD','VBG','VBN','VBZ']:
            if word in ['is','was','are','were']:
                return s.replace(word, word+" not")
            else:    
                return s.replace(word,"not "+word)
    return "not " + s


""" fill in the template """
extend_text = []
for i, (atom, var, neg) in enumerate(zip(extend_atom_variable, extend_variable_tags, extend_negation_tags)):
    text_list = []
    for j in range(len(atom)):
        random_number = np.random.choice([n for n in range(len(variable_text_dict[i][var[j][0]]))])
        text_a = ' '.join(variable_text_dict[i][var[j][0]][random_number])  
        neg_a = neg[j][0]
        if len(var[j]) == 2:
            random_number = np.random.choice([n for n in range(len(variable_text_dict[i][var[j][1]]))])
            text_b = ' '.join(variable_text_dict[i][var[j][1]][random_number]) # default the first
            neg_b = neg[j][1]
        
        if neg[j][0]==0 and "not" in text_a:
            text_a.replace("not ","")   # delete negation in text_a
        if len(neg[j])==2 and neg[j][1]==0 and "not" in text_b:
            text_b.replace("not ","")   # delete negation in text_b
        if neg[j][0]==1 and "not" not in text_a:
            text_a = add_negation(text_a)    # add negation in text_a
        if len(neg[j])==2 and neg[j][1]==1 and "not" not in text_b:
            text_b = add_negation(text_b)    # add negation in text_b
           
        if atom[j][0][0] == 1:
            selected_template = np.random.choice(causal_template)
            text = selected_template.replace("[cause]",text_a).replace("[result]",text_b)
            text_list.append(text)
        if atom[j][0][0] == 2:
            selected_template = np.random.choice(onlyif_template)
            text = selected_template.replace("[premise]",text_a).replace("[hypothesis]",text_b)
            text_list.append(text)
                
        if atom[j][0][0] == 3:
            selected_template = np.random.choice(if_template)
            text = selected_template.replace("[premise]",text_a).replace("[hypothesis]",text_b)
            text_list.append(text)
            
        if atom[j][0][0] == 7:
            selected_template = np.random.choice(fact_template)
            text = selected_template.replace("[fact]",text_a)
            text_list.append(text)
            
    extend_text.append(text_list)

In [239]:
# generate new train data
np.random.seed(42)
random.seed(42)
max_extend = 3

labelmap = {"a": 0, "b": 1, "c": 2, "d": 3}
extend_train_data = []
for index in range(1):
    for i,d in enumerate(train_data):
        new_d = {}
        if len(extend_text[i]) != 0:
            extended_text = [s.capitalize() for s in extend_text[i]]
            if len(extended_text)<=max_extend:   # maintain the maximum
                new_d['context'] = d['context'] + ' ' + ' '.join(extended_text)
            else:
                new_d['context'] = d['context'] + ' ' + ' '.join(np.random.choice(extended_text,max_extend,False))

            new_d['question'] = d['question']
            gt_answer = d['answers'][d['label']]
            new_answers = d['answers'].copy()
            random.shuffle(new_answers)
            new_d['answers'] = new_answers
            new_d['label'] = new_d['answers'].index(gt_answer)
            new_d['qtype'] = d['qtype']
            new_d['id_string'] = ''
            extend_train_data.append(new_d)

### Step 6-2: Generate Text From Equal Atoms (Optional) 

In [243]:
""" generate text from equal atoms """

""" define template """
causal_template = ["Because [cause], [result].", "Since [cause], [result].", "[result], because [result].", \
                  "Due to [cause], [result]."]

onlyif_template = ["Only if [premise], [hypothesis].", "[hypothesis], only if [premise].", "Only when [premise], [hypothesis]."]

if_template = ["If [premise], then [hypothesis].", "[hypothesis], if [premise].", "If [premise], [hypothesis].", \
              "As long as [premise], then [hypothesis]."]

fact_template = ["In fact, [fact].", "Therefore, [fact].", "Actually, [fact].", "So [fact]."]


def add_negation(s):
    text = word_tokenize(s) #分词
    for word,tag in pos_tag(text):
        if tag in ['VB','VBD','VBG','VBN','VBZ']:
            if word in ['is','was','are','were']:
                return s.replace(word, word+" not")
            else:    
                return s.replace(word,"not "+word)
    return "not " + s

""" fill in the template """
equal_text = []
for i, (atom, var, neg) in enumerate(zip(equal_atom_variable, equal_variable_tags, equal_negation_tags)):
    atom_text_list = []
    if len(atom) == 0:
        equal_text.append(atom_text_list)
    else:
        for k in range(len(atom)):   # iter for each equal logic of one atom
            text_list = []
            for j in range(len(atom[k])):   # iter for each sentence of one logic
                random_number = np.random.choice([n for n in range(len(variable_text_dict[i][var[k][j][0]]))])
                text_a = ' '.join(variable_text_dict[i][var[k][j][0]][random_number])  # default the first 
                neg_a = neg[k][j][0]
                if len(var[k][j]) == 2:
                    random_number = np.random.choice([n for n in range(len(variable_text_dict[i][var[k][j][1]]))])
                    text_b = ' '.join(variable_text_dict[i][var[k][j][1]][random_number])   # default the first
                    neg_b = neg[k][j][1]

                if neg[k][j][0]==0 and "not" in text_a:
                    text_a.replace("not ","")   # delete negation in text_a
                if len(neg[k][j])==2 and neg[k][j][1]==0 and "not" in text_b:
                    text_b.replace("not ","")   # delete negation in text_b
                if neg[k][j][0]==1 and "not" not in text_a:
                    text_a = add_negation(text_a)  # add negation in text_a
                if len(neg[k][j])==2 and neg[k][j][1]==1 and "not" not in text_b:
                    text_b = add_negation(text_b)   # add negation in text_b

                if atom[k][j][0][0] == 1:
                    selected_template = np.random.choice(causal_template)
                    text = selected_template.replace("[cause]",text_a).replace("[result]",text_b)
                    text_list.append(text)
                if atom[k][j][0][0] == 2:
                    selected_template = np.random.choice(onlyif_template)
                    text = selected_template.replace("[premise]",text_a).replace("[hypothesis]",text_b)
                    text_list.append(text)

                if atom[k][j][0][0] == 3:
                    selected_template = np.random.choice(if_template)
                    text = selected_template.replace("[premise]",text_a).replace("[hypothesis]",text_b)
                    text_list.append(text)

                if atom[k][j][0][0] == 7:
                    selected_template = np.random.choice(fact_template)
                    text = selected_template.replace("[fact]",text_a)
                    text_list.append(text)
            atom_text_list.append(text_list)
        equal_text.append(atom_text_list)

In [246]:
# generate new train data from equal text
np.random.seed(42)
random.seed(42)
max_limit = True  # max limited new atoms for each instance

equal_train_data = []
for i,d in enumerate(train_data):
    if len(equal_text[i]) != 0 and len(equal_text[i][0])!=0:
        if max_limit:   # max 3 new atoms for each instance
            select_index = np.random.choice(np.arange(len(equal_text[i])), min(3,len(equal_text[i])), replace=False) 
        else:
            select_index = np.arange(len(equal_text[i]))
        for j in select_index:
            new_d = {}
            temp_text = [s.capitalize() for s in equal_text[i][j]]
            new_d['context'] = ' '.join(temp_text)
            new_d['question'] = d['question']
            gt_answer = d['answers'][d['label']]
            new_answers = d['answers'].copy()
            random.shuffle(new_answers)
            new_d['answers'] = new_answers
            new_d['label'] = new_d['answers'].index(gt_answer)
            new_d['qtype'] = d['qtype']
            new_d['id_string'] = ''
            equal_train_data.append(new_d)

In [247]:
with open("reclor_data/train_qtype_equal3_max2.json",'w') as file:
    json.dump(train_data+equal_train_data, file, indent=4)

In [248]:
len(equal_train_data), len(extend_train_data)

(4394, 2132)

### Test Split Extend

In [258]:
""" generate text from extended atoms """

""" define template """
causal_template = ["Because [cause], [result].", "Since [cause], [result].", "[result], because [result].", \
                  "Due to [cause], [result]."]

onlyif_template = ["Only if [premise], [hypothesis].", "[hypothesis], only if [premise].", "Only when [premise], [hypothesis]."]

if_template = ["If [premise], then [hypothesis].", "[hypothesis], if [premise].", "If [premise], [hypothesis].", \
              "As long as [premise], then [hypothesis]."]

fact_template = ["In fact, [fact].", "Therefore, [fact].", "Actually, [fact].", "So [fact]."]

def add_negation(s):
    text = word_tokenize(s) #分词
    for word,tag in pos_tag(text):
        if tag in ['VB','VBD','VBG','VBN','VBZ']:
            if word in ['is','was','are','were']:
                return s.replace(word, word+" not")
            else:    
                return s.replace(word,"not "+word)
    return "not " + s


""" fill in the template """
extend_text = []
for i, (atom, var, neg) in enumerate(zip(extend_atom_variable, extend_variable_tags, extend_negation_tags)):
    text_list = []
    for j in range(len(atom)):
        random_number = np.random.choice([n for n in range(len(variable_text_dict[i][var[j][0]]))])
        text_a = ' '.join(variable_text_dict[i][var[j][0]][random_number])  
        neg_a = neg[j][0]
        if len(var[j]) == 2:
            random_number = np.random.choice([n for n in range(len(variable_text_dict[i][var[j][1]]))])
            text_b = ' '.join(variable_text_dict[i][var[j][1]][random_number]) # default the first
            neg_b = neg[j][1]
        
        if neg[j][0]==0 and "not" in text_a:
            text_a.replace("not ","")   # delete negation in text_a
        if len(neg[j])==2 and neg[j][1]==0 and "not" in text_b:
            text_b.replace("not ","")   # delete negation in text_b
        if neg[j][0]==1 and "not" not in text_a:
            text_a = add_negation(text_a)    # add negation in text_a
        if len(neg[j])==2 and neg[j][1]==1 and "not" not in text_b:
            text_b = add_negation(text_b)    # add negation in text_b
           
        if atom[j][0][0] == 1:
            selected_template = np.random.choice(causal_template)
            text = selected_template.replace("[cause]",text_a).replace("[result]",text_b)
            text_list.append(text)
        if atom[j][0][0] == 2:
            selected_template = np.random.choice(onlyif_template)
            text = selected_template.replace("[premise]",text_a).replace("[hypothesis]",text_b)
            text_list.append(text)
                
        if atom[j][0][0] == 3:
            selected_template = np.random.choice(if_template)
            text = selected_template.replace("[premise]",text_a).replace("[hypothesis]",text_b)
            text_list.append(text)
            
        if atom[j][0][0] == 7:
            selected_template = np.random.choice(fact_template)
            text = selected_template.replace("[fact]",text_a)
            text_list.append(text)
            
    extend_text.append(text_list)

In [259]:
# generate new test data
np.random.seed(42)
random.seed(42)
max_extend = 3

labelmap = {"a": 0, "b": 1, "c": 2, "d": 3}
extend_train_data = []
for index in range(2):
    for i,d in enumerate(train_data):
        new_d = {}
        if len(extend_text[i]) != 0:
            extended_text = [s.capitalize() for s in extend_text[i]]
            if len(extended_text)<=max_extend:   # maintain the maximum
                new_d['context'] = d['context'] + ' ' + ' '.join(extended_text)
            else:
                new_d['context'] = d['context'] + ' ' + ' '.join(np.random.choice(extended_text,max_extend,False))
            new_d['context'] = normalize_text(new_d['context'])
            new_d['question'] = d['question']
            new_d['answers'] = d['answers']
            new_d['qtype'] = d['qtype']
            new_d['id_string'] = 'test_'+str(i)
            extend_train_data.append(new_d)
len(extend_train_data)

964

### Test Split Equal

In [260]:
""" generate text from equal atoms """

""" define template """
causal_template = ["Because [cause], [result].", "Since [cause], [result].", "[result], because [result].", \
                  "Due to [cause], [result]."]

onlyif_template = ["Only if [premise], [hypothesis].", "[hypothesis], only if [premise].", "Only when [premise], [hypothesis]."]

if_template = ["If [premise], then [hypothesis].", "[hypothesis], if [premise].", "If [premise], [hypothesis].", \
              "As long as [premise], then [hypothesis]."]

fact_template = ["In fact, [fact].", "Therefore, [fact].", "Actually, [fact].", "So [fact]."]


def add_negation(s):
    text = word_tokenize(s) #分词
    for word,tag in pos_tag(text):
        if tag in ['VB','VBD','VBG','VBN','VBZ']:
            if word in ['is','was','are','were']:
                return s.replace(word, word+" not")
            else:    
                return s.replace(word,"not "+word)
    return "not " + s

""" fill in the template """
equal_text = []
for i, (atom, var, neg) in enumerate(zip(equal_atom_variable, equal_variable_tags, equal_negation_tags)):
    atom_text_list = []
    if len(atom) == 0:
        equal_text.append(atom_text_list)
    else:
        for k in range(len(atom)):   # iter for each equal logic of one atom
            text_list = []
            for j in range(len(atom[k])):   # iter for each sentence of one logic
                random_number = np.random.choice([n for n in range(len(variable_text_dict[i][var[k][j][0]]))])
                text_a = ' '.join(variable_text_dict[i][var[k][j][0]][random_number])  # default the first 
                neg_a = neg[k][j][0]
                if len(var[k][j]) == 2:
                    random_number = np.random.choice([n for n in range(len(variable_text_dict[i][var[k][j][1]]))])
                    text_b = ' '.join(variable_text_dict[i][var[k][j][1]][random_number])   # default the first
                    neg_b = neg[k][j][1]

                if neg[k][j][0]==0 and "not" in text_a:
                    text_a.replace("not ","")   # delete negation in text_a
                if len(neg[k][j])==2 and neg[k][j][1]==0 and "not" in text_b:
                    text_b.replace("not ","")   # delete negation in text_b
                if neg[k][j][0]==1 and "not" not in text_a:
                    text_a = add_negation(text_a)  # add negation in text_a
                if len(neg[k][j])==2 and neg[k][j][1]==1 and "not" not in text_b:
                    text_b = add_negation(text_b)   # add negation in text_b

                if atom[k][j][0][0] == 1:
                    selected_template = np.random.choice(causal_template)
                    text = selected_template.replace("[cause]",text_a).replace("[result]",text_b)
                    text_list.append(text)
                if atom[k][j][0][0] == 2:
                    selected_template = np.random.choice(onlyif_template)
                    text = selected_template.replace("[premise]",text_a).replace("[hypothesis]",text_b)
                    text_list.append(text)

                if atom[k][j][0][0] == 3:
                    selected_template = np.random.choice(if_template)
                    text = selected_template.replace("[premise]",text_a).replace("[hypothesis]",text_b)
                    text_list.append(text)

                if atom[k][j][0][0] == 7:
                    selected_template = np.random.choice(fact_template)
                    text = selected_template.replace("[fact]",text_a)
                    text_list.append(text)
            atom_text_list.append(text_list)
        equal_text.append(atom_text_list)

In [261]:
# generate new test data from equal text
np.random.seed(42)
random.seed(42)
max_limit = True  # max limited new atoms for each instance

equal_train_data = []
for i,d in enumerate(train_data):
    if len(equal_text[i]) != 0 and len(equal_text[i][0])!=0:
        if max_limit:   # max 3 new atoms for each instance
            select_index = np.random.choice(np.arange(len(equal_text[i])), min(3,len(equal_text[i])), replace=False) 
        else:
            select_index = np.arange(len(equal_text[i]))
        for j in select_index:
            new_d = {}
            temp_text = [s.capitalize() for s in equal_text[i][j]]
            new_d['context'] = normalize_text(' '.join(temp_text))
            new_d['question'] = d['question']
            new_d['answers'] = d['answers']
            new_d['qtype'] = d['qtype']
            new_d['id_string'] = 'test_' + str(i)
            equal_train_data.append(new_d)
len(equal_train_data)

1322

In [262]:
extra_test_data = extend_train_data+equal_train_data
with open("reclor_data/test_extra_data.json",'w') as file:
    json.dump(extra_test_data, file, indent=4)
len(extra_test_data)

2286

In [265]:
# obtain the ensemble results
def softmax(z):
    return np.exp(z)/sum(np.exp(z))

def select_best(x):
    from collections import Counter
    stat = Counter(x)
    most = stat.most_common()
    if most[0][1]-stat[x[0]]>2:
        return most[0][0]
    else:
        return x[0]
    

pred_label_extra = np.load("reclor_data/predictions_extra.npy")
pred_label = np.load("reclor_data/predictions.npy")

ensemble_results = []
for i in pred_label:
    ensemble_results.append([i])
for i,label_extra in enumerate(pred_label_extra):
    test_index = int(extra_test_data[i]['id_string'].split('_')[-1])
    ensemble_results[test_index] += [label_extra]

final_results = []
for r in ensemble_results:
    final_results.append(select_best(r))

for i in range(len(final_results)):
    if final_results[i] != pred_label[i]:
        print(final_results[i], pred_label[i])
        print(ensemble_results[i])
        print("========")
np.save("reclor_data/final_predictions.npy",final_results)

3 2
[2, 3, 3, 3, 3, 1]
2 3
[3, 2, 2, 2, 2, 2]
3 1
[1, 3, 0, 3, 3, 3]
3 1
[1, 3, 3, 3, 3, 3]
0 1
[1, 0, 0, 0, 3, 0]
1 2
[2, 1, 1, 1, 1, 0]
3 1
[1, 3, 3, 3, 3, 3]
3 0
[0, 3, 3, 3, 2, 3]
2 1
[1, 2, 2, 2, 2]
2 3
[3, 0, 2, 2, 2, 2]
1 3
[3, 1, 1, 1, 1]
2 3
[3, 2, 2, 2, 2]
3 1
[1, 3, 3, 3, 3, 3]
0 2
[2, 0, 0, 0, 0, 1]
3 0
[0, 3, 3, 3, 2, 3]
2 3
[3, 2, 0, 2, 2, 2]
2 3
[3, 2, 2, 2, 1, 2]
0 1
[1, 0, 0, 0, 0]
3 0
[0, 3, 3, 3, 3, 2]
3 0
[0, 3, 3, 3, 3, 2]
3 1
[1, 3, 3, 3, 3, 3]
1 2
[2, 1, 1, 0, 1, 1]
3 1
[1, 3, 3, 0, 3, 3]
2 3
[3, 2, 2, 2, 2, 2]
1 0
[0, 1, 1, 1, 1]
1 0
[0, 1, 1, 1, 1]
0 1
[1, 0, 2, 0, 0, 0]
0 3
[3, 0, 1, 0, 0, 0]
1 2
[2, 1, 1, 1, 1, 3]


# Filter Data on ReClor

In [196]:
labelmap2 = {0:"a",1:'b',2:'c',3:'d'}

# write extra data
extra_data = extend_train_data + equal_train_data
with open("reclor_data/extra_data_v4.json","w") as file:
    json.dump(extra_data, file, indent=4)
len(extend_train_data), len(equal_train_data), len(extra_data)

(4264, 4698, 8962)

In [217]:
# obtain the ensemble results
def softmax(z):
    return np.exp(z)/sum(np.exp(z))

def normalization(z):
    return (z-min(z))/(max(z)-min(z))

with open("reclor_data/extra_data_v4.json","r") as file:
    extra_data = json.load(file)
print(len(extra_data))
pred_prob = np.load("reclor_data/eval_predictions_prob.npy")
pred_label = np.load("reclor_data/eval_predictions.npy")

filtered_data = []
for p_prob, p_label, d in zip(pred_prob,pred_label,extra_data):
    if p_label==d['label'] and p_prob[p_label]>1:
        filtered_data.append(d)
len(filtered_data)

8962


6745

In [210]:
with open("reclor_data/filtered_data_1_v4.json",'w',encoding='utf-8') as file:
    json.dump(train_data + filtered_data, file, indent=4)

# Filter Data on LogiQA

In [101]:
labelmap2 = {0:"a",1:'b',2:'c',3:'d'}
def normalize_text(text):
    if text[0] == ".":
        text = text[1:]
    text = text.replace(". . .", ".")
    text = text.replace(". .", ".")
    text = text.replace(". \"", "\"")
    text = text.replace(".\"", "\"")
    text = text.replace("! \"", "\"")
    text = text.replace("!\"", "\"")
    text = text.replace(".).", ").")
    text = text.replace(". ).", ").")
    text = text.replace(".!!",".")
    text = text.replace(". !!", ".")
    text = text.replace("!\".","\".")
    text = text.replace("! \".","\".")
    
    return text

# write extra data
extra_data = extend_train_data + equal_train_data
with open("logiqa_data/extra_data_v3.txt","w",encoding='utf-8') as file:
    for d in extra_data:
        file.write('\n')
        file.write(labelmap2[d['label']]+'\n')
        file.write(normalize_text(d['context'])+'\n')
        file.write(normalize_text(d['question'])+'\n')
        for j,option in enumerate(['a.','b.','c.','d.']):
            file.write(option+normalize_text(d['answers'][j])+'\n')
# len(extra_data)

In [18]:
with open("logiqa_data/extra_data_v3.txt","r",encoding='utf-8') as file:
    lines = file.readlines()
n_examples = int(len(lines)//8)
extra_data = []
for i in range(n_examples):
    dataDict = {}
    dataDict['context'] = lines[i*8+2].strip()
    dataDict['question'] = lines[i*8+3].strip()
    dataDict['answers'] = [lines[i*8+j].strip()[2:].strip() for j in range(4,8)]
    dataDict['label'] = lines[i*8+1].strip()
    dataDict['qtype'] = ''
    extra_data.append(dataDict)
    
print(len(extra_data))
pred_prob = np.load("logiqa_data/eval_predictions_prob.npy")
pred_label = np.load("logiqa_data/eval_predictions.npy")

filtered_data = []
for p_prob, p_label, d in zip(pred_prob,pred_label,extra_data):
    if p_label==labelmap[d['label']] and p_prob[p_label]>0.5:
        filtered_data.append(d)
len(filtered_data)

4700


1658

In [163]:
new_data = train_data + filtered_data
with open("logiqa_data/filtered_data_0.5_v2.txt",'w',encoding='utf-8') as file:
    for d in new_data:
        file.write('\n')
        file.write(d['label']+'\n')
        file.write(normalize_text(d['context'])+'\n')
        file.write(normalize_text(d['question'])+'\n')
        for j,option in enumerate(['a.','b.','c.','d.']):
            file.write(option+normalize_text(d['answers'][j])+'\n')
len(new_data)

11661

# LogiQA 数据集修正

In [81]:
input_file = "logiqa_data/Train.txt"
with open(input_file, "r") as f:
    lines = f.readlines()
n_examples = int(len(lines) / 8)
n_examples

UnicodeDecodeError: 'gbk' codec can't decode byte 0x9a in position 1444: illegal multibyte sequence

# 对extend sample进行回译

In [80]:
new_train_data = equal_train_data.copy()
new_train_data_bt = new_train_data.copy()
from tqdm import tqdm
for i in tqdm(range(len(new_train_data_bt))):
    new_train_data_bt[i]['context'] = back_translate(new_train_data_bt[i]['context'], 'en','zh')

  1%|▍                                                                             | 18/3469 [01:39<5:18:13,  5.53s/it]


KeyboardInterrupt: 

In [112]:
# generate back translation new data
np.random.seed(42)
random.seed(42)

extend_train_data_bt = []
for i,d in enumerate(new_train_data_bt):
    new_d = {}
    new_d['context'] = d['context']
    new_d['question'] = d['question']
    gt_answer = d['answers'][d['label']]
    new_answers = d['answers'].copy()
    random.shuffle(new_answers)
    new_d['answers'] = new_answers
    new_d['label'] = new_d['answers'].index(gt_answer)
    new_d['qtype'] = d['qtype']
    new_d['id_string'] = ''
    extend_train_data_bt.append(new_d)

In [113]:
with open("reclor_data/train_qtype_extend_bt_da.json",'w') as file:
    json.dump(train_data+extend_train_data_bt, file, indent=4)

# 备份

In [None]:
def get_atom(trigger_id, trigger_argument, sentence, space_id, space_ids_list):
    """
        function: convert the sentence into the atom form
            [ ( <id>, <trigger> ), [ [<variable>], [<variable>], [<variable>] ] ]
        input: trigger_id, trigger_argument, sentence, space_id, space_ids_list
        output: atom form
    """
    atom = [(trigger_id, trigger_argument),]
    
    if trigger_id == 1:   # cause -->  result
        if 0 in space_ids_list[:space_id[0]]:   # trigger is not in the front
            if trigger_argument in ["because", "since", "due to", "because of"]:    # result + trigger + cause
                var_list = [sentence[space_id[1]:], sentence[space_id[0]:space_id[1]], sentence[0:space_id[0]]]
            else:       # cause + trigger + result
                var_list = [sentence[0:space_id[0]], sentence[space_id[0]:space_id[1]], sentence[space_id[1]:]]
        else:  # trigger in the front
            if trigger_argument in ["there fore", "t there fore", "thus", "so", "hence"]:   # only result, no cause
                var_list = [[], [trigger_argument], sentence[space_id[1]:]]
            else:    # trigger + cause + , + result
                comma_id = None
                for i in range(len(space_ids_list)):
                    if i>=space_id[1] and space_ids_list[i]==8:   # find the closest ',' in the sentence
                        comma_id = i
                        break
                if comma_id:
                    var_list = [sentence[space_id[1]:comma_id], [trigger_argument], sentence[comma_id+1:]]
                else:
                    var_list = [[], [trigger_argument], sentence[space_id[1]:]]
        atom.append(check_variables(var_list))

    elif trigger_id == 2:  # premise --->  hypothesis  (only if)
        if 0 in space_ids_list[:space_id[0]]:   # trigger is not in the front
            if trigger_argument == "unless" and "not" in sentence[0:space_id[0]]: #process "unless" and negation
                adjust_hypo = sentence[0:space_id[0]].copy()
                adjust_hypo.remove("not")
                var_list = [sentence[space_id[1]:], sentence[space_id[0]:space_id[1]], adjust_hypo]
            else:
                var_list = [sentence[space_id[1]:], sentence[space_id[0]:space_id[1]], sentence[0:space_id[0]]]
        else:   # trigger in the front
            comma_id = None
            for i in range(len(space_ids_list)):
                if i>=space_id[1] and space_ids_list[i]==8:   # find the closest ',' in the sentence
                    comma_id = i
                    break
            if comma_id:
                if trigger_argument == "unless" and "not" in sentence[0:space_id[0]]:#process "unless" and negation
                    adjust_hypo = sentence[comma_id+1:].copy()
                    adjust_hypo.remove("not")
                    var_list = [sentence[space_id[1]:comma_id], [trigger_argument], adjust_hypo]
                else:
                    var_list = [sentence[space_id[1]:comma_id], [trigger_argument], sentence[comma_id+1:]]
            else:
                var_list = [[], [trigger_argument], sentence[space_id[1]:]]
        atom.append(check_variables(var_list))
        
    elif trigger_id == 3:  # premise --->  hypothesis   (if)
        if 0 in space_ids_list[:space_id[0]]:   # trigger is not in the front
            if trigger_argument in ['if', 'once', 'as long as', 'as soon as']:
                var_list = [sentence[space_id[1]:], sentence[space_id[0]:space_id[1]], sentence[0:space_id[0]]]
            else:
                var_list = [sentence[0:space_id[0]], sentence[space_id[0]:space_id[1]], sentence[space_id[1]:]]
        else:   # trigger in the front
            if trigger_argument in ['if', 'once', 'as long as', 'as soon as']:
                comma_id = None
                for i in range(len(space_ids_list)):
                    if i>=space_id[1] and space_ids_list[i]==8:   # find the closest ',' in the sentence
                        comma_id = i
                        break
                if comma_id:
                    var_list = [sentence[space_id[1]:comma_id], [trigger_argument], sentence[comma_id+1:]]
                else:
                    var_list = [[], [trigger_argument], sentence[space_id[1]:]]
            else:
                print(trigger_argument)
        atom.append(check_variables(var_list))

        
    else:
        var_list = [sentence]
        atom.append(check_variables(var_list))
    return atom