From b0b2db9e5e4b1eeaae0434227f3f88d7fd7a49f0 Mon Sep 17 00:00:00 2001 From: kunwuz <514397511@qq.com> Date: Thu, 31 Aug 2023 13:20:00 -0400 Subject: [PATCH] Add a function to load real-world datasets --- causallearn/utils/Dataset.py | 41 ++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 causallearn/utils/Dataset.py diff --git a/causallearn/utils/Dataset.py b/causallearn/utils/Dataset.py new file mode 100644 index 00000000..6e62c2ee --- /dev/null +++ b/causallearn/utils/Dataset.py @@ -0,0 +1,41 @@ +import numpy as np +import urllib.request +from io import StringIO + +def load_dataset(dataset_name): + ''' + Load real-world datasets. Processed datasets are from https://github.com/cmu-phil/example-causal-datasets/tree/main + + Parameters + ---------- + dataset_name : str, ['sachs', 'boston_housing', 'airfoil'] + + Returns + ------- + data = np.array + labels = list + ''' + + url_mapping = { + 'sachs': 'https://raw.githubusercontent.com/cmu-phil/example-causal-datasets/main/real/sachs/data/sachs.2005.continuous.txt', + 'boston_housing': 'https://raw.githubusercontent.com/cmu-phil/example-causal-datasets/main/real/boston-housing/data/boston-housing.continuous.txt', + 'airfoil': 'https://raw.githubusercontent.com/cmu-phil/example-causal-datasets/main/real/airfoil-self-noise/data/airfoil-self-noise.continuous.txt' + } + + if dataset_name not in url_mapping: + raise ValueError("Invalid dataset name") + + url = url_mapping[dataset_name] + with urllib.request.urlopen(url) as response: + content = response.read().decode('utf-8') # Read content and decode to string + + # Use StringIO to turn the string content into a file-like object so numpy can read from it + labels_array = np.genfromtxt(StringIO(content), delimiter="\t", dtype=str, max_rows=1) + data = np.loadtxt(StringIO(content), skiprows=1) + + # Convert labels_array to a list of strings + labels_list = labels_array.tolist() + if isinstance(labels_list, str): # handle the case where there's only one label + labels_list = [labels_list] + + return data, labels_list \ No newline at end of file