In [1]:
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt

from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split

from tsdart.utils import set_random_seed
from tsdart.loss import Prototypes
from tsdart.model import TSDART, TSDARTLayer, TSDARTEstimator
from tsdart.dataprocessing import Preprocessing

In [2]:
device = torch.device('cpu')

In [8]:
data = []
for i in range(10):
    tmp = np.load(rf'./data/spec_osasis_filter_1000_select_CA_pairwise_distances_WT_{i}_traj.npy')
    print(np.shape(tmp))
    data.append(tmp)

(11701, 1000)
(11930, 1000)
(11925, 1000)
(11948, 1000)
(11911, 1000)
(12008, 1000)
(11997, 1000)
(11952, 1000)
(12053, 1000)
(11933, 1000)


In [5]:
pre = Preprocessing(dtype=np.float32)
dataset = pre.create_dataset(lag_time=10,data=data)

In [6]:
set_random_seed(0)

for i in range(1,11):
    import os
    os.chdir(rf'./2states-wt-1/{i}')

    val = int(len(dataset)*0.10)
    train_data, val_data = torch.utils.data.random_split(dataset, [len(dataset)-val, val]) #way the dataset is split depends on the underlying pseudo-random number generator

    loader_train = DataLoader(train_data, batch_size=1000, shuffle=True)
    loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)

    lobe = TSDARTLayer([1000,550,250,100,50,25,2],n_states=2)
    lobe = lobe.to(device=device)

    tsdart = TSDART(lobe = lobe, learning_rate = 1e-3, device = device, mode = 'regularize', beta=0.01 , feat_dim=2, n_states=2, pretrain=50)
    tsdart_model = tsdart.fit(loader_train, n_epochs=60, validation_loader=loader_val).fetch_model()

    validation_vamp = tsdart.validation_vamp
    validation_dis = tsdart.validation_dis
    validation_prototypes = tsdart.validation_prototypes

    training_vamp = tsdart.training_vamp
    training_dis = tsdart.training_dis

    np.save(('validation_vamp.npy'),validation_vamp)
    np.save(('validation_dis.npy'),validation_dis)
    np.save(('validation_prototypes.npy'),validation_prototypes)

    np.save(('training_vamp.npy'),training_vamp)
    np.save(('training_dis.npy'),training_dis)

    hypersphere_embs = tsdart_model.transform(data=data,return_type='hypersphere_embs')
    metastable_states = tsdart_model.transform(data=data,return_type='states')

    tsdart_estimator = TSDARTEstimator(tsdart_model)
    ood_scores = tsdart_estimator.fit(data).ood_scores
    state_centers = tsdart_estimator.fit(data).state_centers

    hypersphere_embs = np.array(hypersphere_embs,dtype=object)
    metastable_states = np.array(metastable_states,dtype=object)
    ood_scores = np.array(ood_scores,dtype=object)

    np.save(('hypersphere_embs.npy'), hypersphere_embs, allow_pickle=True)
    np.save(('metastable_states.npy'), metastable_states, allow_pickle=True)
    np.save(('ood_scores.npy'), ood_scores, allow_pickle=True)
    np.save(('state_centers.npy'), state_centers)

    torch.save(tsdart_model.lobe.state_dict(), 'model.pt')

                                                      