This Notebook will compare different ways to query csv data from S3 from SageMaker.
This was tested with SageMaker SDK v2.160.0.

In Amazon SageMaker Studio, use "Data Science" image and ml.t3.medium instance. 

Please add also to the Studio user-profile role the correct policy in order to have the permission to query Athena.
https://docs.aws.amazon.com/sagemaker/latest/dg/security-iam-awsmanpol.html

The csv file is 1.8 MB in size and will fetched from the sklearn datasets.

The 3 options on quering the data from S3 detailed here are:
1. PyAthena library (check here more info: https://pypi.org/project/pyathena/)
2. Data Wrangler library
3. Direct query the data from S3

In [4]:
%pip install -qU sagemaker PyAthena awswrangler
#%pip install -qU 'sagemaker>=2.15.0' 'PyAthena==1.10.7' #'awswrangler==2.14.0'

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [11]:
import io
import boto3
import sagemaker
import json
from sagemaker import get_execution_role
import os
import sys
from sklearn.datasets import fetch_california_housing #20640 samples
from sklearn.datasets import fetch_kddcup99 #4898431 samples
import pandas as pd
from botocore.exceptions import ClientError

# Get region
session = boto3.session.Session()
region_name = session.region_name

# Get SageMaker session & default S3 bucket
sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()  # replace with your own bucket name if you have one
iam = boto3.client("iam")
s3 = sagemaker_session.boto_session.resource("s3")
role = sagemaker.get_execution_role()
role_name = role.split("/")[-1]
prefix = "data/tabular/california_housing"
prefix2 = "data/tabular/kddcup99"
filename = "california_housing.csv"
filename2 = "mnist_train.csv"

#Check the version for each libs
print("SageMaker version: " +sagemaker.__version__)

SageMaker version: 2.160.0


In [7]:
# helper functions to upload data to s3
def write_to_s3(filename, bucket, prefix):
    # put one file in a separate folder. This is helpful if you read and prepare data with Athena
    filename_key = filename.split(".")[0]
    key = "{}/{}/{}".format(prefix, filename_key, filename)
    return s3.Bucket(bucket).upload_file(filename, key)


def upload_to_s3(bucket, prefix, filename):
    url = "s3://{}/{}/{}".format(bucket, prefix, filename)
    print("Writing to {}".format(url))
    write_to_s3(filename, bucket, prefix)

In [8]:
tabular_data = fetch_california_housing()
tabular_data_full = pd.DataFrame(tabular_data.data, columns=tabular_data.feature_names)
tabular_data_full["target"] = pd.DataFrame(tabular_data.target)
tabular_data_full.to_csv("california_housing.csv", index=False)

upload_to_s3(bucket, "data/tabular", filename)

Writing to s3://sagemaker-eu-west-1-174976546647/data/tabular/california_housing.csv


In [12]:
#tabular_data2 = fetch_kddcup99()
#tabular_data_full2 = pd.DataFrame(tabular_data2.data, columns=tabular_data2.feature_names)
#tabular_data_full2["target"] = pd.DataFrame(tabular_data2.target)
df = pd.read_csv(filename2)
tabular_data_full2 = pd.DataFrame(df)
tabular_data_full2.to_csv(filename2, index=False)

upload_to_s3(bucket, "data/tabular", filename2)

Writing to s3://sagemaker-eu-west-1-174976546647/data/tabular/mnist_train.csv


In [13]:
# check if IAM policy is attached
try:
    existing_policies = iam.list_attached_role_policies(RoleName=role_name)["AttachedPolicies"]
    if "IAMFullAccess" not in [po["PolicyName"] for po in existing_policies]:
        print(
            "ERROR: You need to attach the IAMFullAccess policy in order to attach policy to the role"
        )
    else:
        print("IAMFullAccessPolicy Already Attached")
except ClientError as e:
    if e.response["Error"]["Code"] == "AccessDenied":
        print(
            "ERROR: You need to attach the IAMFullAccess policy in order to attach policy to the role."
        )
    else:
        print("Unexpected error: %s" % e)

IAMFullAccessPolicy Already Attached


In [14]:
athena_access_role_policy_doc = {
    "Version": "2012-10-17",
    "Statement": [
        {"Effect": "Allow", "Action": ["athena:*"], "Resource": ["*"]},
        {
            "Effect": "Allow",
            "Action": [
                "glue:CreateDatabase",
                "glue:DeleteDatabase",
                "glue:GetDatabase",
                "glue:GetDatabases",
                "glue:UpdateDatabase",
                "glue:CreateTable",
                "glue:DeleteTable",
                "glue:BatchDeleteTable",
                "glue:UpdateTable",
                "glue:GetTable",
                "glue:GetTables",
                "glue:BatchCreatePartition",
                "glue:CreatePartition",
                "glue:DeletePartition",
                "glue:BatchDeletePartition",
                "glue:UpdatePartition",
                "glue:GetPartition",
                "glue:GetPartitions",
                "glue:BatchGetPartition",
            ],
            "Resource": ["*"],
        },
        {"Effect": "Allow", "Action": ["lakeformation:GetDataAccess"], "Resource": ["*"]},
    ],
}

In [15]:
# create IAM client
iam = boto3.client("iam")
# create a policy
try:
    response = iam.create_policy(
        PolicyName="myAthenaPolicy", PolicyDocument=json.dumps(athena_access_role_policy_doc)
    )
except ClientError as e:
    if e.response["Error"]["Code"] == "EntityAlreadyExists":
        print("Policy already created.")
    else:
        print("Unexpected error: %s" % e)

Policy already created.


In [16]:
# get policy ARN
sts = boto3.client("sts")
account_id = sts.get_caller_identity()["Account"]
policy_athena_arn = f"arn:aws:iam::{account_id}:policy/myAthenaPolicy"

In [17]:
# Attach a role policy
try:
    response = iam.attach_role_policy(PolicyArn=policy_athena_arn, RoleName=role_name)
except ClientError as e:
    if e.response["Error"]["Code"] == "EntityAlreadyExists":
        print("Policy is already attached to your role.")
    else:
        print("Unexpected error: %s" % e)
        

In [18]:
from pyathena import connect
from pyathena.pandas.cursor import PandasCursor
from pyathena.pandas.util import as_pandas

In [19]:
# Set Athena database name
database_name = "tabular_california_housing"
database_name2 = "mnist_train"

In [20]:
# Set S3 staging directory -- this is a temporary directory used for Athena queries
s3_staging_dir = "s3://{0}/athena/staging".format(bucket)

In [21]:
# write the SQL statement to execute
statement = "CREATE DATABASE IF NOT EXISTS {}".format(database_name)
statement2 = "CREATE DATABASE IF NOT EXISTS {}".format(database_name2)
print(statement)
print(statement2)

CREATE DATABASE IF NOT EXISTS tabular_california_housing
CREATE DATABASE IF NOT EXISTS mnist_train


In [23]:
# connect to s3 using PyAthena
cursor = connect(region_name=region_name, s3_staging_dir=s3_staging_dir).cursor()
cursor.execute(statement)

cursor2 = connect(region_name=region_name, s3_staging_dir=s3_staging_dir).cursor()
cursor2.execute(statement2)

<pyathena.cursor.Cursor at 0x7f73bcee4a10>

In [24]:
prefix = "data/tabular"
filename_key = "california_housing"
filename_key2 = "mnist_train"

In [26]:
data_s3_location = "s3://{}/{}/{}/".format(bucket, prefix, filename_key)
data_s3_location2 = "s3://{}/{}/{}/".format(bucket, prefix, filename_key2)
print(data_s3_location)
print(data_s3_location2)

s3://sagemaker-eu-west-1-174976546647/data/tabular/california_housing/
s3://sagemaker-eu-west-1-174976546647/data/tabular/mnist_train/


In [27]:
table_name_csv = "california_housing_athena"
table_name_csv2 = "mnist_train_athena"

In [None]:
# SQL statement to execute

statement = """CREATE EXTERNAL TABLE IF NOT EXISTS {}.{}(
        MedInc double,
        HouseAge double,
        AveRooms double,
        AveBedrms double,
        Population double,
        AveOccup double,
        Latitude double,
        Longitude double,
        MedValue double

) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\\n' LOCATION '{}'
TBLPROPERTIES ('skip.header.line.count'='1')""".format(
    database_name, table_name_csv, data_s3_location
)

In [28]:
statement2 = """CREATE EXTERNAL TABLE IF NOT EXISTS {}.{}(
        label double,
        1x1 double,
        1x2 double,
        1x3 double,
        1x4 double,
        1x5 double,
        1x6 double,
        1x7 double,
        1x8 double

) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\\n' LOCATION '{}'
TBLPROPERTIES ('skip.header.line.count'='1')""".format(
    database_name2, table_name_csv2, data_s3_location2
)

In [30]:
# Execute statement using connection cursor
cursor = connect(region_name=region_name, s3_staging_dir=s3_staging_dir).cursor()
cursor.execute(statement)

cursor2 = connect(region_name=region_name, s3_staging_dir=s3_staging_dir).cursor()
cursor2.execute(statement2)

<pyathena.cursor.Cursor at 0x7f73aef29dd0>

In [31]:
# verify the table has been created
statement = "SHOW TABLES in {}".format(database_name)
cursor.execute(statement)

df_show = as_pandas(cursor)
df_show.head(5)

Unnamed: 0,tab_name
0,california_housing_athena


In [32]:
# verify the table has been created
statement2 = "SHOW TABLES in {}".format(database_name2)
cursor2.execute(statement2)

df_show2 = as_pandas(cursor2)
df_show2.head(5)

Unnamed: 0,tab_name
0,mnist_train_athena


In [33]:
%%time
# run a sample query

#statement = """SELECT * FROM {}.{} LIMIT 100""".format(database_name, table_name_csv)
statement = """SELECT * FROM {}.{} """.format(database_name, table_name_csv)
# Execute statement using connection cursor
cursor = connect(region_name=region_name, s3_staging_dir=s3_staging_dir).cursor()
cursor.execute(statement)

CPU times: user 296 ms, sys: 7.19 ms, total: 303 ms
Wall time: 1.55 s


<pyathena.cursor.Cursor at 0x7f73bd73d950>

In [34]:
df = as_pandas(cursor)
df.head(5)

Unnamed: 0,medinc,houseage,averooms,avebedrms,population,aveoccup,latitude,longitude,medvalue
0,8.3252,41.0,6.984127,1.02381,322.0,2.555556,37.88,-122.23,4.526
1,8.3014,21.0,6.238137,0.97188,2401.0,2.109842,37.86,-122.22,3.585
2,7.2574,52.0,8.288136,1.073446,496.0,2.80226,37.85,-122.24,3.521
3,5.6431,52.0,5.817352,1.073059,558.0,2.547945,37.85,-122.25,3.413
4,3.8462,52.0,6.281853,1.081081,565.0,2.181467,37.85,-122.25,3.422


In [35]:
%%time
# run a sample query

#statement = """SELECT * FROM {}.{} LIMIT 100""".format(database_name, table_name_csv)
statement2 = """SELECT * FROM {}.{} """.format(database_name2, table_name_csv2)
# Execute statement using connection cursor
cursor2 = connect(region_name=region_name, s3_staging_dir=s3_staging_dir).cursor()
cursor2.execute(statement2)

CPU times: user 125 ms, sys: 7.76 ms, total: 133 ms
Wall time: 2.38 s


<pyathena.cursor.Cursor at 0x7f73bc557a50>

In [36]:
df2 = as_pandas(cursor2)
df2.head(5)

Unnamed: 0,label,1x1,1x2,1x3,1x4,1x5,1x6,1x7,1x8
0,3.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,4.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,9.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,5.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [37]:
#read all data using PyAthena
import pandas as pd

In [38]:
%%time
conn_query = connect(region_name=region_name, s3_staging_dir=s3_staging_dir)
df = pd.read_sql_query("SELECT * FROM tabular_california_housing.california_housing_athena", conn_query)

CPU times: user 1.03 s, sys: 35.6 ms, total: 1.07 s
Wall time: 4.49 s


In [39]:
print(df.head())

   medinc  houseage  averooms  avebedrms  population  aveoccup  latitude  \
0  2.5474      40.0  4.887955   1.036415      1293.0  3.621849     34.09   
1  4.0987      35.0  4.434783   0.967391       887.0  3.213768     34.09   
2  2.9196      33.0  4.824528   1.090566      1715.0  3.235849     34.09   
3  3.7222      34.0  4.546135   1.139651      1485.0  3.703242     34.09   
4  3.7066      32.0  4.279627   0.956059      2316.0  3.083888     34.09   

   longitude  medvalue  
0    -118.07     1.981  
1    -118.07     2.024  
2    -118.08     2.088  
3    -118.08     2.072  
4    -118.08     2.068  


In [40]:
%%time
conn_query2 = connect(region_name=region_name, s3_staging_dir=s3_staging_dir)
df2 = pd.read_sql_query("SELECT * FROM mnist_train.mnist_train_athena", conn_query2)

CPU times: user 2.5 s, sys: 126 ms, total: 2.63 s
Wall time: 11.5 s


Alternative using Data Wrangler

In [41]:
import awswrangler as wr

In [42]:
for table in wr.catalog.get_tables(database=database_name):
    print(table["Name"])

california_housing_athena


In [43]:
for table in wr.catalog.get_tables(database=database_name2):
    print(table["Name"])

mnist_train_athena


In [45]:
%%time
df = wr.athena.read_sql_query(
    sql="SELECT * FROM {} LIMIT 100".format(table_name_csv), database=database_name
)

CPU times: user 222 ms, sys: 8.08 ms, total: 230 ms
Wall time: 2.68 s


In [None]:
df.head(5)

In [44]:
%%time
df2 = wr.athena.read_sql_query(
    sql="SELECT * FROM {} LIMIT 100".format(table_name_csv2), database=database_name2
)

CPU times: user 239 ms, sys: 15.1 ms, total: 254 ms
Wall time: 3.04 s


Alternative using direct S3 Query

In [50]:
#list the files
conn = boto3.client('s3')
contents = conn.list_objects(Bucket=bucket, Prefix=prefix)['Contents']
for f in contents:
    last_file = f['Key']
    print(f['Key'])

data/tabular/california_housing/california_housing.csv
data/tabular/mnist_train/mnist_train.csv


In [58]:
print(contents[0]['Key'])

data/tabular/california_housing/california_housing.csv


In [59]:
%%time

df = pd.read_csv( 's3://{}/{}'.format(bucket,contents[0]['Key']))

CPU times: user 106 ms, sys: 15.8 ms, total: 121 ms
Wall time: 216 ms


In [60]:
df.head(5)

Unnamed: 0,MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude,target
0,8.3252,41.0,6.984127,1.02381,322.0,2.555556,37.88,-122.23,4.526
1,8.3014,21.0,6.238137,0.97188,2401.0,2.109842,37.86,-122.22,3.585
2,7.2574,52.0,8.288136,1.073446,496.0,2.80226,37.85,-122.24,3.521
3,5.6431,52.0,5.817352,1.073059,558.0,2.547945,37.85,-122.25,3.413
4,3.8462,52.0,6.281853,1.081081,565.0,2.181467,37.85,-122.25,3.422


In [61]:
%%time

df2 = pd.read_csv( 's3://{}/{}'.format(bucket,contents[1]['Key']))

CPU times: user 3.44 s, sys: 837 ms, total: 4.28 s
Wall time: 5.13 s
