In [1]:
from neo4j import GraphDatabase
import pandas as pd
import uuid  # For generating unique IDs
import math
from datasketch import MinHash, MinHashLSH


# Load CSV files
account_booking_train = pd.read_csv('data/account_booking_train.csv')
external_parties_train = pd.read_csv('data/external_parties_train.csv')



In [2]:
# Clean the Data
# Clean the Data
# Step 1: Identify transaction counts by `transaction_reference_id`
transaction_counts = account_booking_train['transaction_reference_id'].value_counts()

# Step 2: Filter transactions with only one leg
single_leg_transactions = transaction_counts[transaction_counts == 1].index

# Step 3: Filter the dataset for these transactions
account_booking_train = account_booking_train[
    account_booking_train['transaction_reference_id'].isin(single_leg_transactions)
]


def sanitize_value(value):
    """Replace NaN with None to avoid Neo4j errors."""
    if value is None or (isinstance(value, float) and math.isnan(value)):
        return None
    return value



In [None]:
# Neo4j connection details
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "password123"

class Neo4jLoader:
    def __init__(self, uri, username, password):
        """Initialize the Neo4j driver."""
        self.driver = GraphDatabase.driver(uri, auth=(username, password))

    def close(self):
        """Close the Neo4j driver connection."""
        if self.driver:
            self.driver.close()

    def create_constraints(self):
        """Create unique constraints for External_Party nodes."""
        with self.driver.session() as session:
            try:
                session.run("""
                CREATE CONSTRAINT unique_party_id IF NOT EXISTS
                FOR (p:External_Party) REQUIRE p.id IS UNIQUE
                """)
                print("Constraint 'unique_party_id' created successfully.")
            except Exception as e:
                print(f"Error creating 'unique_party_id' constraint: {e}")

    def load_external_parties(self, session, external_parties_train):
        """Load external parties as nodes."""
        for _, row in external_parties_train.iterrows():
            unique_party_id = str(uuid.uuid4())
            try:
                # Prepare non-null properties


                properties = {
                    "id": unique_party_id,
                    "transaction_reference_id": sanitize_value(row.get('transaction_reference_id')),
                    "party_role": sanitize_value(row.get('party_role')),
                    "party_info_unstructured": sanitize_value(row.get('party_info_unstructured')),
                    "parsed_name": sanitize_value(row.get('parsed_name')),
                    "parsed_address_street_name": sanitize_value(row.get('parsed_address_street_name')),
                    "parsed_address_street_number": sanitize_value(row.get('parsed_address_street_number')),
                    "parsed_address_unit": sanitize_value(row.get('parsed_address_unit')),
                    "parsed_address_postal_code": sanitize_value(row.get('parsed_address_postal_code')),
                    "parsed_address_city": sanitize_value(row.get('parsed_address_city')),
                    "parsed_address_state": sanitize_value(row.get('parsed_address_state')),
                    "parsed_address_country": sanitize_value(row.get('parsed_address_country')),
                    "party_iban": sanitize_value(row.get('party_iban')),
                    "party_phone": sanitize_value(row.get('party_phone')),
                    "external_id": sanitize_value(row.get('external_id'))
                }
                # Remove None values from properties
                non_null_properties = {k: v for k, v in properties.items() if v is not None}

                # Dynamically construct Cypher query
                query = f"MERGE (p:External_Party {{{', '.join([f'{k}: ${k}' for k in non_null_properties.keys()])}}})"
                session.run(query, non_null_properties)
            except Exception as e:
                print(f"Error loading External_Party node: {e}")

    def load_transactions_as_relationships(self, session, account_booking_train):
        """Load transactions as relationships between UBS_Client and External_Party nodes."""
        for _, row in account_booking_train.iterrows():
            try:
                # Prepare transaction properties
                properties = {
                    "transaction_reference_id": sanitize_value(row.get('transaction_reference_id')),
                    "debit_credit_indicator": sanitize_value(row.get('debit_credit_indicator')),
                    "account_id": sanitize_value(row.get('account_id')),
                    "transaction_amount": sanitize_value(row.get('transaction_amount')),
                    "transaction_currency": sanitize_value(row.get('transaction_currency')),
                    "transaction_date": sanitize_value(row.get('transaction_date'))
                }
                # Remove None values from properties
                non_null_properties = {k: v for k, v in properties.items() if v is not None}

                # Ensure account_id is present
                if "account_id" not in non_null_properties or "transaction_reference_id" not in non_null_properties:
                    print(f"Skipping row due to missing required fields: {row}")
                    continue

                # Dynamically construct Cypher queries
                create_ubs_client_query = """
                MERGE (c:UBS_Client {account_id: $account_id})
                """
                create_transaction_relationship_query = """
                MATCH (c:UBS_Client {account_id: $account_id}),
                    (p:External_Party {transaction_reference_id: $transaction_reference_id})
                MERGE (c)-[:TRANSACTION {
                    debit_credit_indicator: $debit_credit_indicator,
                    transaction_amount: $transaction_amount,
                    transaction_currency: $transaction_currency,
                    transaction_date: $transaction_date
                }]->(p)
                """

                # Execute queries
                session.run(create_ubs_client_query, {"account_id": non_null_properties["account_id"]})
                session.run(create_transaction_relationship_query, non_null_properties)
            except Exception as e:
                print(f"Error creating transaction relationship: {e}")

    
    def load_data(self, external_parties_train, account_booking_train):
        """Load all data into Neo4j."""
        with self.driver.session() as session:
            # Load nodes and relationships
            self.load_external_parties(session, external_parties_train)
            self.load_transactions_as_relationships(session, account_booking_train)

