Skip to content

Commit

Permalink
Merge pull request #1515 from henrykironde/kagle-master
Browse files Browse the repository at this point in the history
Update and clean kaggle support pr
  • Loading branch information
henrykironde committed Sep 16, 2020
2 parents 89fd22f + 0705cc2 commit b8eaf50
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 30 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ RUN pip install pylint
RUN pip install flake8 -U
RUN pip install h5py
RUN pip install Pillow
RUN pip install kaggle

# Install Postgis after Python is setup
RUN apt-get install -y --force-yes postgis
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
future
xlrd>=0.7
argcomplete
kaggle
PyMySQL>=0.4
psycopg2-binary
numpydoc
Expand Down
1 change: 1 addition & 0 deletions retriever/lib/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
RETRIEVER_REPOSITORY = RETRIEVER_MASTER_BRANCH
ENCODING = 'utf-8'
HOME_DIR = os.path.expanduser('~/.retriever/')
KAGGLE_TOKEN_PATH = os.path.expanduser('~/.kaggle/kaggle.json')
RETRIEVER_DIR = 'retriever'
if os.path.exists(os.path.join(HOME_DIR, 'retriever_path.txt')):
with open(os.path.join(HOME_DIR, 'retriever_path.txt'), 'r') as f:
Expand Down
121 changes: 92 additions & 29 deletions retriever/lib/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from tqdm import tqdm

