In [1]:
########################################
# Baran: The Error Correction System
# Mohammad Mahdavi
# moh.mahdavi.l@gmail.com
# April 2019
# Big Data Management Group
# TU Berlin
# All Rights Reserved
########################################


########################################
import os
import re
import io
import sys
import math
import argparse
import pandas as pd
import json
import time
import shutil
import pickle
import difflib
import unicodedata
import multiprocessing

import bs4
import bz2
import numpy
import py7zr
import mwparserfromhell
import sklearn.svm
import sklearn.ensemble
import sklearn.naive_bayes
import sklearn.linear_model

import raha
import warnings
warnings.filterwarnings("ignore")

from detection import Detection
import signal
from datetime import datetime

def check_string(string: str):
    if re.search(r"-inner_error-", string):
        return "-inner_error-" + string[-6:-4]
    elif re.search(r"-outer_error-", string):
        return "-outer_error-" + string[-6:-4]
    elif re.search(r"-inner_outer_error-", string):
        return "-inner_outer_error-" + string[-6:-4]
    elif re.search(r"-dirty-original_error-", string):
        return "-original_error-" + string[-9:-4]

def handler(signum, frame):
    raise TimeoutError("Time exceeded")
########################################


########################################
class Correction:
    """
    The main class.
    """

    def __init__(self):
        """
        The constructor.
        """
        self.PRETRAINED_VALUE_BASED_MODELS_PATH = ""
        self.VALUE_ENCODINGS = ["identity", "unicode"]
        self.CLASSIFICATION_MODEL = "ABC"   # ["ABC", "DTC", "GBC", "GNB", "KNC" ,"SGDC", "SVC"]
        self.IGNORE_SIGN = "<<<IGNORE_THIS_VALUE>>>"
        self.VERBOSE = False
        self.SAVE_RESULTS = False
        self.ONLINE_PHASE = False
        self.LABELING_BUDGET = 20
        self.MIN_CORRECTION_CANDIDATE_PROBABILITY = 0.0
        self.MIN_CORRECTION_OCCURRENCE = 2
        self.MAX_VALUE_LENGTH = 50
        self.REVISION_WINDOW_SIZE = 5

    @staticmethod
    def _wikitext_segmenter(wikitext):
        """
        This method takes a Wikipedia page revision text in wikitext and segments it recursively.
        """
        def recursive_segmenter(node):
            if isinstance(node, str):
                segments_list.append(node)
            elif isinstance(node, mwparserfromhell.nodes.text.Text):
                segments_list.append(node.value)
            elif not node:
                pass
            elif isinstance(node, mwparserfromhell.wikicode.Wikicode):
                for n in node.nodes:
                    if isinstance(n, str):
                        recursive_segmenter(n)
                    elif isinstance(n, mwparserfromhell.nodes.text.Text):
                        recursive_segmenter(n.value)
                    elif isinstance(n, mwparserfromhell.nodes.heading.Heading):
                        recursive_segmenter(n.title)
                    elif isinstance(n, mwparserfromhell.nodes.tag.Tag):
                        recursive_segmenter(n.contents)
                    elif isinstance(n, mwparserfromhell.nodes.wikilink.Wikilink):
                        if n.text:
                            recursive_segmenter(n.text)
                        else:
                            recursive_segmenter(n.title)
                    elif isinstance(n, mwparserfromhell.nodes.external_link.ExternalLink):
                        # recursive_parser(n.url)
                        recursive_segmenter(n.title)
                    elif isinstance(n, mwparserfromhell.nodes.template.Template):
                        recursive_segmenter(n.name)
                        for p in n.params:
                            # recursive_parser(p.name)
                            recursive_segmenter(p.value)
                    elif isinstance(n, mwparserfromhell.nodes.html_entity.HTMLEntity):
                        segments_list.append(n.normalize())
                    elif not n or isinstance(n, mwparserfromhell.nodes.comment.Comment) or \
                            isinstance(n, mwparserfromhell.nodes.argument.Argument):
                        pass
                    else:
                        sys.stderr.write("Inner layer unknown node found: {}, {}\n".format(type(n), n))
            else:
                sys.stderr.write("Outer layer unknown node found: {}, {}\n".format(type(node), node))

        try:
            parsed_wikitext = mwparserfromhell.parse(wikitext)
        except:
            parsed_wikitext = ""
        segments_list = []
        recursive_segmenter(parsed_wikitext)
        return segments_list

    def extract_revisions(self, wikipedia_dumps_folder):
        """
        This method takes the folder path of Wikipedia page revision history dumps and extracts the value-based corrections.
        """
        rd_folder_path = os.path.join(wikipedia_dumps_folder, "revision-data")
        if not os.path.exists(rd_folder_path):
            os.mkdir(rd_folder_path)
        compressed_dumps_list = [df for df in os.listdir(wikipedia_dumps_folder) if df.endswith(".7z")]
        page_counter = 0
        for file_name in compressed_dumps_list:
            compressed_dump_file_path = os.path.join(wikipedia_dumps_folder, file_name)
            dump_file_name, _ = os.path.splitext(os.path.basename(compressed_dump_file_path))
            rdd_folder_path = os.path.join(rd_folder_path, dump_file_name)
            if not os.path.exists(rdd_folder_path):
                os.mkdir(rdd_folder_path)
            else:
                continue
            archive = py7zr.SevenZipFile(compressed_dump_file_path, mode="r")
            archive.extractall(path=wikipedia_dumps_folder)
            archive.close()
            decompressed_dump_file_path = os.path.join(wikipedia_dumps_folder, dump_file_name)
            decompressed_dump_file = io.open(decompressed_dump_file_path, "r", encoding="utf-8")
            page_text = ""
            for i, line in enumerate(decompressed_dump_file):
                line = line.strip()
                if line == "<page>":
                    page_text = ""
                page_text += "\n" + line
                if line == "</page>":
                    revisions_list = []
                    page_tree = bs4.BeautifulSoup(page_text, "html.parser")
                    previous_text = ""
                    for revision_tag in page_tree.find_all("revision"):
                        revision_text = revision_tag.find("text").text
                        if previous_text:
                            a = [t for t in self._wikitext_segmenter(previous_text) if t]
                            b = [t for t in self._wikitext_segmenter(revision_text) if t]
                            s = difflib.SequenceMatcher(None, a, b)
                            for tag, i1, i2, j1, j2 in s.get_opcodes():
                                if tag == "equal":
                                    continue
                                revisions_list.append({
                                    "old_value": a[i1:i2],
                                    "new_value": b[j1:j2],
                                    "left_context": a[i1 - self.REVISION_WINDOW_SIZE:i1],
                                    "right_context": a[i2:i2 + self.REVISION_WINDOW_SIZE]
                                })
                        previous_text = revision_text
                    if revisions_list:
                        page_counter += 1
                        if self.VERBOSE and page_counter % 100 == 0:
                            for entry in revisions_list:
                                print("----------Page Counter:---------\n", page_counter,
                                      "\n----------Old Value:---------\n", entry["old_value"],
                                      "\n----------New Value:---------\n", entry["new_value"],
                                      "\n----------Left Context:---------\n", entry["left_context"],
                                      "\n----------Right Context:---------\n", entry["right_context"],
                                      "\n==============================")
                        json.dump(revisions_list, open(os.path.join(rdd_folder_path, page_tree.id.text + ".json"), "w"))
            decompressed_dump_file.close()
            os.remove(decompressed_dump_file_path)
            if self.VERBOSE:
                print("{} ({} / {}) is processed.".format(file_name, len(os.listdir(rd_folder_path)), len(compressed_dumps_list)))

    @staticmethod
    def _value_encoder(value, encoding):
        """
        This method represents a value with a specified value abstraction encoding method.
        """
        if encoding == "identity":
            return json.dumps(list(value))
        if encoding == "unicode":
            return json.dumps([unicodedata.category(c) for c in value])

    @staticmethod
    def _to_model_adder(model, key, value):
        """
        This methods incrementally adds a key-value into a dictionary-implemented model.
        """
        if key not in model:
            model[key] = {}
        if value not in model[key]:
            model[key][value] = 0.0
        model[key][value] += 1.0

    def _value_based_models_updater(self, models, ud):
        """
        This method updates the value-based error corrector models with a given update dictionary.
        """
        # TODO: adding jabeja konannde bakhshahye substring
        if self.ONLINE_PHASE or (ud["new_value"] and len(ud["new_value"]) <= self.MAX_VALUE_LENGTH and
                                 ud["old_value"] and len(ud["old_value"]) <= self.MAX_VALUE_LENGTH and
                                 ud["old_value"] != ud["new_value"] and ud["old_value"].lower() != "n/a" and
                                 not ud["old_value"][0].isdigit()):
            remover_transformation = {}
            adder_transformation = {}
            replacer_transformation = {}
            s = difflib.SequenceMatcher(None, ud["old_value"], ud["new_value"])
            for tag, i1, i2, j1, j2 in s.get_opcodes():
                index_range = json.dumps([i1, i2])
                if tag == "delete":
                    remover_transformation[index_range] = ""
                if tag == "insert":
                    adder_transformation[index_range] = ud["new_value"][j1:j2]
                if tag == "replace":
                    replacer_transformation[index_range] = ud["new_value"][j1:j2]
            for encoding in self.VALUE_ENCODINGS:
                encoded_old_value = self._value_encoder(ud["old_value"], encoding)
                if remover_transformation:
                    self._to_model_adder(models[0], encoded_old_value, json.dumps(remover_transformation))
                if adder_transformation:
                    self._to_model_adder(models[1], encoded_old_value, json.dumps(adder_transformation))
                if replacer_transformation:
                    self._to_model_adder(models[2], encoded_old_value, json.dumps(replacer_transformation))
                self._to_model_adder(models[3], encoded_old_value, ud["new_value"])

    def pretrain_value_based_models(self, revision_data_folder):
        """
        This method pretrains value-based error corrector models.
        """
        def _models_pruner():
            for mi, model in enumerate(models):
                for k in list(model.keys()):
                    for v in list(model[k].keys()):
                        if model[k][v] < self.MIN_CORRECTION_OCCURRENCE:
                            models[mi][k].pop(v)
                    if not models[mi][k]:
                        models[mi].pop(k)

        models = [{}, {}, {}, {}]
        rd_folder_path = revision_data_folder
        page_counter = 0
        for folder in os.listdir(rd_folder_path):
            if os.path.isdir(os.path.join(rd_folder_path, folder)):
                for rf in os.listdir(os.path.join(rd_folder_path, folder)):
                    if rf.endswith(".json"):
                        page_counter += 1
                        if page_counter % 100000 == 0:
                            _models_pruner()
                            if self.VERBOSE:
                                print(page_counter, "pages are processed.")
                        try:
                            revision_list = json.load(io.open(os.path.join(rd_folder_path, folder, rf), encoding="utf-8"))
                        except:
                            continue
                        for rd in revision_list:
                            update_dictionary = {
                                "old_value": raha.dataset.Dataset.value_normalizer("".join(rd["old_value"])),
                                "new_value": raha.dataset.Dataset.value_normalizer("".join(rd["new_value"]))
                            }
                            self._value_based_models_updater(models, update_dictionary)
        _models_pruner()
        pretrained_models_path = os.path.join(revision_data_folder, "pretrained_value_based_models.dictionary")
        if self.PRETRAINED_VALUE_BASED_MODELS_PATH:
            pretrained_models_path = self.PRETRAINED_VALUE_BASED_MODELS_PATH
        pickle.dump(models, bz2.BZ2File(pretrained_models_path, "wb"))
        if self.VERBOSE:
            print("The pretrained value-based models are stored in {}.".format(pretrained_models_path))

    def _vicinity_based_models_updater(self, models, ud):
        """
        This method updates the vicinity-based error corrector models with a given update dictionary.
        """
        for j, cv in enumerate(ud["vicinity"]):
            if cv != self.IGNORE_SIGN:
                self._to_model_adder(models[j][ud["column"]], cv, ud["new_value"])

    def _domain_based_model_updater(self, model, ud):
        """
        This method updates the domain-based error corrector model with a given update dictionary.
        """
        self._to_model_adder(model, ud["column"], ud["new_value"])

    def _value_based_corrector(self, models, ed):
        """
        This method takes the value-based models and an error dictionary to generate potential value-based corrections.
        """
        results_list = []
        for m, model_name in enumerate(["remover", "adder", "replacer", "swapper"]):
            model = models[m]
            for encoding in self.VALUE_ENCODINGS:
                results_dictionary = {}
                encoded_value_string = self._value_encoder(ed["old_value"], encoding)
                if encoded_value_string in model:
                    sum_scores = sum(model[encoded_value_string].values())
                    if model_name in ["remover", "adder", "replacer"]:
                        for transformation_string in model[encoded_value_string]:
                            index_character_dictionary = {i: c for i, c in enumerate(ed["old_value"])}
                            transformation = json.loads(transformation_string)
                            for change_range_string in transformation:
                                change_range = json.loads(change_range_string)
                                if model_name in ["remover", "replacer"]:
                                    for i in range(change_range[0], change_range[1]):
                                        index_character_dictionary[i] = ""
                                if model_name in ["adder", "replacer"]:
                                    ov = "" if change_range[0] not in index_character_dictionary else \
                                        index_character_dictionary[change_range[0]]
                                    index_character_dictionary[change_range[0]] = transformation[change_range_string] + ov
                            new_value = ""
                            for i in range(len(index_character_dictionary)):
                                new_value += index_character_dictionary[i]
                            pr = model[encoded_value_string][transformation_string] / sum_scores
                            if pr >= self.MIN_CORRECTION_CANDIDATE_PROBABILITY:
                                results_dictionary[new_value] = pr
                    if model_name == "swapper":
                        for new_value in model[encoded_value_string]:
                            pr = model[encoded_value_string][new_value] / sum_scores
                            if pr >= self.MIN_CORRECTION_CANDIDATE_PROBABILITY:
                                results_dictionary[new_value] = pr
                results_list.append(results_dictionary)
        return results_list

    def _vicinity_based_corrector(self, models, ed):
        """
        This method takes the vicinity-based models and an error dictionary to generate potential vicinity-based corrections.
        """
        results_list = []
        for j, cv in enumerate(ed["vicinity"]):
            results_dictionary = {}
            if j != ed["column"] and cv in models[j][ed["column"]]:
                sum_scores = sum(models[j][ed["column"]][cv].values())
                for new_value in models[j][ed["column"]][cv]:
                    pr = models[j][ed["column"]][cv][new_value] / sum_scores
                    if pr >= self.MIN_CORRECTION_CANDIDATE_PROBABILITY:
                        results_dictionary[new_value] = pr
            results_list.append(results_dictionary)
        return results_list

    def _domain_based_corrector(self, model, ed):
        """
        This method takes a domain-based model and an error dictionary to generate potential domain-based corrections.
        """
        results_dictionary = {}
        sum_scores = sum(model[ed["column"]].values())
        for new_value in model[ed["column"]]:
            pr = model[ed["column"]][new_value] / sum_scores
            if pr >= self.MIN_CORRECTION_CANDIDATE_PROBABILITY:
                results_dictionary[new_value] = pr
        return [results_dictionary]

    def initialize_dataset(self, d):
        """
        This method initializes the dataset.
        """
        self.ONLINE_PHASE = True
        d.results_folder = os.path.join(os.path.dirname(d.path), "raha-baran-results-" + d.name)
        if self.SAVE_RESULTS and not os.path.exists(d.results_folder):
            os.mkdir(d.results_folder)
        d.column_errors = {}
        for cell in d.detected_cells:
            self._to_model_adder(d.column_errors, cell[1], cell)
        d.labeled_tuples = {} if not hasattr(d, "labeled_tuples") else d.labeled_tuples
        d.labeled_cells = {} if not hasattr(d, "labeled_cells") else d.labeled_cells
        d.corrected_cells = {} if not hasattr(d, "corrected_cells") else d.corrected_cells
        return d

    def initialize_models(self, d):
        """
        This method initializes the error corrector models.
        """
        d.value_models = [{}, {}, {}, {}]
        if os.path.exists(self.PRETRAINED_VALUE_BASED_MODELS_PATH):
            d.value_models = pickle.load(bz2.BZ2File(self.PRETRAINED_VALUE_BASED_MODELS_PATH, "rb"))
            if self.VERBOSE:
                print("The pretrained value-based models are loaded.")
        d.vicinity_models = {j: {jj: {} for jj in range(d.dataframe.shape[1])} for j in range(d.dataframe.shape[1])}
        d.domain_models = {}
        for row in d.dataframe.itertuples():
            i, row = row[0], row[1:]
            vicinity_list = [cv if (i, cj) not in d.detected_cells else self.IGNORE_SIGN for cj, cv in enumerate(row)]
            for j, value in enumerate(row):
                if (i, j) not in d.detected_cells:
                    temp_vicinity_list = list(vicinity_list)
                    temp_vicinity_list[j] = self.IGNORE_SIGN
                    update_dictionary = {
                        "column": j,
                        "new_value": value,
                        "vicinity": temp_vicinity_list
                    }
                    self._vicinity_based_models_updater(d.vicinity_models, update_dictionary)
                    self._domain_based_model_updater(d.domain_models, update_dictionary)
        if self.VERBOSE:
            print("The error corrector models are initialized.")

    def sample_tuple(self, d):
        """
        This method samples a tuple.
        """
        remaining_column_erroneous_cells = {}
        remaining_column_erroneous_values = {}
        for j in d.column_errors:
            for cell in d.column_errors[j]:
                if cell not in d.corrected_cells:
                    self._to_model_adder(remaining_column_erroneous_cells, j, cell)
                    self._to_model_adder(remaining_column_erroneous_values, j, d.dataframe.iloc[cell])
        tuple_score = numpy.ones(d.dataframe.shape[0])
        tuple_score[list(d.labeled_tuples.keys())] = 0.0
        for j in remaining_column_erroneous_cells:
            for cell in remaining_column_erroneous_cells[j]:
                value = d.dataframe.iloc[cell]
                column_score = math.exp(len(remaining_column_erroneous_cells[j]) / len(d.column_errors[j]))
                cell_score = math.exp(remaining_column_erroneous_values[j][value] / len(remaining_column_erroneous_cells[j]))
                tuple_score[cell[0]] *= column_score * cell_score
        d.sampled_tuple = numpy.random.choice(numpy.argwhere(tuple_score == numpy.amax(tuple_score)).flatten())
        if self.VERBOSE:
            print("Tuple {} is sampled.".format(d.sampled_tuple))

    def label_with_ground_truth(self, d):
        """
        This method labels a tuple with ground truth.
        """
        d.labeled_tuples[d.sampled_tuple] = 1
        for j in range(d.dataframe.shape[1]):
            cell = (d.sampled_tuple, j)
            error_label = 0
            if d.dataframe.iloc[cell] != d.clean_dataframe.iloc[cell]:
                error_label = 1
            d.labeled_cells[cell] = [error_label, d.clean_dataframe.iloc[cell]]
        if self.VERBOSE:
            print("Tuple {} is labeled.".format(d.sampled_tuple))

    def update_models(self, d):
        """
        This method updates the error corrector models with a new labeled tuple.
        """
        cleaned_sampled_tuple = [d.labeled_cells[(d.sampled_tuple, j)][1] for j in range(d.dataframe.shape[1])]
        for j in range(d.dataframe.shape[1]):
            cell = (d.sampled_tuple, j)
            update_dictionary = {
                "column": cell[1],
                "old_value": d.dataframe.iloc[cell],
                "new_value": cleaned_sampled_tuple[j],
            }
            if d.labeled_cells[cell][0] == 1:
                if cell not in d.detected_cells:
                    d.detected_cells[cell] = self.IGNORE_SIGN
                    self._to_model_adder(d.column_errors, cell[1], cell)
                self._value_based_models_updater(d.value_models, update_dictionary)
                self._domain_based_model_updater(d.domain_models, update_dictionary)
                update_dictionary["vicinity"] = [cv if j != cj else self.IGNORE_SIGN
                                                 for cj, cv in enumerate(cleaned_sampled_tuple)]
            else:
                update_dictionary["vicinity"] = [cv if j != cj and d.labeled_cells[(d.sampled_tuple, cj)][0] == 1
                                                 else self.IGNORE_SIGN for cj, cv in enumerate(cleaned_sampled_tuple)]
            self._vicinity_based_models_updater(d.vicinity_models, update_dictionary)
        if self.VERBOSE:
            print("The error corrector models are updated with new labeled tuple {}.".format(d.sampled_tuple))

    def _feature_generator_process(self, args):
        """
        This method generates features for each data column in a parallel process.
        """
        d, cell = args
        error_dictionary = {"column": cell[1], "old_value": d.dataframe.iloc[cell], "vicinity": list(d.dataframe.iloc[cell[0], :])}
        value_corrections = self._value_based_corrector(d.value_models, error_dictionary)
        vicinity_corrections = self._vicinity_based_corrector(d.vicinity_models, error_dictionary)
        domain_corrections = self._domain_based_corrector(d.domain_models, error_dictionary)
        models_corrections = value_corrections + vicinity_corrections + domain_corrections
        corrections_features = {}
        for mi, model in enumerate(models_corrections):
            for correction in model:
                if correction not in corrections_features:
                    corrections_features[correction] = numpy.zeros(len(models_corrections))
                corrections_features[correction][mi] = model[correction]
        return corrections_features

    def generate_features(self, d):
        """
        This method generates a feature vector for each pair of a data error and a potential correction.
        """
        d.pair_features = {}
        pairs_counter = 0
        process_args_list = [[d, cell] for cell in d.detected_cells]
        pool = multiprocessing.Pool()
        feature_generation_results = pool.map(self._feature_generator_process, process_args_list)
        pool.close()
        for ci, corrections_features in enumerate(feature_generation_results):
            cell = process_args_list[ci][1]
            d.pair_features[cell] = {}
            for correction in corrections_features:
                d.pair_features[cell][correction] = corrections_features[correction]
                pairs_counter += 1
        if self.VERBOSE:
            print("{} pairs of (a data error, a potential correction) are featurized.".format(pairs_counter))

    def predict_corrections(self, d):
        """
        This method predicts
        """
        
        for j in d.column_errors:
            x_train = []
            y_train = []
            x_test = []
            test_cell_correction_list = []
            for k, cell in enumerate(d.column_errors[j]):
                if cell in d.pair_features:
                    candidates_set = []
                    actual_errors = d.clean_dataframe
                    for correction in d.pair_features[cell]:
                        candidates_set.append(correction)
                        if cell in d.labeled_cells and d.labeled_cells[cell][0] == 1:
                            x_train.append(d.pair_features[cell][correction])
                            y_train.append(int(correction == d.labeled_cells[cell][1]))
                            d.corrected_cells[cell] = d.labeled_cells[cell][1]
                        else:
                            x_test.append(d.pair_features[cell][correction])
                            test_cell_correction_list.append([cell, correction])
                    if d.clean_dataframe.iloc[cell[0], cell[1]] in candidates_set:
                        clean_in_cands.append(cell)
            if self.CLASSIFICATION_MODEL == "ABC":
                classification_model = sklearn.ensemble.AdaBoostClassifier(n_estimators=100)
            if self.CLASSIFICATION_MODEL == "DTC":
                classification_model = sklearn.tree.DecisionTreeClassifier(criterion="gini")
            if self.CLASSIFICATION_MODEL == "GBC":
                classification_model = sklearn.ensemble.GradientBoostingClassifier(n_estimators=100)
            if self.CLASSIFICATION_MODEL == "GNB":
                classification_model = sklearn.naive_bayes.GaussianNB()
            if self.CLASSIFICATION_MODEL == "KNC":
                classification_model = sklearn.neighbors.KNeighborsClassifier(n_neighbors=1)
            if self.CLASSIFICATION_MODEL == "SGDC":
                classification_model = sklearn.linear_model.SGDClassifier(loss="hinge", penalty="l2")
            if self.CLASSIFICATION_MODEL == "SVC":
                classification_model = sklearn.svm.SVC(kernel="sigmoid")
            if x_train and x_test:
                if sum(y_train) == 0:
                    predicted_labels = numpy.zeros(len(x_test))
                elif sum(y_train) == len(y_train):
                    predicted_labels = numpy.ones(len(x_test))
                else:
                    classification_model.fit(x_train, y_train)
                    predicted_labels = classification_model.predict(x_test)
                # predicted_probabilities = classification_model.predict_proba(x_test)
                # correction_confidence = {}
                for index, predicted_label in enumerate(predicted_labels):
                    cell, predicted_correction = test_cell_correction_list[index]
                    # confidence = predicted_probabilities[index][1]
                    if predicted_label:
                        # if cell not in correction_confidence or confidence > correction_confidence[cell]:
                        #     correction_confidence[cell] = confidence
                        d.corrected_cells[cell] = predicted_correction
        if self.VERBOSE:
            print("{:.0f}% ({} / {}) of data errors are corrected.".format(100 * len(d.corrected_cells) / len(d.detected_cells),
                                                                           len(d.corrected_cells), len(d.detected_cells)))

    def store_results(self, d):
        """
        This method stores the results.
        """
        ec_folder_path = os.path.join(d.results_folder, "error-correction")
        if not os.path.exists(ec_folder_path):
            os.mkdir(ec_folder_path)
        pickle.dump(d, open(os.path.join(ec_folder_path, "correction.dataset"), "wb"))
        if self.VERBOSE:
            print("The results are stored in {}.".format(os.path.join(ec_folder_path, "correction.dataset")))

    def run(self, d):
        """
        This method runs Baran on an input dataset to correct data errors.
        """
        if self.VERBOSE:
            print("------------------------------------------------------------------------\n"
                  "---------------------Initialize the Dataset Object----------------------\n"
                  "------------------------------------------------------------------------")
        d = self.initialize_dataset(d)
        if self.VERBOSE:
            print("------------------------------------------------------------------------\n"
                  "--------------------Initialize Error Corrector Models-------------------\n"
                  "------------------------------------------------------------------------")
        self.initialize_models(d)
        if self.VERBOSE:
            print("------------------------------------------------------------------------\n"
                  "--------------Iterative Tuple Sampling, Labeling, and Learning----------\n"
                  "------------------------------------------------------------------------")
        # while len(d.labeled_tuples) < self.LABELING_BUDGET:
        #     self.sample_tuple(d)
        #     if d.has_ground_truth:
        #         self.label_with_ground_truth(d)
            # else:
            #   In this case, user should label the tuple interactively as shown in the Jupyter notebook.
        for si in d.labeled_tuples:
            d.sampled_tuple = si
            self.update_models(d)
            self.generate_features(d)
            self.predict_corrections(d)
            if self.VERBOSE:
                print("------------------------------------------------------------------------")
        if self.SAVE_RESULTS:
            if self.VERBOSE:
                print("------------------------------------------------------------------------\n"
                      "---------------------------Storing the Results--------------------------\n"
                      "------------------------------------------------------------------------")
            # self.store_results(d)
        return d.corrected_cells
