In [1]:
import dysts.flows
import pytorch_lightning as pl
import torch
from dysts.base import DynSys
from numpy.random import rand
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import TensorDataset, random_split, DataLoader

from config import ROOT_DIR
from ecodyna.data import build_in_out_pair_dataset, load_or_generate_and_save, build_data_path
from ecodyna.metrics import ForecastMetricLogger
from ecodyna.models.mutitask_models import MyGRU, MyLSTM
from ecodyna.models.task_modules import ChunkForecaster
from ecodyna.plot import plot_1d_trajectories

In [1]:
params = {
    'data': {
        'attractor': 'Lorenz',
        'trajectory_count': 100,
        'trajectory_length': 1000,
        'resample': True,
        'pts_per_period': 50,
        'ic_noise': 0.01
    },
    'experiment': {
        'max_epochs': 50,
        'train_part': 0.75,
        'random_seed': 26
    },
    'in_out': {
        'n_in': 10,
        'n_out': 5
    },
    'models': {
        'n_hidden': 32,
        'n_layers': 1
    },
    'dataloader': {
        'batch_size': 64,
        'num_workers': 8
    }
}
params['models']['common'].update(params['in_out'])

{'data': {'attractor': 'Lorenz',
  'trajectory_count': 100,
  'trajectory_length': 1000,
  'resample': True,
  'pts_per_period': 0.01,
  'ic_noise': 0.01},
 'experiment': {'max_epochs': 50, 'train_part': 0.75, 'random_seed': 26},
 'in_out': {'n_in': 10, 'n_out': 5},
 'models': {'n_hidden': 32, 'n_layers': 1, 'n_in': 10, 'n_out': 5},
 'dataloader': {'batch_size': 64, 'num_workers': 8, 'n_in': 10, 'n_out': 5}}

In [3]:
# Sets random seed for random, numpy and torch
pl.seed_everything(params['experiment']['random_seed'], workers=True);

Global seed set to 26


In [4]:
attractor: DynSys = getattr(dysts.flows, params['data']['attractor'])()

attractor_x0 = attractor.ic.copy()
space_dim = len(attractor_x0)

data = load_or_generate_and_save(path=build_data_path(**params['data']),
                                 attractor=attractor, verbose=True, **params['data'],
                                 ic_fun=lambda: params['data']['ic_noise'] * (rand(space_dim) - 0.5) + attractor_x0)

train_size = int(params['experiment']['train_part'] * params['data']['trajectory_count'])
val_size = params['data']['trajectory_count'] - train_size

dataset = TensorDataset(data)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_data = data[train_dataset.indices]
val_data = data[val_dataset.indices]

chunk_train_dataset = build_in_out_pair_dataset(train_dataset, **params['in_out'])
chunk_val_dataset = build_in_out_pair_dataset(val_dataset, **params['in_out'])

chunk_train_dl = DataLoader(chunk_train_dataset, **params['dataloader'], shuffle=True)
chunk_val_dl = DataLoader(chunk_val_dataset, **params['dataloader'])

train_dataloader = DataLoader(train_dataset, **params['dataloader'])
val_dataloader = DataLoader(val_dataset, **params['dataloader'])

Generating data for attractor Lorenz


100%|██████████| 100/100 [00:31<00:00,  3.14it/s]


In [5]:
def train_rnn_and_plot_forecasts(rnn):
    n_plots = 4
    forecaster = ChunkForecaster(model=rnn)

    wandb_logger = WandbLogger(
        save_dir=f'{ROOT_DIR}/results',
        project='notebook-lstm-forecast-types-chunkyness-evaluation',
        name=f"{rnn.name()}_{params['data']['attractor']}_{rnn.forecast_type}"
    )

    wandb_logger.experiment.config.update({
        'forecaster': {'name': rnn.name(), **rnn.hyperparams},
        'data': params['data'],
        'dataloader': params['dataloader'],
        'experiment': params['experiment']
    })

    trainer = pl.Trainer(
        max_epochs=params['experiment']['max_epochs'],
        callbacks=[EarlyStopping('val_loss', patience=5), ForecastMetricLogger(train_dataset, val_dataset)],
        logger=wandb_logger)

    trainer.fit(forecaster, train_dataloaders=chunk_train_dl, val_dataloaders=chunk_val_dl)

    for name, data, dataloader in [('train', train_data, train_dataloader), ('validation', val_data, val_dataloader)]:
        labels = [f'ground truth ({name})']
        tensors = [data]
        for func_name, forecast_func in rnn.get_applicable_forecast_functions().items():
            forecaster.prediction_func = forecast_func
            predictions = torch.concat(trainer.predict(forecaster, dataloaders=dataloader))
            labels.append(f'prediction ({func_name})')
            tensors.append(predictions)
        plot_1d_trajectories(labels=labels, tensors=tensors, n_plots=n_plots)
    wandb_logger.experiment.finish(quiet=True)

In [None]:
train_rnn_and_plot_forecasts(
    rnn=MyGRU(forecast_type='one_by_one', space_dim=space_dim, **params['models']))

In [None]:
train_rnn_and_plot_forecasts(
    rnn=MyLSTM(forecast_type='one_by_one', space_dim=space_dim, **params['models']))

In [None]:
train_rnn_and_plot_forecasts(rnn=MyLSTM(forecast_type='multi', space_dim=space_dim, **params['models']))