Skip to content

Commit

Permalink
SNOW-135902 SNOW-135902 added more efficient way to ingest a pandas.D…
Browse files Browse the repository at this point in the history
…ataframe into Snowflake, located in snowflake.connector.pandas_tools
  • Loading branch information
sfc-gh-stakeda authored and sfc-gh-abhatnagar committed Apr 30, 2020
1 parent 0e95dd5 commit 3783c5f
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 4 deletions.
2 changes: 2 additions & 0 deletions options.py
Expand Up @@ -19,6 +19,8 @@ def __getattr__(self, item):

try:
import pandas
# since we enable relative imports without dots this import gives us an issues when ran from test directory
from pandas import DataFrame # NOQA
import pyarrow
installed_pandas = True
except ImportError:
Expand Down
157 changes: 157 additions & 0 deletions pandas_tools.py
@@ -0,0 +1,157 @@
import os
import string
import random
from tempfile import TemporaryDirectory
from typing import Optional, Sequence, TypeVar, Iterator, Tuple, Union, Iterable

from snowflake.connector import ProgrammingError
from snowflake.connector.options import pandas

MYPY = False
if MYPY: # from typing import TYPE_CHECKING once 3.5 is deprecated
from .connection import SnowflakeConnection

try:
import sqlalchemy
except ImportError:
sqlalchemy = None

T = TypeVar('T', bound=Sequence)


def chunk_helper(lst: T, n: int) -> Iterator[Tuple[int, T]]:
"""Helper generator to chunk a sequence efficiently with current index like if enumerate was called on sequence"""
for i in range(0, len(lst), n):
yield int(i / n), lst[i:i + n]


def write_pandas(conn: 'SnowflakeConnection',
df: 'pandas.DataFrame',
table_name: str,
database: Optional[str] = None,
schema: Optional[str] = None,
chunk_size: Optional[int] = None,
compression: str = 'gzip',
on_error: str = 'abort_statement',
parallel: int = 4
) -> Tuple[bool, int, int,
Sequence[Tuple[str, str, int, int, int, int, Optional[str], Optional[int],
Optional[int], Optional[str]]]]:
"""
Allows users to most efficiently write back a pandas DataFrame to Snowflake by dumping the DataFrame into Parquet
files, uploading them and finally copying their data into the table. Returns the COPY INTO command's results to
verify ingestion.
Returns whether all files were ingested correctly, number of chunks uploaded, and number of rows ingested
with all of the COPY INTO command's output for debugging purposes.
:Example:
import pandas
from snowflake.connector.pandas_tools import write_pandas_all
df = pandas.DataFrame([('Mark', 10), ('Luke', 20)], columns=['name', 'balance'])
success, nchunks, nrows, _ = write_pandas_all(cnx, df, 'customers')
@param conn: connection to be used to communicate with Snowflake
@param df: Dataframe we'd like to write back
@param table_name: Table name where we want to insert into
@param database: Database schema and table is in, if not provided the default one will be used
@param schema: Schema table is in, if not provided the default one will be used
@param chunk_size: Number of elements to be inserted once, if not provided all elements will be dumped once
@param compression: The compression used on the Parquet files, can only be gzip, or snappy. Gzip gives supposedly a
better compression, while snappy is faster. Use whichever is more appropriate.
@param on_error: Action to take when COPY INTO statements fail, default follows documentation at:
https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#copy-options-copyoptions
@param parallel: Number of threads to be used when uploading chunks, default follows documentation at:
https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters
@return: tuple of whether all chunks were ingested correctly, # of chunks, # of ingested rows, and ingest's output
"""
if database is not None and schema is None:
raise ProgrammingError("Schema has to be provided to write_pandas_all when a database is provided")
# This dictionary maps the compression algorithm to Snowflake put copy into command type
# https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#type-parquet
compression_map = {
'gzip': 'auto',
'snappy': 'snappy'
}
if compression not in compression_map.keys():
raise ProgrammingError("Invalid compression '{}', only acceptable values are: {}".format(
compression,
compression_map.keys()
))
location = (('"' + database + '".') if database else '' +
('"' + schema + '".') if schema else '' +
('"' + table_name + '"'))
if chunk_size is None:
chunk_size = len(df)
cursor = conn.cursor()
stage_name = None # Forward declaration
while True:
try:
stage_name = ''.join(random.choice(string.ascii_lowercase) for _ in range(5))
cursor.execute('create temporary stage /* Python:snowflake.connector.pandas_tools.write.pandas_all() */ '
'"{stage_name}"'.format(stage_name=stage_name), _is_internal=True).fetchall()
break
except ProgrammingError as pe:
if pe.msg.endswith('already exists.'.format(stage_name)):
continue
raise

