In [1]:
from neo4j import GraphDatabase
import pandas as pd
import uuid  # For generating unique IDs
import math 


# Load CSV files
account_booking_train = pd.read_csv('data/account_booking_train.csv')
external_parties_train = pd.read_csv('data/external_parties_train.csv')



In [2]:
from datasketch import MinHash, MinHashLSH
import pandas as pd

class LSHEntityResolution:
    def __init__(self, dataframe):
        """
        Initialize with the DataFrame.
        :param dataframe: Pandas DataFrame containing entity data.
        """
        self.dataframe = dataframe

    # Generate MinHash for a given text
    def generate_minhash(self, text, num_perm=128):
        """
        Create MinHash for a given text.
        :param text: Input string.
        :param num_perm: Number of permutations for MinHash.
        :return: MinHash object.
        """
        if not text or pd.isna(text):
            return None
        mh = MinHash(num_perm=num_perm)
        for token in text.split():
            mh.update(token.encode("utf8"))
        return mh

    # Perform LSH on nodes
    def perform_lsh(self, threshold=0.8, num_perm=128):
        """
        Perform LSH to find similar entities.
        :param threshold: Similarity threshold for LSH.
        :param num_perm: Number of permutations for MinHash.
        :return: List of matched entity pairs.
        """
        lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
        node_hashes = {}

        # Insert nodes into LSH
        for index, row in self.dataframe.iterrows():
            combined_text = " ".join(
                filter(None, [
                    str(row['parsed_name']), 
                    str(row['parsed_address_street_name']), 
                    str(row['parsed_address_city'])                    
                ])
            )
            mh = self.generate_minhash(combined_text, num_perm)
            if mh:
                lsh.insert(str(index), mh)
                node_hashes[str(index)] = mh

        # Find similar pairs
        matches = []
        for node_id, mh in node_hashes.items():
            similar = lsh.query(mh)
            for sim in similar:
                if int(node_id) < int(sim):  # Avoid duplicate matches (e.g., (A, B) and (B, A))
                    matches.append((int(node_id), int(sim)))

        return matches

    def evaluate_matches(self, matches):
        """
        Evaluate matches based on the 'external_id' column and print detailed results.
        :param matches: List of matched entity pairs.
        :return: Number of correct matches and a list of wrong matches.
        """
        correct_matches = 0
        wrong_matches = []

        for match in matches:
            id1, id2 = match
            external_id1 = self.dataframe.loc[id1, 'external_id']
            external_id2 = self.dataframe.loc[id2, 'external_id']
            if external_id1 == external_id2:
                correct_matches += 1
            else:
                wrong_matches.append((id1, id2, external_id1, external_id2))

        return correct_matches, wrong_matches

lsh_er = LSHEntityResolution(external_parties_train)
matches = lsh_er.perform_lsh(threshold=0.6, num_perm=264)

correct_match_count, wrong_matches = lsh_er.evaluate_matches(matches)

# Print Results
print(f"Number of correct matches: {correct_match_count}")
print(f"Number of wrong matches: {len(wrong_matches)}")


Number of correct matches: 4516
Number of wrong matches: 36


In [3]:
def calculate_total_true_matches(df):
    """
    Calculate the total number of true matches based on external_id.
    """
    total_true_matches = 0
    external_id_groups = df.groupby('external_id').size()
    for count in external_id_groups:
        if count > 1:
            total_true_matches += count * (count - 1) // 2  # nC2 = n * (n - 1) / 2
    return total_true_matches

# Calculate the total true matches and F1 score components
total_true_matches = calculate_total_true_matches(external_parties_train)
false_negatives = total_true_matches - correct_match_count
false_positives = len(wrong_matches)

# Precision, Recall, and F1 Score
precision = correct_match_count / (correct_match_count + false_positives) if (correct_match_count + false_positives) > 0 else 0
recall = correct_match_count / (correct_match_count + false_negatives) if (correct_match_count + false_negatives) > 0 else 0
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

# Print results
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1 Score: {f1_score}")

Precision: 0.992091388400703
Recall: 0.32340303637926093
F1 Score: 0.48779434003024413
