# Creating a Common Feature Table

### Trial environment setup

In [0]:
%pip install databricks-feature-engineering

dbutils.library.restartPython()

In [0]:
# -------------------------------------------------------------------------
# Optional: Reset state if needed
# -------------------------------------------------------------------------
feature_table_name = "apex_bank_demo.analytics.customer_spending_features"

# Drop the table if it exists to ensure we aren't appending to a broken schema
print(f"ðŸ’£ Dropping old table {feature_table_name} to ensure clean slate...")
spark.sql(f"DROP TABLE IF EXISTS {feature_table_name}")

### Connect data from Unity Catalog

In [0]:
catalog = "apex_bank_demo"
schema = "analytics"

raw_transactions_df = spark.table(f"{catalog}.{schema}.transactions_silver")
labels_df = spark.table(f"{catalog}.{schema}.fraud_labels")

print(f"Loaded data from {catalog}.{schema}")
print("------------------------------------------------")
print("Transactions schema: (Verify column names match logic below)")
raw_transactions_df.printSchema()
print("------------------------------------------------")
print("Labels schema:")
labels_df.printSchema()

### Connect most recent transaction from investigation date


In [0]:
from databricks.feature_engineering import FeatureEngineeringClient, FeatureLookup
from pyspark.sql import functions as F
from pyspark.sql.types import TimestampType
from pyspark.sql.window import Window

fe = FeatureEngineeringClient()

raw_transactions_df = spark.table("apex_bank_demo.analytics.transactions_silver")

def compute_customer_history(transactions_df):
    
    # Null checks & type casting
    df = (
        transactions_df
        .filter(F.col("transaction_timestamp").isNotNull())
        .withColumn("transaction_timestamp", F.col("transaction_timestamp").cast(TimestampType()))
        # Convert timestamp to Long (seconds) for Window calculations
        .withColumn("ts_long", F.col("transaction_timestamp").cast("long"))
    )
    
    w_7d = (
        Window.partitionBy("account_id")
        .orderBy("ts_long")
        .rangeBetween(-604800, 0) # Look back 7 days (604800 seconds) from current row
    )
    
    return (
        df
        .withColumn("avg_txn_7d", F.avg("amount").over(w_7d))
        .withColumn("stddev_txn_7d", F.stddev("amount").over(w_7d))
        .withColumn("count_txn_7d", F.count("*").over(w_7d))
        .select( # Select only PKs and Features
            "account_id", 
            "transaction_timestamp", 
            "avg_txn_7d", 
            "stddev_txn_7d", 
            "count_txn_7d"
        )
        .dropDuplicates(["account_id", "transaction_timestamp"])
    )


### Create feature table in Unity Catalog

In [0]:
feature_table_name = "apex_bank_demo.analytics.customer_spending_features"

print("Creating feature table with rolling windows...")
fe.create_table(
    name=feature_table_name,
    primary_keys=["account_id", "transaction_timestamp"],
    timeseries_columns="transaction_timestamp",
    df=compute_customer_history(raw_transactions_df).dropDuplicates(['account_id', 'transaction_timestamp']),
    description="True 7-day rolling stats"
)

ft_count = spark.table(feature_table_name).count()
print(f"Feature table rebuilt. Row count: {ft_count}")

if ft_count == 0:
    print("Error: The feature table is empty. Check source data for null timestamps.")

### Build "fraud" and "non-fraud" training dataframes

In [0]:
from pyspark.sql import Window
from pyspark.sql import functions as F

txns = raw_transactions_df.alias("t")
lbls = labels_df.alias("l")

# Find the transaction that triggered the fraud
joined_df = (
    txns.join(lbls, on="account_id")
    .filter(F.col("t.transaction_timestamp") <= F.col("l.investigation_date"))
)

from pyspark.sql.window import Window
window_spec = Window.partitionBy("account_id").orderBy(F.col("t.transaction_timestamp").desc())

fraud_training_df = (
    joined_df
    .withColumn("rank", F.row_number().over(window_spec))
    .filter(F.col("rank") == 1)
    .select(
        F.col("account_id"),
        F.col("transaction_timestamp").cast(TimestampType()),
        F.lit(1).alias("is_fraud")
    )
)

print(f"Training labels ready. Row count: {fraud_training_df.count()}")

# Teach the model what "Good" behavior looks like 
# (random transactions from accounts that are NOT in the fraud list)
non_fraud_training_df = (
    raw_transactions_df
    .join(labels_df, on="account_id", how="left_anti") # "Left Anti" = Exclude fraud accounts
    .sample(fraction=0.1) # Take a 10% random sample
    .select(
        F.col("account_id"),
        F.col("transaction_timestamp").cast(TimestampType()),
        F.lit(0).alias("is_fraud")
    )
    .limit(fraud_training_df.count() * 5) # Optional: Limit size to keep balance
)

# Combine 2 sets
final_training_set_df = fraud_training_df.union(non_fraud_training_df)

print(f"Final training set prepared: {final_training_set_df.count()} rows")
display(final_training_set_df)

### Create training set (clean timestamps)

In [0]:
training_set = fe.create_training_set(
    df=final_training_set_df,
    feature_lookups=[
        FeatureLookup(
            table_name=feature_table_name,
            lookup_key=["account_id"],
            timestamp_lookup_key="transaction_timestamp"
        )
    ],
    label="is_fraud",
    exclude_columns=["account_id", "transaction_timestamp"]
)

training_df = training_set.load_df()

# Null check
nulls = training_df.filter(F.col("avg_txn_7d").isNull()).count()
print(f"Result: {training_df.count()} rows. {nulls} rows have NULL features.")

display(training_df)


## Proceed to train the model

In [0]:
%pip install databricks-automl

In [0]:
from databricks.automl import classifier

# -------------------------------------------------------------------------
# STEP 5: Train & Register the Model
# -------------------------------------------------------------------------
# Drop the keys (account_id, timestamp) to exclude from the model.
# Keep the features (avg, stddev, count).

train_data = training_df.drop("account_id", "transaction_timestamp")

# Run AutoML
summary = classifier.fit(
    dataset=train_data,
    target_col="is_fraud",
    timeout_minutes=5,
    experiment_name="/Shared/Apex_Fraud_Experiment"
)

print(f"âœ… Model Trained. Best Trial: {summary.best_trial.model_path}")