In [1]:
from ThreeWToolkit.dataset import ParquetDataset
from ThreeWToolkit.core.base_dataset import ParquetDatasetConfig, EventPrefixEnum

## Download from figshare if not on local yet

If the dataset is not found at the given `path`, it will be downloaded automatically.

If you use `force_download=True`, the dataset will be downloaded regardless of whether it already exists locally. If it exists locally, it will be replaced by the new one. `force_download` defaults to `False`.


It may take a while to download from figshare (1.79GB) takes about 2-3 minutes to download.

Only dataset v2.0.0 is supported.

In [2]:
# Modify this path to the folder where your dataset is downloaded
dataset_path = "../../dataset"

In [3]:
# Load all files, selecting two classes (0 and 4)
ds_config = ParquetDatasetConfig(path=dataset_path, target_class=[0, 4])  
ds = ParquetDataset(ds_config)

[ParquetDataset] Dataset found at ../../dataset
[ParquetDataset] Validating dataset integrity...
[ParquetDataset] Dataset integrity check passed!


Get the number of events:

In [4]:
len(ds)

937

## Signal and label split

Our dataset object is a list of dictionaries.

Let's see the keys of the first element (dictionary).

In [5]:
ds[0].keys()

dict_keys(['signal', 'label', 'file_name'])

What is in the "signal" key?

In [6]:
ds[0]["signal"]

Unnamed: 0_level_0,ABER-CKGL,ABER-CKP,ESTADO-DHSV,ESTADO-M1,ESTADO-M2,ESTADO-PXO,ESTADO-SDV-GL,ESTADO-SDV-P,ESTADO-W1,ESTADO-W2,...,P-JUS-CKGL,P-JUS-CKP,P-MON-CKP,P-PDG,P-TPT,QGL,T-JUS-CKP,T-MON-CKP,T-PDG,T-TPT
timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2017-03-17 03:00:00,0.0,0.0,0.867921,0.414652,-0.681653,-0.094347,-1.094009,0.312558,0.650525,-0.563169,...,-1.062395,0.0,0.834230,0.0,-0.032904,-0.846093,-0.015430,0.0,0.0,0.662008
2017-03-17 03:00:01,0.0,0.0,0.867921,0.414652,-0.681653,-0.094347,-1.094009,0.312558,0.650525,-0.563169,...,-1.062395,0.0,0.834666,0.0,-0.032937,-0.846093,-0.015437,0.0,0.0,0.661992
2017-03-17 03:00:02,0.0,0.0,0.867921,0.414652,-0.681653,-0.094347,-1.094009,0.312558,0.650525,-0.563169,...,-1.062395,0.0,0.835103,0.0,-0.032970,-0.846093,-0.015443,0.0,0.0,0.661975
2017-03-17 03:00:03,0.0,0.0,0.867921,0.414652,-0.681653,-0.094347,-1.094009,0.312558,0.650525,-0.563169,...,-1.062395,0.0,0.835539,0.0,-0.033004,-0.846093,-0.015450,0.0,0.0,0.661959
2017-03-17 03:00:04,0.0,0.0,0.867921,0.414652,-0.681653,-0.094347,-1.094009,0.312558,0.650525,-0.563169,...,-1.062395,0.0,0.835482,0.0,-0.033038,-0.846093,-0.015457,0.0,0.0,0.661943
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2017-03-17 05:59:19,0.0,0.0,0.867921,0.414652,-0.681653,-0.094347,-1.094009,0.312558,0.650525,-0.563169,...,-1.062484,0.0,0.821254,0.0,0.013815,-0.846093,-0.022338,0.0,0.0,0.672607
2017-03-17 05:59:20,0.0,0.0,0.867921,0.414652,-0.681653,-0.094347,-1.094009,0.312558,0.650525,-0.563169,...,-1.062484,0.0,0.821214,0.0,0.013871,-0.846093,-0.022052,0.0,0.0,0.672604
2017-03-17 05:59:21,0.0,0.0,0.867921,0.414652,-0.681653,-0.094347,-1.094009,0.312558,0.650525,-0.563169,...,-1.062484,0.0,0.821174,0.0,0.013925,-0.846093,-0.021766,0.0,0.0,0.672600
2017-03-17 05:59:22,0.0,0.0,0.867921,0.414652,-0.681653,-0.094347,-1.094009,0.312558,0.650525,-0.563169,...,-1.062484,0.0,0.821133,0.0,0.013981,-0.846093,-0.021480,0.0,0.0,0.672597


