In [49]:
import io
import boto3
import sagemaker
import json
from sagemaker import get_execution_role
import os
import sys
from sklearn.datasets import fetch_california_housing
import pandas as pd
from IAM.sagemaker.iam_role_sagemaker import *
from botocore.exceptions import ClientError


In [50]:
# Get region
session = boto3.session.Session()
region_name = session.region_name

# Get SageMaker session & default S3 bucket
sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()  # replace with your own bucket name if you have one
iam = boto3.client("iam")
print(iam)
s3 = sagemaker_session.boto_session.resource("s3")
role = 'demo-iam-role'
role_name = role.split("/")[-1]
prefix = "data/tabular/california_housing"
filename = "california_housing.csv"

<botocore.client.IAM object at 0x7f54be908550>


In [51]:
# helper functions to upload data to s3
def write_to_s3(filename, bucket, prefix):
    # put one file in a separate folder. This is helpful if you read and prepare data with Athena
    filename_key = filename.split(".")[0]
    key = "{}/{}/{}".format(prefix, filename_key, filename)
    return s3.Bucket(bucket).upload_file(filename, key)


def upload_to_s3(bucket, prefix, filename):
    url = "s3://{}/{}/{}".format(bucket, prefix, filename)
    print("Writing to {}".format(url))
    write_to_s3(filename, bucket, prefix)

In [52]:
tabular_data = fetch_california_housing()
tabular_data_full = pd.DataFrame(tabular_data.data, columns=tabular_data.feature_names)
tabular_data_full["target"] = pd.DataFrame(tabular_data.target)
tabular_data_full.to_csv("california_housing.csv", index=False)

upload_to_s3(bucket, "data/tabular", filename)

Writing to s3://sagemaker-us-east-1-083839308414/data/tabular/california_housing.csv


In [53]:
athena_access_role_policy_doc = {
    "Version": "2012-10-17",
    "Statement": [
        {"Effect": "Allow", "Action": ["athena:*"], "Resource": ["*"]},
        {
            "Effect": "Allow",
            "Action": [
                "glue:CreateDatabase",
                "glue:DeleteDatabase",
                "glue:GetDatabase",
                "glue:GetDatabases",
                "glue:UpdateDatabase",
                "glue:CreateTable",
                "glue:DeleteTable",
                "glue:BatchDeleteTable",
                "glue:UpdateTable",
                "glue:GetTable",
                "glue:GetTables",
                "glue:BatchCreatePartition",
                "glue:CreatePartition",
                "glue:DeletePartition",
                "glue:BatchDeletePartition",
                "glue:UpdatePartition",
                "glue:GetPartition",
                "glue:GetPartitions",
                "glue:BatchGetPartition",
            ],
            "Resource": ["*"],
        },
        {"Effect": "Allow", "Action": ["lakeformation:GetDataAccess"], "Resource": ["*"]},
    ],
}

In [54]:
# create IAM client
iam = boto3.client("iam")
# create a policy
try:
    response = iam.create_policy(
        PolicyName="myAthenaPolicy", PolicyDocument=json.dumps(athena_access_role_policy_doc)
    )
except ClientError as e:
    if e.response["Error"]["Code"] == "EntityAlreadyExists":
        print("Policy already created.")
    else:
        print("Unexpected error: %s" % e)

Policy already created.


In [55]:
# get policy ARN
sts = boto3.client("sts")
account_id = sts.get_caller_identity()["Account"]
policy_athena_arn = f"arn:aws:iam::{account_id}:policy/myAthenaPolicy"

In [56]:
# Attach a role policy
try:
    response = iam.attach_role_policy(PolicyArn=policy_athena_arn, RoleName=role_name)
except ClientError as e:
    if e.response["Error"]["Code"] == "EntityAlreadyExists":
        print("Policy is already attached to your role.")
    else:
        print("Unexpected error: %s" % e)

Intro to PyAthena
We are going to leverage PyAthena to connect and run Athena queries. PyAthena is a Python DB API 2.0 (PEP 249) compliant client for Amazon Athena. Note
that you will need to specify the region in which you created the database/table in Athena, making sure your catalog in the specified region that contains the database.

In [57]:
from pyathena import connect
from pyathena.pandas.cursor import PandasCursor
from pyathena.pandas.util import  as_pandas

In [58]:
# Set Athena database name
database_name = "tabular_california_housing"

In [59]:
# Set S3 staging directory -- this is a temporary directory used for Athena queries
s3_staging_dir = "s3://{0}/athena/staging".format(bucket)

