In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset

import numpy as np

import math

import time

import dataloader
import models
import training_fun

import optuna

import joblib

import HydroErr

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SEQ_LENGTH = 365 * 2
TARGET_SEQ_LENGTH = 365
BASE_LENGTH = SEQ_LENGTH - TARGET_SEQ_LENGTH

FORCING_DIM = 3

N_CATCHMENTS = 559

# training hyperparameters
UPDATES = 1000
TRAIN_YEAR = 8
PATIENCE = 10

use_amp = True
compile_model = False

if compile_model:
    torch.set_float32_matmul_precision("high")

memory_saving = False
if memory_saving:
    storge_device = "cpu"
    computing_device = DEVICE
    VAL_STEPS = 500
else:
    storge_device = DEVICE
    computing_device = DEVICE

In [None]:
embedding = torch.load("data/final_lstm_embedding.pt", map_location=torch.device('cpu')).to(computing_device)
decoder = torch.load("data/final_lstm_decoder.pt", map_location=torch.device('cpu')).to(computing_device)

embedding.eval()
decoder.eval()

# dimension of embedding
catchment_embeddings=[x.data for x in embedding.parameters()][0]
LATENT_dim = catchment_embeddings.shape[1]

In [None]:
dtrain_val = dataloader.Forcing_Data(
    "data/camels_train_val.csv",
    record_length=3652,
    storge_device=storge_device,
    seq_length=SEQ_LENGTH,
    target_seq_length=TARGET_SEQ_LENGTH,
    base_length=BASE_LENGTH,
)

dtrain = dataloader.Forcing_Data(
    "data/camels_train.csv",
    record_length=2922,
    storge_device=storge_device,
    seq_length=SEQ_LENGTH,
    target_seq_length=TARGET_SEQ_LENGTH,
    base_length=BASE_LENGTH,
)

dval = dataloader.Forcing_Data(
    "data/camels_val.csv",
    record_length=1095,
    storge_device=storge_device,
    seq_length=SEQ_LENGTH,
    target_seq_length=TARGET_SEQ_LENGTH,
    base_length=BASE_LENGTH,
)

dtest = dataloader.Forcing_Data(
    "data/camels_test.csv",
    record_length=4383,
    storge_device=storge_device,
    seq_length=SEQ_LENGTH,
    target_seq_length=TARGET_SEQ_LENGTH,
    base_length=BASE_LENGTH,
)

In [None]:
def get_optimal_update(study):
    
    stats = study.best_trials[0].intermediate_values
    steps = min(stats, key=lambda k: stats[k]) + 1
    
    return steps

