# Exploring PyStarburst with Starburst Galaxy and the TPC-H dataset

## Getting started

### Sign up for a Galaxy account & setup the sample catalog

You'll need a Starburst Galaxy account, https://www.starburst.io/platform/starburst-galaxy/start/, configured with the TPC-H catalog, https://docs.starburst.io/starburst-galaxy/working-with-data/create-catalogs/sample-data-sets/tpch.html.

### Load the Dataframe (DF) API 

Pull up https://pystarburst.eng.starburstdata.net/ in a browser window.

## Explore via code examples

Lets go!

In [None]:
#
# Install the library
#

%pip install pystarburst

In [None]:
#
# Define connection properties
#  get the host and other information from the cluster list
#

import getpass

host = input("Host name")
username = input("User name")
password = getpass.getpass("Password")

In [None]:
#
# Import dependencies
#

from pystarburst import Session
from pystarburst import functions as F
from pystarburst.functions import *
from pystarburst.window import Window as W

import trino

session_properties = {
    "host":host,
    "port": 443,
    # Needed for https secured clusters
    "http_scheme": "https",
    # Setup authentication through login or password or any other supported authentication methods
    # See docs: https://github.com/trinodb/trino-python-client#authentication-mechanisms
    "auth": trino.auth.BasicAuthentication(username, password)
}

session = Session.builder.configs(session_properties).create()

In [None]:
#
# Validate connectivity to the cluster
#

session.sql("select 1 as b").collect()

In [None]:
#
# Ensure we have access to the TPC-H dataset by listing the tables in the tiny schema
#  https://pystarburst.eng.starburstdata.net/session.html#pystarburst.session.Session.sql
#

session.sql("show tables from tpch.tiny").collect()

In [None]:
#
# What columns make up the lineitem table
#  https://pystarburst.eng.starburstdata.net/session.html#pystarburst.session.Session.table
#  https://pystarburst.eng.starburstdata.net/dataframe.html#pystarburst.dataframe.DataFrame.schema
#

# Create a Dataframe for the lineitem table
tli = session.table("tpch.tiny.lineitem")

# Show the columns
print(tli.schema)

In [None]:
#
# That was pretty busy, let's try that again...
#  loop through the fields of fhe schema and print them out
#

for field in tli.schema.fields:
    print(field.name +" , "+str(field.datatype))

In [None]:
#
# Show the data
#  a is the Dataframe (DF) that we defined early
#  the show() command will list out up to 10 rows
#    pass it an argument for something longer
#  https://pystarburst.eng.starburstdata.net/dataframe.html#pystarburst.dataframe.DataFrame.show
#

tli.show()

In [None]:
#
# That was pretty busy, let's try that again...
#  use the select method on an existing DF identifying just the columns to keep
#   https://pystarburst.eng.starburstdata.net/dataframe.html#pystarburst.dataframe.DataFrame.select

tli_projected = tli.select("orderkey", "linenumber", "quantity", "extendedprice", "linestatus")
tli_projected.show()

In [None]:
#
# Add a simple sort
#  https://pystarburst.eng.starburstdata.net/dataframe.html#pystarburst.dataframe.DataFrame.sort
#

tli_projected.sort("orderkey").show()

In [None]:
#
# Multiple column sort
#

tli_projected.sort("orderkey", "linenumber").show()

In [None]:
#
# Filter some of the data
#  https://pystarburst.eng.starburstdata.net/dataframe.html#pystarburst.dataframe.DataFrame.filter
#

# pfs = projected & filtered & sorted
tli_pfs = tli_projected.filter("orderkey <= 5").sort("orderkey", "linenumber")
tli_pfs.show()

In [None]:
#
# Are there no lineitem rows for orderkeys 4 or 5?
#  there are; it is just the default number of rows from show()
#   https://pystarburst.eng.starburstdata.net/dataframe.html#pystarburst.dataframe.DataFrame.count

# How many rows are present?
print(tli_pfs.count())

In [None]:
#
# Use limit() to only have a specific number of rows
#  https://pystarburst.eng.starburstdata.net/dataframe.html#pystarburst.dataframe.DataFrame.limit
# 