########################################

In [7]:
########################################
# if __name__ == "__main__":
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--clean_path', type=str, default=None)
#     parser.add_argument('--dirty_path', type=str, default=None)
#     parser.add_argument('--task_name', type=str, default=None)
#     args = parser.parse_args()
#     dirty_path = args.dirty_path
#     clean_path = args.clean_path
#     task_name = args.task_name

# task_name = "tax1"
# clean_path = "../data_with_rules/tax/split_data/tax-dirty-original_error-1.csv"
# dirty_path = "../data_with_rules/tax/split_data/tax-dirty-original_error-1.csv"

# clean_path = "../data_with_rules/hospital/hospital_clean_bclean.csv"
# dirty_path = "../data_with_rules/hospital/hospital-inner_outer_error-02.csv"
# rule_path = "../data_with_rules/hospital/dc_rules-validate-fd-horizon.txt"
# task_name = "hospital_bclean2"

dirty_path = "../../BClean/dataset/flights/flights-inner_outer_error-02.csv"
clean_path = "../../BClean/dataset/flights/flights_clean.csv"
# rule_path = "./data_with_rules/flights/dc_rules-validate-fd-horizon.txt"
task_name = "flights_bclean2"

# clean_path = 'D:/WorkSpace/Code/2024/Automatic-Data-Repair-main/data_with_rules/tax/split_data/tax-clean-clean_data_ori-0010k.csv'
# dirty_path = "./data_with_rules/tax/split_data/tax-dirty-original_error-0010k.csv"