with TemporaryDirectory() as tmp_folder:
for i, chunk in chunk_helper(df, chunk_size):
chunk_path = '{}/file{}.txt'.format(tmp_folder, i)
# Dump chunk into parquet file
chunk.to_parquet(chunk_path, compression=compression)
# Upload parquet file
cursor.execute('PUT /* Python:snowflake.connector.pandas_tools.write.pandas_all() */ '
'file://{path} @"{stage_name}" PARALLEL={parallel}'.format(
path=chunk_path,
stage_name=stage_name,
parallel=parallel
), _is_internal=True)
# Remove chunk file
os.remove(chunk_path)
copy_results = cursor.execute((
'COPY INTO {location} /* Python:snowflake.connector.pandas_tools.write.pandas_all() */ '
'FROM @"{stage_name}" FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression}) '
'MATCH_BY_COLUMN_NAME=CASE_SENSITIVE PURGE=TRUE ON_ERROR={on_error}'
).format(
location=location,
stage_name=stage_name,
compression=compression_map[compression],
on_error=on_error
), _is_internal=True).fetchall()
cursor.close()
return (all((e[1] == 'LOADED' for e in copy_results)),
len(copy_results),
sum((e[3] for e in copy_results)),
copy_results)


def pd_writer(table: pandas.io.sql.SQLTable,
conn: Union['sqlalchemy.engine.Engine', 'sqlalchemy.engine.Connection'],
keys: Iterable,
data_iter: Iterable) -> None:
"""
This is a wrapper on top of write_pandas_all to make it compatible with to_sql method in pandas.
:Example:
import pandas as pd
from snowflake.connector.pandas_utils import pf_writer
sf_connector_version_df = pd.DataFrame([('snowflake-connector-python',)], columns=['NAME', 'NEWEST_VERSION'])
sf_connector_version_df.to_sql('driver_versions', engine, index=False, method=pd_writer)
@param table: Pandas package's table object
@param conn: SQLAlchemy engine object to talk to Snowflake
@param keys: Column names that we are trying to insert
@param data_iter: Iterator over the rows
@return: None
"""
sf_connection = conn.connection.connection
df = pandas.DataFrame(data_iter, columns=keys)
write_pandas(conn=sf_connection,
df=df,
# Note: Our sqlalchemy connector creates table in the case insensitive way
table_name=table.name.upper(),
schema=table.schema)
11 changes: 8 additions & 3 deletions test/conftest.py
Expand Up @@ -9,6 +9,7 @@
import uuid
from contextlib import contextmanager
from logging import getLogger
from typing import Callable, Generator

import pytest
from generate_test_files import generate_k_lines_of_n_files
Expand All @@ -18,6 +19,10 @@
from snowflake.connector.compat import IS_WINDOWS, TO_UNICODE
from snowflake.connector.connection import DefaultConverterClass

MYPY = False
if MYPY: # from typing import TYPE_CHECKING once 3.5 is deprecated
from snowflake.connector import SnowflakeConnection

try:
from parameters import CONNECTION_PARAMETERS_S3
except ImportError:
Expand Down Expand Up @@ -257,7 +262,7 @@ def fin():
request.addfinalizer(fin)


