In [1]:
from pyspark.sql.types import *
from pyspark.sql.functions import *
import pandas as pd
from pyspark.mllib.stat import Statistics


In [2]:
user_log_schema=StructType([
        StructField("msno",StringType(),False),
        StructField("date",StringType(),False),
        StructField("num_25",IntegerType(),False),
        StructField("num_50",IntegerType(),False),
        StructField("num_985",IntegerType(),False),
        StructField("num_100",IntegerType(),False),
        StructField("num_unq",IntegerType(),False),
        StructField("total_secs",DoubleType(),False)       
    ])

train=spark.read.csv("/FileStore/tables/train_v2.csv",header=True,sep=",") # is needed for target variables
members=spark.read.csv("/FileStore/tables/members_v3.csv",header=True,sep=",") #members information
transactions=spark.read.csv("/FileStore/tables/transactions_v2.csv",header=True,sep=",") #information on transactions
user_logs=spark.read.csv("/FileStore/tables/user_logs_v2.csv",header=True,sep=",",schema=user_log_schema) #logs of users

In [3]:
#aggregation
agg_user_logs=user_logs.groupBy('msno').agg(count("msno").alias('no_transactions'),
                                 sum('num_25').alias('Total25'),sum('num_100').alias('Total100'), mean('num_unq').alias('UniqueSongs'),mean('total_secs').alias('TotalSecHeard')
                               )

agg_user_logs.show()

In [4]:

from pyspark.sql import functions
#joining datasets all together
z=agg_user_logs.join(transactions,"msno").join(train,"msno").join(members, "msno")

#Data format conversion
z = z.withColumn("payment_method_id", z["payment_method_id"].cast("integer"))
z = z.withColumn("payment_plan_days", z["payment_plan_days"].cast("integer"))
z = z.withColumn("plan_list_price", z["plan_list_price"].cast("integer"))
z = z.withColumn("actual_amount_paid", z["actual_amount_paid"].cast("integer"))
z = z.withColumn("is_auto_renew", z["is_auto_renew"].cast("integer"))
z = z.withColumn("transaction_date",date_format(from_unixtime(unix_timestamp('transaction_date', 'yyyyMMdd')),'yyyy-MM-dd').alias("transaction_date"))
z = z.withColumn("membership_expire_date",date_format(from_unixtime(unix_timestamp('membership_expire_date', 'yyyyMMdd')),'yyyy-MM-dd').alias("membership_expire_date"))
z = z.withColumn("registration_init_time",date_format(from_unixtime(unix_timestamp('registration_init_time', 'yyyyMMdd')),'yyyy-MM-dd').alias("registration_init_time"))
z = z.withColumn("is_cancel", z["is_cancel"].cast("integer"))
z = z.withColumn("is_churn", z["is_churn"].cast("integer"))
z = z.withColumn("bd", z["bd"].cast("integer"))
z = z.withColumn("registered_via", z["registered_via"].cast("integer"))
z = z.withColumn("DaysOnBoard",datediff(z['membership_expire_date'],z['registration_init_time']))

z.show()

In [5]:
#dropping unrequired columns: 
columns_to_drop = ['membership_expire_date', 'registration_init_time']
final_data = z.drop(*columns_to_drop)


# Removing age outliers.
final_data = final_data.where("bd between 15 and 100")
final_data.show()

In [6]:
#The significant number of users did not wanted to reveal their gender 

final_data.groupBy("gender").count().show()


#final_data = final_data.filter("gender != null")
#final_data.groupBy("gender").count().show()

In [7]:
#Finding correlation matrix
df = final_data
features = df.select("Total25","Total100","UniqueSongs","TotalSecHeard","payment_plan_days","plan_list_price","actual_amount_paid","DaysOnBoard","bd")
features1=features.rdd.map(lambda x: x[0:])
corr_mat=Statistics.corr(features1, method="pearson")
corr_df = pd.DataFrame(corr_mat)


In [8]:
corr_df.index = ("Total25","Total100","UniqueSongs","TotalSecHeard","payment_plan_days","plan_list_price","actual_amount_paid","DaysOnBoard","bd")
corr_df.columns= ("Total25","Total100","UniqueSongs","TotalSecHeard","payment_plan_days","plan_list_price","actual_amount_paid","DaysOnBoard","bd")

In [9]:
print(corr_df.to_string())

In [10]:
from pyspark.sql import Row
fdataframes = sqlContext.createDataFrame(sc.parallelize([
  Row(product="male", is_churn=1,Count=58),
  Row(category="gender", product="female", is_churn=1,Count=9),
  Row(category="gender", product="male", is_churn=0,Count=87),
  Row(category="gender", product="female", is_churn=1,Count=91)
]))
fdataframes.registerTempTable("test_sales_table")
display(sqlContext.sql("select * from test_sales_table"))


Count,category,is_churn,product
58,gender,1,male
9,gender,1,female
87,gender,0,male
91,gender,1,female
