In [None]:
import torch
import os
import re
import itertools

import json
import numpy as np
import pandas as pd
import phonlp
from underthesea import dependency_parse
from typing import List, Dict, Tuple
from itertools import chain
from vncorenlp import VnCoreNLP
from tqdm import tqdm
from difflib import SequenceMatcher

In [None]:
LENGTH_UNIT = ['mm', 'cm', 'dm', 'm', 'km']
APPROXIMATION = ['khoảng', 'xấp xỉ', 'gần', 'cỡ',]
NEGATION_WORD = ["không", "chưa", "chẳng", "chả", "khỏi", "đâu", "chả", "chớ"]
VNCORENLP_PATH = "/kaggle/working/vncorenlp/VnCoreNLP-1.2.jar"
PHONLP_PATH = '/kaggle/working/phonlp'
ACRONYM_PATH = '/kaggle/input/acronym-dict/acronym.json'

In [None]:
CLAIM_FILE = "..result/private_test_claim_annotation.txt"
RETRIEVAL_FILE = '../result/retrieval_result/private_test_retrieval_v1_top5_top_5.json'
EVIDENCE_FILE = '../result/private_test_retrieval_v1_top5_top_5_annotation.txt'
OUT_FILE = '../result/private_test_retrieval_v1_top5_top_5_checked.json'

In [None]:
vncorenlp_model = VnCoreNLP(VNCORENLP_PATH, max_heap_size='-Xmx16g')
phonlp_model = phonlp.load(save_dir=PHONLP_PATH).cuda()

In [None]:
class Annotator:
    def __init__(self, vncorenlp_path, phonlp_path):
        self.vncorenlp_model = VnCoreNLP(vncorenlp_path, max_heap_size='-Xmx16g')
        self.phonlp_model= phonlp.load(save_dir=phonlp_path)
    
    def annotate(self, text: str):
        """
        A function performs word segmentation and annotation.
        As phonlp and vncorenlp return batch output, we need
        to flatten it.

        Parameters
        ----------
        text: str

        Returns
        ----------
        annotation: List[List, List, List, List[List]]
            The lists are in the following order:
                - Segmented words
                - POS Tag
                - Named Entity Tag
                - Dependency Tag
        """
        segmented_input = " ".join([" ".join(x) for x in self.vncorenlp_model.tokenize(text)])
        batch_annotation = self.phonlp_model.annotate(text=segmented_input)
        
        annotation = np.squeeze(batch_annotation).tolist()
        annotation[1] = list(chain.from_iterable(annotation[1]))

        # phonlp_model.print_out(batch_annotation)

        return annotation

In [None]:
class Formatter:
    def __init__(self, LENGTH_UNIT):
        self.LENGTH_UNIT = LENGTH_UNIT
    
    def __preprocess_text(self, text: str) -> str:    
        text = re.sub(r"['\"\.\?:\-!]", " ", text)
        text = re.sub(r"\n", "", text)
        text = text.strip()
        text = text[1:].strip() if text[0]=='.' else text
        text = text[:-1].strip() if text[-1]=='.' else text
        text = " ".join(text.split())
        # text = text.lower()
        return text 
    
    
    def __clean(self, text: str) -> str:
        def split_number_with_unit(text: str) -> str:
            # Format 1m -> 1 m; 2dm -> 2 dm
            checkUnitPattern = [re.search(f"^[0-9]+{unit}$", text) is not None for unit in self.LENGTH_UNIT]
            isContainUnitPattern =  any(checkUnitPattern)

            if not isContainUnitPattern:
                return text

            unit_index = checkUnitPattern.index(True)

            return re.sub(f"{self.LENGTH_UNIT[unit_index]}", f" {self.LENGTH_UNIT[unit_index]}", text)
        
        # Split number with unit
        text = " ".join([split_number_with_unit(word) for word in text.split()])
        return text
    
    def __normalize(self, text: str) -> str: 
        # Convertword to number, except year casepreprocess_text
        # TODO
        return text
    
    def format_text(self, text: str) -> str:
        text = self.__preprocess_text(text)
        text = self.__clean(text)
        text = self.__normalize(text)
        return text

