# SRVNet cross validation

In [14]:
from VAMPNet_SRVNet import *
from sklearn.model_selection import KFold
import os
from torch.utils.data.dataloader import DataLoader

### Split the dataset into # of folders and generate training and validation dataset respectively
### Do the cross validation based on trajectories instead of transition pairs

### set number of trajectories in total
num_trajs = 100

### set the lagtime and number of cv to do cross validation
### Note if GMRQ is implemented as criterion, need further clustering and MSM construction
lagtime = [5, 10, 20, 40]
num_cvs = [2, 3, 4, 5, 6]

### set training hyper-parameters
training_batch = 50000
num_epochs = 5
learning_rate = 1e-4

### specify the output address
output_dir = 'cross_validation_srvnet_num_cv_lagtime_5epochs_5e4batch_1e-4learning'


### set the device (not necessary)
if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.backends.cudnn.benchmark = True
else:
    device = torch.device("cpu")
    
## Set the random seed to ensure the reproducible training(not necessary)
set_random_seed(42)

os.makedirs(output_dir)
trajs = np.load("./alanine_dipeptide_45pairwise_distances_100trajs_0.1ps.npy", allow_pickle=True).item()
trajs = [i for i in trajs.values()]
trajs_idx = [i for i in range(num_trajs)]; k_folder = KFold(n_splits=5, shuffle=True, random_state=random_seed)
num_folder = 0
for train_idx, test_idx in k_folder.split(trajs_idx):
    for num_cv in num_cvs:
        for lag in lagtime:
            train_data = [trajs[i] for i in train_idx]; test_data = [trajs[i] for i in test_idx]
            train_data = TimeLaggedDataset(trajs=train_data, lagtime=lag); test_data = TimeLaggedDataset(trajs=test_data, lagtime=lag)
            train_loader = DataLoader(train_data, batch_size=training_batch, shuffle=True)
            test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=True)

            lobe = torch.nn.Sequential(
            torch.nn.BatchNorm1d(45),
            torch.nn.Linear(45, 45), torch.nn.ELU(),
            torch.nn.Linear(45, 20), torch.nn.ELU(),
            torch.nn.Linear(20, num_cv))
            lobe = lobe.to(device=device)

            projector = deep_projector(network_type='SRVNet', lobe=lobe, epsilon=1e-6, learning_rate=learning_rate, device=device)
            print("Training is starting...")
            projector.fit(train_loader=train_loader, num_epochs=num_epochs, validation_loader=test_loader)
            torch.save(projector, output_dir+"/alanine_dipeptide_45pairdistances_srvnet_{}randomseed_{}pslag_{}cvs_{}folder_model.th".format(random_seed, lag/10, num_cv, num_folder))

            training_score = np.array(projector.train_score)
            validation_score = np.array(projector.validate_score)
            np.savetxt(output_dir+"/alanine_dipeptide_45pairdistances_srvnet_{}randomseed_{}nslag_{}cvs_{}folder_validation_score.txt".format(random_seed, lag/10, num_cv, num_folder), validation_score)
            np.savetxt(output_dir+"/alanine_dipeptide_45pairdistances_srvnet_{}randomseed_{}nslag_{}cvs_{}folder_training_score.txt".format(random_seed, lag/10, num_cv, num_folder), training_score)
            
            
            project_trajs = projector.transform(trajs)
            np.save(output_dir+"/alanine_dipeptide_45pairdistances_srvnet_{}randomseed_{}nslag_{}cvs_{}folder_trajs.npy".format(random_seed, lag/10, num_cv, num_folder), project_trajs)
    num_folder += 1

load data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:01<00:00, 18.97it/s]
load data: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 19.16it/s]


Training is starting...
==>epoch=0, training process=33.33%, the training loss function=1.3315860033035278, eigenvalues:[0.49641174 0.2918243 ];
==>epoch=0, training process=33.33%, the validation loss function=2.059948444366455;
==>epoch=1, training process=66.67%, the training loss function=1.4312770366668701, eigenvalues:[0.5502906  0.35840952];
==>epoch=1, training process=66.67%, the validation loss function=2.0941214561462402;
==>epoch=2, training process=100.00%, the training loss function=1.5323909521102905, eigenvalues:[0.597474  0.4188267];
==>epoch=2, training process=100.00%, the validation loss function=2.12217378616333;


load data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 18.76it/s]
  arr = np.asanyarray(arr)
load data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:01<00:00, 18.33it/s]
load data: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 19.67it/s]


Training is starting...
==>epoch=0, training process=33.33%, the training loss function=1.3564165830612183, eigenvalues:[0.5460421  0.24135987];
==>epoch=0, training process=33.33%, the validation loss function=1.561816692352295;
==>epoch=1, training process=66.67%, the training loss function=1.4790363311767578, eigenvalues:[0.636523  0.2717991];
==>epoch=1, training process=66.67%, the validation loss function=1.6239337921142578;
==>epoch=2, training process=100.00%, the training loss function=1.585392951965332, eigenvalues:[0.70429194 0.29894125];
==>epoch=2, training process=100.00%, the validation loss function=1.6807923316955566;


