In [45]:
import sys

sys.path.append('../')

import os
import torch
import pickle
import numpy as np
import pandas as pd
from tqdm.auto import tqdm, trange
import random

from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import KFold

from scipy.integrate import solve_ivp
from scipy.integrate import odeint

import torch
from torch.utils.data import Dataset, DataLoader

import gzip
import json

from infrastructure.randutils import *
from infrastructure.misc import *

np.random.seed(0)
random.seed(0)


def generate_simulate_data(t_min, t_max, steps, folds, test_split, max_samples, extrap):
    
    R = 3
    dim = 3
    
    dall = []

    U0 = np.random.rand(dim, R)
    U1 = np.random.rand(dim, R)
    
    #print(U0)
    #print(U1)

    for i in range(dim):
        for j in range(dim):

            ui = U0[i,:]
            uj = U1[j,:]

            def dynamics(t, y, ui, uj):
                term1 = np.log(1 + y**2 + t*(np.power(y, 2/3.))) 
                term2 = np.tanh(np.sum(ui*uj)*t)
                return term1 + term2

            dm = lambda t,y : dynamics(t, y, ui, uj)

            y0 = np.expand_dims(np.sum(ui*uj), axis=0)
            t_eval = np.linspace(0.001, t_max, steps)
            t_span = np.array([0.0, t_max])

            soln = solve_ivp(fun=dm, t_span=t_span, y0=y0, t_eval=t_eval, method='DOP853')

            ind_n = np.tile(np.array([i, j]), [steps, 1])
            obs_n = soln.y.reshape([steps, 1])
            t_n = t_eval.reshape([steps, 1])

            dij = np.hstack([ind_n, t_n, obs_n])
            dall.append(dij)
        #
    #
    
    data = np.vstack(dall)
    obs_all = data[:,-1].reshape([-1,1])

    scaler_y = StandardScaler()
    scaler_y.fit(obs_all)

    yscaled = scaler_y.transform(obs_all)

    data[:,-1] = yscaled.squeeze()
    
    if extrap:
        t_split = t_max - t_max*test_split
        tr_idx = data[:,-2]<=t_split
        te_idx = data[:,-2]>t_split
    else:
        t_split1 = (0.5-0.5*test_split)*t_max
        t_split2 = (0.5+0.5*test_split)*t_max
        
        #print(data[:,-2]<t_split1)
        #print(data[:,-2]>=t_split2)
        tr_idx = np.any([data[:,-2]<t_split1, data[:,-2]>=t_split2], axis=0)
        te_idx = np.all([data[:,-2]>=t_split1, data[:,-2]<t_split2], axis=0)
        
    #print(tr_idx)
    #print(te_idx)

    print(data)

    data_tr = data[tr_idx]
    data_te = data[te_idx]
    
    perm_tr = np.random.permutation(data_tr.shape[0])
    perm_te = np.random.permutation(data_te.shape[0])

    data_tr = data_tr[perm_tr, :]
    data_te = data_te[perm_te, :]
    
    #cprint('r', data_tr.shape)
    #cprint('b', data_te.shape)

    train_list = []
    test_list = []

    cprint('r', data_tr.shape)
    cprint('b', data_te.shape)

    for fold in range(5):
        train_list.append(data_tr)
        test_list.append(data_te)
    #
    
    D = {}
    D['nvec'] = [dim, dim]
    D['nmod'] = 2
    D['train_folds'] = train_list
    D['test_folds'] = test_list
    D['t_min'] = t_min
    D['t_max'] = t_max
    
    
    save_path = os.path.join('processed')
    
    if extrap:
        pickle_name = 'SimuExtrap' + '.pickle'
    else:
        pickle_name = 'SimuInterp' + '.pickle'

    create_path(save_path)

    with open(os.path.join(save_path, pickle_name), 'wb') as handle:
        pickle.dump(D, handle, protocol=pickle.HIGHEST_PROTOCOL)
    #

Data = generate_simulate_data(
    t_min=0.001, 
    t_max=10.0, 
    steps=100,
    folds=5, 
    test_split=0.333,
    max_samples=15000,
    extrap=False,
)

[[ 0.00000000e+00  0.00000000e+00  1.00000000e-03 -1.31303472e+00]
 [ 0.00000000e+00  0.00000000e+00  1.02000000e-01 -1.30867217e+00]
 [ 0.00000000e+00  0.00000000e+00  2.03000000e-01 -1.30309090e+00]
 ...
 [ 2.00000000e+00  2.00000000e+00  9.79800000e+00  1.81415610e+00]
 [ 2.00000000e+00  2.00000000e+00  9.89900000e+00  1.85937337e+00]
 [ 2.00000000e+00  2.00000000e+00  1.00000000e+01  1.90472340e+00]]
[31m(612, 4)[0m
[34m(288, 4)[0m
Directory 'processed' created successfully
