Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SNOW-135902 SNOW-135902 added more efficient way to ingest a pandas.D…
…ataframe into Snowflake, located in snowflake.connector.pandas_tools
- Loading branch information
1 parent
0e95dd5
commit 3783c5f
Showing
5 changed files
with
227 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import os | ||
import string | ||
import random | ||
from tempfile import TemporaryDirectory | ||
from typing import Optional, Sequence, TypeVar, Iterator, Tuple, Union, Iterable | ||
|
||
from snowflake.connector import ProgrammingError | ||
from snowflake.connector.options import pandas | ||
|
||
MYPY = False | ||
if MYPY: # from typing import TYPE_CHECKING once 3.5 is deprecated | ||
from .connection import SnowflakeConnection | ||
|
||
try: | ||
import sqlalchemy | ||
except ImportError: | ||
sqlalchemy = None | ||
|
||
T = TypeVar('T', bound=Sequence) | ||
|
||
|
||
def chunk_helper(lst: T, n: int) -> Iterator[Tuple[int, T]]: | ||
"""Helper generator to chunk a sequence efficiently with current index like if enumerate was called on sequence""" | ||
for i in range(0, len(lst), n): | ||
yield int(i / n), lst[i:i + n] | ||
|
||
|
||
def write_pandas(conn: 'SnowflakeConnection', | ||
df: 'pandas.DataFrame', | ||
table_name: str, | ||
database: Optional[str] = None, | ||
schema: Optional[str] = None, | ||
chunk_size: Optional[int] = None, | ||
compression: str = 'gzip', | ||
on_error: str = 'abort_statement', | ||
parallel: int = 4 | ||
) -> Tuple[bool, int, int, | ||
Sequence[Tuple[str, str, int, int, int, int, Optional[str], Optional[int], | ||
Optional[int], Optional[str]]]]: | ||
""" | ||
Allows users to most efficiently write back a pandas DataFrame to Snowflake by dumping the DataFrame into Parquet | ||
files, uploading them and finally copying their data into the table. Returns the COPY INTO command's results to | ||
verify ingestion. | ||
Returns whether all files were ingested correctly, number of chunks uploaded, and number of rows ingested | ||
with all of the COPY INTO command's output for debugging purposes. | ||
:Example: | ||
import pandas | ||
from snowflake.connector.pandas_tools import write_pandas_all | ||
df = pandas.DataFrame([('Mark', 10), ('Luke', 20)], columns=['name', 'balance']) | ||
success, nchunks, nrows, _ = write_pandas_all(cnx, df, 'customers') | ||
@param conn: connection to be used to communicate with Snowflake | ||
@param df: Dataframe we'd like to write back | ||
@param table_name: Table name where we want to insert into | ||
@param database: Database schema and table is in, if not provided the default one will be used | ||
@param schema: Schema table is in, if not provided the default one will be used | ||
@param chunk_size: Number of elements to be inserted once, if not provided all elements will be dumped once | ||
@param compression: The compression used on the Parquet files, can only be gzip, or snappy. Gzip gives supposedly a | ||
better compression, while snappy is faster. Use whichever is more appropriate. | ||
@param on_error: Action to take when COPY INTO statements fail, default follows documentation at: | ||
https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#copy-options-copyoptions | ||
@param parallel: Number of threads to be used when uploading chunks, default follows documentation at: | ||
https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters | ||
@return: tuple of whether all chunks were ingested correctly, # of chunks, # of ingested rows, and ingest's output | ||
""" | ||
if database is not None and schema is None: | ||
raise ProgrammingError("Schema has to be provided to write_pandas_all when a database is provided") | ||
# This dictionary maps the compression algorithm to Snowflake put copy into command type | ||
# https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#type-parquet | ||
compression_map = { | ||
'gzip': 'auto', | ||
'snappy': 'snappy' | ||
} | ||
if compression not in compression_map.keys(): | ||
raise ProgrammingError("Invalid compression '{}', only acceptable values are: {}".format( | ||
compression, | ||
compression_map.keys() | ||
)) | ||
location = (('"' + database + '".') if database else '' + | ||
('"' + schema + '".') if schema else '' + | ||
('"' + table_name + '"')) | ||
if chunk_size is None: | ||
chunk_size = len(df) | ||
cursor = conn.cursor() | ||
stage_name = None # Forward declaration | ||
while True: | ||
try: | ||
stage_name = ''.join(random.choice(string.ascii_lowercase) for _ in range(5)) | ||
cursor.execute('create temporary stage /* Python:snowflake.connector.pandas_tools.write.pandas_all() */ ' | ||
'"{stage_name}"'.format(stage_name=stage_name), _is_internal=True).fetchall() | ||
break | ||
except ProgrammingError as pe: | ||
if pe.msg.endswith('already exists.'.format(stage_name)): | ||
continue | ||
raise | ||
|
||
with TemporaryDirectory() as tmp_folder: | ||
for i, chunk in chunk_helper(df, chunk_size): | ||
chunk_path = '{}/file{}.txt'.format(tmp_folder, i) | ||
# Dump chunk into parquet file | ||
chunk.to_parquet(chunk_path, compression=compression) | ||
# Upload parquet file | ||
cursor.execute('PUT /* Python:snowflake.connector.pandas_tools.write.pandas_all() */ ' | ||
'file://{path} @"{stage_name}" PARALLEL={parallel}'.format( | ||
path=chunk_path, | ||
stage_name=stage_name, | ||
parallel=parallel | ||
), _is_internal=True) | ||
# Remove chunk file | ||
os.remove(chunk_path) | ||
copy_results = cursor.execute(( | ||
'COPY INTO {location} /* Python:snowflake.connector.pandas_tools.write.pandas_all() */ ' | ||
'FROM @"{stage_name}" FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression}) ' | ||
'MATCH_BY_COLUMN_NAME=CASE_SENSITIVE PURGE=TRUE ON_ERROR={on_error}' | ||
).format( | ||
location=location, | ||
stage_name=stage_name, | ||
compression=compression_map[compression], | ||
on_error=on_error | ||
), _is_internal=True).fetchall() | ||
cursor.close() | ||
return (all((e[1] == 'LOADED' for e in copy_results)), | ||
len(copy_results), | ||
sum((e[3] for e in copy_results)), | ||
copy_results) | ||
|
||
|
||
def pd_writer(table: pandas.io.sql.SQLTable, | ||
conn: Union['sqlalchemy.engine.Engine', 'sqlalchemy.engine.Connection'], | ||
keys: Iterable, | ||
data_iter: Iterable) -> None: | ||
""" | ||
This is a wrapper on top of write_pandas_all to make it compatible with to_sql method in pandas. | ||
:Example: | ||
import pandas as pd | ||
from snowflake.connector.pandas_utils import pf_writer | ||
sf_connector_version_df = pd.DataFrame([('snowflake-connector-python',)], columns=['NAME', 'NEWEST_VERSION']) | ||
sf_connector_version_df.to_sql('driver_versions', engine, index=False, method=pd_writer) | ||
@param table: Pandas package's table object | ||
@param conn: SQLAlchemy engine object to talk to Snowflake | ||
@param keys: Column names that we are trying to insert | ||
@param data_iter: Iterator over the rows | ||
@return: None | ||
""" | ||
sf_connection = conn.connection.connection | ||
df = pandas.DataFrame(data_iter, columns=keys) | ||
write_pandas(conn=sf_connection, | ||
df=df, | ||
# Note: Our sqlalchemy connector creates table in the case insensitive way | ||
table_name=table.name.upper(), | ||
schema=table.schema) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import math | ||
from typing import Callable, Generator | ||
import pandas | ||
|
||
import pytest | ||
|
||
from snowflake.connector.pandas_tools import write_pandas | ||
|
||
MYPY = False | ||
if MYPY: # from typing import TYPE_CHECKING once 3.5 is deprecated | ||
from snowflake.connector import SnowflakeConnection | ||
|
||
sf_connector_version_data = [ | ||
('snowflake-connector-python', '1.2.23'), | ||
('snowflake-sqlalchemy', '1.1.1'), | ||
('snowflake-connector-go', '0.0.1'), | ||
('snowflake-go', '1.0.1'), | ||
('snowflake-odbc', '3.12.3'), | ||
] | ||
|
||
sf_connector_version_df = pandas.DataFrame(sf_connector_version_data, columns=['name', 'newest_version']) | ||
|
||
|
||
@pytest.mark.parametrize('chunk_size', [5, 4, 3, 2, 1]) | ||
@pytest.mark.parametrize('compression', ['gzip', 'snappy']) | ||
# Note: since the file will to small to chunk, this is only testing the put command's syntax | ||
@pytest.mark.parametrize('parallel', [4, 99]) | ||
def test_write_pandas(conn_cnx: Callable[..., Generator['SnowflakeConnection', None, None]], | ||
compression: str, | ||
parallel: int, | ||
chunk_size: int): | ||
num_of_chunks = math.ceil(len(sf_connector_version_data) / chunk_size) | ||
|
||
with conn_cnx() as cnx: # type: SnowflakeConnection | ||
table_name = 'driver_versions' | ||
cnx.execute_string('CREATE OR REPLACE TABLE "{}"("name" STRING, "newest_version" STRING)'.format(table_name)) | ||
try: | ||
success, nchunks, nrows, _ = write_pandas(cnx, | ||
sf_connector_version_df, | ||
table_name, | ||
compression=compression, | ||
parallel=parallel, | ||
chunk_size=chunk_size) | ||
if num_of_chunks == 1: | ||
# Note: since we used one chunk order is conserved | ||
assert (cnx.cursor().execute('SELECT * FROM "{}"'.format(table_name)).fetchall() == | ||
sf_connector_version_data) | ||
else: | ||
# Note: since we used one chunk order is NOT conserved | ||
assert (set(cnx.cursor().execute('SELECT * FROM "{}"'.format(table_name)).fetchall()) == | ||
set(sf_connector_version_data)) | ||
# Make sure all files were loaded and no error occurred | ||
assert success | ||
# Make sure overall as many rows were ingested as we tried to insert | ||
assert nrows == len(sf_connector_version_data) | ||
# Make sure we uploaded in as many chunk as we wanted to | ||
assert nchunks == num_of_chunks | ||
finally: | ||
cnx.execute_string("DROP TABLE IF EXISTS {}".format(table_name)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters