In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
from synthcity.utils.serialization import save_to_file, load_from_file

In [2]:
base_path = Path("/code/datasets/ward")

train_static = pd.read_csv(base_path / "ward_static_train_data.csv.gz")
train_temporal = pd.read_csv(base_path / "ward_temporal_train_data_eav.csv.gz")

test_static = pd.read_csv(base_path / "ward_static_test_data.csv.gz")
test_temporal = pd.read_csv(base_path / "ward_temporal_test_data_eav.csv.gz")

In [3]:
import numpy as np
import pandas as pd

def process_temporal(in_static_data, in_temporal_data):
    all_temporal_data = []
    horizons = []
    
    for uid in in_static_data["id"].unique():
        local_temporal_idx = in_temporal_data["id"] == uid
        local_temporal = in_temporal_data[local_temporal_idx]
        columns = sorted(local_temporal["variable"].unique())
        times = sorted(local_temporal["time"].unique())

        temporal_data = pd.DataFrame([], columns = columns)


        for horizon in times:
            local_temporal_horizon_idx = local_temporal["time"] == horizon
            local_temporal_horizon = local_temporal[local_temporal_horizon_idx]

            horizon_data = pd.DataFrame(-1 * np.ones((1, len(columns))), columns = columns)

            proc = local_temporal_horizon[["variable", "value"]]
            proc.index = local_temporal_horizon["variable"]
            proc = proc.drop(columns = ["variable"])
            proc = proc.T.reset_index(drop = True)
            print(len(proc.columns), len(columns))
            horizon_data[proc.columns] = proc

            temporal_data = pd.concat([temporal_data, horizon_data], ignore_index = True)
        temporal_data.index = times

        for col in columns:
            if col not in temporal_data:
                temporal_data[col] = -1
        horizons.append(temporal_data["time"])
        all_temporal_data.append(temporal_data[columns])
        
    assert len(all_temporal_data) == len(in_static_data)

    return all_temporal_data, horizons

def eav_to_wide(df):
    """Transform EAV format to WIDE format.
    
    Args:
        - df: EAV format dataframe
        
    Returns:
        - df_wide: WIDE format dataframe.    
    """
    # Original data needs the following four column name in order.
    col_names = list(df.columns)
    assert col_names[0] == "id"
    assert col_names[1] == "time"
    assert col_names[2] == "variable"
    assert col_names[3] == "value"

    # Convert EAV format to WIDE format
    df_wide = pd.pivot_table(df, index=["id", "time"], columns="variable", values="value").reset_index(level=[0, 1])
    return df_wide


train_temporal_wide = eav_to_wide(train_temporal)
test_temporal_wide = eav_to_wide(test_temporal)

In [4]:
train_temporal_wide[train_temporal_wide["id"] == 1]

variable,id,time,Best Motor Response,Best Verbal Response,CHLORIDE,CREATINEINE,DBP,Eye Opening,GLUCLOSE,Glasgow Coma Scale Score,...,POTASSIUM,Pulse,Respiratory Rate,SBP,SODIUM,SpO2,TOTAL CO2,Temperature,UREA NITROGEN,WHITE BLOOD CELL COUNT
0,1,0.0,5.0,5.0,99.0,0.6,107.00,3.0,133.0,13.0,...,3.8,78.0,12.00,174.0,136.0,100.00,21.0,98.6,8.0,19.75
1,1,2.0,5.0,5.0,,,96.50,3.0,,13.0,...,,85.5,17.00,156.5,,100.00,,98.1,,
2,1,3.0,5.0,5.0,,,104.00,3.0,,13.0,...,,79.0,18.00,169.0,,100.00,,98.1,,
3,1,4.0,5.0,5.0,,,115.00,3.0,,13.0,...,,73.0,18.00,177.0,,100.00,,98.1,,
4,1,6.0,5.0,5.0,,,102.00,3.0,,13.0,...,,76.0,18.00,162.0,,100.00,,98.1,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
75,1,239.0,6.0,5.0,,,65.00,4.0,,15.0,...,,89.0,16.00,110.0,,99.00,,99.5,,
76,1,240.0,,,97.0,0.5,,,103.0,,...,3.7,,,,135.0,,24.5,,6.0,14.99
77,1,241.0,6.0,5.0,,,80.00,4.0,,15.0,...,,89.0,17.00,127.0,,99.00,,97.5,,
78,1,247.0,6.0,5.0,,,68.25,4.0,,15.0,...,,88.5,16.75,113.5,,97.25,,97.7,,