stra_path = "../data_with_rules/" + task_name[:-1] + "/noise/raha-baran-results-" + task_name + check_string(dirty_path)

if os.path.exists(stra_path):
    shutil.rmtree(stra_path)
stra_path = "../DATASET/data_with_rules/" + task_name[:-1] + "/noise/raha-baran-results-" + task_name + check_string(dirty_path)
if os.path.exists(stra_path):
    shutil.rmtree(stra_path)
stra_path = "../data_with_rules/tax/split_data/raha-baran-results-" + task_name + check_string(dirty_path)
if os.path.exists(stra_path):
    shutil.rmtree(stra_path)
stra_path = "../data_with_rules/tax/split_data/raha-baran-results-" + task_name + check_string(dirty_path)
if os.path.exists(stra_path):
    shutil.rmtree(stra_path)
dataset_name = task_name

dataset_dictionary = {
    "name": task_name + check_string(dirty_path),
    "path": dirty_path,
    "clean_path": clean_path
}
time_limit = 24*3600
# signal.signal(signal.SIGALRM, handler)
# signal.alarm(time_limit)
clean_in_cands = []
try:
    # get Raha result
    time_start = time.time()
    app1 = Detection()
    detected_cells = app1.run(dataset_dictionary)
    p, r, f = app1.d.get_data_cleaning_evaluation(detected_cells)[:3]
    time_end = time.time()

    out_path = "../Exp_result/raha_baran/" + task_name[:-1] + "/onlyED_" + task_name + check_string(dirty_path) + ".txt"
    if not os.path.exists(os.path.dirname(out_path)):
        os.makedirs(os.path.dirname(out_path))
    
    f = open(out_path, 'w')
    # sys.stdout = f
    print("{}\n{}\n{}".format(p, r, f))
    print(time_end-time_start)
    f.close()

    time_start = time.time()
    data = raha.dataset.Dataset(dataset_dictionary)
    app = Correction()
    correction_dictionary = app.run(app1.d)
    p, r, f = data.get_data_cleaning_evaluation(correction_dictionary)[-3:]

    out_path = "../Exp_result/raha_baran/" + task_name[:-1] + "/oriED+EC_" + task_name + check_string(dirty_path) + ".txt"
    res_path = "../Repaired_res/raha_baran/" + task_name[:-1] + "/repaired_" + task_name + check_string(dirty_path) + ".csv"
    if not os.path.exists(os.path.dirname(res_path)):
        os.makedirs(os.path.dirname(res_path))
    repaired_df = pd.read_csv(dirty_path)
    for cell, value in correction_dictionary.items():
        repaired_df.iloc[cell[0], cell[1]] = value
    repaired_df.to_csv(res_path, index=False, columns=list(repaired_df.columns))

    f = open(out_path, 'w')
    # sys.stdout = f
    print("{}\n{}\n{}".format(p, r, str(f)))
    time_end = time.time()
    print(time_end-time_start)
    f.close()
    # --------------------
    # app.extract_revisions(wikipedia_dumps_folder="../wikipedia-data")
    # app.pretrain_value_based_models(revision_data_folder="../wikipedia-data/revision-data")

    # sys.stdout = sys.__stdout__ 
    out_path = "../Exp_result/raha_baran/" + task_name[:-1] + "/all_compute_" + task_name + check_string(dirty_path) + ".txt"
    f = open(out_path, 'w')
    # sys.stdout = f
    actual_errors = data.get_actual_errors_dictionary()
    actual_errors_list = list(actual_errors.keys())
    repaired_cells = list(correction_dictionary.keys())
    right2wrong = 0
    right2right = 0
    wrong2right = 0
    wrong2wrong = 0
    rep_right = 0
    rec_right = 0

    rep_total = len(repaired_cells)
    wrong_cells = len(actual_errors_list)
    repair_right_cells = []
    for cell in repaired_cells:
        try:
            if correction_dictionary[cell] == actual_errors[cell]:
                repair_right_cells.append(cell)
                rep_right += 1
        except:
            continue

    for cell in actual_errors_list:
        try:
            if cell in repaired_cells:
                if correction_dictionary[cell] == actual_errors[cell]:
                    rec_right += 1
        except:
            continue

    for cell in repair_right_cells:
        if cell in actual_errors_list:
            wrong2right += 1
        else:
            right2right += 1

    print("rep_right:"+str(rep_right))
    print("rec_right:"+str(rec_right))
    print("wrong_cells:"+str(wrong_cells))
    print("prec:"+str(p))
    print("rec:"+str(r))
    print("wrong2right:"+str(wrong2right))
    print("right2right:"+str(right2right))
    repair_wrong_cells = [i for i in repaired_cells if i not in repair_right_cells]
    for cell in repair_wrong_cells:
        if cell in actual_errors_list:
            wrong2wrong += 1
        else:
            right2wrong += 1
    print("wrong2wrong:"+str(wrong2wrong))
    print("right2wrong:"+str(right2wrong))
    print("clean_in_cands BEFORE filter:"+str(len(clean_in_cands)))
    clean_in_cands = list(set(clean_in_cands))
    for cell in clean_in_cands:
        if cell not in repaired_cells:
            clean_in_cands.remove(cell)
    print("clean_in_cands AFTER filter:"+str(len(clean_in_cands)))
    print("proportion of clean value in candidates:"+str(len(clean_in_cands)/(rep_total+1e-8)))
    clean_in_cands_repair_right = []
    for cell in clean_in_cands:
        if cell in repair_right_cells:
            clean_in_cands_repair_right.append(cell)
    print("proportion of clean value in candidates and selected correctly:"+str(len(clean_in_cands_repair_right)/(len(clean_in_cands)+1e-8)))
    f.close()
