In [None]:
import os
os.chdir("../")

In [None]:
import numpy as np
import torch

In [None]:
from storage.har_datasets import WISDMDataset, sts_medoids, split_by_test_subject
from s3ts.api.dms.har_datasets import LDFDataset, DFDataset
from storage.label_mappings import *
from s3ts.api.nets.methods import create_model_from_DM, train_model, test_model
from torchvision.transforms import Normalize

In [None]:
# this dataset comes at 20hz
ds = WISDMDataset("./datasets/WISDM/", wsize=20, normalize=True, label_mapping=None)
print(len(ds))

In [None]:
if not os.path.exists("./datasets/WISDM/meds.npz"):
    meds = sts_medoids(ds, n=500)
    with open("./datasets/WISDM/meds.npz", "wb") as f:
        np.save(f, meds)
else:
    meds = np.load("./datasets/WISDM/meds.npz")

dfds = DFDataset(ds, patterns=meds, w=0.1, dm_transform=None, ram=False)

In [None]:
DM = []

np.random.seed(42)
for i in np.random.choice(np.arange(len(dfds)), 500):
    dm, _, _ = dfds[i]
    DM.append(dm)

DM = torch.stack(DM)

dm_transform = Normalize(mean=DM.mean(dim=[0, 2, 3]), std=DM.std(dim=[0, 2, 3]))
dfds.dm_transform = dm_transform

In [None]:
data_split = split_by_test_subject(ds, 35)

dm = LDFDataset(dfds, data_split=data_split, batch_size=128, random_seed=42, num_workers=8)

In [None]:
print(len(dm.ds_train))
print(len(dm.ds_val))
print(len(dm.ds_test))

In [None]:
model = create_model_from_DM(dm, name=None, 
        dsrc="img", arch="cnn", task="cls")

In [None]:
model, data = train_model(dm, model, max_epochs=2)
print(data)