# Initialize the Neo4j loader
loader = Neo4jLoader(NEO4J_URI, NEO4J_USERNAME, NEO4J_PASSWORD)

try:
    print("Creating constraints in Neo4j...")
    loader.create_constraints()

    print("Loading data into Neo4j...")
    loader.load_data(external_parties_train, account_booking_train)
    print("Data loaded successfully!")
finally:
    loader.close()


In [None]:
class LSHEntityResolution:
    def __init__(self, uri, user, password):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))

    def close(self):
        self.driver.close()

    # Fetch data from Neo4j
    def fetch_nodes(self):
        query = """
        MATCH (n:External_Party)
        RETURN id(n) AS node_id, n.parsed_name AS name, n.parsed_address_street_name AS street, n.parsed_address_city AS city
        """
        with self.driver.session() as session:
            result = session.run(query)
            return [(record["node_id"], record["name"], record["street"], record["city"]) for record in result]

    # Generate MinHash for a given text
    def generate_minhash(self, text, num_perm=128):
        if not text:
            return None
        mh = MinHash(num_perm=num_perm)
        for token in text.split():
            mh.update(token.encode("utf8"))
        return mh

    # Perform LSH on nodes
    def perform_lsh(self, nodes, threshold=0.8, num_perm=128):
        lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
        node_hashes = {}

        # Insert nodes into LSH
        for node_id, name, street, city in nodes:
            combined_text = " ".join(filter(None, [name, street, city]))
            mh = self.generate_minhash(combined_text, num_perm)
            if mh:
                lsh.insert(str(node_id), mh)
                node_hashes[str(node_id)] = mh

        # Find similar pairs
        matches = []
        for node_id, mh in node_hashes.items():
            similar = lsh.query(mh)
            for sim in similar:
                if node_id < sim:  # Avoid duplicate matches (e.g., (A, B) and (B, A))
                    matches.append((int(node_id), int(sim)))

        return matches

    # Save matches to Neo4j
    def save_matches_to_neo4j(self, matches):
        query = """
        UNWIND $pairs AS pair
        MATCH (n1), (n2)
        WHERE id(n1) = pair[0] AND id(n2) = pair[1]
        MERGE (n1)-[:POSSIBLE_DUPLICATE]->(n2)
        """
        with self.driver.session() as session:
            session.run(query, pairs=matches)

    # Merge duplicate nodes in Neo4j
    def merge_duplicate_nodes(self):
        query = """
        // Match nodes connected by POSSIBLE_DUPLICATE relationships
        MATCH (n)-[r:POSSIBLE_DUPLICATE]-(m)
        WITH n, m
        WHERE id(n) < id(m)  // To avoid processing the same pair twice

        // Step 1: Merge properties from both nodes
        SET n += m

        // Step 2: Redirect incoming relationships to n
        WITH n, m
        MATCH (m)<-[rel]-(other)
        MERGE (n)<-[newRel:TYPE(rel)]-(other)
        ON CREATE SET newRel += properties(rel)

        // Step 3: Redirect outgoing relationships to n
        WITH n, m
        MATCH (m)-[rel]->(other)
        MERGE (n)-[newRel:TYPE(rel)]->(other)
        ON CREATE SET newRel += properties(rel)

        // Step 4: Delete duplicate node and POSSIBLE_DUPLICATE relationship
        WITH m
        DETACH DELETE m

        """
        with self.driver.session() as session:
            result = session.run(query)
            return result.single()["merged_nodes"]

    # Full workflow
    def run_lsh_entity_resolution(self, threshold=0.8):
        print("Fetching nodes from Neo4j...")
        nodes = self.fetch_nodes()
        print(f"Fetched {len(nodes)} nodes.")

        print(f"Performing LSH with threshold {threshold}...")
        matches = self.perform_lsh(nodes, threshold=threshold)
        print(f"Found {len(matches)} matches.")

        print("Saving matches to Neo4j...")
        self.save_matches_to_neo4j(matches)

        #print("Merging duplicate nodes in Neo4j...")
        #merged_nodes = self.merge_duplicate_nodes()
        #print(f"Total nodes merged: {merged_nodes}")

if __name__ == "__main__":
    lsh_er = LSHEntityResolution(NEO4J_URI, NEO4J_USERNAME, NEO4J_PASSWORD)
    
    try:
        lsh_er.run_lsh_entity_resolution(threshold=0.8)
    finally:
        lsh_er.close()



In [None]:

class Neo4jReset:

    def __init__(self, uri, username, password):
        self.driver = GraphDatabase.driver(uri, auth=(username, password))
    
    def close(self):
        self.driver.close()
    
    def reset_database(self):
        with self.driver.session() as session:
            # Delete all nodes and relationships
            session.run("MATCH (n) DETACH DELETE n")
            
            # Drop all constraints
            constraints = session.run("SHOW CONSTRAINTS")
            for record in constraints:
                constraint_name = record["name"]
                session.run(f"DROP CONSTRAINT {constraint_name}")

            # Drop all indexes
            indexes = session.run("SHOW INDEXES")
            for record in indexes:
                index_name = record["name"]
                session.run(f"DROP INDEX {index_name}")

# Initialize the Neo4jReset class
resetter = Neo4jReset(NEO4J_URI, NEO4J_USERNAME, NEO4J_PASSWORD)

try:
    print("Resetting Neo4j database...")
    resetter.reset_database()
    print("Database reset successfully!")
finally:
    resetter.close()
