In [None]:
!git clone https://github.com/Snowflake-Labs/sfguide-getting-started-snowpark-python
%cd sfguide-getting-started-snowpark-python/customer-churn-prediction

In [None]:
!pip install jupyter==1.0.0
!pip install numpy
!pip install "snowflake-snowpark-python[pandas]"

In [4]:
%%writefile config.py
snowflake_conn_prop = {
   "account": "",
   "user": "",
   "password": "",
   "role": "ACCOUNTADMIN",
   "database": "snowpark_quickstart",
   "schema": "TELCO",
   "warehouse": "sp_qs_wh",
}

Overwriting config.py


## Load Data Using Snowpark Python Client API

- Establish the Snowpark Python session
- Create the database, schema, and warehouses
- Load raw parquet data into Snowflake using Snowpark Python

In [None]:
from snowflake.snowpark.session import Session
from snowflake.snowpark import functions as F
from snowflake.snowpark.types import *

import pandas as pd

from sklearn import linear_model

import matplotlib.pyplot as plt

%matplotlib inline
import datetime as dt
import numpy as np
import seaborn as sns

#Snowflake connection info is saved in config.py
from config import snowflake_conn_prop

# lets import some tranformations functions
from snowflake.snowpark.functions import udf, col, lit, translate, is_null, iff

In [6]:
session = Session.builder.configs(snowflake_conn_prop).create()
session.sql("use role accountadmin").collect()
session.sql("create database if not exists  {}".format(snowflake_conn_prop['database'])).collect()
session.sql("use database {}".format(snowflake_conn_prop['database'])).collect()
session.sql("create schema if not exists {}".format(snowflake_conn_prop['schema'])).collect()
session.sql("use schema {}".format(snowflake_conn_prop['schema'])).collect()
session.sql("create or replace warehouse {} with \
                WAREHOUSE_SIZE = XSMALL \
                AUTO_SUSPEND = 120 \
                AUTO_RESUME = TRUE".format(snowflake_conn_prop['warehouse'])).collect()
session.sql("use warehouse {}".format(snowflake_conn_prop['warehouse']))
print(session.sql('select current_warehouse(), current_database(), current_schema()').collect())

[Row(CURRENT_WAREHOUSE()='SP_QS_WH', CURRENT_DATABASE()='SNOWPARK_QUICKSTART', CURRENT_SCHEMA()='TELCO')]


We are using a sample telecommunications dataset, a parquet file containing ~100k rows which is bundled with the notebook. Using the session object from Snowpark, we upload that file as a table in our warehouse as shown below.

In [7]:
filename = "raw_telco_data.parquet"
stagename = "rawdata"
rawtable = "RAW_PARQUET_DATA"

In [8]:
session.sql(f"create or replace stage {stagename} DIRECTORY = (ENABLE = TRUE);").collect()
session.file.put(filename,stagename)

[PutResult(source='raw_telco_data.parquet', target='raw_telco_data.parquet', source_size=3037540, target_size=3037552, source_compression='PARQUET', target_compression='PARQUET', status='UPLOADED', message='')]

In [9]:
session.sql("CREATE OR REPLACE FILE FORMAT MY_PARQUET_FORMAT TYPE = PARQUET;").collect()

session.sql(f"CREATE OR REPLACE \
            TABLE {rawtable} USING TEMPLATE ( \
                SELECT ARRAY_AGG(OBJECT_CONSTRUCT(*)) \
                FROM \
                    TABLE( INFER_SCHEMA( \
                    LOCATION => '@{stagename}/{filename}', \
                    FILE_FORMAT => 'MY_PARQUET_FORMAT' \
                    ) \
                ) \
            );  ").collect()

[Row(status='Table RAW_PARQUET_DATA successfully created.')]

In [10]:
dfClear = session.table(rawtable).delete()

In [11]:
dfRaw = session.read.option("compression","snappy").parquet(f"@{stagename}/{filename}")
dfRaw.copy_into_table(rawtable,FORCE= True)

[Row(file='rawdata/raw_telco_data.parquet', status='LOADED', rows_parsed=100000, rows_loaded=100000, error_limit=1, errors_seen=0, first_error=None, first_error_line=None, first_error_character=None, first_error_column_name=None)]

