In [1]:
from format_data import (CATEGORICAL_VARIABLES, DATE_VARIABLES, ID_VARIABLES,
                         TARGET_VARIABLE)
from transform_aggregated import SUMMARY_FEATURE_CATEGORICAL_VARIABLES, WINDOW_FEATURE_CATEGORICAL_VARIABLES
from spark_utils import get_spark_session

spark = get_spark_session()

# run transform_latest.py if this don't exist
test_data = spark.read.parquet(
    'data_transformed/amex-default-prediction/test_data_aggregated')
train_data = spark.read.parquet(
    'data_transformed/amex-default-prediction/train_data_aggregated')
# run format_data.py if these don't exist
train_labels = spark.read.parquet('data/amex-default-prediction/train_labels')
sample_submission = spark.read.parquet(
    'data/amex-default-prediction/sample_submission')

train_data_labelled = train_data.join(train_labels, on=ID_VARIABLES, how='inner')
assert train_data_labelled.count() == train_data.count()
assert train_data_labelled.select(ID_VARIABLES).distinct().count() == train_data.select(ID_VARIABLES).distinct().count()

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/08/02 00:16:29 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
                                                                                

In [2]:
non_feature_columns = [
    TARGET_VARIABLE,
    *ID_VARIABLES,
    *DATE_VARIABLES.keys(),
]
feature_columns = [
    c for c in train_data.columns
    if c not in non_feature_columns
]
print(
    f'Feature columns ({len(feature_columns)}):\n'
    + ', '.join(feature_columns)
)

Feature columns (1296):
S_2_days_since_previous, P_2, D_39, B_1, B_2, R_1, S_3, D_41, B_3, D_42, D_43, D_44, B_4, D_45, B_5, R_2, D_46, D_47, D_48, D_49, B_6, B_7, B_8, D_50, D_51, B_9, R_3, D_52, P_3, B_10, D_53, S_5, B_11, S_6, D_54, R_4, S_7, B_12, S_8, D_55, D_56, B_13, R_5, D_58, S_9, B_14, D_59, D_60, D_61, B_15, S_11, D_62, D_63, D_64, D_65, B_16, B_17, B_18, B_19, D_66, B_20, D_68, S_12, R_6, S_13, B_21, D_69, B_22, D_70, D_71, D_72, S_15, B_23, D_73, P_4, D_74, D_75, D_76, B_24, R_7, D_77, B_25, B_26, D_78, D_79, R_8, R_9, S_16, D_80, R_10, R_11, B_27, D_81, D_82, S_17, R_12, B_28, R_13, D_83, R_14, R_15, D_84, R_16, B_29, B_30, S_18, D_86, D_87, R_17, R_18, D_88, B_31, S_19, R_19, B_32, S_20, R_20, R_21, B_33, D_89, R_22, R_23, D_91, D_92, D_93, D_94, R_24, R_25, D_96, S_22, S_23, S_24, S_25, S_26, D_102, D_103, D_104, D_105, D_106, D_107, B_36, B_37, R_26, R_27, B_38, D_108, D_109, D_110, D_111, B_39, D_112, B_40, S_27, D_113, D_114, D_115, D_116, D_117, D_118, D_119, D_120,

In [3]:
from getxy import GetXY

getxy = GetXY(
    spark=spark,
    feature_columns=feature_columns,
    categorical_columns=[
        *CATEGORICAL_VARIABLES,
        *WINDOW_FEATURE_CATEGORICAL_VARIABLES,
        *SUMMARY_FEATURE_CATEGORICAL_VARIABLES,
    ],
    target_column=TARGET_VARIABLE,
).fit(train_data)


In [4]:
# some rough calculations for batch size
known_good_df = spark.read.parquet('data_transformed/amex-default-prediction/train_data_latest')
known_good_shape = (known_good_df.count(), len(known_good_df.columns))
target_shape = (train_data.count(), len(feature_columns))
batch_size = known_good_df.count() * (len(known_good_df.columns) / len(feature_columns))
print(batch_size)

67633.01157407407


In [5]:
from batched import Batched, BatchedLGBMClassifier

train_data_labelled_batches = Batched(batch_size=batch_size).fit_transform(train_data_labelled)
test_data_batches = Batched(batch_size=batch_size).fit_transform(test_data)

m = BatchedLGBMClassifier(lgb_params={}, getxy=getxy).fit(dfs=train_data_labelled_batches)
pred_df = m.predict(dfs=test_data_batches, id_variables=ID_VARIABLES)
assert len(pred_df) == test_data.count()
pred_df


22/08/02 00:17:04 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

Fitting 0/7 with 65626 rows


                                                                                

Fitting 1/7 with 65553 rows


                                                                                

Fitting 2/7 with 65544 rows


                                                                                

Fitting 3/7 with 65447 rows


                                                                                

Fitting 4/7 with 65658 rows


                                                                                

Fitting 5/7 with 65668 rows


                                                                                

Fitting 6/7 with 65417 rows


                                                                                

Predicting 0/14 with 66251 rows


                                                                                

Predicting 1/14 with 66168 rows


                                                                                

Predicting 2/14 with 66128 rows


                                                                                

Predicting 3/14 with 65895 rows


                                                                                

Predicting 4/14 with 66274 rows


                                                                                

Predicting 5/14 with 65483 rows


                                                                                

Predicting 6/14 with 66154 rows


                                                                                

Predicting 7/14 with 65866 rows


                                                                                

Predicting 8/14 with 66242 rows


                                                                                

Predicting 9/14 with 66222 rows


                                                                                

Predicting 10/14 with 66073 rows


                                                                                

Predicting 11/14 with 66024 rows


                                                                                

Predicting 12/14 with 65759 rows


                                                                                

Predicting 13/14 with 66082 rows


                                                                                

Unnamed: 0,customer_ID,prediction
0,13ae6d2445b57e6450eb92d7d94836552d67dbdff6d4f8...,9.908358e-01
1,3ddb8e3505eb72e98e93c6cae6bda2d1dea4edec761da4...,6.252160e-01
2,84d568ba2f5702b882b6d41bb25e52b11c5f1e5c7353ba...,8.791473e-03
3,e8931cde8d349326c11d0b2b4a9be6c970e7ffb8faf550...,2.972499e-03
4,8ab2264ae98c13904380300fb2b2d75f8612189f8948e8...,0.000000e+00
...,...,...
66077,174487e698a6d249ea82b4ab49bd31fc27c8a82912f48d...,0.000000e+00
66078,7476ec92e97278969092d74a4598f4fc92bdffeb5b5835...,2.380968e-04
66079,a31c6de090acc8a6a67251c2e7a9886f01d963b3f60eb0...,1.455082e-03
66080,50fe2dee1a3480a9f375f481f9a8014c4acd9697876cfc...,2.058596e-11


In [6]:
spark.stop()