########################################
except TimeoutError as e: 
    print("Time exceeded:", e, task_name, dirty_path)
    out_file = open("../aggre_results/timeout_log.txt", "a")
    now = datetime.now()
    out_file.write(now.strftime("%Y-%m-%d %H:%M:%S"))
    out_file.write("Baran with Raha.py: ")
    out_file.write(f" {task_name}")
    out_file.write(f" {dirty_path}\n")
    out_file.close()


0.5176767676767676
1.0
<_io.TextIOWrapper name='../Exp_result/raha_baran/flights_bclean/onlyED_flights_bclean2-inner_outer_error-02.txt' mode='w' encoding='UTF-8'>
8.111816883087158
0.11788617886178862
0.005894308943089431
<_io.TextIOWrapper name='../Exp_result/raha_baran/flights_bclean/oriED+EC_flights_bclean2-inner_outer_error-02.txt' mode='w' encoding='UTF-8'>
46.18794345855713
rep_right:29
rec_right:29
wrong_cells:4920
prec:0.11788617886178862
rec:0.005894308943089431
wrong2right:29
right2right:0
wrong2wrong:158
right2wrong:59
clean_in_cands BEFORE filter:204
clean_in_cands AFTER filter:112
proportion of clean value in candidates:0.45528455282702096
proportion of clean value in candidates and selected correctly:0.25892857140545283