In [5]:
def prepare_temporal(temporal_wide):
    temporal = []
    horizons = []
    for k, v in temporal_wide.groupby("id"):
        h = v["time"].values.tolist()
        local_data = v.drop(columns = ["id", "time"])
        local_data.index = h
        local_data = local_data.dropna()
        
        horizons.append(local_data.index)
        temporal.append(local_data)
    return temporal, horizons

train_temporal, train_horizons = prepare_temporal(train_temporal_wide)
test_temporal, test_horizons = prepare_temporal(test_temporal_wide)

train_outcome = train_static["icu_admission"]
test_outcome = test_static["icu_admission"]
train_static = train_static.drop(columns = ["id", "icu_admission"]).fillna(0)
test_static = test_static.drop(columns = ["id", "icu_admission"]).fillna(0)

In [6]:
assert len(train_temporal) == len(train_static)
assert len(test_temporal) == len(test_static)

In [7]:
from synthcity.plugins.core.dataloader import (
     TimeSeriesDataLoader,
)

dataloader_train = TimeSeriesDataLoader(
    temporal_data=train_temporal,
    temporal_horizons=train_horizons,
    static_data=train_static,
    outcome = train_outcome.to_frame(),
)


dataloader_test = TimeSeriesDataLoader(
    temporal_data=test_temporal,
    temporal_horizons=test_horizons,
    static_data=test_static,
    outcome = test_outcome.to_frame(),
)

<stdin>:1:10: fatal error: cuda.h: No such file or directory
compilation terminated.

<stdin>:1:10: fatal error: cuda.h: No such file or directory
compilation terminated.

<stdin>:1:10: fatal error: cuda.h: No such file or directory
compilation terminated.



In [8]:
train_static_eval, train_temporal_eval, train_horizons_eval, train_outcome_eval = dataloader_train.unpack(as_numpy = True, pad = True)

In [9]:
test_static_eval, test_temporal_eval, test_horizons_eval, test_outcome_eval = dataloader_test.unpack(as_numpy = True, pad = True)

In [10]:
train_temporal[0]

variable,Best Motor Response,Best Verbal Response,CHLORIDE,CREATINEINE,DBP,Eye Opening,GLUCLOSE,Glasgow Coma Scale Score,HEMOGLOBIN,O2 Device: Aerosol mask,...,POTASSIUM,Pulse,Respiratory Rate,SBP,SODIUM,SpO2,TOTAL CO2,Temperature,UREA NITROGEN,WHITE BLOOD CELL COUNT
0.0,5.0,5.0,99.0,0.6,107.0,3.0,133.0,13.0,13.5,0.0,...,3.8,78.0,12.0,174.0,136.0,100.0,21.0,98.6,8.0,19.75
48.0,6.0,5.0,108.0,0.4,90.0,4.0,105.0,15.0,13.6,0.0,...,3.6,85.0,20.0,146.0,138.0,95.0,19.0,98.9,9.0,18.43
83.0,6.0,5.0,102.0,0.4,72.0,4.0,266.5,15.0,11.3,0.0,...,3.3,101.0,18.0,138.0,132.0,95.0,19.0,101.1,7.5,13.55
96.0,6.0,5.0,96.0,0.4,73.0,4.0,428.0,15.0,12.3,0.0,...,3.0,96.5,20.0,129.5,126.0,95.5,19.0,99.4,6.0,16.31
118.0,6.0,5.0,97.0,0.5,81.0,4.0,110.0,15.0,11.9,0.0,...,3.8,91.5,19.0,138.0,135.0,96.0,22.0,98.7,6.0,18.4