load data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 19.16it/s]
load data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:01<00:00, 19.21it/s]
load data: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 19.84it/s]


Training is starting...
==>epoch=0, training process=33.33%, the training loss function=1.3413852453231812, eigenvalues:[0.57477534 0.10496929];
==>epoch=0, training process=33.33%, the validation loss function=1.6470134258270264;
==>epoch=1, training process=66.67%, the training loss function=1.4816646575927734, eigenvalues:[0.6846812  0.11347363];
==>epoch=1, training process=66.67%, the validation loss function=1.6921864748001099;
==>epoch=2, training process=100.00%, the training loss function=1.6051278114318848, eigenvalues:[0.7685863  0.12001208];
==>epoch=2, training process=100.00%, the validation loss function=1.7248921394348145;


load data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 19.50it/s]
load data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:01<00:00, 19.55it/s]
load data: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 19.65it/s]


Training is starting...
==>epoch=0, training process=33.33%, the training loss function=1.3630365133285522, eigenvalues:[0.59686005 0.08242927];
==>epoch=0, training process=33.33%, the validation loss function=1.5962040424346924;
==>epoch=1, training process=66.67%, the training loss function=1.440946102142334, eigenvalues:[0.65696025 0.09669174];
==>epoch=1, training process=66.67%, the validation loss function=1.6254653930664062;
==>epoch=2, training process=100.00%, the training loss function=1.5035536289215088, eigenvalues:[0.7013246  0.10815507];
==>epoch=2, training process=100.00%, the validation loss function=1.647284984588623;


load data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 19.43it/s]
load data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:01<00:00, 19.67it/s]
load data: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 19.53it/s]


Training is starting...
==>epoch=0, training process=33.33%, the training loss function=1.6305315494537354, eigenvalues:[0.73952067 0.26052237 0.12557407];
==>epoch=0, training process=33.33%, the validation loss function=2.25898814201355;
==>epoch=1, training process=66.67%, the training loss function=1.7089447975158691, eigenvalues:[0.769935   0.30732703 0.14729223];
==>epoch=1, training process=66.67%, the validation loss function=2.2866737842559814;
==>epoch=2, training process=100.00%, the training loss function=1.7942153215408325, eigenvalues:[0.79580367 0.36609617 0.16396789];
==>epoch=2, training process=100.00%, the validation loss function=2.3086047172546387;


load data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 19.43it/s]
load data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:01<00:00, 19.41it/s]
load data: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 19.62it/s]


Training is starting...
==>epoch=0, training process=33.33%, the training loss function=1.4903067350387573, eigenvalues:[0.6462383  0.2644483  0.05243928];
==>epoch=0, training process=33.33%, the validation loss function=2.071075916290283;
==>epoch=1, training process=66.67%, the training loss function=1.584977388381958, eigenvalues:[0.6950179  0.31509635 0.05139849];
==>epoch=1, training process=66.67%, the validation loss function=2.0917582511901855;
==>epoch=2, training process=100.00%, the training loss function=1.668588638305664, eigenvalues:[0.73315424 0.35845453 0.05083131];
==>epoch=2, training process=100.00%, the validation loss function=2.108994960784912;


load data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 19.49it/s]
load data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:01<00:00, 19.63it/s]
load data: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 19.69it/s]


Training is starting...
==>epoch=0, training process=33.33%, the training loss function=1.3708285093307495, eigenvalues:[0.5232181  0.31148428 0.00699098];
==>epoch=0, training process=33.33%, the validation loss function=1.9269356727600098;
==>epoch=1, training process=66.67%, the training loss function=1.5181307792663574, eigenvalues:[0.6356645  0.33767366 0.00615727];
==>epoch=1, training process=66.67%, the validation loss function=1.9292519092559814;
==>epoch=2, training process=100.00%, the training loss function=1.6488748788833618, eigenvalues:[0.7233968  0.35431325 0.005836  ];
==>epoch=2, training process=100.00%, the validation loss function=1.928896188735962;


load data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 19.46it/s]
load data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:01<00:00, 19.71it/s]
load data: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 19.46it/s]


Training is starting...
==>epoch=0, training process=33.33%, the training loss function=1.0868654251098633, eigenvalues:[0.22616443 0.188238   0.01677885];
==>epoch=0, training process=33.33%, the validation loss function=1.657677173614502;
==>epoch=1, training process=66.67%, the training loss function=1.1920018196105957, eigenvalues:[0.37477717 0.22601333 0.02149165];
==>epoch=1, training process=66.67%, the validation loss function=1.6744961738586426;
==>epoch=2, training process=100.00%, the training loss function=1.3398652076721191, eigenvalues:[0.52876496 0.24421231 0.02516141];
==>epoch=2, training process=100.00%, the validation loss function=1.6874561309814453;


load data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 19.53it/s]
load data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:01<00:00, 19.46it/s]


KeyboardInterrupt: 