diff --git a/docs/en/docs/index.md b/docs/en/docs/index.md index 6d2fa35..6eb27cc 100644 --- a/docs/en/docs/index.md +++ b/docs/en/docs/index.md @@ -53,7 +53,7 @@ read and write databases and other datasource - get_sql(sql) same as read_sql - exe_sql(sql) - to_db(df, tb_name[, how, fast_load, chunksize]) - - df: pd.DataFrame()对象 + - df: pd.DataFrame() - tb_name: target table name - how: option, default append, - fast_load: option, default False; only support MySQL and Clickhouse, export pd.DataFrame to a temp csv then import to db diff --git a/pyqueen/io/data_source.py b/pyqueen/io/data_source.py index 79e4c04..6fd7810 100644 --- a/pyqueen/io/data_source.py +++ b/pyqueen/io/data_source.py @@ -1,3 +1,4 @@ +import inspect import warnings import pandas as pd from pyqueen.io.ds_plugin import DsLog, DsPlugin, DsConfig, DsExt @@ -19,26 +20,30 @@ } __support_conn_type__ = tuple(__conn_type_mapping__.keys()) -print(','.join(__support_conn_type__)) class DataSource(DsLog, DsPlugin, DsConfig, DsExt): - def __init__(self, - conn_type=None, - host=None, - username=None, - password=None, - port=None, - db_name=None, - db_type=None, - file_path=None, - jdbc_url=None, - cache_dir=None, - keep_conn=False, - charset=None, - conn_package=None - ): + def __init__(self, conn_type=None, host=None, username=None, password=None, port=None, db_name=None, db_type=None, file_path=None, jdbc_url=None, + cache_dir=None, keep_conn=False, charset=None, conn_package=None): super().__init__() + + self.conn_type = conn_type + self.host = host + self.username = username + self.password = password + self.port = port + self.db_name = db_name + self.db_type = db_type + self.file_path = file_path + self.jdbc_url = jdbc_url + self.cache_dir = cache_dir + self.keep_conn = keep_conn + self.charset = charset + self.conn_package = conn_package + + init_params = {k: getattr(self, k) for k in list(inspect.signature(self.__init__).parameters.keys())} + init_params = {k: v for k, v in init_params.items() if v is not None} + if conn_type is None and db_type is None: raise Exception('missing conn_type! supported conn_type:' + ','.join(__support_conn_type__)) if conn_type is None and db_type is not None: @@ -48,20 +53,16 @@ def __init__(self, operator = __conn_type_mapping__[conn_type] if conn_type in ('mysql', 'mssql', 'oracle', 'clickhouse', 'sqlite', 'postgresql', 'pgsql', 'jdbc'): from pyqueen.io import sqldb - self.operator = getattr(sqldb, operator)( - host=host, - username=username, - password=password, - port=port, - db_name=db_name, - jdbc_url=jdbc_url, - keep_conn=keep_conn, - charset=charset, - conn_package=conn_package - ) + _class = getattr(sqldb, operator) + req_params = list(inspect.signature(_class).parameters.keys()) + run_param = {k: v for k, v in init_params.items() if k in req_params} + self.operator = _class(**run_param) elif conn_type in ('excel',): from pyqueen.io import excel - self.operator = getattr(excel, operator)(file_path=file_path) + _class = getattr(excel, operator) + req_params = list(inspect.signature(_class).parameters.keys()) + run_param = {k: v for k, v in init_params.items() if k in req_params} + self.operator = _class(**run_param) elif conn_type in ('redis',): from pyqueen.io import kvdb self.operator = getattr(kvdb, operator)(conn_type=conn_type, host=host, port=port, db_name=db_name, keep_conn=keep_conn) diff --git a/pyqueen/io/sqldb.py b/pyqueen/io/sqldb.py index 9816056..8b56e5f 100644 --- a/pyqueen/io/sqldb.py +++ b/pyqueen/io/sqldb.py @@ -221,11 +221,10 @@ def to_db(self, df, tb_name, how, chunksize, fast_load): super().to_db(df=df, tb_name=tb_name, how=how, chunksize=chunksize) -class Sqlite: - def __init__(self, file_path=None, keep_conn=False, jdbc_url=None): +class Sqlite(SqlDB): + def __init__(self, file_path=None, jdbc_url=None): if jdbc_url is None: jdbc_url = 'sqlite:///%s' % str(file_path) super().__init__( - jdbc_url=jdbc_url, - keep_conn=keep_conn + jdbc_url=jdbc_url ) diff --git a/setup.py b/setup.py index 77104d4..a7b7424 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='pyqueen', - version='1.1.0', + version='1.1.1', url='https://github.com/ts7ming/pyqueen.git', description='Rule your Data', long_description=open("README.md", encoding='utf-8').read(), diff --git a/test/test_data_source.py b/test/test_data_source.py index becbfba..a1fd374 100644 --- a/test/test_data_source.py +++ b/test/test_data_source.py @@ -1,3 +1,5 @@ +import pandas as pd + from pyqueen import DataSource @@ -28,6 +30,33 @@ def test_mysql(): else: print('mysql read_sql not match') +def test_mysql2(): + ds = DataSource(conn_type='mysql', host='olympus', username='root', password='54maqiming', port=3306, db_name='cdc_source') + + sql = 'DROP TABLE IF EXISTS t_table' + ds.exe_sql(sql) + + sql = ''' + CREATE TABLE t_table ( + id INT NOT NULL, + name varchar(100) NULL + ) + ''' + ds.exe_sql(sql) + print('mysql exe_sql is ok') + + import pandas as pd + + df = pd.DataFrame({'id': [1, 2, 3, 4], 'name': ['libnl', 'agds', 'gfrt', 'hhg']}) + ds.to_db(df, tb_name='t_table', fast_load=False) + print('mysql to_db is ok') + + v = ds.get_sql('select count(1) as v from t_table') + if int(v.values[0][0]) == 4: + print('mysql read_sql is ok') + else: + print('mysql read_sql not match') + def test_postgresql(): ds = DataSource(conn_type='pgsql', host='localhost', username='postgres', password='1qaz2wsx!', port=5432, db_name='postgres') @@ -74,7 +103,33 @@ def test_redis(): print('redis get not match') +def test_sqlite(): + ds = DataSource(conn_type='sqlite', file_path='tst.db') + import pandas as pd + + df = pd.DataFrame({'dd': [1, 2, 3], 'gg': ['daf', 'gfytr', 'eee']}) + ds.to_db(df, tb_name='hhhhh') + print('sqlite: to_db is ok') + sql = 'select max(dd) as f from hhhhh' + v = int(ds.get_value(sql)) + if v == 3: + print('sqlite: read_sql is ok') + +def test_excel(): + ds = DataSource(conn_type='excel', file_path='./tst.xlsx') + import pandas as pd + + df = pd.DataFrame({'dd': [1, 2, 3], 'gg': ['daf', 'gfytr', 'eee']}) + ds.to_excel(sheet_list=[[df,'ma']]) + print('excel: to_excel is ok') + + + + if __name__ == '__main__': - test_mysql() - test_postgresql() - test_redis() + # test_mysql() + test_mysql2() + # test_postgresql() + # test_redis() + test_sqlite() + test_excel()