# Title

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
import numpy as np

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
import pandas

from tsdm.tasks import KIWI_FINAL_PRODUCT

In [None]:
task = KIWI_FINAL_PRODUCT()

In [None]:
dloader = task.get_dataloader((0, "train"), shuffle=False)

sampler = dloader.sampler

In [None]:
ts = task.timeseries
t0 = ts.loc[(439, 15325)].index[0]

In [None]:
(sampler[(439, 15325)]["right"] - t0) / pandas.Timedelta("1h")

In [None]:
from tsdm.encoders import *

ts, md = task.splits[0, "train"]


encoder = ChainedEncoder(
    TensorEncoder(device="cuda", names=("time", "value", "index")),
    DataFrameEncoder(
        column_encoders={
            "value": IdentityEncoder(),
            tuple(ts.columns): FloatEncoder("float32"),
        },
        index_encoders=MinMaxScaler() @ DateTimeEncoder(unit="h"),
    ),
    TripletEncoder(sparse=True),
    Standardizer(),
)
encoder.fit(ts.reset_index([0, 1], drop=True))
task.target_idx = task.timeseries.columns.get_loc(task.target)
target_encoder = (
    TensorEncoder(device="cuda") @ FloatEncoder() @ encoder[-1][task.target_idx]
)

In [None]:
from typing import NamedTuple

import torch
from torch import Tensor

from tsdm.utils.strings import *


class Batch(NamedTuple):
    index: Tensor
    timeseries: Tensor
    metadata: Tensor
    targets: Tensor
    encoded_targets: Tensor

    def __repr__(self):
        return repr_mapping(
            self._asdict(), title=self.__class__.__name__, repr_fun=repr_array
        )


def mycollate(batch: list):
    index = []
    timeseries = []
    metadata = []
    targets = []
    encoded_targets = []

    for idx, (ts_data, (md_data, target)) in batch:
        index.append(torch.tensor(idx[0]))
        timeseries.append(encoder.encode(ts_data))
        metadata.append(md_data)
        targets.append(target)
        encoded_targets.append(target_encoder.encode(target))

    index = torch.stack(index)
    targets = pandas.concat(targets)
    encoded_targets = torch.concat(encoded_targets)

    return Batch(index, timeseries, metadata, targets, encoded_targets)

In [None]:
dloader = task.batchloaders[0, "train"]

In [None]:
encoder[-1].mean

In [None]:
dloader.collate_fn = mycollate

In [None]:
batch = next(iter(dloader))

In [None]:
key = next(iter(dloader.sampler))
sample = dloader.dataset[key]
(key, slc), (ts, (md, target)) = sample

In [None]:
key

In [None]:
target