Let’s check out our new table

In [12]:
dfR = session.table(rawtable).sample(n=5)
dfR.toPandas()

Unnamed: 0,COUNTRY,CITY,PHONE SERVICE,MULTIPLE LINES,LATITUDE,ONLINE SECURITY,SENIOR CITIZEN,MONTHLY CHARGES,STREAMING MOVIES,PAYMENT METHOD,...,CHURN SCORE,GENDER,LONGITUDE,ONLINE BACKUP,TOTAL CHARGES,CLTV,CHURN REASON,DEVICE PROTECTION,STATE,ZIP CODE
0,United States,Wendel,Yes,No,40.345949,No,False,70.75,Yes,Credit card (automatic),...,1,Female,-120.081187,No,450.8,2248,Attitude of support person,No,California,96136
1,United States,Wilton,Yes,Yes,38.392559,Yes,False,59.45,No,Credit card (automatic),...,0,Female,-121.225093,Yes,2136.9,3682,do not know,No,California,95693
2,United States,Sylmar,Yes,No,34.321621,Yes,True,85.7,No,Bank transfer (automatic),...,0,Female,-118.399841,Yes,3778.1,5397,do not know,No,California,91342
3,United States,Adin,Yes,No,41.171578,No,False,82.3,No,Electronic check,...,1,Female,-120.913161,Yes,82.3,4278,Competitor had better devices,Yes,California,96006
4,United States,San Jose,Yes,Yes,37.371862,Yes,False,85.4,Yes,Bank transfer (automatic),...,0,Male,-121.860349,No,3297.0,5292,do not know,Yes,California,95133


The Snowpark API provides programming language constructs for building SQL statements. It's a new developer experience which enables us to build code in :-

<b><li>  Language of our choice </li></b>
<b><li> Tool of our choice and </li></b>
<b><li> Lazy execution to prevent multiple network hops to server </li></b>

Once the customer data is available in the RAW schema, we can use snowpark to create dimensions and fact tables. We will use the RAW_PARQUET table to create following tables -
    
<li> DEMOGRAPHICS </li>
<li> LOCATION </li>
<li> STATUS </li>
<li> SERVICES </li>

We will also transform and clean the data using Snowpark dataframe API

In [14]:
dfR = session.table(rawtable)

In [15]:
dfDemographics = dfR.select(col("CUSTOMERID"),
                             col("COUNT").alias("COUNT"),
                             translate(col("GENDER"),lit("NULL"),lit("Male")).alias("GENDER"),
                             col("SENIOR CITIZEN").alias("SENIORCITIZEN"),
                             col("PARTNER"),
                             col("DEPENDENTS")          
                            )


dfDemographics.write.mode('overwrite').saveAsTable('DEMOGRAPHICS')
dfDemographics.show()


----------------------------------------------------------------------------------
|"CUSTOMERID"  |"COUNT"  |"GENDER"  |"SENIORCITIZEN"  |"PARTNER"  |"DEPENDENTS"  |
----------------------------------------------------------------------------------
|7090-ZyCMx    |1        |Female    |False            |False      |True          |
|1364-wJXMS    |1        |Female    |False            |False      |True          |
|6564-sLgIC    |1        |Male      |True             |False      |True          |
|7853-2xheR    |1        |Male      |False            |False      |True          |
|8457-E9FuW    |1        |Female    |False            |False      |True          |
|5718-ykxBT    |1        |Male      |False            |False      |True          |
|7092-gCJX5    |1        |Male      |False            |False      |False         |
|8249-GOs7s    |1        |Male      |True             |False      |False         |
|9445-kPPEc    |1        |Male      |False            |False      |False         |
|158

In [16]:
dfLocation = dfR.select(col("CUSTOMERID"),
                         col("COUNTRY").name("COUNTRY"),
                         col("STATE").name("STATE"),
                         col("CITY").name("CITY"),
                         translate(col("ZIP CODE"),lit("NULL"),lit(0)).name("ZIPCODE"),
                         col("LAT LONG").name("LATLONG"),
                         col("LATITUDE").name("LATITUDE"),
                         col("LONGITUDE").name("LONGITUDE")       
                        )