In [8]:
print('hello')

hello


In [9]:
# 打印出，把错误修改正确的样本
w2r = []
for cell in repair_right_cells:
    if cell in actual_errors_list:
        wrong2right += 1
        w2r.append(cell)
    else:
        right2right += 1
print(w2r)

[(127, 2), (73, 2), (131, 2), (127, 3), (73, 3), (173, 3), (225, 3), (360, 3), (414, 3), (460, 3), (512, 3), (652, 3), (704, 3), (1093, 3), (1145, 3), (1207, 3), (1246, 3), (1300, 3), (1510, 3), (1564, 3), (1973, 3), (2025, 3), (2178, 3), (2230, 3), (127, 4), (127, 5), (414, 5), (1300, 5), (1564, 5)]


In [10]:
len(w2r)

29

In [11]:
a = [(i[0], repaired_df.columns[i[1]]) for i in w2r]
print(f"baran = {a}")

baran = [(127, 'sched_dep_time'), (73, 'sched_dep_time'), (131, 'sched_dep_time'), (127, 'act_dep_time'), (73, 'act_dep_time'), (173, 'act_dep_time'), (225, 'act_dep_time'), (360, 'act_dep_time'), (414, 'act_dep_time'), (460, 'act_dep_time'), (512, 'act_dep_time'), (652, 'act_dep_time'), (704, 'act_dep_time'), (1093, 'act_dep_time'), (1145, 'act_dep_time'), (1207, 'act_dep_time'), (1246, 'act_dep_time'), (1300, 'act_dep_time'), (1510, 'act_dep_time'), (1564, 'act_dep_time'), (1973, 'act_dep_time'), (2025, 'act_dep_time'), (2178, 'act_dep_time'), (2230, 'act_dep_time'), (127, 'sched_arr_time'), (127, 'act_arr_time'), (414, 'act_arr_time'), (1300, 'act_arr_time'), (1564, 'act_arr_time')]