In [11]:
from typing import Any, List
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import roc_auc_score
from synthcity.plugins.core.models.survival_analysis.metrics import (
     evaluate_brier_score,
     evaluate_c_index,
     generate_score,
     print_score,
 )
import synthcity.logger as log
import sys

log.remove()
log.add(sink=sys.stderr, level="DEBUG")

def evaluate_ts_classification(
     estimator: Any,
     static: np.ndarray,
     temporal: np.ndarray,
     temporal_horizons: np.ndarray,
     Y: np.ndarray,
     n_folds: int = 3,
     metrics: List[str] = ["aucroc"],
     random_state: int = 0,
     pretrained: bool = False,
):
    results = {
        "aucroc": [],
    }
 
    static = np.asarray(static)
    temporal = np.asarray(temporal, dtype=object)
    temporal_horizons = np.asarray(temporal_horizons, dtype=object)
    Y = np.asarray(Y)

    def _get_metrics(
        cv_idx: int,
        static_train: np.ndarray,
        static_test: np.ndarray,
        temporal_train: np.ndarray,
        temporal_test: np.ndarray,
        temporal_horizons_train: np.ndarray,
        temporal_horizons_test: np.ndarray,
        Y_train: np.ndarray,
        Y_test: np.ndarray,
    ) -> tuple:
        if pretrained:
            model = estimator[cv_idx]
        else:
            model = copy.deepcopy(estimator)

            model.fit(
                static_train, temporal_train, temporal_horizons_train, Y_train
            )
        pred = model.predict(
                static_test, temporal_test, temporal_horizons_test
            )
 
        return roc_auc_score(Y_test, pred)

    if n_folds == 1:
        cv_idx = 0
        (
            static_train,
            static_test,
            temporal_train,
            temporal_test,
            temporal_horizons_train,
            temporal_horizons_test,
            Y_train,
            Y_test,
        ) = train_test_split(
            static, temporal, temporal_horizons, Y, random_state=random_state
        )
 
        aucroc = _get_metrics(
            cv_idx,
            static_train,
            static_test,
            temporal_train,
            temporal_test,
            temporal_horizons_train,
            temporal_horizons_test,
            Y_train,
            Y_test,
        )
        results["aucroc"] = [aucroc]
    else:
        skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=random_state)
 
        cv_idx = 0
        for train_index, test_index in skf.split(temporal, Y):
            static_train = static[train_index]
            temporal_train = temporal[train_index]
            temporal_horizons_train = temporal_horizons[train_index]
            Y_train = Y[train_index]
 
            static_test = static[test_index]
            temporal_test = temporal[test_index]
            temporal_horizons_test = temporal_horizons[test_index]
            Y_test = Y[test_index]
 
 
            aucroc = _get_metrics(
                cv_idx,
                static_train,
                static_test,
                temporal_train,
                temporal_test,
                temporal_horizons_train,
                temporal_horizons_test,
                Y_train,
                Y_test,
            )
            results["aucroc"].append(aucroc)
 
            cv_idx += 1
 
    output: dict = {
        "clf": {},
        "str": {},
    }
 
    for metric in metrics:
        output["clf"][metric] = generate_score(results[metric])
        output["str"][metric] = print_score(output["clf"][metric])
 
    return output

        

In [12]:
# stdlib
from typing import Any, Callable, List, Optional, Tuple

# third party
import numpy as np
import torch
from pydantic import validate_arguments
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, sampler
from tsai.models.gMLP import gMLP
from tsai.models.InceptionTime import InceptionTime
from tsai.models.InceptionTimePlus import InceptionTimePlus
from tsai.models.MINIROCKET_Pytorch import MiniRocket
from tsai.models.MINIROCKETPlus_Pytorch import MiniRocketPlus
from tsai.models.mWDN import mWDNPlus
from tsai.models.OmniScaleCNN import OmniScaleCNN
from tsai.models.ResCNN import ResCNN
from tsai.models.RNN_FCN import MLSTM_FCN
from tsai.models.TCN import TCN
from tsai.models.TransformerModel import TransformerModel
from tsai.models.TSiTPlus import TSiTPlus
from tsai.models.TST import TST
from tsai.models.TSTPlus import TSTPlus
from tsai.models.XceptionTime import XceptionTime
from tsai.models.XCM import XCM