Let's see the ids of the classes that we have in our loaded dataset:

In [7]:
set_of_labels = set()
for file in ds:
    for class_id in file["label"]["class"].unique():
        set_of_labels.add(class_id.item())
set_of_labels


{0, 4}

## Customizable label selection

Let's reload the dataset ommiting the target column:

In [8]:
ds_config = ParquetDatasetConfig(path=dataset_path, target_column=None)  
ds = ParquetDataset(ds_config)

[ParquetDataset] Dataset found at ../../dataset
[ParquetDataset] Validating dataset integrity...
[ParquetDataset] Dataset integrity check passed!


We expect not to have the "label" key.

In [9]:
ds[0].keys()

dict_keys(['signal', 'file_name'])

In [10]:
ds[0]["signal"]  # Should contain 'class' column

Unnamed: 0_level_0,ABER-CKGL,ABER-CKP,ESTADO-DHSV,ESTADO-M1,ESTADO-M2,ESTADO-PXO,ESTADO-SDV-GL,ESTADO-SDV-P,ESTADO-W1,ESTADO-W2,...,P-JUS-CKP,P-MON-CKP,P-PDG,P-TPT,QGL,T-JUS-CKP,T-MON-CKP,T-PDG,T-TPT,class
timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2016-07-04 18:00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.914069,0.312558,0.0,0.0,...,0.0,-0.579013,0.0,0.0,0.470780,-0.742003,0.0,0.0,-0.030592,0.0
2016-07-04 18:00:01,0.0,0.0,0.0,0.0,0.0,0.0,0.914069,0.312558,0.0,0.0,...,0.0,-0.594322,0.0,0.0,0.469668,-0.741887,0.0,0.0,-0.030732,0.0
2016-07-04 18:00:02,0.0,0.0,0.0,0.0,0.0,0.0,0.914069,0.312558,0.0,0.0,...,0.0,-0.609631,0.0,0.0,0.468555,-0.741771,0.0,0.0,-0.030871,0.0
2016-07-04 18:00:03,0.0,0.0,0.0,0.0,0.0,0.0,0.914069,0.312558,0.0,0.0,...,0.0,-0.624940,0.0,0.0,0.467442,-0.741654,0.0,0.0,-0.031011,0.0
2016-07-04 18:00:04,0.0,0.0,0.0,0.0,0.0,0.0,0.914069,0.312558,0.0,0.0,...,0.0,-0.597055,0.0,0.0,0.467239,-0.741538,0.0,0.0,-0.031150,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2016-07-06 12:59:56,0.0,0.0,0.0,0.0,0.0,0.0,0.914069,0.312558,0.0,0.0,...,0.0,-0.242209,0.0,0.0,0.462812,-0.890495,0.0,0.0,-0.007008,0.0
2016-07-06 12:59:57,0.0,0.0,0.0,0.0,0.0,0.0,0.914069,0.312558,0.0,0.0,...,0.0,-0.231274,0.0,0.0,0.462761,-0.890539,0.0,0.0,-0.007148,0.0
2016-07-06 12:59:58,0.0,0.0,0.0,0.0,0.0,0.0,0.914069,0.312558,0.0,0.0,...,0.0,-0.220338,0.0,0.0,0.462709,-0.890583,0.0,0.0,-0.007148,0.0
2016-07-06 12:59:59,0.0,0.0,0.0,0.0,0.0,0.0,0.914069,0.312558,0.0,0.0,...,0.0,-0.209403,0.0,0.0,0.462658,-0.890628,0.0,0.0,-0.007148,0.0


-------

## Event type splitting

Defining 2 types of events: drawn and simulated

In [11]:
event_types = [EventPrefixEnum.DRAWN, EventPrefixEnum.SIMULATED]

In [12]:
ds_config = ParquetDatasetConfig(path=dataset_path, event_type=event_types)
ds = ParquetDataset(ds_config)
print(f"Filtering events of types {event_types}, there are {len(ds)} events")

[ParquetDataset] Dataset found at ../../dataset
[ParquetDataset] Validating dataset integrity...
[ParquetDataset] Dataset integrity check passed!
Filtering events of types [<EventPrefixEnum.DRAWN: 'DRAWN'>, <EventPrefixEnum.SIMULATED: 'SIMULATED'>], there are 1109 events


Only real events now:

