# Load and store the datasets

Because other libraries such as SDV uses pandas, we save a copy of the dataset in CSV to re-use in other steps.

In [1]:
import os
from pathlib import Path

from common.config import *

In [2]:
# Read data from SDV
SDV_BASE_URL = "https://sdv-datasets.s3.amazonaws.com"

datasets_config = get_datsets_config()

import urllib.request
import zipfile

from pandas import read_csv

from ydata.connectors import GCSConnector
from ydata.connectors.filetype import FileType
from ydata.dataset import Dataset
from ydata.utils.formats import read_json


def get_from_gcs(gs_filepath, credential_path=CREDENTIALS_PATH):
    token = read_json(credential_path)
    connector = GCSConnector("ydatasynthetic", keyfile_dict=token)
    return connector.read_file(gs_filepath, file_type=FileType.CSV)

def download_sdv_dataset(dataset_info, output_folder = Path(DOWNLOADED_DATASETS)):
    if dataset_info.get('origin') == 'sdv':
        output_file_path = output_folder / f"{dataset_info['name']}.zip"
        print(output_file_path)
        urllib.request.urlretrieve(f"{SDV_BASE_URL}/{dataset_info['name']}.zip", output_file_path)
        with zipfile.ZipFile(output_file_path, 'r') as zip_ref:
            zip_ref.extractall(output_folder)

def get_sdv_dataset(dataset_info, as_fabric: bool = True, output_folder = Path(DOWNLOADED_DATASETS)):
    data_file = output_folder / f"{dataset_info['name']}" /  f"{dataset_info['name']}.csv"
    if not Path(data_file).is_file():
        download_sdv_dataset(dataset_info, output_folder = Path(DOWNLOADED_DATASETS))
    data = read_csv(data_file)
    return Dataset(data) if as_fabric else data
            
def get_fabric_dataset(dataset_info: str, credential_path=None):
    if 'url' in dataset_info:
        # For now I assume GCS
        return get_from_gcs(dataset_info['url'], credential_path=credential_path)
    else:  # Assume it is from the platform
        datasource = DataSources.get(uid=dataset_info['uuid'], namespace=dataset_info['namespace'])
        return datasource.read() 
                
def get_dataset(name: str, as_fabric=True, credential_path=CREDENTIALS_PATH, output_folder=Path(DOWNLOADED_DATASETS)):
    datasets_config = get_datsets_config()                          
    if name not in datasets_config:
        raise Exception("Unknown dataset")
    dataset_info = datasets_config[name]
    if dataset_info.get('origin') in [None, 'fabric']:
        dataset = get_fabric_dataset(dataset_info, credential_path=credential_path)
        return dataset if as_fabric else dataset.to_pandas()
    elif dataset_info.get('origin') == 'sdv':
        return get_sdv_dataset(dataset_info, as_fabric=as_fabric, output_folder=output_folder)
    else:
        raise Exception("Unknown origin")

In [3]:
datsets_config = get_datsets_config()

for name in datsets_config.keys():
    print(f'# Get {name}')
    dataset_path = Path(DATASET_PATH) / f'{name}.csv'
    if os.path.isfile(dataset_path):
        print(' -> Skip as already exists localy...')
    else:
        dataset = get_dataset(name)
        dataset = dataset.to_pandas() 
        dataset.columns = [c.strip() for c in dataset.columns]
        dataset.to_csv(dataset_path, index=False)

# Get sdv.adult
 -> Skip as already exists localy...
