In [1]:
# Load .env variables
from dotenv import load_dotenv
import os

load_dotenv()
NEO4J_URI = os.getenv("NEO4J_URI")
NEO4J_USER = os.getenv("NEO4J_USER")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
NEO4J_DATABASE = os.getenv("NEO4J_DATABASE", "neo4j")

In [2]:
# Open session with Neo4j
from neo4j import GraphDatabase

driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
session = driver.session(database=NEO4J_DATABASE)

In [None]:
def run_cypher(session, query, parameters=None):
    """Helper function to run Cypher queries."""
    results = []
    with session.begin_transaction() as tx:
        # Split by semicolon to handle multiple statements
        for statement in query.split(';'):
            if statement.strip():  # Skip empty strings
                result = tx.run(statement, parameters)
                # Only fetch data if there are records to consume
                if result.peek():
                    results.append(result.data())
    
    # Return the last result or all results depending on preference, 
    # here we return the last non-empty one or empty list
    return results[-1] if results else []

In [4]:
# Define all necessary queries which we will be using later
queries = {
    # Cleanup existing projected graphs, pipelines, and models  
    "cleanup": """
        CALL gds.graph.drop('churnGraph', false);
        CALL gds.pipeline.drop('churnPipeline', false);
        CALL gds.model.drop('churnModel', false);
    """,
    # Remove temporary labels
    "remove_temp_labels": """
        MATCH (u:TrainingCohort) REMOVE u:TrainingCohort;
    """,
    # Correct for undersampling in the training set
    "correct_undersampling": """
        // Apply a "TrainingCohort" label to:
        // - All Churned Users (100%)
        // - A tiny fraction of Active Users (5%)
        MATCH (u:User)
        WHERE u.churned = 1 OR rand() < 0.05
        SET u:TrainingCohort;
    """,
    # Create SHOPPED_AT relationship with weights
    "create_shopped_at": """
        // Ensure the relationship exists for these users (Materialize it)
        // This makes the projection lightning fast and avoids complex matching later
        MATCH (u:TrainingCohort)-[:OWNS]->(:Card)-[:PERFORMED]->(t:Transaction)-[:TO]->(m:Merchant)
        WITH u, m, count(t) AS weight
        MERGE (u)-[r:SHOPPED_AT]->(m)
        SET r.weight = weight;
    """,
    # Project to a churn graph
    "project_graph": """
        CALL gds.graph.project(
            'churnGraph',
            {
                TrainingCohort: {
                    properties: ['churned', 'yearly_income', 'total_debt', 'credit_score']
                },
                Merchant: {
                    properties: [] 
                }
            },
            {
                SHOPPED_AT: {
                    orientation: 'UNDIRECTED',
                    properties: 'weight'
                }
            }
        );
    """,
    # Create a pipeline for feature engineering
    "create_pipeline": """
        CALL gds.fastRP.mutate('churnGraph', {
            embeddingDimension: 64, // Small dimension for small dataset
            relationshipWeightProperty: 'weight',
            mutateProperty: 'embedding_fastrp'
        });

        CALL gds.pageRank.mutate('churnGraph', {
            relationshipWeightProperty: 'weight',
            mutateProperty: 'score_pagerank'
        });
    """,
    # Configure pipeline for model training
    "configure_pipeline": """
        CALL gds.beta.pipeline.nodeClassification.create('churnPipeline');

        CALL gds.beta.pipeline.nodeClassification.selectFeatures('churnPipeline', [
            'embedding_fastrp', 
            'score_pagerank', 
            'yearly_income', 
            'total_debt', 
            'credit_score'
        ]);

        // Reduced folds to 2 because we only have ~16 churners
        CALL gds.beta.pipeline.nodeClassification.configureSplit('churnPipeline', {
            testFraction: 0.2,
            validationFolds: 2
        });

        // Tuned Random Forest for small data
        CALL gds.beta.pipeline.nodeClassification.addRandomForest('churnPipeline', {
            numberOfDecisionTrees: 50,
            maxDepth: 5,
            minLeafSize: 1,
            minSplitSize: 2
        });
    """,
    # Train the model
    "train_model": """
        CALL gds.beta.pipeline.nodeClassification.train(
            'churnGraph',
            {
                pipeline: 'churnPipeline',
                // FIX: We target the new specific label we created
                targetNodeLabels: ['TrainingCohort'], 
                modelName: 'churnModel',
                targetProperty: 'churned',
                metrics: ['ACCURACY', 'OUT_OF_BAG_ERROR', 'PRECISION(class=1)', 'RECALL(class=1)']
            }
        )
        YIELD modelInfo, modelSelectionStats
        RETURN 
            modelInfo.metrics.ACCURACY AS Accuracy,
            modelInfo.metrics['PRECISION(class=1)'] AS Precision,
            modelInfo.metrics['RECALL(class=1)'] AS Recall;
    """,
    # Make predictions on all users
    "predict_churn": """
        CALL gds.beta.pipeline.nodeClassification.predict.stream(
            'churnGraph',
            {
                modelName: 'churnModel',
                includePredictedProbabilities: true
            }
        )
        YIELD nodeId, predictedClass, predictedProbabilities
        WITH 
            gds.util.asNode(nodeId) AS user, 
            predictedClass,
            predictedProbabilities[1] AS Risk_Score

        // Filter: Show 'Active' users that the model thinks are risky
        WHERE user.churned = 0

        RETURN 
            user.id AS UserID, 
            user.yearly_income AS Income,
            user.credit_score AS Credit_Score,
            Risk_Score, 
            predictedClass
        ORDER BY Risk_Score DESC
        LIMIT 50;
    """
}

In [5]:
# Cleanup from previous runs
run_cypher(session, queries["cleanup"])

# Remove temporary labels
run_cypher(session, queries["remove_temp_labels"])

# Correct for undersampling in the training set
run_cypher(session, queries["correct_undersampling"])

# Create SHOPPED_AT relationship with weights
run_cypher(session, queries["create_shopped_at"])

CypherSyntaxError: {neo4j_code: Neo.ClientError.Statement.SyntaxError} {message: Expected exactly one statement per query but got: 3 (line 0, column 0 (offset: 0))
""
 ^} {gql_status: 42001} {gql_status_description: error: syntax error or access rule violation - invalid syntax}

In [None]:
# Project to a churn graph
res = run_cypher(session, queries["project_graph"])
print("Projected graph:", res)