# synthcity absolute
import synthcity.logger as log
from synthcity.plugins.core.models.mlp import MLP, MultiActivationHead, get_nonlin
from synthcity.utils.constants import DEVICE
from synthcity.utils.reproducibility import enable_reproducible_results

modes = [
    "LSTM",
    "GRU",
    "RNN",
    "MLSTM_FCN",
    "TCN",
    "InceptionTime",
    "InceptionTimePlus",
    "XceptionTime",
    "ResCNN",
    "OmniScaleCNN",
    "TST",
    "TSTPlus",
    "XCM",
    "gMLP",
    "MiniRocket",
    "MiniRocketPlus",
    "TransformerModel",
    "TSiTPlus",
    "mWDNPlus",
]


class TimeSeriesModel(nn.Module):
    @validate_arguments(config=dict(arbitrary_types_allowed=True))
    def __init__(
        self,
        task_type: str,  # regression, classification
        n_static_units_in: int,
        n_temporal_units_in: int,
        n_temporal_window: int,
        output_shape: List[int],
        n_static_units_hidden: int = 102,
        n_static_layers_hidden: int = 2,
        n_temporal_units_hidden: int = 102,
        n_temporal_layers_hidden: int = 2,
        n_iter: int = 500,
        mode: str = "RNN",
        n_iter_print: int = 10,
        batch_size: int = 150,
        lr: float = 1e-3,
        weight_decay: float = 1e-3,
        window_size: int = 1,
        device: Any = DEVICE,
        dataloader_sampler: Optional[sampler.Sampler] = None,
        nonlin_out: Optional[List[Tuple[str, int]]] = None,
        loss: Optional[Callable] = None,
        dropout: float = 0,
        nonlin: Optional[str] = "relu",
        random_state: int = 0,
        clipping_value: int = 1,
        use_horizon_condition: bool = True,
    ) -> None:
        super(TimeSeriesModel, self).__init__()

        enable_reproducible_results(random_state)

        assert task_type in ["classification", "regression"]
        assert mode in modes, f"Unsupported mode {mode}. Available: {modes}"
        assert len(output_shape) > 0

        self.task_type = task_type

        if loss is not None:
            self.loss = loss
        elif task_type == "regression":
            self.loss = nn.MSELoss()
        elif task_type == "classification":
            self.loss = nn.CrossEntropyLoss()

        self.n_iter = n_iter
        self.n_iter_print = n_iter_print
        self.batch_size = batch_size
        self.n_static_units_in = n_static_units_in
        self.n_temporal_units_in = n_temporal_units_in
        self.n_temporal_window = n_temporal_window
        self.n_static_units_hidden = n_static_units_hidden
        self.n_temporal_units_hidden = n_temporal_units_hidden
        self.n_static_layers_hidden = n_static_layers_hidden
        self.n_temporal_layers_hidden = n_temporal_layers_hidden
        self.device = device
        self.window_size = window_size
        self.dataloader_sampler = dataloader_sampler
        self.lr = lr
        self.output_shape = output_shape
        self.n_units_out = np.prod(self.output_shape)
        self.clipping_value = clipping_value
        self.use_horizon_condition = use_horizon_condition

        self.temporal_layer = TimeSeriesLayer(
            n_static_units_in=n_static_units_in,
            n_temporal_units_in=n_temporal_units_in
            + int(use_horizon_condition),  # measurements + horizon
            n_temporal_window=n_temporal_window,
            n_units_out=self.n_units_out,
            n_static_units_hidden=n_static_units_hidden,
            n_static_layers_hidden=n_static_layers_hidden,
            n_temporal_units_hidden=n_temporal_units_hidden,
            n_temporal_layers_hidden=n_temporal_layers_hidden,
            mode=mode,
            window_size=window_size,
            device=device,
            dropout=dropout,
            nonlin=nonlin,
            random_state=random_state,
        )

        self.mode = mode

        self.out_activation: Optional[nn.Module] = None
        self.n_act_out: Optional[int] = None

        if nonlin_out is not None:
            self.n_act_out = 0
            activations = []
            for nonlin, nonlin_len in nonlin_out:
                self.n_act_out += nonlin_len
                activations.append((get_nonlin(nonlin), nonlin_len))

            if self.n_units_out % self.n_act_out != 0:
                raise RuntimeError(
                    f"Shape mismatch for the output layer. Expected length {self.n_units_out}, but got {nonlin_out} with length {self.n_act_out}"
                )
            self.out_activation = MultiActivationHead(activations, device=device)
        elif self.task_type == "classification":
            self.n_act_out = self.n_units_out
            self.out_activation = MultiActivationHead(
                [(nn.Softmax(dim=-1), self.n_units_out)], device=device
            )

        self.optimizer = torch.optim.Adam(
            self.parameters(),
            lr=lr,
            weight_decay=weight_decay,
        )  # optimize all rnn parameters

    @validate_arguments(config=dict(arbitrary_types_allowed=True))
    def forward(
        self,
        static_data: torch.Tensor,
        temporal_data: torch.Tensor,
        temporal_horizons: torch.Tensor,
    ) -> torch.Tensor:
        # x shape (batch, time_step, input_size)
        # r_out shape (batch, time_step, output_size)

        assert torch.isnan(static_data).sum() == 0
        assert torch.isnan(temporal_data).sum() == 0
        assert torch.isnan(temporal_horizons).sum() == 0

        if self.use_horizon_condition:
            temporal_data_merged = torch.cat(
                [temporal_data, temporal_horizons.unsqueeze(2)], dim=2
            )
        else:
            temporal_data_merged = temporal_data

        assert torch.isnan(temporal_data_merged).sum() == 0

        pred = self.temporal_layer(static_data, temporal_data_merged)

        if self.out_activation is not None:
            pred = pred.reshape(-1, self.n_act_out)
            pred = self.out_activation(pred)

        pred = pred.reshape(-1, *self.output_shape)

        return pred

    @validate_arguments(config=dict(arbitrary_types_allowed=True))
    def predict(
        self,
        static_data: np.ndarray,
        temporal_data: np.ndarray,
        temporal_horizons: np.ndarray,
    ) -> np.ndarray:
        self.eval()
        with torch.no_grad():
            temporal_data_t = self._check_tensor(temporal_data).float()
            temporal_horizons_t = self._check_tensor(temporal_horizons).float()
            static_data_t = self._check_tensor(static_data).float()

            yt = self(static_data_t, temporal_data_t, temporal_horizons_t)

            if self.task_type == "classification":
                return np.argmax(yt.cpu().numpy(), -1)
            else:
                return yt.cpu().numpy()

    def score(
        self,
        static_data: np.ndarray,
        temporal_data: np.ndarray,
        temporal_horizons: np.ndarray,
        outcome: np.ndarray,
    ) -> float:
        y_pred = self.predict(static_data, temporal_data, temporal_horizons)
        if self.task_type == "classification":
            return np.mean(y_pred == outcome)
        else:
            return np.mean(np.inner(outcome - y_pred, outcome - y_pred) / 2.0)

    @validate_arguments(config=dict(arbitrary_types_allowed=True))
    def fit(
        self,
        static_data: np.ndarray,
        temporal_data: np.ndarray,
        temporal_horizons: np.ndarray,
        outcome: np.ndarray,
    ) -> Any:
        temporal_data_t = self._check_tensor(temporal_data).float()
        temporal_horizons_t = self._check_tensor(temporal_horizons).float()
        static_data_t = self._check_tensor(static_data).float()
        outcome_t = self._check_tensor(outcome).float()
        if self.task_type == "classification":
            outcome_t = outcome_t.long()

        return self._train(
            static_data_t, temporal_data_t, temporal_horizons_t, outcome_t
        )

    @validate_arguments(config=dict(arbitrary_types_allowed=True))
    def _train(
        self,
        static_data: Optional[torch.Tensor],
        temporal_data: torch.Tensor,
        temporal_horizons: torch.Tensor,
        outcome: torch.Tensor,
    ) -> Any:
        loader = self.dataloader(static_data, temporal_data, temporal_horizons, outcome)
        # training and testing
        for it in range(self.n_iter):
            loss = self._train_epoch(loader)
            if it % self.n_iter_print == 0:
                log.info(f"Epoch:{it}| train loss: {loss}")

        return self

    def _train_epoch(self, loader: DataLoader) -> float:
        losses = []
        for step, (static_mb, temporal_mb, horizons_mb, y_mb) in enumerate(loader):
            self.optimizer.zero_grad()  # clear gradients for this training step

            pred = self(static_mb, temporal_mb, horizons_mb)  # rnn output
            loss = self.loss(pred, y_mb)

            loss.backward()  # backpropagation, compute gradients
            if self.clipping_value > 0:
                torch.nn.utils.clip_grad_norm_(self.parameters(), self.clipping_value)
            self.optimizer.step()  # apply gradients

            losses.append(loss.detach().cpu())

        return np.mean(losses)

    def dataloader(
        self,
        static_data: torch.Tensor,
        temporal_data: torch.Tensor,
        temporal_horizons: torch.Tensor,
        outcome: torch.Tensor,
    ) -> DataLoader:
        dataset = TensorDataset(static_data, temporal_data, temporal_horizons, outcome)

        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            sampler=self.dataloader_sampler,
            pin_memory=False,
        )

    def _check_tensor(self, X: torch.Tensor) -> torch.Tensor:
        if isinstance(X, torch.Tensor):
            return X.to(self.device)
        else:
            return torch.from_numpy(np.asarray(X)).to(self.device)


