In [1]:
import os
import pickle 
from datetime import datetime

import numpy as np 
import pandas as pd 
# import matplotlib.pyplot as plt

from ts2vec import TS2Vec
import torch
from torch import nn


In [2]:
# Model traing parameters 
OUTPUT_DIMS = 320 
TEMPORAL_UNIT = 2 
BATCH_SIZE = 128
N_EPOCHS = 2
HIDDEN_DIMS = 64
KERNEL_SIZE = 3
SAVE_CHECK_POINT = False

# 1. Model training (only for onnx testing)

In [3]:
# load train data 
train_data_dir = "/Users/fguo/cmt/ts2vec/sample_data/train_data_sample10.npy"
train_motion_names_dir = "/Users/fguo/cmt/ts2vec/sample_data/train_motion_names_sample10.parquet"
train_data = np.load(train_data_dir)
train_motion_names = pd.read_parquet(train_motion_names_dir)
print(train_data.shape, train_motion_names.shape)

# load val data 
val_data_dir = "/Users/fguo/cmt/ts2vec/sample_data/val_data_sample4.npy"
val_motion_names_dir = "/Users/fguo/cmt/ts2vec/sample_data/val_motion_names_sample4.parquet"
val_data = np.load(val_data_dir)
val_motion_names = pd.read_parquet(val_motion_names_dir)
print(val_data.shape, val_motion_names.shape)

(10, 900, 8) (10, 2)
(4, 900, 8) (4, 2)


In [4]:
model = TS2Vec(
    input_dims=train_data.shape[-1],
    device='cpu', 
    output_dims=OUTPUT_DIMS,
    hidden_dims=HIDDEN_DIMS, 
    temporal_unit=TEMPORAL_UNIT,
    batch_size=BATCH_SIZE,
    after_epoch_callback=None
)
loss_log = model.fit(
    train_data,
    val_data,
    verbose=True, 
    n_epochs=N_EPOCHS, 
    )

Epoch #0: train loss=39.84284973144531
Epoch #0: val loss=18.48442268371582
 
Epoch #1: train loss=26.8051815032959
Epoch #1: val loss=12.734784126281738
 


In [5]:
model.save("/Users/fguo/cmt/ts2vec/model_checkpoints/sample_model.pkl")

# 2. Load model & ONNX

In [None]:
# ts2vec model
# model_run_id: 07406a4af7284df0b8d8f266509161e8
# epoch: 18, model_18

In [3]:
model = TS2Vec(
    input_dims=8,
    device='cpu', 
    output_dims=OUTPUT_DIMS,
    hidden_dims=HIDDEN_DIMS, 
    temporal_unit=TEMPORAL_UNIT,
    batch_size=BATCH_SIZE,
    after_epoch_callback=None
)

model.load("/Users/fguo/cmt/ts2vec/model_checkpoints/model_18.pkl")

In [4]:
model.net.eval()

AveragedModel(
  (module): TSEncoder(
    (input_fc): Linear(in_features=8, out_features=64, bias=True)
    (feature_extractor): DilatedConvEncoder(
      (net): Sequential(
        (0): ConvBlock(
          (conv1): SamePadConv(
            (conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
          )
          (conv2): SamePadConv(
            (conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
          )
        )
        (1): ConvBlock(
          (conv1): SamePadConv(
            (conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
          )
          (conv2): SamePadConv(
            (conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
          )
        )
        (2): ConvBlock(
          (conv1): SamePadConv(
            (conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(4,))
          )
          (conv2): SamePadConv(
            (conv): Conv1d(64, 64, ke

In [5]:
model.net.training

False

In [6]:
dummy_input = torch.randn(1, 1, 8, device="cpu")
onnx_program = torch.onnx.dynamo_export(model.net, dummy_input)
onnx_program.save("/Users/fguo/cmt/ts2vec/model_checkpoints/model_18.onnx")



# 3. Load Onnx model and evaluation 

In [7]:
import onnxruntime as ort 

In [8]:
ort_session = ort.InferenceSession("/Users/fguo/cmt/ts2vec/model_checkpoints/model_18.onnx")

[0;93m2024-03-28 09:06:52.501937 [W:onnxruntime:, graph.cc:3593 CleanUnusedInitializersAndNodeArgs] Removing initializer '_if_then_branch__inlfunc_aten_index_put_bool_tmp_7'. It is not used by any node and should be removed from the model.[m
[0;93m2024-03-28 09:06:52.501953 [W:onnxruntime:, graph.cc:3593 CleanUnusedInitializersAndNodeArgs] Removing initializer '_inlfunc_aten_index_put_bool_cond_6'. It is not used by any node and should be removed from the model.[m
[0;93m2024-03-28 09:06:52.501957 [W:onnxruntime:, graph.cc:3593 CleanUnusedInitializersAndNodeArgs] Removing initializer '_inlfunc_aten_index_put_bool_tmp_3'. It is not used by any node and should be removed from the model.[m
[0;93m2024-03-28 09:06:52.501961 [W:onnxruntime:, graph.cc:3593 CleanUnusedInitializersAndNodeArgs] Removing initializer '_inlfunc_aten_index_put_bool_tmp_2'. It is not used by any node and should be removed from the model.[m
[0;93m2024-03-28 09:06:52.501965 [W:onnxruntime:, graph.cc:3593 CleanU

In [9]:
input_data = np.array([[[-0.5066, -0.0651, -0.1815, -0.8371, -0.5118, -0.0956, -0.2307, -0.4243]]]).astype(np.float32)

In [10]:
torch_out = model.net(torch.from_numpy(input_data)).detach().numpy()

input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
outputs = ort_session.run([output_name], {input_name: input_data})
ort_out = outputs[0]

In [11]:
close_out = np.isclose(torch_out, ort_out)

In [12]:
torch_out[~close_out]

array([-0.00075907, -0.0053235 , -0.00055186], dtype=float32)

In [13]:
ort_out[~close_out]

array([-0.00075901, -0.0053236 , -0.00055191], dtype=float32)