Skip to content

Commit

Permalink
Tomme#56 Tomme#121 Tomme#124: Initializes boto3 session globally to s…
Browse files Browse the repository at this point in the history
…upport configured AWS profile when calling boto3
  • Loading branch information
Marco Salazar committed Aug 29, 2022
1 parent 83f7756 commit 915175d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
4 changes: 2 additions & 2 deletions dbt/adapters/athena/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from dbt.adapters.sql import SQLConnectionManager
from dbt.exceptions import RuntimeException, FailedToConnectException
from dbt.events import AdapterLogger
from dbt.adapters.athena.session import get_boto3_session

import tenacity
from tenacity.retry import retry_if_exception
Expand Down Expand Up @@ -140,13 +141,12 @@ def open(cls, connection: Connection) -> Connection:
handle = AthenaConnection(
s3_staging_dir=creds.s3_staging_dir,
endpoint_url=creds.endpoint_url,
region_name=creds.region_name,
schema_name=creds.schema,
work_group=creds.work_group,
cursor_class=AthenaCursor,
formatter=AthenaParameterFormatter(),
poll_interval=creds.poll_interval,
profile_name=creds.aws_profile_name,
session=get_boto3_session(connection),
retry_config=RetryConfig(
attempt=creds.num_retries,
exceptions=(
Expand Down
14 changes: 6 additions & 8 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import agate
import re
import boto3
from botocore.exceptions import ClientError
from itertools import chain
from threading import Lock
Expand Down Expand Up @@ -59,10 +58,9 @@ def clean_up_partitions(
# Look up Glue partitions & clean up
conn = self.connections.get_thread_connection()
client = conn.handle

with boto3_client_lock:
glue_client = boto3.client('glue', region_name=client.region_name)
s3_resource = boto3.resource('s3', region_name=client.region_name)
glue_client = client.session.client('glue')
s3_resource = client.session.resource('s3')
partitions = glue_client.get_partitions(
# CatalogId='123456789012', # Need to make this configurable if it is different from default AWS Account ID
DatabaseName=database_name,
Expand All @@ -87,7 +85,7 @@ def clean_up_table(
conn = self.connections.get_thread_connection()
client = conn.handle
with boto3_client_lock:
glue_client = boto3.client('glue', region_name=client.region_name)
glue_client = client.session.client('glue')
try:
table = glue_client.get_table(
DatabaseName=database_name,
Expand All @@ -105,7 +103,7 @@ def clean_up_table(
if m is not None:
bucket_name = m.group(1)
prefix = m.group(2)
s3_resource = boto3.resource('s3', region_name=client.region_name)
s3_resource = client.session.resource('s3')
s3_bucket = s3_resource.Bucket(bucket_name)
s3_bucket.objects.filter(Prefix=prefix).delete()

Expand Down Expand Up @@ -152,7 +150,7 @@ def _get_data_catalog(self, catalog_name):
conn = self.connections.get_thread_connection()
client = conn.handle
with boto3_client_lock:
athena_client = boto3.client('athena', region_name=client.region_name)
athena_client = client.session.client('athena')

response = athena_client.get_data_catalog(Name=catalog_name)
return response['DataCatalog']
Expand All @@ -172,7 +170,7 @@ def list_relations_without_caching(
conn = self.connections.get_thread_connection()
client = conn.handle
with boto3_client_lock:
glue_client = boto3.client('glue', region_name=client.region_name)
glue_client = client.session.client('glue')
paginator = glue_client.get_paginator('get_tables')

kwargs = {
Expand Down
19 changes: 19 additions & 0 deletions dbt/adapters/athena/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import boto3.session
from dbt.contracts.connection import Connection


__BOTO3_SESSION__: boto3.session.Session = None


def get_boto3_session(connection: Connection) -> boto3.session.Session:
def init_session():
global __BOTO3_SESSION__
__BOTO3_SESSION__ = boto3.session.Session(
region_name=connection.credentials.region_name,
profile_name=connection.credentials.aws_profile_name,
)

if not __BOTO3_SESSION__:
init_session()

return __BOTO3_SESSION__

0 comments on commit 915175d

Please sign in to comment.