class TimeSeriesLayer(nn.Module):
    def __init__(
        self,
        n_static_units_in: int,
        n_temporal_units_in: int,
        n_temporal_window: int,
        n_units_out: int,
        n_static_units_hidden: int = 100,
        n_static_layers_hidden: int = 2,
        n_temporal_units_hidden: int = 100,
        n_temporal_layers_hidden: int = 2,
        mode: str = "RNN",
        window_size: int = 1,
        device: Any = DEVICE,
        dropout: float = 0,
        nonlin: Optional[str] = "relu",
        random_state: int = 0,
    ) -> None:
        super(TimeSeriesLayer, self).__init__()
        temporal_params = {
            "input_size": n_temporal_units_in,
            "hidden_size": n_temporal_units_hidden,
            "num_layers": n_temporal_layers_hidden,
            "dropout": 0 if n_temporal_layers_hidden == 1 else dropout,
            "batch_first": True,
        }
        temporal_models = {
            "RNN": nn.RNN,
            "LSTM": nn.LSTM,
            "GRU": nn.GRU,
        }

        if mode in ["RNN", "LSTM", "GRU"]:
            self.temporal_layer = temporal_models[mode](**temporal_params)
        elif mode == "MLSTM_FCN":
            self.temporal_layer = MLSTM_FCN(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                hidden_size=n_temporal_units_hidden,
                rnn_layers=n_temporal_layers_hidden,
                fc_dropout=dropout,
                seq_len=n_temporal_window,
                shuffle=False,
            )
        elif mode == "TCN":
            self.temporal_layer = TCN(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                fc_dropout=dropout,
            )
        elif mode == "InceptionTime":
            self.temporal_layer = InceptionTime(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                depth=n_temporal_layers_hidden,
                seq_len=n_temporal_window,
            )
        elif mode == "InceptionTimePlus":
            self.temporal_layer = InceptionTimePlus(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                depth=n_temporal_layers_hidden,
                seq_len=n_temporal_window,
            )
        elif mode == "XceptionTime":
            self.temporal_layer = XceptionTime(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
            )
        elif mode == "ResCNN":
            self.temporal_layer = ResCNN(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
            )
        elif mode == "OmniScaleCNN":
            self.temporal_layer = OmniScaleCNN(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                seq_len=max(n_temporal_window, 10),
            )
        elif mode == "TST":
            self.temporal_layer = TST(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                seq_len=n_temporal_window,
                max_seq_len=n_temporal_window,
                n_layers=n_temporal_layers_hidden,
            )
        elif mode == "XCM":
            self.temporal_layer = XCM(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                seq_len=n_temporal_window,
                fc_dropout=dropout,
            )
        elif mode == "gMLP":
            self.temporal_layer = gMLP(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                seq_len=n_temporal_window,
                depth=n_temporal_layers_hidden,
            )
        elif mode == "MiniRocket":
            self.temporal_layer = MiniRocket(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                seq_len=n_temporal_window,
                random_state=random_state,
                fc_dropout=dropout,
            )
        elif mode == "MiniRocketPlus":
            self.temporal_layer = MiniRocketPlus(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                seq_len=n_temporal_window,
                fc_dropout=dropout,
            )
        elif mode == "TransformerModel":
            self.temporal_layer = TransformerModel(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                dropout=dropout,
                n_layers=n_temporal_layers_hidden,
            )
        elif mode == "TSiTPlus":
            self.temporal_layer = TSiTPlus(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                seq_len=n_temporal_window,
                depth=n_temporal_layers_hidden,
                dropout=dropout,
            )
        elif mode == "TSTPlus":
            self.temporal_layer = TSTPlus(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                seq_len=n_temporal_window,
                n_layers=n_temporal_layers_hidden,
                dropout=dropout,
            )
        elif mode == "mWDNPlus":
            self.temporal_layer = mWDNPlus(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                seq_len=n_temporal_window,
            )
        else:
            raise RuntimeError(f"Unknown TS mode {mode}")

        self.device = device
        self.mode = mode

        if mode in ["RNN", "LSTM", "GRU"]:
            self.out = WindowLinearLayer(
                n_static_units_in=n_static_units_in,
                n_temporal_units_in=n_temporal_units_hidden,
                window_size=window_size,
                n_units_out=n_units_out,
                n_layers=n_static_layers_hidden,
                dropout=dropout,
                nonlin=nonlin,
                device=device,
            )
        else:
            self.out = MLP(
                task_type="regression",
                n_units_in=n_static_units_in + n_temporal_units_hidden,
                n_units_out=n_units_out,
                n_layers_hidden=n_static_layers_hidden,
                n_units_hidden=n_static_units_hidden,
                dropout=dropout,
                nonlin=nonlin,
                device=device,
            )

        self.temporal_layer.to(device)
        self.out.to(device)

    def forward(
        self, static_data: torch.Tensor, temporal_data: torch.Tensor
    ) -> torch.Tensor:
        if self.mode in ["RNN", "LSTM", "GRU"]:
            X_interm, _ = self.temporal_layer(temporal_data)

            assert torch.isnan(X_interm).sum() == 0

            return self.out(static_data, X_interm)
        else:
            X_interm = self.temporal_layer(torch.swapaxes(temporal_data, 1, 2))

            assert torch.isnan(X_interm).sum() == 0

            return self.out(torch.cat([static_data, X_interm], dim=1))