from retriever.lib.cleanup import no_cleanup
from retriever.lib.defaults import DATA_DIR, DATA_SEARCH_PATHS, DATA_WRITE_PATH, ENCODING
from retriever.lib.defaults import DATA_DIR, DATA_SEARCH_PATHS, DATA_WRITE_PATH, ENCODING, KAGGLE_TOKEN_PATH
from retriever.lib.tools import (
open_fr,
open_fw,
Expand Down Expand Up @@ -283,8 +283,8 @@ def auto_get_datatypes(self, pk, source, columns):
if column_types[i][0] == 'double':
try:
val = float(val)
if "e" in str(val) or \
("." in str(val) and len(str(val).split(".")[1]) > 10):
if "e" in str(val) or ("." in str(val) and len(
str(val).split(".")[1]) > 10):
column_types[i] = ("decimal", "50,30")
except Exception as _:
column_types[i] = ('char', max_lengths[i])
Expand Down Expand Up @@ -509,6 +509,62 @@ def download_file(self, url, filename):
progbar.close()
return True

def download_from_kaggle(
self,
data_source,
dataset_name,
archive_dir,
archive_full_path,
):
"""Download files from Kaggle into the raw data directory"""
kaggle_token = os.path.isfile(KAGGLE_TOKEN_PATH)
kaggle_username = os.getenv('KAGGLE_USERNAME', "").strip()
kaggle_key = os.getenv('KAGGLE_KEY', "").strip()

if kaggle_token or (kaggle_username and kaggle_key):
from kaggle.api.kaggle_api_extended import KaggleApi
from kaggle.rest import ApiException
else:
print(f"Could not find kaggle.json. Make sure it's located at "
f"{KAGGLE_TOKEN_PATH}. Or available in the environment variables. "
f"For more information "
f"checkout https://github.com/Kaggle/kaggle-api#api-credentials")
return

api = KaggleApi()
api.authenticate()

if data_source == "dataset":
archive_full_path = archive_full_path + ".zip"
try:
api.dataset_download_files(dataset=dataset_name,
path=archive_dir,
quiet=False,
force=True)
file_names = self.extract_zip(archive_full_path, archive_dir)
except ApiException:
print(f"The dataset '{dataset_name}' isn't currently available "
f"in the Retriever.\nRun 'retriever ls' to see a "
f"list of currently available datasets.")
return []

else:
archive_full_path = archive_full_path.replace("kaggle:competition:",
"") + ".zip"
try:
api.competition_download_files(competition=dataset_name,
path=archive_dir,
quiet=False,
force=True)
file_names = self.extract_zip(archive_full_path, archive_dir)
except ApiException:
print(f"The dataset '{dataset_name}' isn't currently available "
f"in the Retriever.\nRun 'retriever ls' to see a "
f"list of currently available datasets.")
return []

return file_names

def download_files_from_archive(
self,
url,
Expand All @@ -534,34 +590,41 @@ def download_files_from_archive(
if not os.path.exists(archive_dir):
os.makedirs(archive_dir)

if not file_names:
self.download_file(url, archive_name)
if archive_type in ('tar', 'tar.gz'):
file_names = self.extract_tar(archive_full_path, archive_dir,
archive_type)
elif archive_type == 'zip':
file_names = self.extract_zip(archive_full_path, archive_dir)
elif archive_type == 'gz':
file_names = self.extract_gz(archive_full_path, archive_dir)
return file_names
if hasattr(self.script.__dict__, "kaggle"):
file_names = self.download_from_kaggle(data_source=self.script.data_source,
dataset_name=url,
archive_dir=archive_dir,
archive_full_path=archive_full_path)

archive_downloaded = bool(self.data_path)
for file_name in file_names:
archive_full_path = self.format_filename(archive_name)
if not self.find_file(os.path.join(archive_dir, file_name)):
# if no local copy, download the data
self.create_raw_data_dir()
if not archive_downloaded:
self.download_file(url, archive_name)
archive_downloaded = True
if archive_type == 'zip':
self.extract_zip(archive_full_path, archive_dir, file_name)
else:
if not file_names:
self.download_file(url, archive_name)
if archive_type in ('tar', 'tar.gz'):
file_names = self.extract_tar(archive_full_path, archive_dir,
archive_type)
elif archive_type == 'zip':
file_names = self.extract_zip(archive_full_path, archive_dir)
elif archive_type == 'gz':
self.extract_gz(archive_full_path, archive_dir, file_name)
elif archive_type in ('tar', 'tar.gz'):
self.extract_tar(archive_full_path, archive_dir, archive_type,
file_name)
return file_names
file_names = self.extract_gz(archive_full_path, archive_dir)
return file_names

archive_downloaded = bool(self.data_path)
for file_name in file_names:
archive_full_path = self.format_filename(archive_name)
if not self.find_file(os.path.join(archive_dir, file_name)):
# if no local copy, download the data
self.create_raw_data_dir()
if not archive_downloaded:
self.download_file(url, archive_name)
archive_downloaded = True
if archive_type == 'zip':
self.extract_zip(archive_full_path, archive_dir, file_name)
elif archive_type == 'gz':
self.extract_gz(archive_full_path, archive_dir, file_name)
elif archive_type in ('tar', 'tar.gz'):
self.extract_tar(archive_full_path, archive_dir, archive_type,
file_name)
return file_names

def drop_statement(self, object_type, object_name):
"""Return drop table or database SQL statement."""
Expand Down
28 changes: 27 additions & 1 deletion test/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from retriever.lib.engine_tools import create_file
from retriever.lib.engine_tools import file_2list
from retriever.lib.datapackage import clean_input, is_empty
from retriever.lib.defaults import HOME_DIR, RETRIEVER_DATASETS, RETRIEVER_REPOSITORY
from retriever.lib.defaults import HOME_DIR, RETRIEVER_DATASETS, RETRIEVER_REPOSITORY, KAGGLE_TOKEN_PATH

# Create simple engine fixture
test_engine = Engine()
Expand Down Expand Up @@ -80,6 +80,12 @@
tar_gz_url = os.path.normpath(achive_url.format(file_path='sample_tar.tar.gz'))
gz_url = os.path.normpath(achive_url.format(file_path='sample.gz'))

kaggle_datasets = [
# test_name, data_source, dataset_identifier, dataset_name, repath, expected
("kaggle_competition", "competition", "titanic", "titanic", ["gender_submission.csv", "test.csv", "train.csv"]),
("kaggle_unknown", "dataset", "uciml/iris", "iris", ['Iris.csv', 'database.sqlite']),
("kaggle_dataset", "competition", "non_existent_dataset", "non_existent_dataset", []),
]

def setup_module():
""""Automatically sets up the environment before the module runs.
Expand Down Expand Up @@ -244,6 +250,26 @@ def test_drop_statement():
'TABLE', 'tablename') == "DROP TABLE IF EXISTS tablename"


@pytest.mark.parametrize("test_name, data_source, dataset_identifier, repath, expected", kaggle_datasets)
def test_download_kaggle_dataset(test_name, data_source, dataset_identifier, repath, expected):
"""Test the downloading of dataset from kaggle."""
setup_functions()
files = test_engine.download_from_kaggle(
data_source=data_source,
dataset_name=dataset_identifier,
archive_dir=raw_dir_files,
archive_full_path=os.path.join(raw_dir_files, repath)
)

kaggle_token = os.path.isfile(KAGGLE_TOKEN_PATH)
kaggle_username = os.getenv('KAGGLE_USERNAME', "").strip()
kaggle_key = os.getenv('KAGGLE_KEY', "").strip()
if kaggle_token or (kaggle_username and kaggle_key):
assert files == expected
else:
assert files == None


def test_download_archive_gz_known():
"""Download and extract known files
Expand Down

0 comments on commit b8eaf50

Please sign in to comment.