In [1]:
import dysts.flows
import pytorch_lightning as pl
import torch
from dysts.base import DynSys
from numpy.random import rand
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import TensorDataset, random_split, DataLoader

from config import ROOT_DIR
from ecodyna.data import build_in_out_pair_dataset, load_or_generate_and_save
from ecodyna.metrics import ForecastMetricLogger
from ecodyna.mutitask_models import MultiTaskRNN
from ecodyna.pl_wrappers import LightningForecaster
from ecodyna.plot import plot_1d_trajectories

In [2]:
data_parameters = {'trajectory_count': 100, 'trajectory_length': 1000, 'resample': True, 'pts_per_period': 50}
in_out_parameters = {'n_in': 10, 'n_out': 5}
common_model_parameters = {'n_hidden': 32, 'n_layers': 1, **in_out_parameters}
experiment_parameters = {'max_epochs': 50, 'train_part': 0.75, 'random_seed': 26}
dataloader_parameters = {'batch_size': 64, 'num_workers': 8}

train_size = int(experiment_parameters['train_part'] * data_parameters['trajectory_count'])
val_size = data_parameters['trajectory_count'] - train_size

attractor_name = 'Lorenz'

In [3]:
# Sets random seed for random, numpy and torch
pl.seed_everything(experiment_parameters['random_seed'], workers=True);

Global seed set to 26


In [4]:
attractor: DynSys = getattr(dysts.flows, attractor_name)()

attractor_x0 = attractor.ic.copy()
space_dim = len(attractor_x0)

data = load_or_generate_and_save(path=f'{ROOT_DIR}/data/attractor={attractor_name}-{"-".join([f"{k}={v}" for k, v in data_parameters.items()])}',
                                 attractor=attractor, verbose=True, **data_parameters,
                                 ic_fun=lambda: rand(space_dim) - 0.5 + attractor_x0)
dataset = TensorDataset(data)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_data = data[train_dataset.indices]
val_data = data[val_dataset.indices]

chunk_train_dataset = build_in_out_pair_dataset(train_dataset, **in_out_parameters)
chunk_val_dataset = build_in_out_pair_dataset(val_dataset, **in_out_parameters)

chunk_train_dl = DataLoader(chunk_train_dataset, **dataloader_parameters, shuffle=True)
chunk_val_dl = DataLoader(chunk_val_dataset, **dataloader_parameters)

train_dataloader = DataLoader(train_dataset, **dataloader_parameters)
val_dataloader = DataLoader(val_dataset, **dataloader_parameters)

In [5]:
def train_rnn_and_plot_different_forecast_types(rnn):
    n_plots = 4
    forecaster = LightningForecaster(model=rnn)

    wandb_logger = WandbLogger(
        save_dir=f'{ROOT_DIR}/results',
        project='notebook-lstm-forecast-types-chunkyness-evaluation',
        name=f'{rnn.name()}_{attractor_name}_{rnn.forecast_type}'
    )

    wandb_logger.experiment.config.update({
        'forecaster': {'name': rnn.name(), **rnn.hyperparams},
        'data': {'attractor': attractor_name, **data_parameters},
        'dataloader': dataloader_parameters,
        'experiment': experiment_parameters
    })

    trainer = pl.Trainer(
        max_epochs=experiment_parameters['max_epochs'],
        callbacks=[EarlyStopping('val_loss', patience=5), ForecastMetricLogger(train_dataset, val_dataset)],
        logger=wandb_logger)

    trainer.fit(forecaster, train_dataloaders=chunk_train_dl, val_dataloaders=chunk_val_dl)

    for name, data, dataloader in [('train', train_data, train_dataloader), ('validation', val_data, val_dataloader)]:
        labels = [f'ground truth ({name})']
        tensors = [data]
        for forecast_function_name in rnn.get_applicable_forecast_functions().keys():
            forecaster.prediction_func_name = forecast_function_name
            predictions = torch.concat(trainer.predict(forecaster, dataloaders=dataloader))
            labels.append(f'prediction ({forecast_function_name})')
            tensors.append(predictions)
        plot_1d_trajectories(labels=labels, tensors=tensors, n_plots=n_plots)
    wandb_logger.experiment.finish(quiet=True)

In [None]:
train_rnn_and_plot_different_forecast_types(rnn=MultiTaskRNN(model='GRU', forecast_type='one_by_one', space_dim=space_dim, **common_model_parameters))

In [None]:
train_rnn_and_plot_different_forecast_types(rnn=MultiTaskRNN(model='LSTM', forecast_type='one_by_one', space_dim=space_dim, **common_model_parameters))

In [None]:
train_rnn_and_plot_different_forecast_types(rnn=MultiTaskRNN(model='LSTM', forecast_type='multi', space_dim=space_dim, **common_model_parameters))