In [None]:
class FINE_TUNE:
    def __init__(self, selected_catchment=0, eval_fun=HydroErr.kge_2009):
        self.selected_catchment = selected_catchment
        self.eval_fun=eval_fun
        
    def fine_tune(self,trial):
       
        # define batch size
        batch_size_power = trial.suggest_int("batch_size_power", 4, 8)
        batch_size = 2**batch_size_power
        
        # load model
        decoder = torch.load("data/final_lstm_decoder.pt", map_location=torch.device('cpu')).to(computing_device)
        
        # define new embeding for the selected catchment
        embedding = nn.Embedding(1, LATENT_dim).to(computing_device)
        embedding_input = torch.zeros(size = (batch_size,), dtype=torch.long, device=computing_device)

        # validation data
        x_val, y_val = dval.get_catchment_val_batch(self.selected_catchment)
        x_val, y_val = x_val.to(computing_device), y_val.to(computing_device)
        
        # define optimizers
        lr_embedding = trial.suggest_float("lr_embedding", 5e-5, 1e-2, log=True)
        embedding_optimizer = optim.Adam(embedding.parameters(), lr=lr_embedding)

        lr_decoder = trial.suggest_float("lr_decoder", 5e-5, 1e-2, log=True)
        decoder_optimizer = optim.Adam(decoder.parameters(), lr=lr_decoder)
        
        # model training
        scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
        
        # define early stopper
        early_stopper = training_fun.EarlyStopper(patience=PATIENCE, min_delta=0)
        
        for update in range(UPDATES):
            
            decoder.train()
            embedding.train()
            
            decoder_optimizer.zero_grad()
            embedding_optimizer.zero_grad()
            
            # put the models into training mode
            decoder.train()
            embedding.train()
            
            # get training batch and pass to device
            (x_batch, y_batch, _) = dtrain.get_catchment_random_batch(
                selected_catchment=self.selected_catchment, batch_size=batch_size
            )
            
            x_batch, y_batch = (
                x_batch.to(computing_device),
                y_batch.to(computing_device),
            )
            
            # slice batch for training
            with torch.autocast(
                device_type="cuda", dtype=torch.float16, enabled=use_amp
            ):
                code = embedding(embedding_input)

                # pass through decoder
                out = decoder.decode(code, x_batch)

                # compute loss
                loss = training_fun.mse_loss_with_nans(out, y_batch)
                
            scaler.scale(loss).backward()
            scaler.step(embedding_optimizer)
            scaler.step(decoder_optimizer)
            scaler.update()

            # validate model after each update
            decoder.eval()
            embedding.eval()
            
            with torch.autocast(
                device_type="cuda", dtype=torch.float16, enabled=use_amp
            ):
                with torch.no_grad():                    
                    code = embedding(torch.zeros(size = (x_val.shape[0],), dtype=torch.long, device=computing_device))
                    out = decoder.decode(code, x_val)
                    
                    val_loss = training_fun.mse_loss_with_nans(out, y_val).detach().cpu().numpy()
            
            # Handle pruning based on the intermediate value
            trial.report(val_loss, update)

            if trial.should_prune():
                torch.cuda.empty_cache()
                raise optuna.exceptions.TrialPruned()

            # Early stop using early_stopper, break for loop
            if early_stopper.early_stop(val_loss):
                break

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
            
        return early_stopper.min_validation_loss

    def test_final_model(self, n_trials=200, return_model = False):
        
        self.study = optuna.create_study(study_name="fine_tune", direction="minimize", pruner=optuna.pruners.NopPruner())
        
        optuna.logging.set_verbosity(optuna.logging.WARNING)

        self.study.optimize(self.fine_tune, n_trials = n_trials)
        
        # optimal parameters
        updates = get_optimal_update(self.study)
        
        lr_decoder = self.study.best_params["lr_decoder"]
        lr_embedding = self.study.best_params["lr_embedding"]
        batch_size_power = self.study.best_params["batch_size_power"]
        batch_size = 2 ** batch_size_power

        # load model
        decoder = torch.load("data/final_lstm_decoder.pt", map_location=torch.device('cpu')).to(computing_device)
        
        # define new embedding for the selected catchment
        embedding = nn.Embedding(1, LATENT_dim).to(computing_device)
        embedding_input = torch.zeros(size = (batch_size,), dtype=torch.long, device=computing_device)

        # define model optimizer
        decoder_optimizer = optim.Adam(decoder.parameters(), lr=lr_decoder)
        embedding_optimizer = optim.Adam(embedding.parameters(), lr=lr_embedding)

        # validation data
        x_test, y_test = dtest.get_catchment_val_batch(self.selected_catchment)
        x_test, y_test = x_test.to(computing_device).contiguous(), y_test.to(computing_device).contiguous()
        
        # start training
        scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
        for update in range(updates):
            
            decoder.train()
            embedding.train()
            
            decoder_optimizer.zero_grad()
            embedding_optimizer.zero_grad()
            
            # put the models into training mode
            decoder.train()
            embedding.train()
            
            # get training batch and pass to device
            (x_batch, y_batch, _) = dtrain_val.get_catchment_random_batch(
                selected_catchment=self.selected_catchment, batch_size=batch_size
            )
            
            x_batch, y_batch = (
                x_batch.to(computing_device),
                y_batch.to(computing_device),
            )
            
            # slice batch for training
            with torch.autocast(
                device_type="cuda", dtype=torch.float16, enabled=use_amp
            ):
                code = embedding(embedding_input)

                # pass through decoder
                out = decoder.decode(code, x_batch)

                # compute loss
                loss = training_fun.mse_loss_with_nans(out, y_batch)
                
            scaler.scale(loss).backward()
            scaler.step(embedding_optimizer)
            scaler.step(decoder_optimizer)
            scaler.update()
        
        decoder.eval()
        embedding.eval()
        
        with torch.autocast(
            device_type="cuda", dtype=torch.float16, enabled=use_amp
        ):
            with torch.no_grad():                    
                code = embedding(torch.zeros(size = (x_test.shape[0],), dtype=torch.long, device=computing_device))
                pred = decoder.decode(code, x_test).view(-1).detach().cpu().numpy()
                
                ob = y_test.view(-1).detach().cpu().numpy()
                
                gof = self.eval_fun(simulated_array=pred, observed_array=ob)

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        if return_model:
            return gof, embedding, decoder
        else:
            return gof

In [None]:
calibrated_KGES = np.ones(N_CATCHMENTS)

for i in range(N_CATCHMENTS):

    fine_tune = FINE_TUNE(i)
    calibrated_KGES[i], embedding, decoder = fine_tune.test_final_model(n_trials=200, return_model=True)
        
    torch.save(embedding.cpu(), f"data/fine_tune/embedding{i}.pt")
    torch.save(decoder.cpu(), f"data/fine_tune/ecoder{i}.pt")
    
    joblib.dump(fine_tune.study, f"data/fine_tune/study{i}.pkl")