dfLocation.write.mode('overwrite').saveAsTable('LOCATION')
dfLocation.show()

-------------------------------------------------------------------------------------------------------------------------------
|"CUSTOMERID"  |"COUNTRY"      |"STATE"     |"CITY"           |"ZIPCODE"  |"LATLONG"               |"LATITUDE"  |"LONGITUDE"  |
-------------------------------------------------------------------------------------------------------------------------------
|7090-ZyCMx    |United States  |California  |Los Angeles      |90005      |34.059281, -118.30742   |34.059281   |-118.307420  |
|1364-wJXMS    |United States  |California  |Los Angeles      |90006      |34.048013, -118.293953  |34.048013   |-118.293953  |
|6564-sLgIC    |United States  |California  |Los Angeles      |90065      |34.108833, -118.229715  |34.108833   |-118.229715  |
|7853-2xheR    |United States  |California  |La Habra         |90631      |33.940619, -117.9513    |33.940619   |-117.951300  |
|8457-E9FuW    |United States  |California  |Glendale         |91206      |34.162515, -118.203869  |34.1

In [17]:
dfServices = dfR.select(col("CUSTOMERID"),
                       col("TENURE MONTHS").name("TENUREMONTHS"),
                       iff(is_null(col("PHONE SERVICE")),lit('N'),col("PHONE SERVICE")).name("PHONESERVICE"),
                       iff(is_null(col("MULTIPLE LINES")),lit("No"),col("MULTIPLE LINES")).name("MULTIPLELINES"),
                       iff(is_null(col("INTERNET SERVICE")),lit("No"),col("INTERNET SERVICE")).name("INTERNETSERVICE"),
                       iff(is_null(col("ONLINE SECURITY")),lit("No"),col("ONLINE SECURITY")).name("ONLINESECURITY"),
                       iff(is_null(col("ONLINE BACKUP")),lit("No"),col("ONLINE BACKUP")).name("ONLINEBACKUP"),
                       iff(is_null(col("DEVICE PROTECTION")),lit("No"),col("DEVICE PROTECTION")).name("DEVICEPROTECTION"),
                       iff(is_null(col("TECH SUPPORT")),lit('N'),col("TECH SUPPORT")).name("TECHSUPPORT"),
                       iff(is_null(col("STREAMING TV")),lit("No"),col("STREAMING TV")).name("STREAMINGTV"),
                       iff(is_null(col("STREAMING MOVIES")),lit("No"),col("STREAMING MOVIES")).name("STREAMINGMOVIES"),
                       iff(is_null(col("CONTRACT")),lit("Month-to-month"),col("CONTRACT")).name("CONTRACT"),
                       iff(is_null(col("PAPERLESS BILLING")),lit('Y'),col("PAPERLESS BILLING")).name("PAPERLESSBILLING"),
                       iff(is_null(col("PAYMENT METHOD")),lit("Mailed check"),col("PAYMENT METHOD")).name("PAYMENTMETHOD"),
                       col("MONTHLY CHARGES").name("MONTHLYCHARGES"),
                       col("TOTAL CHARGES").name("TOTALCHARGES"),
                       col("CHURN VALUE").name("CHURNVALUE")        

                      )

dfServices.write.mode('overwrite').saveAsTable('SERVICES')
dfServices.show()

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"CUSTOMERID"  |"TENUREMONTHS"  |"PHONESERVICE"  |"MULTIPLELINES"  |"INTERNETSERVICE"  |"ONLINESECURITY"     |"ONLINEBACKUP"       |"DEVICEPROTECTION"   |"TECHSUPPORT"        |"STREAMINGTV"        |"STREAMINGMOVIES"    |"CONTRACT"      |"PAPERLESSBILLING"  |"PAYMENTMETHOD"   |"MONTHLYCHARGES"  |"TOTALCHARGES"  |"CHURNVALUE"  |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|7090-ZyCMx  