print(tli_pfs.limit(5).count())

In [None]:
#
# To see them all (or just more than 10) add an argument to show()
#

tli_pfs.show(50)

In [None]:
#
# The Dataframe API also let's you just write SQL, too
# 

tli_pfs_sql = session.sql(" \
    SELECT orderkey, linenumber, quantity, extendedprice, linestatus \
      FROM tpch.tiny.lineitem \
     WHERE orderkey <= 5 \
     ORDER BY orderkey, linenumber")
tli_pfs_sql.show(50)

In [None]:
#
# You can also mix/n/match with SQL and the API
# 

session.sql("SELECT orderkey, linenumber, quantity, extendedprice, linestatus \
               FROM tpch.tiny.lineitem") \
     .filter("orderkey <= 5").sort("orderkey", "linenumber").show(50)

In [None]:
#
# Let's verify that the DF created by the API's methods is the same 
#  as the DF created by writing SQL
#   https://pystarburst.eng.starburstdata.net/dataframe.html#pystarburst.dataframe.DataFrame.except_
#    returns a new DF that contains all the rows from the current DF except 
#    for the rows that also appear in the other DataFrame
#    (THERE SHOULD BE NO ROWS PRESENT AS THEY ARE THE SAME)

tli_pfs.except_(tli_pfs_sql).show()

In [None]:
#
# You saw that select() was a way to specifically call out the columns you want
#  from an existing DF, but what if there was a bunch of columns and you wanted
#  almost all of them?
#
# The drop() method is the reverse; you identify the columns you'd like to eliminate
#  https://pystarburst.eng.starburstdata.net/dataframe.html#pystarburst.dataframe.DataFrame.drop
#

tli.drop("comment", "shipmode", "shipinstruct").show()

In [None]:
#
# We saw that you can order by multiple columns already.  When you need to have
#  multiple predicates, just chain the filter() methods back to back
#

tli.filter("discount > 0.05") \
   .filter("returnflag = 'A'") \
   .filter("suppkey IN (55, 60, 88)") \
   .filter("shipinstruct LIKE 'TAKE BACK%'") \
   .select("orderkey", "linenumber", "suppkey", "discount", "shipinstruct") \
   .sort("discount", "suppkey", "orderkey", "linenumber").show()

In [None]:
#
# Use standard SQL to see what the shipmode options are
#

session.sql("SELECT DISTINCT(shipmode) FROM tpch.tiny.lineitem").show()

In [None]:
#
# You can get the same thing from the API
#  https://pystarburst.eng.starburstdata.net/dataframe.html#pystarburst.dataframe.DataFrame.distinct
#

tli.select("shipmode").distinct().show()

In [None]:
#
# Use standard SQL to find out how many lineitems for each shipmode
#

session.sql(" \
     SELECT shipmode, count() \
       FROM tpch.tiny.lineitem \
      GROUP BY shipmode \
      ORDER BY shipmode").show()

In [None]:
#
# You can get the same thing from the API
#  https://pystarburst.eng.starburstdata.net/dataframe.html#pystarburst.dataframe.DataFrame.groupBy
#

tli.group_by("shipmode").count().sort("shipmode").show()

In [None]:
#
# Oh... the DF API almost always has at least 2 ways to perform the same action!
#  here's TWO more ways for this example
#   https://pystarburst.eng.starburstdata.net/dataframe_grouping_functions.html#pystarburst.relational_grouped_dataframe.RelationalGroupedDataFrame
#

tli.group_by("shipmode").agg((col("*"), "count")).sort("shipmode").show()

tli.group_by("shipmode").function("count")("*").sort("shipmode").show()

In [None]:
#
# You can surely calculate multiple aggregate functions for a single group_by
#

tli.group_by("shipmode").agg( \
     (col("shipmode"), "count"), \
     (col("quantity"), "sum"), \
     (col("extendedprice"), "avg"), \
     (col("discount"), "max") \
).sort("count(shipmode)", ascending=False).show()

In [None]:
#
# Probably no surprise that this is the equivalent SQL to the last cell
#