In [None]:
class Checker:
    def __init__(self, annotator, formatter, acronym_path):
        self.formatter = formatter
        self.annotator = annotator
        self.acronym_dict = json.load(open(acronym_path))
    
    def __similar(self, a, b):
        return SequenceMatcher(None, a, b).ratio()
        
    def __check_same_length(self, annotation_1, annotation_2):        
        def get_different_index_from_two_list(list_1, list_2):
            if len(list_1) != len(list_2):
                raise Exception("Only accept 2 lists with the same length")
            return [x for x in range(len(list_1)) if list_1[x]!=list_2[x]]
        
        def isArcronym(word_1, word_2):
            if word_1 in self.acronym_dict.keys():
                if word_2 in self.acronym_dict[word_1]: 
                    return True
            else:
                return False
        
        different_index = get_different_index_from_two_list(
            annotation_1[0],
            annotation_2[0],
        )

        check_list = list()
        for index in different_index:
            if (
                annotation_1[1][index] == annotation_2[1][index] == 'Np' and 
                annotation_1[0][index] not in annotation_2[0] and 
                annotation_2[0][index] not in annotation_1[0]
            ):
                print("PROPER NOUND")
                check_list.append('not_equal')
            elif (annotation_1[1][index] == annotation_2[1][index] == 'M' and 
                annotation_1[0][index] not in annotation_2[0] and 
                annotation_2[0][index] not in annotation_1[0]
            ):
                print("NUMBER")
                check_list.append('not_equal')
            elif annotation_1[1][index] == annotation_2[1][index] == 'Nu':
                print("UNIT")
                check_list.append('not_equal')
            elif isArcronym(annotation_1[0][index], annotation_2[0][index]):
                print("ACRONYM")
                check_list.append('not_equal')

        return 'not_equal' if 'not_equal' in check_list else 'unknown'
    
    def __check_negation(self, annotation_1, annotation_2):
        def list_subtraction(list_1, list_2):
            ret_list = list_2.copy()
            for word in list_1:
                if word in ret_list:
                    ret_list.remove(word)
            return ret_list
        
        def check_contain_negation(list_1, list_2):
            list_1_minus_2 = list_subtraction(list_1, list_2)
            list_2_minus_1 = list_subtraction(list_2, list_1)
            
            print(list_1_minus_2)
            print(list_2_minus_1)
            
            is_list_1_contain_negation_word =  all([x in NEGATION_WORD for x in list_1_minus_2]) and len(list_1_minus_2) != 0
            is_list_2_contain_negation_word =  all([x in NEGATION_WORD for x in list_2_minus_1]) and len(list_2_minus_1) != 0
            
            if (is_list_1_contain_negation_word ^ is_list_2_contain_negation_word):
                return True

            return False
    
        if check_contain_negation(annotation_1[0],annotation_2[0]):
            print("NEGATION")
            return 'not_equal'
        
        return 'unkown'
    
    def __check_interval(self):
        # TODO
        return 'unknown'
            
    def check(self, text_1=None, text_2=None, annotation_1=None, annotation_2=None):
        # try:
        # Format inputs
        if annotation_1==None and annotation_2==None:
            text_1 = self.formatter.format_text(text_1)
            text_2 = self.formatter.format_text(text_2)

            # Create Text Annotation
            annotation_1 = self.annotator.annotate(text_1)
            annotation_2 = self.annotator.annotate(text_2)
        else:
            text_1 = " ".join(annotation_1[0])
            text_2 = " ".join(annotation_2[0])

        # Case 1: Return if 2 sentences are exactly the same
        if text_1 == text_2:
            print("EXACT MATCH")
            return "equal"

        ret_list = list()
        # Case 2: Same length
        if (
            (type(annotation_1[0])==list and type(annotation_2[0])==list) and
            len(annotation_1[0]) == len(annotation_2[0]) and # Check same length
            self.__similar(text_1, text_2) >=0.5 and
            self.__check_same_length(annotation_1, annotation_2) == 'not_equal' # Check smaller cases
        ):
            return 'not_equal'

        # Case 3: Negation
        if annotation_1[0][0] == annotation_2[0][0] and self.__check_negation(annotation_1, annotation_2) == 'not_equal':
            return 'not_equal'

        if self.__check_interval() == 'not_equal':
            return 'not_equal'

        return "unkown"
        # except Exception as error:
            # print(error)
            # return "unkown"        

In [None]:
formatter = Formatter(LENGTH_UNIT)

In [None]:
data = json.load(open(RETRIEVAL_FILE))
# Get Evidence
claim = [data[key]['claim'] for key in data.keys()]
if type(claim[0]) == list:
    claim = list(itertools.chain(*claim))
