In [0]:
import numpy as np
import pandas as pd
# from pyspark.sql import SparkSession
from databricks.connect import DatabricksSession
from utils import load_config

In [0]:
spark = DatabricksSession.builder.getOrCreate()

In [0]:
# Load configuration
config = load_config("../project_config.yml")
catalog_name = config.catalog_name
schema_name = config.schema_name

[32m2025-04-04 15:50:57.875[0m | [1mINFO    [0m | [36mutils[0m:[36mload_config[0m:[36m66[0m - [1mLoaded configuration from ../project_config.yml[0m


In [0]:
# 37354 is the original number of rows in the features_balanced after first SMOTE
# 100 is the number of synthetic rows to generate each time running this notebook
# Load train and test sets
features_balanced = spark.table(f"{catalog_name}.{schema_name}.features_balanced").toPandas()
existing_ids = set(int(id) for id in features_balanced["Id"])

In [0]:
len(existing_ids)

37454

In [0]:
min(existing_ids), max(existing_ids)

(1, 43454)

In [0]:
# Generate a dataframe with unique values for each column with few unique values
# to identify the discrete values
def generate_unique_values_dataframe(df, columns):
    unique_values = {col: df[col].dropna().unique().tolist() for col in columns}
    return pd.DataFrame([unique_values])


# Load train and test sets
train_set = spark.table(f"{catalog_name}.{schema_name}.train_set").toPandas()
test_set = spark.table(f"{catalog_name}.{schema_name}.test_set").toPandas()
combined_set = pd.concat([train_set, test_set], ignore_index=True)

# Columns with few unique values (Age is the largest with 56 unique values)
columns = ["Sex", "Education", "Marriage", "Age", "Pay_0", "Pay_2", "Pay_3", "Pay_4", "Pay_5", "Pay_6", "Default"]

result = generate_unique_values_dataframe(combined_set, columns)
print(result)

      Sex     Education  ...                     Pay_6 Default
0  [2, 1]  [1, 2, 3, 4]  ...  [0, 4, 2, 3, 7, 5, 6, 8]  [0, 1]

[1 rows x 11 columns]


In [0]:
# Define function to create synthetic data without random state
# This will add some data drift in the Bill_amt columns
def create_synthetic_data(df, num_rows=100):
    synthetic_data = pd.DataFrame()

    for column in df.columns:
        if pd.api.types.is_numeric_dtype(df[column]) and column != "Id":
            # Check if the column has a small set of discrete values
            unique_values = df[column].unique()
            if len(unique_values) <= 10:  # Assume discrete values if there are 10 or fewer unique values
                # This includes all above columns except "Age"
                synthetic_data[column] = np.random.choice(unique_values, num_rows)
            elif column.startswith("Pay_amt"):  # Ensure positive values for "Pay_amt" columns
                mean, std = df[column].mean(), df[column].std()
                synthetic_data[column] = np.abs(np.random.normal(mean, std, num_rows)).astype(int).astype(float)
            else:
                # This will add some data drift in the Bill_amt columns
                mean, std = df[column].mean(), df[column].std()
                synthetic_data[column] = np.round(np.random.normal(mean, std, num_rows)).astype(int).astype(float)

        elif pd.api.types.is_datetime64_any_dtype(df[column]):
            min_date, max_date = df[column].min(), df[column].max()
            if min_date < max_date:
                # Ensure the timestamp is between max_date and current time
                current_time = pd.to_datetime("now")
                if max_date < current_time:
                    timestamp_range_start = max_date.value
                    timestamp_range_end = current_time.value
                    synthetic_data[column] = pd.to_datetime(
                        np.random.randint(timestamp_range_start, timestamp_range_end, num_rows)
                    )
                else:
                    synthetic_data[column] = [max_date] * num_rows
            else:
                synthetic_data[column] = [min_date] * num_rows

    new_ids = []
    # The first synthetic Id must be one greater than the maximum existing Id of the whole dataframe (train + test). If no existing_ids, then starts from 1.
    i = max(existing_ids) + 1 if existing_ids else 1

    while len(new_ids) < num_rows:
        if i not in existing_ids:
            new_ids.append(str(i))  # Convert numeric ID to string
        i += 1

    synthetic_data["Id"] = new_ids

    return synthetic_data


# Create synthetic data
synthetic_df = create_synthetic_data(combined_set)

# Move "Id" to the first position
columns = ["Id"] + [col for col in synthetic_df.columns if col != "Id"]
synthetic_df = synthetic_df[columns]

In [0]:
synthetic_df.tail()

Unnamed: 0,Id,Limit_bal,Sex,Education,Marriage,Age,Pay_0,Pay_2,Pay_3,Pay_4,Pay_5,Pay_6,Bill_amt1,Bill_amt2,Bill_amt3,Bill_amt4,Bill_amt5,Bill_amt6,Pay_amt1,Pay_amt2,Pay_amt3,Pay_amt4,Pay_amt5,Pay_amt6,Default,Update_timestamp_utc
95,43550,413409.0,2,2,3,35.0,6,5,4,2,4,5,-26706.0,12926.0,28517.0,116337.0,26948.0,2789.0,21443.0,10296.0,3465.0,49329.0,15811.0,19563.0,1,2025-04-04 15:39:01.397376499
96,43551,157063.0,1,4,2,17.0,8,8,1,4,7,7,90208.0,812.0,56751.0,36949.0,-16655.0,131872.0,4925.0,13592.0,305.0,3415.0,5510.0,11661.0,0,2025-04-04 15:38:18.388616934
97,43552,55968.0,1,3,3,42.0,1,7,2,3,7,0,192931.0,4935.0,81068.0,55208.0,138627.0,55194.0,15084.0,6727.0,10345.0,19466.0,6511.0,1195.0,0,2025-04-04 15:45:58.539260652
98,43553,247942.0,2,3,3,25.0,8,7,8,2,6,2,98777.0,56354.0,43757.0,25849.0,-34255.0,-1745.0,29691.0,6634.0,8003.0,32033.0,14128.0,10004.0,1,2025-04-04 15:35:09.449300281
99,43554,9141.0,2,1,1,27.0,4,4,0,2,2,3,31020.0,-41799.0,-143495.0,93446.0,70034.0,124695.0,18955.0,11521.0,13996.0,22022.0,2369.0,3449.0,0,2025-04-04 15:32:14.135841129


In [0]:
list(synthetic_df.Id)

['43455',
 '43456',
 '43457',
 '43458',
 '43459',
 '43460',
 '43461',
 '43462',
 '43463',
 '43464',
 '43465',
 '43466',
 '43467',
 '43468',
 '43469',
 '43470',
 '43471',
 '43472',
 '43473',
 '43474',
 '43475',
 '43476',
 '43477',
 '43478',
 '43479',
 '43480',
 '43481',
 '43482',
 '43483',
 '43484',
 '43485',
 '43486',
 '43487',
 '43488',
 '43489',
 '43490',
 '43491',
 '43492',
 '43493',
 '43494',
 '43495',
 '43496',
 '43497',
 '43498',
 '43499',
 '43500',
 '43501',
 '43502',
 '43503',
 '43504',
 '43505',
 '43506',
 '43507',
 '43508',
 '43509',
 '43510',
 '43511',
 '43512',
 '43513',
 '43514',
 '43515',
 '43516',
 '43517',
 '43518',
 '43519',
 '43520',
 '43521',
 '43522',
 '43523',
 '43524',
 '43525',
 '43526',
 '43527',
 '43528',
 '43529',
 '43530',
 '43531',
 '43532',
 '43533',
 '43534',
 '43535',
 '43536',
 '43537',
 '43538',
 '43539',
 '43540',
 '43541',
 '43542',
 '43543',
 '43544',
 '43545',
 '43546',
 '43547',
 '43548',
 '43549',
 '43550',
 '43551',
 '43552',
 '43553',
 '43554']

In [0]:
combined_set.Bill_amt2.min(), combined_set.Bill_amt2.max()

(-139883.0, 983931.0)

In [0]:
# Some values are outside the original column names (data drift)
synthetic_df.Bill_amt2.min(), synthetic_df.Bill_amt2.max()

(-90425.0, 191342.0)

In [0]:
synthetic_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 100 entries, 0 to 99
Data columns (total 26 columns):
 #   Column                Non-Null Count  Dtype         
---  ------                --------------  -----         
 0   Id                    100 non-null    object        
 1   Limit_bal             100 non-null    float64       
 2   Sex                   100 non-null    int32         
 3   Education             100 non-null    int32         
 4   Marriage              100 non-null    int32         
 5   Age                   100 non-null    float64       
 6   Pay_0                 100 non-null    int32         
 7   Pay_2                 100 non-null    int32         
 8   Pay_3                 100 non-null    int32         
 9   Pay_4                 100 non-null    int32         
 10  Pay_5                 100 non-null    int32         
 11  Pay_6                 100 non-null    int32         
 12  Bill_amt1             100 non-null    float64       
 13  Bill_amt2            

In [0]:
# Create source_data table with the same schema as train_set
train_set_schema = spark.table(f"{catalog_name}.{schema_name}.train_set").schema

# Create an empty DataFrame with the same schema
empty_source_data_df = spark.createDataFrame(data=[], schema=train_set_schema)

# Create an empty source_data table
empty_source_data_df.write.mode("overwrite").saveAsTable(f"{catalog_name}.{schema_name}.source_data")

print(f"Empty table '{catalog_name}.{schema_name}.source_data' created successfully.")

Empty table 'credit.default.source_data' created successfully.


In [0]:
# Create synthetic data
existing_schema = spark.table(f"{catalog_name}.{schema_name}.source_data").schema

synthetic_spark_df = spark.createDataFrame(synthetic_df, schema=existing_schema)

# Append synthetic data as new data to source_data table
synthetic_spark_df.write.mode("append").saveAsTable(f"{catalog_name}.{schema_name}.source_data")