def create_connection(**kwargs):
def create_connection(**kwargs) -> 'SnowflakeConnection':
"""
Creates a connection using the parameters defined in JDBC connect string
"""
Expand All @@ -268,7 +273,7 @@ def create_connection(**kwargs):


@contextmanager
def db(**kwargs):
def db(**kwargs) -> Generator['SnowflakeConnection', None, None]:
if not kwargs.get(u'timezone'):
kwargs[u'timezone'] = u'UTC'
if not kwargs.get(u'converter_class'):
Expand Down Expand Up @@ -307,7 +312,7 @@ def fin():


@pytest.fixture()
def conn_cnx():
def conn_cnx() -> Callable[..., Generator['SnowflakeConnection', None, None]]:
return db


Expand Down
59 changes: 59 additions & 0 deletions test/pandas/test_pandas_tools.py
@@ -0,0 +1,59 @@
import math
from typing import Callable, Generator
import pandas

import pytest

from snowflake.connector.pandas_tools import write_pandas

MYPY = False
if MYPY: # from typing import TYPE_CHECKING once 3.5 is deprecated
from snowflake.connector import SnowflakeConnection

sf_connector_version_data = [
('snowflake-connector-python', '1.2.23'),
('snowflake-sqlalchemy', '1.1.1'),
('snowflake-connector-go', '0.0.1'),
('snowflake-go', '1.0.1'),
('snowflake-odbc', '3.12.3'),
]

sf_connector_version_df = pandas.DataFrame(sf_connector_version_data, columns=['name', 'newest_version'])


@pytest.mark.parametrize('chunk_size', [5, 4, 3, 2, 1])
@pytest.mark.parametrize('compression', ['gzip', 'snappy'])
# Note: since the file will to small to chunk, this is only testing the put command's syntax
@pytest.mark.parametrize('parallel', [4, 99])
def test_write_pandas(conn_cnx: Callable[..., Generator['SnowflakeConnection', None, None]],
compression: str,
parallel: int,
chunk_size: int):
num_of_chunks = math.ceil(len(sf_connector_version_data) / chunk_size)

with conn_cnx() as cnx: # type: SnowflakeConnection
table_name = 'driver_versions'
cnx.execute_string('CREATE OR REPLACE TABLE "{}"("name" STRING, "newest_version" STRING)'.format(table_name))
try:
success, nchunks, nrows, _ = write_pandas(cnx,
sf_connector_version_df,
table_name,
compression=compression,
parallel=parallel,
chunk_size=chunk_size)
if num_of_chunks == 1:
# Note: since we used one chunk order is conserved
assert (cnx.cursor().execute('SELECT * FROM "{}"'.format(table_name)).fetchall() ==
sf_connector_version_data)
else:
# Note: since we used one chunk order is NOT conserved
assert (set(cnx.cursor().execute('SELECT * FROM "{}"'.format(table_name)).fetchall()) ==
set(sf_connector_version_data))
# Make sure all files were loaded and no error occurred
assert success
# Make sure overall as many rows were ingested as we tried to insert
assert nrows == len(sf_connector_version_data)
# Make sure we uploaded in as many chunk as we wanted to
assert nchunks == num_of_chunks
finally:
cnx.execute_string("DROP TABLE IF EXISTS {}".format(table_name))
2 changes: 1 addition & 1 deletion tox.ini
Expand Up @@ -95,4 +95,4 @@ include_trailing_comma = True
force_grid_wrap = 0
line_length = 120
known_first_party =snowflake,parameters
known_third_party =Cryptodome,OpenSSL,asn1crypto,azure,boto3,botocore,certifi,cryptography,dateutil,generate_test_files,jwt,mock,numpy,pendulum,pyasn1,pyasn1_modules,pytest,pytz,requests,setuptools,urllib3
known_third_party =Cryptodome,OpenSSL,asn1crypto,azure,boto3,botocore,certifi,cryptography,dateutil,generate_test_files,jwt,mock,numpy,pandas,pendulum,pyasn1,pyasn1_modules,pytest,pytz,requests,setuptools,urllib3

0 comments on commit 3783c5f

Please sign in to comment.