# Word Segmentation
print("Word Segmentating")
claim = [formatter.format_text(x) for x in claim]
claim = [" ".join([" ".join(x) for x in vncorenlp_model.tokenize(text)]) for text in claim]
# Create Text Files
with open('temp.txt', 'w') as f:
    f.write('\n'.join(claim))
# Annotation
print("Annotating")
phonlp_model.annotate(input_file='temp.txt', output_file=CLAIM_FILE, batch_size=8)
torch.cuda.empty_cache()

In [None]:
def batch_inference(file_name):
    # Create output filename
    output_file = file_name.split('/')[-1].split('.')[0] + "_annotation.txt"
    # Read Json File
    data = json.load(open(file_name))
    # Get Evidence
    evidence = [data[key]['evidence'] for key in data.keys()]
    if type(evidence[0]) == list:
        print(len(list(itertools.chain(*evidence))))
        evidence = list(itertools.chain(*evidence))
    # Word Segmentation
    print("Word Segmentating")
    evidence = [formatter.format_text(x) for x in evidence]
    evidence = [" ".join([" ".join(x) for x in vncorenlp_model.tokenize(text)]) for text in evidence]
    print(len(evidence))
    # Create Text Files
    with open('temp.txt', 'w') as f:
        f.write('\n'.join(evidence))
    # Annotation
    print("Annotating")
    phonlp_model.annotate(input_file='temp.txt', output_file=output_file, batch_size=8)
    torch.cuda.empty_cache()

In [None]:
batch_inference(RETRIEVAL_FILE)

In [None]:
def create_annotation_format(annotation_list):
    index = [x[0] for x in annotation_list]
    word = [x[1] for x in annotation_list]
    pos = [x[2] for x in annotation_list] 
    ner = [x[3] for x in annotation_list]
    dp_index = [x[4] for x in annotation_list]
    dp_tag = [x[5] for x in annotation_list]
    dp = [[dp_index[x], dp_tag[x]] for x in range(len(dp_index))]
    
    return [word, pos, ner, dp]

In [None]:
def load_annotation_single_evidence(file_name):
    f = open(file_name)
    annotation = list()
    temp = list()
    for line in f:
        if line != "\n":
            text = line.replace("\n", "").split("\t")
            if text[0] == '1' and temp!=list():
                annotation.append(temp)
                temp = list()
            temp.append(text)
            
    annotation.append(temp)
    annotation = [create_annotation_format(x) for x in annotation]
    return annotation

In [None]:
def load_annotation_multiple_evidence(file_name, original_evidence):
    f = open(file_name).readlines()
    annotation = list()
    temp = list()
    for line in f:
        if line != "\n":
            text = line.replace("\n", "").split("\t")
            if text[0] == '1' and temp!=list():
                annotation.append(temp)
                temp = list()   
            temp.append(text)
            
    annotation.append(temp)
    start = 0
    new_annotation = list()
    for i in range(len(original_evidence)):
        length = len(original_evidence[i])
        end = start + length
        if i!= len(original_evidence)-1:
            new_annotation.append([create_annotation_format(x) for x in annotation[start:end]])
        else:
            new_annotation.append([create_annotation_format(x) for x in annotation[start:]])
        start = end
   
    return new_annotation

In [None]:
data = json.load(open(RETRIEVAL_FILE))
evidence = [data[key]['evidence'] for key in data.keys()]
claim_annotation = load_annotation_single_evidence(CLAIM_FILE)
evidence_annotation = load_annotation_multiple_evidence(EVIDENCE_FILE, evidence)

In [None]:
checker = Checker(None, None, ACRONYM_PATH)

In [None]:
check_result = [
    [checker.check(annotation_1=claim_annotation[sample_index], annotation_2=evidence_annotation[sample_index][evidence_index]) for evidence_index in range(len(evidence_annotation[sample_index]))] 
    for sample_index in range(len(claim_annotation))
] 

In [None]:
def return_verdict(text):
    return "REFUTED" if text == 'not_equal' else "SUPPORTED" if text == 'equal' else ''

In [None]:
for idx, key in enumerate(list(data.keys())):
    data[key]['verdict'] = [return_verdict(x) for x in check_result[idx]]

In [None]:
with open(OUT_FILE, "w") as outfile:
    json.dump(data, outfile)