In [12]:
# baran 发挥的不稳定！

In [15]:
repaired_df

Unnamed: 0,src,flight,sched_dep_time,act_dep_time,sched_arr_time,act_arr_time
0,aa,AA-3859-IAH-ORD,7:10 a.m.,7:16 a.m.,9:40 a.m.,9:32 a.m.
1,aa,AA-1733-ORD-PHX,7:45 p.m.,7:58 p.m.,10:30 p.m.,
2,aa,AA-1640-MIA-MCO,6:30 p.m.,,7:25 p.m.,
3,aa,AA-518-MIA-JFK,6:40 a.m.,6:54 a.m.,9:25 a.m.,9:28 a.m.
4,aa,AA-3756-ORD-SLC,12:15 p.m.,12:41 p.m.,2:45 p.m.,2:50 p.m.
...,...,...,...,...,...,...
2371,world-flight-tracker,UA-3099-PHX-PHL,11:55 a.m.,11:43 a.m.,6:17 p.m.,5:38 p.m.
2372,world-flight-tracker,AA-4198-ORD-CLE,10:40 a.m.,10:54 a.m.,12:55 p.m.,12:50 p.m.
2373,world-flight-tracker,CO-45-EWR-MIA,4:00 p.m.,3:58 p.m.,7:05 p.m.,6:36 p.m.
2374,world-flight-tracker,AA-3809-PHX-LAX,6:00 a.m.,6:10 a.m.,6:40 a.m.,6:19 a.m.