In [18]:
dfStatus = dfR.select(col("CUSTOMERID"),
                    iff(is_null(col("CHURN LABEL")),lit('N'),col("CHURN LABEL")).name("CHURNLABEL"),
                    col("CHURN VALUE").name("CHURNVALUE"),
                    col("CHURN SCORE").name("CHURNSCORE"),
                    col("CLTV").name("CLTV"),
                    iff(is_null(col("CHURN REASON")),lit("do not know"),col("CHURN REASON")).name("CHURNREASON")          
                    )

dfStatus.write.mode('overwrite').saveAsTable('STATUS')
dfStatus.show()

-----------------------------------------------------------------------------------------------------------------
|"CUSTOMERID"  |"CHURNLABEL"  |"CHURNVALUE"  |"CHURNSCORE"  |"CLTV"  |"CHURNREASON"                             |
-----------------------------------------------------------------------------------------------------------------
|7090-ZyCMx    |true          |1.0           |1             |2701    |Moved                                     |
|1364-wJXMS    |true          |1.0           |1             |5372    |Moved                                     |
|6564-sLgIC    |true          |1.0           |1             |3179    |Competitor made better offer              |
|7853-2xheR    |true          |1.0           |1             |4415    |Product dissatisfaction                   |
|8457-E9FuW    |true          |1.0           |1             |5142    |Price too high                            |
|5718-ykxBT    |true          |1.0           |1             |2484    |Poor expertise of 

In [19]:
# Lets run a query for quick sanity check
# This Query will show us the total revenue by city and contract term

dfLoc = session.table("LOCATION")
dfServ = session.table("SERVICES")

dfJoin = dfLoc.join(dfServ,dfLoc.col("CUSTOMERID") == dfServ.col("CUSTOMERID"))

dfResult = dfJoin.select(col("CITY"),
                         col("CONTRACT"),
                         col("TOTALCHARGES")).groupBy(col("CITY"),col("CONTRACT")).sum(col("TOTALCHARGES"))

dfResult.show()

----------------------------------------------------------
|"CITY"           |"CONTRACT"      |"SUM(TOTALCHARGES)"  |
----------------------------------------------------------
|Los Angeles      |Month-to-month  |3931004.7            |
|La Habra         |Month-to-month  |6828.35              |
|Glendale         |Month-to-month  |460483.05            |
|Burbank          |Month-to-month  |378354.4             |
|Ontario          |Two year        |57487.6              |
|Alpine           |Month-to-month  |69186.04999999999    |
|Borrego Springs  |Month-to-month  |94737.0              |
|Oceanside        |Month-to-month  |49559.5              |
|Niland           |Month-to-month  |24946.0              |
|San Bernardino   |Month-to-month  |253583.3             |
----------------------------------------------------------



In [20]:
### Let's create a view for data science team to begin data analysis
### To do so, join up the `DEMOGRAPHICS` and `SERVICES` tables based on `CUSTOMERID`

dfD = session.table('DEMOGRAPHICS')
dfS = session.table('SERVICES')
dfJ = dfD.join(dfS, using_columns='CUSTOMERID', join_type = 'left')
dfJ.select(col('GENDER'),
              col('SENIORCITIZEN'),
              col('PARTNER'),
              col('DEPENDENTS'),
              col('MULTIPLELINES'),
              col('INTERNETSERVICE'),
              col('ONLINESECURITY'),
              col('ONLINEBACKUP'),
              col('DEVICEPROTECTION'),
              col('TECHSUPPORT'),
              col('STREAMINGTV'),
              col('STREAMINGMOVIES'),
              col('CONTRACT'),
              col('PAPERLESSBILLING'),
              col('PAYMENTMETHOD'),
              col('TENUREMONTHS'),
              col('MONTHLYCHARGES'),
              col('TOTALCHARGES'),
              col('CHURNVALUE'))
dfJ.create_or_replace_view('TRAIN_DATASET')

[Row(status='View TRAIN_DATASET successfully created.')]

In [21]:
%%time

raw = session.table('TRAIN_DATASET').sample(n = 20)
data = raw.toPandas()

CPU times: user 16.1 ms, sys: 8.51 ms, total: 24.6 ms
Wall time: 825 ms


In [25]:
%cd ../..
!rm -rf sfguide-getting-started-snowpark-python