Skip to content

Commit

Permalink
Merge pull request #76 from MarcCoru/master
Browse files Browse the repository at this point in the history
Fixed parsing of UEA UCR datasets
  • Loading branch information
rtavenar committed Oct 18, 2018
2 parents 9ea50ae + 3c157e8 commit f414a0c
Showing 1 changed file with 36 additions and 24 deletions.
60 changes: 36 additions & 24 deletions tslearn/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,27 @@ def extract_from_zip_url(url, target_dir=None, verbose=False):
sys.stderr.write("Corrupted zip file encountered, aborting.\n")
return None

def in_file_string_replace(filename, old_string, new_string):
""" String replacement within a text file. It is used to fix typos in downloaded csv file.
The code was modified from "https://stackoverflow.com/questions/4128144/replace-string-within-file-contents"
Parameters
----------
filename : string
Path to the file where strings should be replaced
old_string : str
The string to be replaced in the file.
new_string : str
The new string that will replace old_string
"""
with open(filename) as f:
s = f.read()

with open(filename, 'w') as f:
s = s.replace(old_string, new_string)
f.write(s)


class UCR_UEA_datasets(object):
"""A convenience class to access UCR/UEA time series datasets.
Expand Down Expand Up @@ -89,29 +110,14 @@ def __init__(self, use_cache=True):
url_baseline = "http://www.timeseriesclassification.com/singleTrainTest.csv"
self._baseline_scores_filename = os.path.join(self._data_dir, os.path.basename(url_baseline))
urlretrieve(url_baseline, self._baseline_scores_filename)

# fix typos in that CSV to match with the name in the download link
in_file_string_replace(self._baseline_scores_filename, "CinCECGtorso", "CinCECGTorso")
in_file_string_replace(self._baseline_scores_filename, "StarlightCurves", "StarLightCurves")
except:
self._baseline_scores_filename = None

self._ignore_list = ["Data Descriptions"]
# File names for datasets for which it is not obvious
self._filenames = {"CinCECGtorso": "CinC_ECG_torso",
"CricketX": "Cricket_X",
"CricketY": "Cricket_Y",
"CricketZ": "Cricket_Z",
"FiftyWords": "50words",
"Lightning2": "Lighting2",
"Lightning7": "Lighting7",
"NonInvasiveFatalECGThorax1": "NonInvasiveFetalECG_Thorax1",
"NonInvasiveFatalECGThorax2": "NonInvasiveFetalECG_Thorax2",
"GunPoint": "Gun_Point",
"SonyAIBORobotSurface1": "SonyAIBORobotSurface",
"SonyAIBORobotSurface2": "SonyAIBORobotSurfaceII",
"SyntheticControl": "synthetic_control",
"TwoPatterns": "Two_Patterns",
"UWaveGestureLibraryX": "UWaveGestureLibrary_X",
"UWaveGestureLibraryY": "UWaveGestureLibrary_Y",
"UWaveGestureLibraryZ": "UWaveGestureLibrary_Z",
"WordSynonyms": "WordsSynonyms"}

def baseline_accuracy(self, list_datasets=None, list_methods=None):
"""Report baseline performances as provided by UEA/UCR website.
Expand Down Expand Up @@ -209,23 +215,29 @@ def load_dataset(self, dataset_name):
(1000, 128, 1)
>>> print(y_train.shape)
(1000,)
>>> X_train, y_train, X_test, y_test = UCR_UEA_datasets().load_dataset("StarLightCurves")
>>> print(X_train.shape)
(1000, 1024, 1)
>>> X_train, y_train, X_test, y_test = UCR_UEA_datasets().load_dataset("CinCECGTorso")
>>> print(X_train.shape)
(40, 1639, 1)
>>> X_train, y_train, X_test, y_test = UCR_UEA_datasets().load_dataset("DatasetThatDoesNotExist")
>>> print(X_train)
None
"""
full_path = os.path.join(self._data_dir, dataset_name)
fname_train = self._filenames.get(dataset_name, dataset_name) + "_TRAIN.txt"
fname_test = self._filenames.get(dataset_name, dataset_name) + "_TEST.txt"
fname_train = dataset_name + "_TRAIN.txt"
fname_test = dataset_name + "_TEST.txt"
if not os.path.exists(os.path.join(full_path, fname_train)) or \
not os.path.exists(os.path.join(full_path, fname_test)):
url = "http://www.timeseriesclassification.com/Downloads/%s.zip" % dataset_name
for fname in [fname_train, fname_test]:
if os.path.exists(os.path.join(full_path, fname)):
os.remove(os.path.join(full_path, fname))
extract_from_zip_url(url, target_dir=self._data_dir, verbose=False)
extract_from_zip_url(url, target_dir=full_path, verbose=False)
try:
data_train = numpy.loadtxt(os.path.join(full_path, fname_train), delimiter=",")
data_test = numpy.loadtxt(os.path.join(full_path, fname_test), delimiter=",")
data_train = numpy.loadtxt(os.path.join(full_path, fname_train), delimiter=None)
data_test = numpy.loadtxt(os.path.join(full_path, fname_test), delimiter=None)
except:
return None, None, None, None
X_train = to_time_series_dataset(data_train[:, 1:])
Expand Down

0 comments on commit f414a0c

Please sign in to comment.