diff --git a/options.py b/options.py index 3886bb3ad..12313b78c 100644 --- a/options.py +++ b/options.py @@ -19,6 +19,8 @@ def __getattr__(self, item): try: import pandas + # since we enable relative imports without dots this import gives us an issues when ran from test directory + from pandas import DataFrame # NOQA import pyarrow installed_pandas = True except ImportError: diff --git a/pandas_tools.py b/pandas_tools.py new file mode 100644 index 000000000..e64dd0659 --- /dev/null +++ b/pandas_tools.py @@ -0,0 +1,157 @@ +import os +import string +import random +from tempfile import TemporaryDirectory +from typing import Optional, Sequence, TypeVar, Iterator, Tuple, Union, Iterable + +from snowflake.connector import ProgrammingError +from snowflake.connector.options import pandas + +MYPY = False +if MYPY: # from typing import TYPE_CHECKING once 3.5 is deprecated + from .connection import SnowflakeConnection + + try: + import sqlalchemy + except ImportError: + sqlalchemy = None + +T = TypeVar('T', bound=Sequence) + + +def chunk_helper(lst: T, n: int) -> Iterator[Tuple[int, T]]: + """Helper generator to chunk a sequence efficiently with current index like if enumerate was called on sequence""" + for i in range(0, len(lst), n): + yield int(i / n), lst[i:i + n] + + +def write_pandas(conn: 'SnowflakeConnection', + df: 'pandas.DataFrame', + table_name: str, + database: Optional[str] = None, + schema: Optional[str] = None, + chunk_size: Optional[int] = None, + compression: str = 'gzip', + on_error: str = 'abort_statement', + parallel: int = 4 + ) -> Tuple[bool, int, int, + Sequence[Tuple[str, str, int, int, int, int, Optional[str], Optional[int], + Optional[int], Optional[str]]]]: + """ + Allows users to most efficiently write back a pandas DataFrame to Snowflake by dumping the DataFrame into Parquet + files, uploading them and finally copying their data into the table. Returns the COPY INTO command's results to + verify ingestion. + + Returns whether all files were ingested correctly, number of chunks uploaded, and number of rows ingested + with all of the COPY INTO command's output for debugging purposes. + + :Example: + + import pandas + from snowflake.connector.pandas_tools import write_pandas_all + + df = pandas.DataFrame([('Mark', 10), ('Luke', 20)], columns=['name', 'balance']) + success, nchunks, nrows, _ = write_pandas_all(cnx, df, 'customers') + + @param conn: connection to be used to communicate with Snowflake + @param df: Dataframe we'd like to write back + @param table_name: Table name where we want to insert into + @param database: Database schema and table is in, if not provided the default one will be used + @param schema: Schema table is in, if not provided the default one will be used + @param chunk_size: Number of elements to be inserted once, if not provided all elements will be dumped once + @param compression: The compression used on the Parquet files, can only be gzip, or snappy. Gzip gives supposedly a + better compression, while snappy is faster. Use whichever is more appropriate. + @param on_error: Action to take when COPY INTO statements fail, default follows documentation at: + https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#copy-options-copyoptions + @param parallel: Number of threads to be used when uploading chunks, default follows documentation at: + https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters + @return: tuple of whether all chunks were ingested correctly, # of chunks, # of ingested rows, and ingest's output + """ + if database is not None and schema is None: + raise ProgrammingError("Schema has to be provided to write_pandas_all when a database is provided") + # This dictionary maps the compression algorithm to Snowflake put copy into command type + # https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#type-parquet + compression_map = { + 'gzip': 'auto', + 'snappy': 'snappy' + } + if compression not in compression_map.keys(): + raise ProgrammingError("Invalid compression '{}', only acceptable values are: {}".format( + compression, + compression_map.keys() + )) + location = (('"' + database + '".') if database else '' + + ('"' + schema + '".') if schema else '' + + ('"' + table_name + '"')) + if chunk_size is None: + chunk_size = len(df) + cursor = conn.cursor() + stage_name = None # Forward declaration + while True: + try: + stage_name = ''.join(random.choice(string.ascii_lowercase) for _ in range(5)) + cursor.execute('create temporary stage /* Python:snowflake.connector.pandas_tools.write.pandas_all() */ ' + '"{stage_name}"'.format(stage_name=stage_name), _is_internal=True).fetchall() + break + except ProgrammingError as pe: + if pe.msg.endswith('already exists.'.format(stage_name)): + continue + raise + + with TemporaryDirectory() as tmp_folder: + for i, chunk in chunk_helper(df, chunk_size): + chunk_path = '{}/file{}.txt'.format(tmp_folder, i) + # Dump chunk into parquet file + chunk.to_parquet(chunk_path, compression=compression) + # Upload parquet file + cursor.execute('PUT /* Python:snowflake.connector.pandas_tools.write.pandas_all() */ ' + 'file://{path} @"{stage_name}" PARALLEL={parallel}'.format( + path=chunk_path, + stage_name=stage_name, + parallel=parallel + ), _is_internal=True) + # Remove chunk file + os.remove(chunk_path) + copy_results = cursor.execute(( + 'COPY INTO {location} /* Python:snowflake.connector.pandas_tools.write.pandas_all() */ ' + 'FROM @"{stage_name}" FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression}) ' + 'MATCH_BY_COLUMN_NAME=CASE_SENSITIVE PURGE=TRUE ON_ERROR={on_error}' + ).format( + location=location, + stage_name=stage_name, + compression=compression_map[compression], + on_error=on_error + ), _is_internal=True).fetchall() + cursor.close() + return (all((e[1] == 'LOADED' for e in copy_results)), + len(copy_results), + sum((e[3] for e in copy_results)), + copy_results) + + +def pd_writer(table: pandas.io.sql.SQLTable, + conn: Union['sqlalchemy.engine.Engine', 'sqlalchemy.engine.Connection'], + keys: Iterable, + data_iter: Iterable) -> None: + """ + This is a wrapper on top of write_pandas_all to make it compatible with to_sql method in pandas. + :Example: + + import pandas as pd + from snowflake.connector.pandas_utils import pf_writer + + sf_connector_version_df = pd.DataFrame([('snowflake-connector-python',)], columns=['NAME', 'NEWEST_VERSION']) + sf_connector_version_df.to_sql('driver_versions', engine, index=False, method=pd_writer) + @param table: Pandas package's table object + @param conn: SQLAlchemy engine object to talk to Snowflake + @param keys: Column names that we are trying to insert + @param data_iter: Iterator over the rows + @return: None + """ + sf_connection = conn.connection.connection + df = pandas.DataFrame(data_iter, columns=keys) + write_pandas(conn=sf_connection, + df=df, + # Note: Our sqlalchemy connector creates table in the case insensitive way + table_name=table.name.upper(), + schema=table.schema) diff --git a/test/conftest.py b/test/conftest.py index a18a7b959..b87227631 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -9,6 +9,7 @@ import uuid from contextlib import contextmanager from logging import getLogger +from typing import Callable, Generator import pytest from generate_test_files import generate_k_lines_of_n_files @@ -18,6 +19,10 @@ from snowflake.connector.compat import IS_WINDOWS, TO_UNICODE from snowflake.connector.connection import DefaultConverterClass +MYPY = False +if MYPY: # from typing import TYPE_CHECKING once 3.5 is deprecated + from snowflake.connector import SnowflakeConnection + try: from parameters import CONNECTION_PARAMETERS_S3 except ImportError: @@ -257,7 +262,7 @@ def fin(): request.addfinalizer(fin) -def create_connection(**kwargs): +def create_connection(**kwargs) -> 'SnowflakeConnection': """ Creates a connection using the parameters defined in JDBC connect string """ @@ -268,7 +273,7 @@ def create_connection(**kwargs): @contextmanager -def db(**kwargs): +def db(**kwargs) -> Generator['SnowflakeConnection', None, None]: if not kwargs.get(u'timezone'): kwargs[u'timezone'] = u'UTC' if not kwargs.get(u'converter_class'): @@ -307,7 +312,7 @@ def fin(): @pytest.fixture() -def conn_cnx(): +def conn_cnx() -> Callable[..., Generator['SnowflakeConnection', None, None]]: return db diff --git a/test/pandas/test_pandas_tools.py b/test/pandas/test_pandas_tools.py new file mode 100644 index 000000000..989756fd7 --- /dev/null +++ b/test/pandas/test_pandas_tools.py @@ -0,0 +1,59 @@ +import math +from typing import Callable, Generator +import pandas + +import pytest + +from snowflake.connector.pandas_tools import write_pandas + +MYPY = False +if MYPY: # from typing import TYPE_CHECKING once 3.5 is deprecated + from snowflake.connector import SnowflakeConnection + +sf_connector_version_data = [ + ('snowflake-connector-python', '1.2.23'), + ('snowflake-sqlalchemy', '1.1.1'), + ('snowflake-connector-go', '0.0.1'), + ('snowflake-go', '1.0.1'), + ('snowflake-odbc', '3.12.3'), +] + +sf_connector_version_df = pandas.DataFrame(sf_connector_version_data, columns=['name', 'newest_version']) + + +@pytest.mark.parametrize('chunk_size', [5, 4, 3, 2, 1]) +@pytest.mark.parametrize('compression', ['gzip', 'snappy']) +# Note: since the file will to small to chunk, this is only testing the put command's syntax +@pytest.mark.parametrize('parallel', [4, 99]) +def test_write_pandas(conn_cnx: Callable[..., Generator['SnowflakeConnection', None, None]], + compression: str, + parallel: int, + chunk_size: int): + num_of_chunks = math.ceil(len(sf_connector_version_data) / chunk_size) + + with conn_cnx() as cnx: # type: SnowflakeConnection + table_name = 'driver_versions' + cnx.execute_string('CREATE OR REPLACE TABLE "{}"("name" STRING, "newest_version" STRING)'.format(table_name)) + try: + success, nchunks, nrows, _ = write_pandas(cnx, + sf_connector_version_df, + table_name, + compression=compression, + parallel=parallel, + chunk_size=chunk_size) + if num_of_chunks == 1: + # Note: since we used one chunk order is conserved + assert (cnx.cursor().execute('SELECT * FROM "{}"'.format(table_name)).fetchall() == + sf_connector_version_data) + else: + # Note: since we used one chunk order is NOT conserved + assert (set(cnx.cursor().execute('SELECT * FROM "{}"'.format(table_name)).fetchall()) == + set(sf_connector_version_data)) + # Make sure all files were loaded and no error occurred + assert success + # Make sure overall as many rows were ingested as we tried to insert + assert nrows == len(sf_connector_version_data) + # Make sure we uploaded in as many chunk as we wanted to + assert nchunks == num_of_chunks + finally: + cnx.execute_string("DROP TABLE IF EXISTS {}".format(table_name)) diff --git a/tox.ini b/tox.ini index df9afef4c..ec0c2906d 100644 --- a/tox.ini +++ b/tox.ini @@ -95,4 +95,4 @@ include_trailing_comma = True force_grid_wrap = 0 line_length = 120 known_first_party =snowflake,parameters -known_third_party =Cryptodome,OpenSSL,asn1crypto,azure,boto3,botocore,certifi,cryptography,dateutil,generate_test_files,jwt,mock,numpy,pendulum,pyasn1,pyasn1_modules,pytest,pytz,requests,setuptools,urllib3 +known_third_party =Cryptodome,OpenSSL,asn1crypto,azure,boto3,botocore,certifi,cryptography,dateutil,generate_test_files,jwt,mock,numpy,pandas,pendulum,pyasn1,pyasn1_modules,pytest,pytz,requests,setuptools,urllib3