forked from sfc-gh-bhess/lab_data_api_python
/
snowpark.py
87 lines (78 loc) · 3.33 KB
/
snowpark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import datetime
import os
from flask import Blueprint, request, abort, make_response, jsonify
# Make the Snowflake connection
from snowflake.snowpark import Session
import snowflake.snowpark.functions as f
def connect() -> Session:
if os.path.isfile("/snowflake/session/token"):
creds = {
'host': os.getenv('SNOWFLAKE_HOST'),
'port': os.getenv('SNOWFLAKE_PORT'),
'protocol': "https",
'account': os.getenv('SNOWFLAKE_ACCOUNT'),
'authenticator': "oauth",
'token': open('/snowflake/session/token', 'r').read(),
'warehouse': os.getenv('SNOWFLAKE_WAREHOUSE'),
'database': os.getenv('SNOWFLAKE_DATABASE'),
'schema': os.getenv('SNOWFLAKE_SCHEMA'),
'client_session_keep_alive': True
}
else:
creds = {
'account': os.getenv('SNOWFLAKE_ACCOUNT'),
'user': os.getenv('SNOWFLAKE_USER'),
'password': os.getenv('SNOWFLAKE_PASSWORD'),
'warehouse': os.getenv('SNOWFLAKE_WAREHOUSE'),
'database': os.getenv('SNOWFLAKE_DATABASE'),
'schema': os.getenv('SNOWFLAKE_SCHEMA'),
'client_session_keep_alive': True
}
return Session.builder.configs(creds).create()
session = connect()
# Make the API endpoints
snowpark = Blueprint('snowpark', __name__)
## Top 10 customers in date range
dateformat = '%Y-%m-%d'
@snowpark.route('/customers/top10')
def customers_top10():
# Validate arguments
sdt_str = request.args.get('start_range') or '1995-01-01'
edt_str = request.args.get('end_range') or '1995-03-31'
try:
sdt = datetime.datetime.strptime(sdt_str, dateformat)
edt = datetime.datetime.strptime(edt_str, dateformat)
except:
abort(400, "Invalid start and/or end dates.")
try:
df = session.table('snowflake_sample_data.tpch_sf10.orders') \
.filter((f.col('O_ORDERDATE') >= sdt) & (f.col('O_ORDERDATE') <= edt)) \
.group_by(f.col('O_CUSTKEY')) \
.agg(f.sum(f.col('O_TOTALPRICE')).alias('SUM_TOTALPRICE')) \
.sort(f.col('SUM_TOTALPRICE').desc()) \
.limit(10)
return make_response(jsonify([x.as_dict() for x in df.to_local_iterator()]))
except:
abort(500, "Error reading from Snowflake. Check the logs for details.")
## Monthly sales for a clerk in a year
@snowpark.route('/clerk/<clerkid>/yearly_sales/<year>')
def clerk_montly_sales(clerkid, year):
# Validate arguments
try:
year_int = int(year)
except:
abort(400, "Invalid year.")
if not clerkid.isdigit():
abort(400, "Clerk ID can only contain numbers.")
clerkid_str = f"Clerk#{clerkid}"
try:
df = session.table('snowflake_sample_data.tpch_sf10.orders') \
.filter(f.year(f.col('O_ORDERDATE')) == year_int) \
.filter(f.col('O_CLERK') == clerkid_str) \
.with_column('MONTH', f.month(f.col('O_ORDERDATE'))) \
.groupBy(f.col('O_CLERK'), f.col('MONTH')) \
.agg(f.sum(f.col('O_TOTALPRICE')).alias('SUM_TOTALPRICE')) \
.sort(f.col('O_CLERK'), f.col('MONTH'))
return make_response(jsonify([x.as_dict() for x in df.to_local_iterator()]))
except:
abort(500, "Error reading from Snowflake. Check the logs for details.")