# USHCN task

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
from sklearn.preprocessing import LabelEncoder

## How splits should look like

In [None]:
from tsdm.tasks import KIWI_RUNS_TASK

task = KIWI_RUNS_TASK()

In [None]:
task.splits  # dict that holds the actual data
task.split_idx  # dense form
task.split_idx_sparse  # sparse splits

In [None]:
import tsdm

USHCN = tsdm.datasets.USHCN_SmallChunkedSporadic()
ds = USHCN.dataset
IDX = ds.index.to_frame()

In [None]:
from collections.abc import Mapping, Sequence
from functools import cached_property
from typing import Any
from torch.nn.utils.rnn import pad_sequence
import torch
from pandas import DataFrame, MultiIndex
from torch.utils.data import DataLoader

from tsdm.datasets import USHCN_SmallChunkedSporadic
from tsdm.tasks import BaseTask


class USHCN(BaseTask):
    """TODO: resale time minmax."""
    
    observation_time = 150
    prediction_steps = 3

    def __init__(self):
        super().__init__()

        self.IDs = self.dataset.reset_index()["ID"].unique()

    @cached_property
    def dataset(self) -> DataFrame:
        return USHCN_SmallChunkedSporadic().dataset

    @cached_property
    def folds(self) -> list[dict[int, Sequence]]:
        num_folds = 5
        folds = []
        np.random.seed(432)
        for fold in range(num_folds):
            train_idx, test_idx = train_test_split(self.IDs, test_size=0.1)
            train_idx, valid_idx = train_test_split(train_idx, test_size=0.2)
            folds.append(
                {
                    "train": train_idx,
                    "valid": valid_idx,
                    "test": test_idx,
                }
            )

        return folds

    @cached_property
    def split_idx(self):
        fold_idx = Index(list(range(len(self.folds))), name="fold")

        splits = DataFrame(index=self.IDs, columns=fold_idx, dtype="string")

        for k in range(num_folds):
            for key, split in self.folds[k].items():
                mask = splits.index.isin(split)
                splits[k] = splits[k].where(
                    ~mask, key
                )  # where cond is false is replaces with key
        return splits

    @cached_property
    def split_idx_sparse(self) -> DataFrame:
        r"""Return sparse table with indices for each split.

        Returns
        -------
        DataFrame[bool]
        """
        df = self.split_idx
        columns = df.columns

        # get categoricals
        categories = {
            col: df[col].astype("category").dtype.categories for col in columns
        }

        if isinstance(df.columns, MultiIndex):
            index_tuples = [
                (*col, cat)
                for col, cats in zip(columns, categories)
                for cat in categories[col]
            ]
            names = df.columns.names + ["partition"]
        else:
            index_tuples = [
                (col, cat)
                for col, cats in zip(columns, categories)
                for cat in categories[col]
            ]
            names = [df.columns.name, "partition"]

        new_columns = MultiIndex.from_tuples(index_tuples, names=names)
        result = DataFrame(index=df.index, columns=new_columns, dtype=bool)

        if isinstance(df.columns, MultiIndex):
            for col in new_columns:
                result[col] = df[col[:-1]] == col[-1]
        else:
            for col in new_columns:
                result[col] = df[col[0]] == col[-1]

        return result

    def test_metric(self):
        """The test metric"""
        return MSE()

    @cached_property
    def splits(self) -> Mapping:
        splits = {}
        for key in self.index:
            mask = task.split_idx_sparse[key]
            ids = task.split_idx_sparse.index[mask]
            splits[key] = task.dataset.loc[ids]
        return splits

    @cached_property
    def index(self) -> MultiIndex:
        return self.split_idx_sparse.columns

    @cached_property
    def tensors(self) -> Mapping:
        tensors = {}
        for _id in self.IDs:
            s = self.dataset.loc[_id]
            t = torch.tensor(s.index.values, dtype=torch.float32)
            x = torch.tensor(s.values, dtype=torch.float32)
            tensors[_id] = (t, x)
        return tensors

    def get_dataloader(
        self, key, /, **dataloader_kwargs: Any
    ) -> DataLoader:
        """Return the dataloader for the given key."""

        fold, partition = key
        
        dataset = TaskDataset(
            {idx:value for idx, val in self.tensors if idx in self.folds[fold][partitions]}
            observation_horizon=self.observation_horizon
            forecasting_steps=self.forecasting_steps)
        
        
        return DataLoader(dataset, batch_size=32, collate_fn=mycollate)

In [None]:
from dataclasses import dataclass
from typing import NamedTuple

from torch import Tensor

from tsdm.utils.strings import repr_namedtuple


class Inputs(NamedTuple):
    r"""A single sample of the data."""

    t: Tensor
    x: Tensor
    t_target: Tensor

    def __repr__(self) -> str:
        return repr_namedtuple(self, recursive=False)


class Sample(NamedTuple):
    r"""A single sample of the data."""

    key: int
    inputs: Inputs
    targets: tuple[Tensor, Tensor]
    originals: tuple[Tensor, Tensor]

    def __repr__(self) -> str:
        return repr_namedtuple(self, recursive=False)