session.sql(" \
     SELECT shipmode, count(shipmode), sum(quantity), avg(extendedprice), max(discount) \
       FROM tpch.tiny.lineitem \
      GROUP BY shipmode \
      ORDER BY 2 DESC").show()

In [None]:
#
# Show some basic statistics for all columns
#  https://pystarburst.eng.starburstdata.net/dataframe.html#pystarburst.dataframe.DataFrame.describe
#

tli.describe().show()

In [None]:
#
# Yep, that was busy -- let's just look at a few fields
#

tli.describe().select("summary", "quantity", "extendedprice", "discount", "tax").show()

In [None]:
#
# Exercise some of the Trino string functions
#  https://pystarburst.eng.starburstdata.net/dataframe_functions.htm
#

str_test1 = session.sql("SELECT shipmode, shipinstruct FROM tpch.tiny.lineitem") \
     .withColumn("ship_dets", concat_ws(lit(" > "), "shipmode", "shipinstruct")) \
     .withColumn("ship_dets_lc", lower("ship_dets"))
str_test1.show()

str_test2 = session.table("tpch.tiny.lineitem").select("comment") \
     .withColumn("unusual_comment", starts_with("comment", lit("unusual"))) \
     .filter("unusual_comment = true") \
     .withColumn("comment_mod", replace("comment", lit("unusual"), lit("WEIRD")))
str_test2.show()

In [None]:
# 
# Let's join some tables
#  https://pystarburst.eng.starburstdata.net/dataframe.html#pystarburst.dataframe.DataFrame.join
#

ordersDF = session.table("tpch.tiny.orders")
lineitemDF = session.table("tpch.tiny.lineitem").rename("orderkey", "li_ok")

joinedDF = lineitemDF.join(ordersDF, ordersDF.orderkey == lineitemDF.li_ok) \
     .select("orderkey", "linenumber", "extendedprice", "linestatus", "custkey") \
     .sort("orderkey", "linenumber")
joinedDF.show()

In [None]:
#
# Let's join 4 tables together and determine the average lineitem price by nation name
#  note: renaming the (logical) FK col names to aid in auto-renaming confusion that occurs
# 

smaller_orders_lineitems = joinedDF.drop("linenumber", "linestatus") \
     .rename("custkey", "sol_ck").filter("orderkey BETWEEN 100 AND 199")

customerDF = session.sql("SELECT custkey, nationkey AS c_nk FROM tpch.tiny.customer") 

o_li_c = smaller_orders_lineitems.join(customerDF, \
                smaller_orders_lineitems.sol_ck == customerDF.custkey)

nationDF = session.table("tpch.tiny.nation").drop("regionkey").drop("comment")

nation_avg_price = o_li_c.join(nationDF, o_li_c.c_nk == nationDF.nationkey) \
     .rename("name", "nation_name") \
     .select("nation_name", "extendedprice") \
     .group_by("nation_name").avg("extendedprice") \
     .with_column("avg_price", round("avg(extendedprice)", lit(2))) \
     .select("nation_name", "avg_price") \
     .sort("avg_price", ascending=False)
nation_avg_price.show()

In [None]:
#
# SQL version of the above cell, plus verifying the results are identical
#  by showing the except() output is empty
# 

nation_avg_price_sql = session.sql(" \
     SELECT n.name AS nation_name, \
            ROUND(AVG(li.extendedprice), 2) AS avg_price \
       FROM tpch.tiny.lineitem li \
       JOIN tpch.tiny.orders o   ON (li.orderkey = o.orderkey) \
       JOIN tpch.tiny.customer c ON (o.custkey = c.custkey) \
       JOIN tpch.tiny.nation n   ON (c.nationkey = n.nationkey) \
      WHERE o.orderkey BETWEEN 100 and 199 \
      GROUP BY n.name \
      ORDER BY avg_price DESC")
nation_avg_price_sql.show()

nation_avg_price.except_(nation_avg_price_sql).show()

## You definitely have some **optionality** with the DataFrame API.

## That's enough examples for this notebook :)