class WindowLinearLayer(nn.Module):
    @validate_arguments(config=dict(arbitrary_types_allowed=True))
    def __init__(
        self,
        n_static_units_in: int,
        n_temporal_units_in: int,
        window_size: int,
        n_units_out: int,
        n_units_hidden: int = 100,
        n_layers: int = 1,
        dropout: float = 0,
        nonlin: Optional[str] = "relu",
        device: Any = DEVICE,
    ) -> None:
        super(WindowLinearLayer, self).__init__()

        self.device = device
        self.window_size = window_size
        self.model = MLP(
            task_type="regression",
            n_units_in=n_static_units_in + n_temporal_units_in * window_size,
            n_units_out=n_units_out,
            n_layers_hidden=n_layers,
            n_units_hidden=n_units_hidden,
            dropout=dropout,
            nonlin=nonlin,
            device=device,
        )

    @validate_arguments(config=dict(arbitrary_types_allowed=True))
    def forward(
        self, static_data: torch.Tensor, temporal_data: torch.Tensor
    ) -> torch.Tensor:
        assert len(static_data) == len(temporal_data)
        batch_size, seq_len, n_feats = temporal_data.shape
        temporal_batch = temporal_data[:, seq_len - self.window_size :, :].reshape(
            batch_size, n_feats * self.window_size
        )
        batch = torch.cat([static_data, temporal_batch], axis=1)

        return self.model(batch).to(self.device)