class Batch(NamedTuple):
    r"""A single sample of the data."""

    T: Tensor
    """B×N: the timestamps."""
    X: Tensor
    """B×N×D: the observations."""
    Y: Tensor
    """B×K×D: the target values."""
    M: Tensor
    """B×N: which t correspond to targets."""

    def __repr__(self) -> str:
        return repr_namedtuple(self, recursive=False)


@dataclass
class TaskDataset(torch.utils.data.Dataset):
    tensors: dict[int, tuple[Tensor, Tensor]]
    observation_horizon: float = 150.0
    forecasting_steps: int = 3

    def __len__(self):
        return len(self.tensors)

    def __getitem__(self, key):
        t, x = self.tensors[key]
        observation_mask = t <= self.observation_horizon
        first_target = observation_mask.sum()
        target_mask = slice(first_target, first_target + self.forecasting_steps)
        return Sample(
            key=key,
            inputs=Inputs(t[observation_mask], x[observation_mask], t[target_mask]),
            targets=x[target_mask],
            originals=(t, x),
        )

    def __repr__(self):
        return f"{self.__class__.__name__}"


def mycollate(batch: list[Sample]) -> Batch:
    t_list = []
    x_list = []
    m_list = []
    y_list = []

    for sample in batch:
        t, x, t_target = sample.inputs
        mask = torch.cat(
            (torch.zeros_like(t, dtype=bool), torch.ones_like(t_target, dtype=bool))
        )
        x_padder = torch.full((t_target.shape[0], x.shape[-1]), fill_value=torch.nan)
        time = torch.cat((t, t_target))
        values = torch.cat((x, x_padder))
        idx = torch.argsort(time)
        t_list.append(time[idx])
        x_list.append(values[idx])
        m_list.append(mask[idx])
        y_list.append(sample.targets)

    T = pad_sequence(t_list, batch_first=True, padding_value=torch.nan).squeeze()
    X = pad_sequence(x_list, batch_first=True, padding_value=torch.nan).squeeze()
    Y = pad_sequence(y_list, batch_first=True, padding_value=torch.nan).squeeze()
    M = pad_sequence(m_list, batch_first=True, padding_value=False).squeeze()

    return Batch(T, X, Y, M)

In [None]:
dataset = DS(task.tensors)


dloader = DataLoader(dataset, batch_size=32, collate_fn=mycollate)

In [None]:
next(iter(dloader))

In [None]:
t_list = []
x_list = []
m_list = []
y_list = []

for sample in batch:
    t, x, t_target = sample.inputs
    mask = torch.cat(
        (torch.zeros_like(t, dtype=bool), torch.ones_like(t_target, dtype=bool))
    )
    x_padder = torch.full((t_target.shape[0], x.shape[-1]), fill_value=torch.nan)
    time = torch.cat((t, t_target))
    values = torch.cat((x, x_padder))
    idx = torch.argsort(time)
    t_list.append(time[idx])
    x_list.append(values[idx])
    m_list.append(mask[idx])
    y_list.append(sample.targets)


T = pad_sequence(t_list, batch_first=True, padding_value=torch.nan)
X = pad_sequence(x_list, batch_first=True, padding_value=torch.nan)
M = pad_sequence(m_list, batch_first=True, padding_value=False)
Y = pad_sequence(y_list, batch_first=True, padding_value=torch.nan)

In [None]:
[x.shape for x in batch[2].inputs]

In [None]:
zz = torch.cat((t_target, t))

idx = torch.argsort(zz)

In [None]:
zz[idx]

In [None]:
def my_collate(batch: list[Sample]): ...

In [None]:
ds[0].originals

In [None]:
t[211]

In [None]:
self = USHCN()

In [None]:
key = 0, "train"
fold, partition = key

In [None]:
def get_sample():
    obs_mask = t<=150           # first 3 years are observations
    val_idx = index[t>150][:3]  # next 3 observations are targets
    targets =

In [None]:
ts = task.splits[0, "test"]
ids = ts.reset_index()["ID"].unique()
ts

In [None]:
for i in ids:
    print(ts.loc[i].size)

In [None]:
task.split_idx_sparse.index[mask]

In [None]:
f0 = task.folds[0]["train"]
ts.loc[f0]

In [None]:
task.index

In [None]:
IDS = IDX["ID"].unique()

num_folds = 5
np.random.seed(432)

In [None]:
splits

In [None]:
from sklearn.model_selection import *

In [None]:
ts = ds.reset_index()
groups = ts.ID

In [None]:
GroupKFold(n_splits=5)

In [None]:
ds.reset_index().groupby("ID").ngroup()

In [None]:
df = ds.copy().loc[ds["ID"].isin(folds[0].train)]
df.ID = LabelEncoder().fit_transform(df.ID)
df = df.sort_values(["Time", "ID"]).set_index("ID")

In [None]:
from gru_ode_bayes.data_utils import ODE_Dataset

In [None]:
ode_ds = ODE_Dataset(panda_df=ds, idx=folds[0].train)
ode_ds.df = ode_ds.df.sort_values(["Time", "ID"])
ode_ds.df

In [None]:
df.loc[df.index == 0]

In [None]:
ode_ds[0]["path"]