In [19]:
loc = (56, 'act_arr_time')
print(f" location: {loc} | baran_repaired: {repaired_df[loc[1]][loc[0]]}")

 location: (56, 'act_arr_time') | baran_repaired: 5:14 p.m.


In [20]:
loc = (2, 'act_arr_time')
print(f" location: {loc} | baran_repaired: {repaired_df[loc[1]][loc[0]]}")

 location: (2, 'act_arr_time') | baran_repaired: nan


In [22]:
loc = (127, 'sched_dep_time')
print(f" location: {loc} | baran_repaired: {repaired_df[loc[1]][loc[0]]}")

 location: (127, 'sched_dep_time') | baran_repaired: 2:55 p.m.


In [23]:
repaired_df

Unnamed: 0,src,flight,sched_dep_time,act_dep_time,sched_arr_time,act_arr_time
0,aa,AA-3859-IAH-ORD,7:10 a.m.,7:16 a.m.,9:40 a.m.,9:32 a.m.
1,aa,AA-1733-ORD-PHX,7:45 p.m.,7:58 p.m.,10:30 p.m.,
2,aa,AA-1640-MIA-MCO,6:30 p.m.,,7:25 p.m.,
3,aa,AA-518-MIA-JFK,6:40 a.m.,6:54 a.m.,9:25 a.m.,9:28 a.m.
4,aa,AA-3756-ORD-SLC,12:15 p.m.,12:41 p.m.,2:45 p.m.,2:50 p.m.
...,...,...,...,...,...,...
2371,world-flight-tracker,UA-3099-PHX-PHL,11:55 a.m.,11:43 a.m.,6:17 p.m.,5:38 p.m.
2372,world-flight-tracker,AA-4198-ORD-CLE,10:40 a.m.,10:54 a.m.,12:55 p.m.,12:50 p.m.
2373,world-flight-tracker,CO-45-EWR-MIA,4:00 p.m.,3:58 p.m.,7:05 p.m.,6:36 p.m.
2374,world-flight-tracker,AA-3809-PHX-LAX,6:00 a.m.,6:10 a.m.,6:40 a.m.,6:19 a.m.