In [19]:
from synthcity.utils.samplers import ImbalancedDatasetSampler
sampler = ImbalancedDatasetSampler(train_outcome_eval.squeeze().tolist())

model = TimeSeriesModel(
        task_type = "classification",  # regression, classification
         n_static_units_in = train_static_eval.shape[-1],
         n_temporal_units_in = train_temporal_eval.shape[-1],
         n_temporal_window = train_temporal_eval.shape[1],
         output_shape = [2],
         dataloader_sampler = sampler,
            n_iter = 100,
    )

model.fit(train_static_eval, train_temporal_eval, train_horizons_eval, train_outcome_eval.squeeze())

model.predict(train_static_eval, train_temporal_eval, train_horizons_eval).sum() 


[2022-07-16T10:49:15.691159+0300][73514][INFO] Epoch:0| train loss: 0.6926623582839966
[2022-07-16T10:49:20.919168+0300][73514][INFO] Epoch:10| train loss: 0.6813757419586182
[2022-07-16T10:49:25.948066+0300][73514][INFO] Epoch:20| train loss: 0.6724405288696289
[2022-07-16T10:49:30.973634+0300][73514][INFO] Epoch:30| train loss: 0.6722961664199829
[2022-07-16T10:49:36.192089+0300][73514][INFO] Epoch:40| train loss: 0.669145405292511
[2022-07-16T10:49:41.225527+0300][73514][INFO] Epoch:50| train loss: 0.6708655953407288
[2022-07-16T10:49:46.458650+0300][73514][INFO] Epoch:60| train loss: 0.6705731749534607
[2022-07-16T10:49:51.496984+0300][73514][INFO] Epoch:70| train loss: 0.6675920486450195
[2022-07-16T10:49:56.535480+0300][73514][INFO] Epoch:80| train loss: 0.6683081984519958
[2022-07-16T10:50:01.765552+0300][73514][INFO] Epoch:90| train loss: 0.6665168404579163


