In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset

import numpy as np

import math

import time

import dataloader
import models
import training_fun

import optuna

import joblib

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SEQ_LENGTH = 365 * 2
TARGET_SEQ_LENGTH = 365
BASE_LENGTH = SEQ_LENGTH - TARGET_SEQ_LENGTH

FORCING_DIM = 3

N_CATCHMENTS = 749

# training hyperparameters
EPOCHS = 500
TRAIN_YEAR = 19
PATIENCE = 20

use_amp = True
compile_model = False

if compile_model:
    torch.set_float32_matmul_precision("high")

memory_saving = False
if memory_saving:
    storge_device = "cpu"
    computing_device = DEVICE
    VAL_STEPS = 500
else:
    storge_device = DEVICE
    computing_device = DEVICE

In [4]:
dtrain = dataloader.Forcing_Data(
    "data/data3f_train.csv",
    record_length=5843,
    storge_device=storge_device,
    seq_length=SEQ_LENGTH,
    target_seq_length=TARGET_SEQ_LENGTH,
    base_length=BASE_LENGTH,
)

dval = dataloader.Forcing_Data(
    "data/data3f_val.csv",
    record_length=2191,
    storge_device=storge_device,
    seq_length=SEQ_LENGTH,
    target_seq_length=TARGET_SEQ_LENGTH,
    base_length=BASE_LENGTH,
)


In [5]:
class Objective:
    def __init__(self, model_builder):
        self.model_builder = model_builder

    def objective(self, trial):

        # prepare early stopper
        early_stopper = training_fun.EarlyStopper(patience=PATIENCE, min_delta=0)

        # define model
        embedding, decoder = self.model_builder.define_model(trial)
        embedding, decoder = embedding.to(computing_device), decoder.to(
            computing_device
        )

        if compile_model:
            # pytorch2.0 new feature, complile model for fast training
            embedding, decoder = torch.compile(embedding), torch.compile(decoder)

        # define optimizers
        lr_embedding = trial.suggest_float("lr_embedding", 5e-5, 1e-2, log=True)
        embedding_optimizer = optim.Adam(embedding.parameters(), lr=lr_embedding)

        lr_decoder = trial.suggest_float("lr_decoder", 5e-5, 1e-2, log=True)
        decoder_optimizer = optim.Adam(decoder.parameters(), lr=lr_decoder)

        scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

        # define batch size
        batch_size_power = trial.suggest_int("batch_size_power", 4, 8)
        batch_size = 2**batch_size_power

        # train model
        for epoch in range(EPOCHS):

            # for each epoch get_random_batch method generates a batch that contains one year data for each catchment
            # repeat TRAIN_YEAR times to finish an epoch
            decoder.train()
            embedding.train()

            for year in range(TRAIN_YEAR):

                x_batch, y_batch = dtrain.get_random_batch()

                if memory_saving:
                    x_batch, y_batch = x_batch.to(computing_device), y_batch.to(
                        computing_device
                    )

                catchment_index = torch.randperm(
                    N_CATCHMENTS, device=computing_device
                )  # add randomness

                # interate over catchments
                for i in range(int(N_CATCHMENTS / batch_size)):

                    # prepare data
                    ind_s = i * batch_size
                    ind_e = (i + 1) * batch_size

                    selected_catchments = catchment_index[ind_s:ind_e]

                    x_sub, y_sub = x_batch[ind_s:ind_e, :, :], y_batch[ind_s:ind_e, :]

                    # prepare training, put the models into training mode
                    decoder_optimizer.zero_grad()
                    embedding_optimizer.zero_grad()

                    # forward pass
                    with torch.autocast(
                        device_type="cuda", dtype=torch.float16, enabled=use_amp
                    ):
                        code = embedding(selected_catchments)
                        out = decoder.decode(code, x_sub)

                        # backprop
                        loss = training_fun.mse_loss_with_nans(out, y_sub)

                    scaler.scale(loss).backward()
                    scaler.step(embedding_optimizer)
                    scaler.step(decoder_optimizer)
                    scaler.update()

            # validate model after each epochs
            decoder.eval()
            embedding.eval()

            # Handle pruning based on the intermediate value
            if memory_saving:
                val_loss = training_fun.val_model_mem_saving(
                    embedding=embedding,
                    decoder=decoder,
                    dataset=dval,
                    storge_device=storge_device,
                    computing_device=computing_device,
                    use_amp=use_amp,
                    val_metric=training_fun.mse_loss_with_nans,
                    return_summary=True,
                    val_steps=VAL_STEPS,
                )
            else:
                val_loss = (
                    training_fun.val_model(
                        embedding=embedding,
                        decoder=decoder,
                        dataset=dval,
                        storge_device=storge_device,
                        computing_device=computing_device,
                        use_amp=use_amp,
                        val_metric=training_fun.mse_loss_with_nans,
                        return_summary=True,
                    )
                    .detach()
                    .cpu()
                    .numpy()
                )

            trial.report(val_loss, epoch)

            if trial.should_prune():
                torch.cuda.empty_cache()
                raise optuna.exceptions.TrialPruned()

            # Early stop using early_stopper, break for loop
            if early_stopper.early_stop(val_loss):
                break

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        return early_stopper.min_validation_loss

In [7]:
LSTM_model_builder = training_fun.LSTM_model_builder(
    n_catchments=N_CATCHMENTS, base_length=BASE_LENGTH, forcing_dim=FORCING_DIM
)

LSTM_objective = Objective(LSTM_model_builder).objective

In [8]:
study = optuna.create_study(
    study_name="base_model", direction="minimize", pruner=optuna.pruners.NopPruner()
)
study.optimize(LSTM_objective, n_trials=10)

joblib.dump(study, "complete_LSTM_study.pkl")

[32m[I 2022-12-16 14:33:27,525][0m A new study created in memory with name: base_model[0m
[33m[W 2022-12-16 14:33:45,755][0m Trial 0 failed because of the following error: KeyboardInterrupt()[0m
Traceback (most recent call last):
  File "/Users/yang/opt/anaconda3/envs/pytorch1.13/lib/python3.10/site-packages/optuna/study/_optimize.py", line 196, in _run_trial
    value_or_values = func(trial)
  File "/var/folders/0j/tmjcqbl14mz0t6hplcmbj7yh0000gn/T/ipykernel_20010/1314539915.py", line 79, in objective
    scaler.scale(loss).backward()
  File "/Users/yang/opt/anaconda3/envs/pytorch1.13/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/Users/yang/opt/anaconda3/envs/pytorch1.13/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
KeyboardInterrupt


KeyboardInterrupt: 