In [13]:
event_types = [EventPrefixEnum.REAL]
ds_config = ParquetDatasetConfig(path=dataset_path, event_type=event_types)
ds = ParquetDataset(ds_config)
print(f"Filtering events of types {event_types}, there are {len(ds)} events")

[ParquetDataset] Dataset found at ../../dataset
[ParquetDataset] Validating dataset integrity...
[ParquetDataset] Dataset integrity check passed!
Filtering events of types [<EventPrefixEnum.REAL: 'WELL'>], there are 1119 events


------

## Event class splitting

We can select any combination of class events.

First, let's see how many events the whole dataset has.

In [14]:
# If no target class is provided (equivalent to `target_class=None`), all classes are loaded
ds_config = ParquetDatasetConfig(path=dataset_path)
ds = ParquetDataset(ds_config)
len(ds)

[ParquetDataset] Dataset found at ../../dataset
[ParquetDataset] Validating dataset integrity...
[ParquetDataset] Dataset integrity check passed!


2228

Now, let's select only events of one class:

In [15]:
# Only loading events from class 4
target_class = [4]
ds_config = ParquetDatasetConfig(path=dataset_path, target_class=target_class)
ds = ParquetDataset(ds_config)
len(ds)

[ParquetDataset] Dataset found at ../../dataset
[ParquetDataset] Validating dataset integrity...
[ParquetDataset] Dataset integrity check passed!


343

Of 2 classes:

In [16]:
target_class = [0, 4]
ds_config = ParquetDatasetConfig(path=dataset_path, target_class=target_class)
ds = ParquetDataset(ds_config)
len(ds)

[ParquetDataset] Dataset found at ../../dataset
[ParquetDataset] Validating dataset integrity...
[ParquetDataset] Dataset integrity check passed!


937

-------

## File list splitting

We can use a list to select loaded events. This is useful for customized train/val/test splitting.

In [17]:
split = ["6/SIMULATED_00012.parquet", "6/SIMULATED_00049.parquet", "0/WELL-00001_20170201010207.parquet", "4/WELL-00001_20170316110203.parquet"]

In [18]:
# To get only files in split, we need to set split="list"
ds_config = ParquetDatasetConfig(path=dataset_path, split="list", file_list=split)
ds = ParquetDataset(ds_config)
len(ds)

[ParquetDataset] Dataset found at ../../dataset
[ParquetDataset] Validating dataset integrity...
[ParquetDataset] Dataset integrity check passed!


4

In [19]:
ds[0]["signal"]

Unnamed: 0_level_0,ABER-CKGL,ABER-CKP,ESTADO-DHSV,ESTADO-M1,ESTADO-M2,ESTADO-PXO,ESTADO-SDV-GL,ESTADO-SDV-P,ESTADO-W1,ESTADO-W2,...,P-JUS-CKGL,P-JUS-CKP,P-MON-CKP,P-PDG,P-TPT,QGL,T-JUS-CKP,T-MON-CKP,T-PDG,T-TPT
timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2018-04-01 18:45:12,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,-0.359913,-0.112196,0.283350,0.450170,0.0,0.943161,0.721158,0.0,0.0
2018-04-01 18:45:13,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,-0.359913,-0.112196,0.283350,0.450172,0.0,0.943161,0.721158,0.0,0.0
2018-04-01 18:45:14,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,-0.359913,-0.112196,0.283350,0.450172,0.0,0.943161,0.721157,0.0,0.0
2018-04-01 18:45:15,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,-0.359913,-0.112196,0.283350,0.450172,0.0,0.943160,0.721157,0.0,0.0
2018-04-01 18:45:16,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,-0.359913,-0.112196,0.283350,0.450170,0.0,0.943160,0.721157,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2018-04-02 02:15:06,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,-0.366104,2.534416,0.818176,1.280823,0.0,0.284945,0.361025,0.0,0.0
2018-04-02 02:15:07,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,-0.366104,2.534415,0.818176,1.280823,0.0,0.284944,0.361024,0.0,0.0
2018-04-02 02:15:08,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,-0.366104,2.534415,0.818176,1.280823,0.0,0.284942,0.361023,0.0,0.0
2018-04-02 02:15:09,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,-0.366104,2.534415,0.818176,1.280823,0.0,0.284940,0.361022,0.0,0.0


In [20]:
for event in ds:
    print(event["label"]["class"].unique())

[6]
[6]
[0]
[4]