1629

In [18]:
model.predict(train_static_eval, train_temporal_eval, train_horizons_eval).sum()

0

In [None]:
len(train_outcome_eval)

In [20]:
from synthcity.plugins.core.models.ts_model import modes

def eval_model(**kwargs):
    n_folds = 3
    model = TimeSeriesModel(
        task_type = "classification",  # regression, classification
         n_static_units_in = train_static.shape[-1],
         n_temporal_units_in = train_temporal[0].shape[-1],
         n_temporal_window = max([len(t) for t in train_temporal]),
         output_shape = [2],
        n_iter = 100,
         **kwargs,
    )

    model.fit(train_static_eval, train_temporal_eval, train_horizons_eval, train_outcome_eval.squeeze())
        
    score = evaluate_ts_classification(
        [model] * n_folds, 
        test_static, test_temporal, test_horizons, test_outcome_eval.squeeze(), 
        pretrained = True,
        n_folds = n_folds
    )
    return score


In [21]:
import tabulate

headers = ["Model", "AUCROC"]
results = pd.DataFrame([], columns = headers)

for mode in ["RNN"]:
    try:
        score = eval_model(mode = mode)["str"]
        
        print(mode, score)
    except BaseException as e:
        print("failed", mode, e)
        continue
    local_results = pd.DataFrame([[f"TimeSeriesModel[{mode}]", score["aucroc"]]], columns = headers)
    results = pd.concat([results, local_results], ignore_index = True)
    
    break
    
tabulate.tabulate(results, tablefmt='html')

failed RNN input.size(-1) must be equal to input_size. Expected 39, got 77


''

In [None]:
train_outcome_eval.squeeze()