In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline
# %config Completer.use_jedi = True

In [None]:
import logging

logging.basicConfig(level=logging.DEBUG)
LOGGER = logging.getLogger(__name__)

In [None]:
import tsdm

In [None]:
import os
import subprocess
from getpass import getpass

import pandas as pd

from tsdm.datasets.base import BaseDataset, SimpleDataset, Dataset


# class MIMIC_III(SimpleDataset):
#     base_url: str = r"https://physionet.org/content/mimiciii/get-zip/1.4/"
#     info_url: str = r"https://physionet.org/content/mimiciii/1.4/"
#     # dataset_files = "observations.feather"
#     rawdata_files = "mimic-iii-clinical-database-1.4.zip"

#     def _clean(self):
#         ts_file = self.rawdata_dir / "complete_tensor.csv"
#         if not ts_file.exists():
#             raise RuntimeError(
#                 "Please apply the preprocessing code found at "
#                 "https://github.com/edebrouwer/gru_ode_bayes/."
#                 f"\nPut the resulting file 'complete_tensor.csv' in {self.rawdata_dir}."
#             )

#         ts = pd.read_csv(ts_file)
#         ts = ts.sort_values(by=["UNIQUE_ID", "TIME_STAMP"])
#         ts = ts.astype(
#             {
#                 "UNIQUE_ID": "int16",
#                 "LABEL_CODE": "int16",
#                 "TIME_STAMP": "int16",
#                 "LABEL_CODE": "float32",
#                 "MEAN": "float32",
#                 "STD": "float32",
#             }
#         )

#         means = ts.groupby("LABEL_CODE").mean()["VALUENUM"].rename("MEANS")
#         stdvs = ts.groupby("LABEL_CODE").std()["VALUENUM"].rename("STDVS")
#         stats = pd.DataFrame([means, stdvs]).T.reset_index()
#         stats = stats.astype(
#             {
#                 "LABEL_CODE": "int16",
#                 "MEANS": "float32",
#                 "STDVS": "float32",
#             }
#         )

#         ts = ts[["UNIQUE_ID", "TIME_STAMP", "LABEL_CODE", "VALUENORM"]]
#         ts = ts.reset_index(drop=True)
#         stats.to_feather(self.dataset_dir / "stats.feather")
#         ts.to_feather(self.dataset_dir / "MIMIC_III.feather")

#     def _load(self):
#         # return NotImplemented
#         return pd.read_feather(self.dataset_dir / self.dataset_files)

#     def _download(self):
#         cut_dirs = self.base_url.count("/") - 3
#         user = input("MIMIC-III username: ")
#         password = getpass(prompt="MIMIC-III password: ", stream=None)

#         os.environ["PASSWORD"] = password

#         subprocess.run(
#             f"wget --user {user} --password $PASSWORD -c -r -np -nH -N "
#             + f"--cut-dirs {cut_dirs} -P '{self.rawdata_dir}' {self.base_url} ",
#             shell=True,
#             check=True,
#         )

#         file = self.rawdata_dir / "index.html"
#         os.rename(file, self.rawdata_files)

In [None]:
class MIMIC_III(Dataset):
    base_url: str = r"https://physionet.org/content/mimiciii/get-zip/1.4/"
    info_url: str = r"https://physionet.org/content/mimiciii/1.4/"
    dataset_files = {"observations": "observations.feather", "stats": "stats.feather"}
    rawdata_files = "mimic-iii-clinical-database-1.4.zip"
    index = ["observations", "stats"]
    default_key = "observations"

    def _clean(self, key):
        ts_file = self.rawdata_dir / "complete_tensor.csv"
        if not ts_file.exists():
            raise RuntimeError(
                "Please apply the preprocessing code found at "
                "https://github.com/edebrouwer/gru_ode_bayes/."
                f"\nPut the resulting file 'complete_tensor.csv' in {self.rawdata_dir}."
            )

        ts = pd.read_csv(ts_file)
        ts = ts.sort_values(by=["UNIQUE_ID", "TIME_STAMP"])
        ts = ts.astype(
            {
                "UNIQUE_ID": "int16",
                "TIME_STAMP": "int16",
                "LABEL_CODE": "int16",
                "VALUENORM": "float32",
                "MEAN": "float32",
                "STD": "float32",
            }
        )

        means = ts.groupby("LABEL_CODE").mean()["VALUENUM"].rename("MEANS")
        stdvs = ts.groupby("LABEL_CODE").std()["VALUENUM"].rename("STDVS")
        stats = pd.DataFrame([means, stdvs]).T.reset_index()
        stats = stats.astype(
            {
                "LABEL_CODE": "int16",
                "MEANS": "float32",
                "STDVS": "float32",
            }
        )

        ts = ts[["UNIQUE_ID", "TIME_STAMP", "LABEL_CODE", "VALUENORM"]]
        ts = ts.reset_index(drop=True)
        stats.to_feather(self.dataset_paths["stats"])
        ts.to_feather(self.dataset_paths["observations"])

    def _load(self, key):
        # return NotImplemented
        return pd.read_feather(self.dataset_paths[key])

    def _download(self):
        cut_dirs = self.base_url.count("/") - 3
        user = input("MIMIC-III username: ")
        password = getpass(prompt="MIMIC-III password: ", stream=None)

        os.environ["PASSWORD"] = password

        subprocess.run(
            f"wget --user {user} --password $PASSWORD -c -r -np -nH -N "
            + f"--cut-dirs {cut_dirs} -P '{self.rawdata_dir}' {self.base_url} ",
            shell=True,
            check=True,
        )

        file = self.rawdata_dir / "index.html"
        os.rename(file, self.rawdata_files)

In [None]:
ds = MIMIC_III()

In [None]:
ds.observations

In [None]:
def f(x, *, a=1):
    return x