In [0]:
from pyspark.sql import functions as F

# Read the table
df_fraud_abt = spark.table("vr_demo.fraud.fraud_abt")

# Balance the sample by the column fraud_report with 10000 records
fraud_count = df_fraud_abt.filter(F.col('fraud_report') == 'Y').count()
non_fraud_count = df_fraud_abt.filter(F.col('fraud_report') == 'N').count()

fraud_sample_size = int(10000 * (fraud_count / (fraud_count + non_fraud_count)))
non_fraud_sample_size = 10000 - fraud_sample_size

fraud_sample = df_fraud_abt.filter(F.col('fraud_report') == 'Y').sample(False, fraud_sample_size / fraud_count)
non_fraud_sample = df_fraud_abt.filter(F.col('fraud_report') == 'N').sample(False, non_fraud_sample_size / non_fraud_count)

balanced_sample = fraud_sample.union(non_fraud_sample)

balanced_sample.write.mode('overwrite').saveAsTable("vr_demo.fraud.fraud_abt_sample")

In [0]:
%sql
create table vr_demo.fraud.fraud_abt_sample2 as 
select
  * except (fraud_report),
  case when fraud_report = 'Y' then 1 else 0 end as fraud_report
from vr_demo.fraud.fraud_abt_sample

# V2

In [0]:
%sql
create database vr_demo.fraud2

In [0]:
from pyspark.sql import functions as F

# Read the table
df_fraud_abt = spark.table("vr_demo.fraud.visits")

# Balance the sample by the column fraud_report with 10000 records
fraud_count = df_fraud_abt.filter(F.col('fraud_report') == 'Y').count()
non_fraud_count = df_fraud_abt.filter(F.col('fraud_report') == 'N').count()

fraud_sample_size = int(10000 * (fraud_count / (fraud_count + non_fraud_count)))
non_fraud_sample_size = 10000 - fraud_sample_size

fraud_sample = df_fraud_abt.filter(F.col('fraud_report') == 'Y').sample(False, fraud_sample_size / fraud_count)
non_fraud_sample = df_fraud_abt.filter(F.col('fraud_report') == 'N').sample(False, non_fraud_sample_size / non_fraud_count)

balanced_sample = fraud_sample.union(non_fraud_sample)

balanced_sample.write.mode('overwrite').saveAsTable("vr_demo.fraud2.visits")

In [0]:
%sql
create table vr_demo.fraud2.customers as
select c.* from vr_demo.fraud.customers c
inner join vr_demo.fraud2.visits v
on c.customer_id = v.customer_id

In [0]:
%sql
create table vr_demo.fraud2.locations as
select c.* from vr_demo.fraud.locations c
inner join vr_demo.fraud2.visits v
on c.atm_id = v.atm_id

In [0]:
%sql
create volume vr_demo.fraud2.export

In [0]:
spark.table('vr_demo.fraud2.visits').repartition(1).write.format('parquet').save('/Volumes/vr_demo/fraud2/export/visits')

In [0]:
spark.table('vr_demo.fraud2.customers').repartition(1).write.format('parquet').save('/Volumes/vr_demo/fraud2/export/customers')

In [0]:
spark.table('vr_demo.fraud2.locations').repartition(1).write.format('parquet').save('/Volumes/vr_demo/fraud2/export/locations')