In [1]:
import os
os.chdir("../")

In [2]:
import torch
import tslearn
import numpy as np
import torch

In [3]:
from pyts.image import GramianAngularField, MarkovTransitionField

In [4]:
from pyts.datasets.load import load_basic_motions
from typing import Tuple

In [5]:
@torch.jit.script
def minmax_scaler(X: torch.Tensor, range: Tuple[float, float] = (-1, 1)) -> torch.Tensor:
    '''
        Scales the last dimension of X into the given range X has shape (...,d)
        output has the same shape but the last dimension is scaled to the range
    '''
    X_min = X.min(dim=-1, keepdim=True).values
    X_std = (X - X_min)/(X.max(dim=-1, keepdim=True).values - X_min)
    return X_std * (range[1] - range[0]) + range[0]

In [6]:
data = load_basic_motions()

In [7]:
X = data["data_train"]

In [8]:
import matplotlib.pyplot as plt

In [9]:
from s3ts.helper_functions import load_dm, get_parser

In [10]:
parser = get_parser()

In [11]:
args = parser.parse_args("--dataset WISDM --window_size 32 --window_stride 1 --pattern_size 16 --subjects_for_test 30 31 32 33 34 35 --encoder_architecture cnn_gap --encoder_features 20 --decoder_architecture mlp --decoder_features 32 --decoder_layers 1 --max_epochs 1 --batch_size 128 --lr 0.001 --num_workers 8 --reduce_imbalance --normalize --label_mode 1 --num_medoids 1 --compute_n 300 --rho 0.1 --voting 1 --use_medoids --overlap -1 --mode gasf --mtf_bins 50".split(" "))

In [12]:
dm = load_dm(args)

Loaded dataset WISDM with a total of 1097056 observations for window size 32
Using 443730 observations for training and 154471 observations for validation and test


In [13]:
dm_train = enumerate(dm.train_dataloader())

In [35]:
data = next(dm_train)
print(data[1]["transformed"].min(), data[1]["transformed"].max())

tensor(-1.0000) tensor(1.)


In [20]:
problematic = data[1]["series"]

In [21]:
from s3ts.api.gaf_mtf import gaf_compute, mtf_compute

In [23]:
gaf_compute(problematic)[0, 0]

tensor([[-0.9262, -0.8091, -0.5136,  ..., -0.9913, -0.9939, -0.7028],
        [-0.8091,  0.0721,  0.4577,  ..., -0.4116, -0.6185,  0.2333],
        [-0.5136,  0.4577,  0.7699,  ..., -0.0209, -0.2606,  0.5962],
        ...,
        [-0.9913, -0.4116, -0.0209,  ..., -0.7959, -0.9181, -0.2579],
        [-0.9939, -0.6185, -0.2606,  ..., -0.9181, -0.9864, -0.4826],
        [-0.7028,  0.2333,  0.5962,  ..., -0.2579, -0.4826,  0.3883]])

In [198]:
tr = GramianAngularField()

In [199]:
tr.transform(problematic[0])

array([[[1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        ...,
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.]],

       [[1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        ...,
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.]],

       [[1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        ...,
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.]]])

In [189]:
problematic[0]

tensor([[-0.0968, -0.0968, -0.0968, -0.0968, -0.0968, -0.0968, -0.0968, -0.0968,
         -0.0968, -0.0968, -0.0968, -0.0968, -0.0968, -0.0968, -0.0968, -0.0968,
         -0.0968, -0.0968, -0.0968, -0.0968, -0.0968, -0.0968, -0.0968, -0.0968,
         -0.0968, -0.0968, -0.0968, -0.0968, -0.0968, -0.0968, -0.0968, -0.0968],
        [-1.0755, -1.0755, -1.0755, -1.0755, -1.0755, -1.0755, -1.0755, -1.0755,
         -1.0755, -1.0755, -1.0755, -1.0755, -1.0755, -1.0755, -1.0755, -1.0755,
         -1.0755, -1.0755, -1.0755, -1.0755, -1.0755, -1.0755, -1.0755, -1.0755,
         -1.0755, -1.0755, -1.0755, -1.0755, -1.0755, -1.0755, -1.0755, -1.0755],
        [-0.0865, -0.0865, -0.0865, -0.0865, -0.0865, -0.0865, -0.0865, -0.0865,
         -0.0865, -0.0865, -0.0865, -0.0865, -0.0865, -0.0865, -0.0865, -0.0865,
         -0.0865, -0.0865, -0.0865, -0.0865, -0.0865, -0.0865, -0.0865, -0.0865,
         -0.0865, -0.0865, -0.0865, -0.0865, -0.0865, -0.0865, -0.0865, -0.0865]])

In [186]:
gaf_compute(problematic, "s", (-1, 1))[0]

tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]])

In [173]:
data[1]["transformed"]

tensor([[[[    nan,     nan,     nan,  ...,     nan,     nan,     nan],
          [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
          [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
          ...,
          [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
          [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
          [    nan,     nan,     nan,  ...,     nan,     nan,     nan]],

         [[    nan,     nan,     nan,  ...,     nan,     nan,     nan],
          [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
          [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
          ...,
          [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
          [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
          [    nan,     nan,     nan,  ...,     nan,     nan,     nan]],

         [[    nan,     nan,     nan,  ...,     nan,     nan,     nan],
          [    nan,     nan,  