In [60]:
# write the SQL statement to execute
statement = "CREATE DATABASE IF NOT EXISTS {}".format(database_name)
print(statement)

CREATE DATABASE IF NOT EXISTS tabular_california_housing


In [61]:
# connect to s3 using PyAthena
cursor = connect(region_name=region_name, s3_staging_dir=s3_staging_dir).cursor()
cursor.execute(statement)

<pyathena.cursor.Cursor at 0x7f54bce585c0>

Register Table with Athena
When you run a CREATE TABLE query in Athena, you register your table with the AWS Glue Data Catalog.

To specify the path to your data in Amazon S3, use the LOCATION property, as shown in the following example: LOCATION s3://bucketname/folder/

The LOCATION in Amazon S3 specifies all of the files representing your table. Athena reads all data stored in s3://bucketname/folder/. If you have data that you do not want Athena to read, do not store that data in the same Amazon S3 folder as the data you want Athena to read. If you are leveraging partitioning, to ensure Athena scans data within a partition, your WHERE filter must include the partition. For more information, see Table Location and Partitions.

In [62]:
prefix = "data/tabular"
filename_key = "california_housing"

In [63]:
data_s3_location = "s3://{}/{}/{}/".format(bucket, prefix, filename_key)

In [64]:
table_name_csv = "california_housing_athena"

In [65]:
# SQL statement to execute

statement = """CREATE EXTERNAL TABLE IF NOT EXISTS {}.{}(
        MedInc double,
        HouseAge double,
        AveRooms double,
        AveBedrms double,
        Population double,
        AveOccup double,
        Latitude double,
        Longitude double,
        MedValue double

) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\\n' LOCATION '{}'
TBLPROPERTIES ('skip.header.line.count'='1')""".format(
    database_name, table_name_csv, data_s3_location
)

In [66]:
# Execute statement using connection cursor
cursor = connect(region_name=region_name, s3_staging_dir=s3_staging_dir).cursor()
cursor.execute(statement)

<pyathena.cursor.Cursor at 0x7f54bd522390>

In [67]:
# verify the table has been created
statement = "SHOW TABLES in {}".format(database_name)
cursor.execute(statement)

df_show = as_pandas(cursor)
df_show.head(5)

Unnamed: 0,tab_name
0,california_housing_athena


In [68]:
# run a sample query
statement = """SELECT * FROM {}.{}
LIMIT 100""".format(
    database_name, table_name_csv
)
# Execute statement using connection cursor
cursor = connect(region_name=region_name, s3_staging_dir=s3_staging_dir).cursor()
cursor.execute(statement)

<pyathena.cursor.Cursor at 0x7f54bc2f31d0>

In [69]:
df = as_pandas(cursor)
df.head(5)

Unnamed: 0,medinc,houseage,averooms,avebedrms,population,aveoccup,latitude,longitude,medvalue
0,8.3252,41.0,6.984127,1.02381,322.0,2.555556,37.88,-122.23,4.526
1,8.3014,21.0,6.238137,0.97188,2401.0,2.109842,37.86,-122.22,3.585
2,7.2574,52.0,8.288136,1.073446,496.0,2.80226,37.85,-122.24,3.521
3,5.6431,52.0,5.817352,1.073059,558.0,2.547945,37.85,-122.25,3.413
4,3.8462,52.0,6.281853,1.081081,565.0,2.181467,37.85,-122.25,3.422


Alternatives: Use AWS Data Wrangler to query data

In [70]:
import awswrangler as wr
for table in wr.catalog.get_tables(database=database_name):
    print(table["Name"])

california_housing_athena


In [71]:
%%time
df = wr.athena.read_sql_query(
    sql="SELECT * FROM {} LIMIT 100".format(table_name_csv), database=database_name
)

CPU times: user 113 ms, sys: 0 ns, total: 113 ms
Wall time: 11.3 s


In [72]:
df.head(5)

Unnamed: 0,medinc,houseage,averooms,avebedrms,population,aveoccup,latitude,longitude,medvalue
0,8.3252,41.0,6.984127,1.02381,322.0,2.555556,37.88,-122.23,4.526
1,8.3014,21.0,6.238137,0.97188,2401.0,2.109842,37.86,-122.22,3.585
2,7.2574,52.0,8.288136,1.073446,496.0,2.80226,37.85,-122.24,3.521
3,5.6431,52.0,5.817352,1.073059,558.0,2.547945,37.85,-122.25,3.413
4,3.8462,52.0,6.281853,1.081081,565.0,2.181467,37.85,-122.25,3.422
