# IMPORT LIBRARIES

In [1]:
import sys
import config
import os
import random
import math
import time
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE

from blitz.modules import BayesianGRU
from blitz.utils import variational_estimator


# CHANNELS INFO
channels = config.channels
input_channels = config.weather_channels#+config.sf_channels
static_channels = config.static_channels

DIR = 'DATA_DIR/'


# TIME SERIES INFO
window = config.window

# TRAIN INFO
device = config.device
code_dim = config.code_dim
n_clusters = config.n_clusters
epochs = 200
batch_size = config.batch_size
learning_rate = config.learning_rate
alpha = 2

# MODEL INFO
recon_weight = 0.1#config.recon_weight
static_weight = 10#config.static_weight
triplet_weight = 1#config.triplet_weight
sum_weight = recon_weight+static_weight+triplet_weight

#BIM : ATT_NL_3
# IM+VAE: VAE
# IM+CD : ATT_NL_CD
# Other Bayesian baselines are also given in this notebook
architecture = "ATT_NL_3"
# Forward Models : LSTM , EALSTM

run = int(len(sys.argv))
Hidden="Hidden_0_Serial_reconstruction_loss+triplet_loss_withoutSF"
num_hidden = 0
model_name = "{}_{}_{}_{}_{}_{}_{}_{}".format("ALL", architecture, code_dim, len(static_channels), run,batch_size,"1_NL",Hidden)

pretrain = None

print("{} Hyperparameters".format(model_name))
print("Channels : {}".format(channels))
print("Input Channels : {}".format(input_channels))
print("Static Channels : {}".format(static_channels))
print("Code dim : {}".format(code_dim))
print("Epochs : {}".format(epochs))
print("Batch Size : {}".format(batch_size))
print("Learning rate : {}".format(learning_rate))
print("Reconstruction Weight : {}".format(recon_weight))
print("Static Weight : {}".format(static_weight))
print("Triplet Weight : {}".format(triplet_weight))
print("Pretrain : {}".format(pretrain))


DATA_DIR = os.path.join("DATA")
NUMPY_DIR = os.path.join(DATA_DIR, "NUMPY")
RESULT_DIR = os.path.join(DATA_DIR, "RESULT")
MODEL_DIR = os.path.join(DATA_DIR, "MODEL")

if not os.path.exists(os.path.join(MODEL_DIR)):
    os.makedirs(os.path.join(MODEL_DIR))
if not os.path.exists(os.path.join(RESULT_DIR)):
    os.makedirs(os.path.join(RESULT_DIR))

ALL_ATT_NL_32_27_3_200_1_NL_Hidden_0_Serial_reconstruction_loss+triplet_loss_withoutSF Hyperparameters
Channels : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32]
Input Channels : [27, 28, 29, 30, 31]
Static Channels : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]
Code dim : 32
Epochs : 200
Batch Size : 200
Learning rate : 0.005
Reconstruction Weight : 0.1
Static Weight : 10
Triplet Weight : 1
Pretrain : None


# LOAD DATA

In [2]:
train_data = np.load(os.path.join(DIR, 'NUMPY', "train_data_basin.npy"))[:,:,:,:-1]
validation_data = np.load(os.path.join(DIR, 'NUMPY', "validation_data_basin.npy"))[:,:,:,:-1]
test_data = np.load(os.path.join(DIR, 'NUMPY', "test_data_basin.npy"))[-num_hidden:,:,:,:-1]
hidden_train_data = np.load(os.path.join(DIR, 'NUMPY', "train_hidden_data_basin.npy"))[:,:,:,:-1]
print("Train Data:{}\tValidation Data:{}\tTest Data:{}\tHidden Train Data:{}".format(train_data.shape, validation_data.shape, test_data.shape,hidden_train_data.shape))

feature_names = np.load(os.path.join(DIR, "RAW_DATA", "feature_names.npy"), allow_pickle=True)
print("Static features:{}".format(feature_names[config.static_channels]))
print("Weather features:{}".format(feature_names[config.weather_channels]))
print("SF features:{}".format(feature_names[config.sf_channels]))

Train Data:(531, 39, 365, 33)	Validation Data:(531, 9, 365, 33)	Test Data:(531, 19, 365, 33)	Hidden Train Data:(531, 39, 365, 33)
Static features:['p_mean' 'pet_mean' 'p_seasonality' 'frac_snow' 'aridity'
 'high_prec_freq' 'high_prec_dur' 'low_prec_freq' 'low_prec_dur'
 'carbonate_rocks_frac' 'geol_permeability' 'soil_depth_pelletier'
 'soil_depth_statsgo' 'soil_porosity' 'soil_conductivity'
 'max_water_content' 'sand_frac' 'silt_frac' 'clay_frac' 'elev_mean'
 'slope_mean' 'area_gages2' 'frac_forest' 'lai_max' 'lai_diff' 'gvf_max'
 'gvf_diff']
Weather features:['PRCP(mm/day)' 'SRAD(W/m2)' 'Tmax(C)' 'Tmin(C)' 'Vp(Pa)']
SF features:['SF']


# BUILD MODEL

In [3]:
#Bayesian GRU
#Import libraries
from torch import nn
from torch.nn import functional as F
from blitz.modules.base_bayesian_module import BayesianModule, BayesianRNN
from blitz.modules.weight_sampler import TrainableRandomDistribution, PriorWeightDistribution

In [4]:
class BayesianGRU(BayesianRNN):
    """
    Bayesian GRU layer, implements the linear layer proposed on Weight Uncertainity on Neural Networks
    (Bayes by Backprop paper).
    Its objective is be interactable with torch nn.Module API, being able even to be chained in nn.Sequential models with other non-this-lib layers
    
    parameters:
        in_fetaures: int -> incoming features for the layer
        out_features: int -> output features for the layer
        bias: bool -> whether the bias will exist (True) or set to zero (False)
        prior_sigma_1: float -> prior sigma on the mixture prior distribution 1
        prior_sigma_2: float -> prior sigma on the mixture prior distribution 2
        prior_pi: float -> pi on the scaled mixture prior
        posterior_mu_init float -> posterior mean for the weight mu init
        posterior_rho_init float -> posterior mean for the weight rho init
        freeze: bool -> wheter the model will start with frozen(deterministic) weights, or not
    
    """
    def __init__(self,
                 in_features,
                 out_features,
                 bias = True,
                 prior_sigma_1 = 0.1,
                 prior_sigma_2 = 0.002,
                 prior_pi = 1,
                 posterior_mu_init = 0,
                 posterior_rho_init = -6.0,
                 freeze = False,
                 prior_dist = None,
                 **kwargs):
        
        super().__init__(**kwargs)
        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = bias
        self.freeze = freeze
        
        self.posterior_mu_init = posterior_mu_init
        self.posterior_rho_init = posterior_rho_init
        
        self.prior_sigma_1 = prior_sigma_1
        self.prior_sigma_2 = prior_sigma_2
        self.prior_pi = prior_pi
        self.prior_dist = prior_dist
        
        # Variational weight parameters and sample for weight ih
        self.weight_ih_mu = nn.Parameter(torch.Tensor(in_features, out_features * 4).normal_(posterior_mu_init, 0.1))
        self.weight_ih_rho = nn.Parameter(torch.Tensor(in_features, out_features * 4).normal_(posterior_rho_init, 0.1))
        self.weight_ih_sampler = TrainableRandomDistribution(self.weight_ih_mu, self.weight_ih_rho)
        self.weight_ih = None
        
        # Variational weight parameters and sample for weight hh
        self.weight_hh_mu = nn.Parameter(torch.Tensor(out_features, out_features * 4).normal_(posterior_mu_init, 0.1))
        self.weight_hh_rho = nn.Parameter(torch.Tensor(out_features, out_features * 4).normal_(posterior_rho_init, 0.1))
        self.weight_hh_sampler = TrainableRandomDistribution(self.weight_hh_mu, self.weight_hh_rho)
        self.weight_hh = None
        
        # Variational weight parameters and sample for bias
        self.bias_mu = nn.Parameter(torch.Tensor(out_features * 4).normal_(posterior_mu_init, 0.1))
        self.bias_rho = nn.Parameter(torch.Tensor(out_features * 4).normal_(posterior_rho_init, 0.1))
        self.bias_sampler = TrainableRandomDistribution(self.bias_mu, self.bias_rho)
        self.bias=None
        
        #our prior distributions
        self.weight_ih_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist = self.prior_dist)
        self.weight_hh_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist = self.prior_dist)
        self.bias_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist = self.prior_dist)
        
        self.init_sharpen_parameters()
        
        self.log_prior = 0
        self.log_variational_posterior = 0
    
    def sample_weights(self):
        #sample weights
        weight_ih = self.weight_ih_sampler.sample()
        weight_hh = self.weight_hh_sampler.sample()
        
        #if use bias, we sample it, otherwise, we are using zeros
        if self.use_bias:
            b = self.bias_sampler.sample()
            b_log_posterior = self.bias_sampler.log_posterior()
            b_log_prior = self.bias_prior_dist.log_prior(b)
            
        else:
            b = 0
            b_log_posterior = 0
            b_log_prior = 0
            
        bias = b
        
        #gather weights variational posterior and prior likelihoods
        self.log_variational_posterior = self.weight_hh_sampler.log_posterior() + b_log_posterior + self.weight_ih_sampler.log_posterior()
        
        self.log_prior = self.weight_ih_prior_dist.log_prior(weight_ih) + b_log_prior + self.weight_hh_prior_dist.log_prior(weight_hh)
        
        self.ff_parameters = [weight_ih, weight_hh, bias]
        return weight_ih, weight_hh, bias
        
    def get_frozen_weights(self):
        
        #get all deterministic weights
        weight_ih = self.weight_ih_mu
        weight_hh = self.weight_hh_mu
        if self.use_bias:
            bias = self.bias_mu
        else:
            bias = 0

        return weight_ih, weight_hh, bias

    
    def forward_(self,
                 x,
                 hidden_states,
                 sharpen_loss):
        
        if self.loss_to_sharpen is not None:
            sharpen_loss = self.loss_to_sharpen
            weight_ih, weight_hh, bias = self.sharpen_posterior(loss=sharpen_loss, input_shape=x.shape)
        elif (sharpen_loss is not None):
            sharpen_loss = sharpen_loss
            weight_ih, weight_hh, bias = self.sharpen_posterior(loss=sharpen_loss, input_shape=x.shape)
        
        else:
            weight_ih, weight_hh, bias = self.sample_weights()

        #Assumes x is of shape (batch, sequence, feature)
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        
        #if no hidden state, we are using zeros
        if hidden_states is None:
            
            h_t = torch.zeros(bs, self.out_features).to(x.device)
        else:
            h_t = hidden_states
        
        #simplifying our out features, and hidden seq list
        HS = self.out_features
        hidden_seq = []
        
        for t in range(seq_sz):
            x_t = x[:, t, :]
            # batch the computations into a single matrix multiplication
            
            A_t = x_t @ weight_ih[:, :HS*2] + h_t[:,t,:] @ weight_hh[:, :HS*2] + bias[:HS*2]

            r_t, z_t = (
                torch.sigmoid(A_t[:, :HS]),
                torch.sigmoid(A_t[:, HS:HS*2])
            )
            n_t = torch.tanh(x_t @ weight_ih[:, HS*2:HS*3] + bias[HS*2:HS*3] + r_t * (h_t[:,t,:] @ weight_hh[:, HS*3:HS*4] + bias[HS*3:HS*4]))
            h_t = (1 - z_t) * n_t + z_t * h_t

            hidden_seq.append(h_t.unsqueeze(0))
            
        hidden_seq = torch.cat(hidden_seq, dim=0)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        
        return hidden_seq, h_t

    def forward_frozen(self,
                       x,
                       hidden_states):

        weight_ih, weight_hh, bias = self.get_frozen_weights()

        #Assumes x is of shape (batch, sequence, feature)
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        
        #if no hidden state, we are using zeros
        if hidden_states is None:
            h_t = torch.zeros(bs, self.out_features).to(x.device)
        else:
            h_t = hidden_states
        
        #simplifying our out features, and hidden seq list
        HS = self.out_features
        hidden_seq = []
        
        for t in range(seq_sz):
            x_t = x[:, t, :]
            # batch the computations into a single matrix multiplication
            A_t = x_t @ weight_ih[:, :HS*2] + h_t @ weight_hh[:, :HS*2] + bias[:HS*2]

            r_t, z_t = (
                torch.sigmoid(A_t[:,:HS]),
                torch.sigmoid(A_t[:, HS:HS*2])
            )

            n_t = torch.tanh(x_t @ weight_ih[:, HS*2:HS*3] + bias[HS*2:HS*3] + r_t * (h_t @ weight_hh[:, HS*3:HS*4] + bias[HS*3:HS*4]))
            h_t = (1 - z_t) * n_t + z_t * h_t

            hidden_seq.append(h_t.unsqueeze(0))
            
        hidden_seq = torch.cat(hidden_seq, dim=0)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        
        return hidden_seq, h_t         

    def forward(self,
                x,
                hidden_states=None,
                sharpen_loss=None):

        if self.freeze:
            return self.forward_frozen(x, hidden_states)
        
        if not self.sharpen:
            sharpen_loss = None
            
        return self.forward_(x, hidden_states, sharpen_loss)    

    


In [5]:
from blitz.modules import BayesianLinear

class ATT(torch.nn.Module):
    def __init__(self, in_channels, stat_channels, code_dim, device):
        super(ATT,self).__init__()
        
        self.code_dim = code_dim
        self.device = device
        self.encoder = torch.nn.GRU(in_channels, code_dim, batch_first=True, bidirectional=True)
        self.att = torch.nn.Linear(code_dim, 1)
        self.decoder = torch.nn.GRU(in_channels, code_dim, batch_first=True)
        self.out = torch.nn.Linear(code_dim, in_channels)
        self.static_out = torch.nn.Linear(code_dim, stat_channels)

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
    
    def forward(self, x):
        batch, seq_len, in_channels = x.shape
        
        code_vec, _ = self.encoder(x)

        code_vec = torch.add(code_vec[:,:,:self.code_dim], code_vec[:,:,self.code_dim:])
        att_weights = torch.unsqueeze(torch.nn.functional.softmax(self.att(code_vec).squeeze(), dim=1), dim=-1)
        code_vec = att_weights*code_vec
        code_vec = torch.unsqueeze(torch.sum(code_vec, dim=1), dim=0)
        static_out = self.static_out(code_vec.squeeze())

        out = torch.zeros(batch, seq_len, in_channels).to(self.device)
        h = code_vec
        input = torch.zeros((batch, 1, in_channels)).to(self.device)
        
        for step in range(seq_len):
            input, h = self.decoder(input, h)
            output = self.out(torch.mul(input.squeeze(), code_vec.squeeze()))
            out[:,step] = output
            input = output.unsqueeze(1)
        
        out = out.view(batch, seq_len, in_channels)
        return code_vec.squeeze().view(batch, -1), out, static_out

@variational_estimator
class ATT_NL(torch.nn.Module):
    def __init__(self, in_channels, stat_channels, code_dim, device):
        super(ATT_NL,self).__init__()
        
        self.code_dim = code_dim
        self.device = device
        self.encoder = torch.nn.GRU(in_channels, code_dim, batch_first=True, bidirectional=True)
        self.att = torch.nn.Linear(code_dim, 1)
#         self.decoder = torch.nn.GRU(in_channels, code_dim, batch_first=True)
        self.decoder = BayesianGRU(in_channels, code_dim)#, batch_first=True)
        self.out = torch.nn.Linear(code_dim, in_channels)
        self.linear_1 = torch.nn.Linear(code_dim,code_dim)
        #self.linear_2 = torch.nn.Linear(code_dim,code_dim)
        self.static_out = torch.nn.Linear(code_dim, stat_channels)
        self.relu = torch.nn.ReLU()

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
    
    def forward(self, x):
        batch, seq_len, in_channels = x.shape
        
        code_vec, _ = self.encoder(x)
        code_vec = torch.add(code_vec[:,:,:self.code_dim], code_vec[:,:,self.code_dim:])
        att_weights = torch.unsqueeze(torch.nn.functional.softmax(self.att(code_vec).squeeze(), dim=1), dim=-1)
        code_vec = att_weights*code_vec


        code_vec = torch.unsqueeze(torch.sum(code_vec, dim=1), dim=0)
        static_out = self.relu(self.linear_1(code_vec.squeeze()))
        #static_out = self.relu(self.linear_2(static_out))
        static_out = self.static_out(static_out)

        out = torch.zeros(batch, seq_len, in_channels).to(self.device)
        h = code_vec
        input = torch.zeros((batch, 1, in_channels)).to(self.device)
#         print(h.size())
        for step in range(seq_len):
            input, h = self.decoder(input, h)
            output = self.out(torch.mul(input.squeeze(), code_vec.squeeze()))
            out[:,step] = output
            input = output.unsqueeze(1)
        
        out = out.view(batch, seq_len, in_channels)
        return code_vec.squeeze().view(batch, -1), out, static_out    

class ATT_NL_0(torch.nn.Module):
    def __init__(self, in_channels, stat_channels, code_dim, device):
        super(ATT_NL_0,self).__init__()
        
        self.code_dim = code_dim
        self.device = device
        self.encoder = torch.nn.GRU(in_channels, code_dim, batch_first=True, bidirectional=True)
        self.att = torch.nn.Linear(code_dim, 1)
        self.decoder = torch.nn.GRU(in_channels, code_dim, batch_first=True)
#         self.decoder = BayesianGRU(in_channels, code_dim)#, batch_first=True)
        self.out = torch.nn.Linear(code_dim, in_channels)
        self.linear_1 = torch.nn.Linear(code_dim,code_dim)
        #self.linear_2 = torch.nn.Linear(code_dim,code_dim)
        self.static_out = torch.nn.Linear(code_dim, stat_channels)
        self.relu = torch.nn.ReLU()

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
    
    def forward(self, x):
        batch, seq_len, in_channels = x.shape
        
        code_vec, _ = self.encoder(x)
        code_vec = torch.add(code_vec[:,:,:self.code_dim], code_vec[:,:,self.code_dim:])
        att_weights = torch.unsqueeze(torch.nn.functional.softmax(self.att(code_vec).squeeze(), dim=1), dim=-1)
        code_vec = att_weights*code_vec


        code_vec = torch.unsqueeze(torch.sum(code_vec, dim=1), dim=0)
        static_out = self.relu(self.linear_1(code_vec.squeeze()))
        #static_out = self.relu(self.linear_2(static_out))
        static_out = self.static_out(static_out)

        out = torch.zeros(batch, seq_len, in_channels).to(self.device)
        h = code_vec
        input = torch.zeros((batch, 1, in_channels)).to(self.device)
#         print(h.size())
        for step in range(seq_len):
            input, h = self.decoder(input, h)
            output = self.out(torch.mul(input.squeeze(), code_vec.squeeze()))
            out[:,step] = output
            input = output.unsqueeze(1)
        
        out = out.view(batch, seq_len, in_channels)
        return code_vec.squeeze().view(batch, -1), out, static_out    

    
@variational_estimator
class ATT_NL_1(torch.nn.Module):
    def __init__(self, in_channels, stat_channels, code_dim, device):
        super(ATT_NL_1,self).__init__()
        
        self.code_dim = code_dim
        self.device = device
#         self.encoder = torch.nn.GRU(in_channels, code_dim, batch_first=True, bidirectional=True)
        self.encoder = BayesianGRU(in_channels, code_dim)#, batch_first=True, bidirectional=True)
        self.att = torch.nn.Linear(code_dim, 1)
        self.decoder = torch.nn.GRU(in_channels, code_dim, batch_first=True)
#         self.decoder = BayesianGRU(in_channels, code_dim)#, batch_first=True)
        self.out = torch.nn.Linear(code_dim, in_channels)
        self.linear_1 = torch.nn.Linear(code_dim,code_dim)
        #self.linear_2 = torch.nn.Linear(code_dim,code_dim)
        self.static_out = torch.nn.Linear(code_dim, stat_channels)
        self.relu = torch.nn.ReLU()

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
    
    def forward(self, x):
        batch, seq_len, in_channels = x.shape
        
        code_vec, _ = self.encoder(x)
        code_vec = torch.add(code_vec[:,:,:self.code_dim], code_vec[:,:,self.code_dim:])
        att_weights = torch.unsqueeze(torch.nn.functional.softmax(self.att(code_vec).squeeze(), dim=1), dim=-1)
        code_vec = att_weights*code_vec


        code_vec = torch.unsqueeze(torch.sum(code_vec, dim=1), dim=0)
        static_out = self.relu(self.linear_1(code_vec.squeeze()))
        #static_out = self.relu(self.linear_2(static_out))
        static_out = self.static_out(static_out)

        out = torch.zeros(batch, seq_len, in_channels).to(self.device)
        h = code_vec
        input = torch.zeros((batch, 1, in_channels)).to(self.device)
#         print(h.size())
        for step in range(seq_len):
            input, h = self.decoder(input, h)
            output = self.out(torch.mul(input.squeeze(), code_vec.squeeze()))
            out[:,step] = output
            input = output.unsqueeze(1)
        
        out = out.view(batch, seq_len, in_channels)
        return code_vec.squeeze().view(batch, -1), out, static_out    

    
@variational_estimator
class ATT_NL_3(torch.nn.Module):
    def __init__(self, in_channels, stat_channels, code_dim, device):
        super(ATT_NL_3,self).__init__()
        
        self.code_dim = code_dim
        self.device = device
        self.encoder = torch.nn.GRU(in_channels, code_dim, batch_first=True, bidirectional=True)
#         self.att = torch.nn.Linear(code_dim, 1)
        self.att = BayesianLinear(code_dim, 1)
        self.decoder = torch.nn.GRU(in_channels, code_dim, batch_first=True)
#         self.decoder = BayesianGRU(in_channels, code_dim)#, batch_first=True)
        self.out = torch.nn.Linear(code_dim, in_channels)
        self.linear_1 = torch.nn.Linear(code_dim,code_dim)
        #self.linear_2 = torch.nn.Linear(code_dim,code_dim)
        self.static_out = torch.nn.Linear(code_dim, stat_channels)
        self.relu = torch.nn.ReLU()

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
    
    def forward(self, x):
        batch, seq_len, in_channels = x.shape
        
        code_vec, _ = self.encoder(x)
        code_vec = torch.add(code_vec[:,:,:self.code_dim], code_vec[:,:,self.code_dim:])
        att_weights = torch.unsqueeze(torch.nn.functional.softmax(self.att(code_vec).squeeze(), dim=1), dim=-1)
        code_vec = att_weights*code_vec


        code_vec = torch.unsqueeze(torch.sum(code_vec, dim=1), dim=0)
        static_out = self.relu(self.linear_1(code_vec.squeeze()))
        #static_out = self.relu(self.linear_2(static_out))
        static_out = self.static_out(static_out)

        out = torch.zeros(batch, seq_len, in_channels).to(self.device)
        h = code_vec
        input = torch.zeros((batch, 1, in_channels)).to(self.device)
#         print(h.size())
        for step in range(seq_len):
            input, h = self.decoder(input, h)
            output = self.out(torch.mul(input.squeeze(), code_vec.squeeze()))
            out[:,step] = output
            input = output.unsqueeze(1)
        
        out = out.view(batch, seq_len, in_channels)
        return code_vec.squeeze().view(batch, -1), out, static_out    



@variational_estimator
class ATT_NL_4(torch.nn.Module):
    def __init__(self, in_channels, stat_channels, code_dim, device):
        super(ATT_NL_4,self).__init__()
        
        self.code_dim = code_dim
        self.device = device
        self.encoder = torch.nn.GRU(in_channels, code_dim, batch_first=True, bidirectional=True)
        self.att = torch.nn.Linear(code_dim, 1)
        self.decoder = torch.nn.GRU(in_channels, code_dim, batch_first=True)
#         self.decoder = BayesianGRU(in_channels, code_dim)#, batch_first=True)
#         self.out = torch.nn.Linear(code_dim, in_channels)
        self.out = BayesianLinear(code_dim, in_channels)
        self.linear_1 = torch.nn.Linear(code_dim,code_dim)
        #self.linear_2 = torch.nn.Linear(code_dim,code_dim)
        self.static_out = torch.nn.Linear(code_dim, stat_channels)
        self.relu = torch.nn.ReLU()

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
    
    def forward(self, x):
        batch, seq_len, in_channels = x.shape
        
        code_vec, _ = self.encoder(x)
        code_vec = torch.add(code_vec[:,:,:self.code_dim], code_vec[:,:,self.code_dim:])
        att_weights = torch.unsqueeze(torch.nn.functional.softmax(self.att(code_vec).squeeze(), dim=1), dim=-1)
        code_vec = att_weights*code_vec


        code_vec = torch.unsqueeze(torch.sum(code_vec, dim=1), dim=0)
        static_out = self.relu(self.linear_1(code_vec.squeeze()))
        #static_out = self.relu(self.linear_2(static_out))
        static_out = self.static_out(static_out)

        out = torch.zeros(batch, seq_len, in_channels).to(self.device)
        h = code_vec
        input = torch.zeros((batch, 1, in_channels)).to(self.device)
#         print(h.size())
        for step in range(seq_len):
            input, h = self.decoder(input, h)
            output = self.out(torch.mul(input.squeeze(), code_vec.squeeze()))
            out[:,step] = output
            input = output.unsqueeze(1)
        
        out = out.view(batch, seq_len, in_channels)
        return code_vec.squeeze().view(batch, -1), out, static_out    
    

@variational_estimator
class ATT_NL_5(torch.nn.Module):
    def __init__(self, in_channels, stat_channels, code_dim, device):
        super(ATT_NL_5,self).__init__()
        
        self.code_dim = code_dim
        self.device = device
        self.encoder = torch.nn.GRU(in_channels, code_dim, batch_first=True, bidirectional=True)
        self.att = torch.nn.Linear(code_dim, 1)
        self.decoder = torch.nn.GRU(in_channels, code_dim, batch_first=True)
#         self.decoder = BayesianGRU(in_channels, code_dim)#, batch_first=True)
        self.out = torch.nn.Linear(code_dim, in_channels)
#         self.linear_1 = torch.nn.Linear(code_dim,code_dim)
        self.linear_1 = BayesianLinear(code_dim,code_dim)
        #self.linear_2 = torch.nn.Linear(code_dim,code_dim)
        self.static_out = torch.nn.Linear(code_dim, stat_channels)
        self.relu = torch.nn.ReLU()

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
    
    def forward(self, x):
        batch, seq_len, in_channels = x.shape
        
        code_vec, _ = self.encoder(x)
        code_vec = torch.add(code_vec[:,:,:self.code_dim], code_vec[:,:,self.code_dim:])
        att_weights = torch.unsqueeze(torch.nn.functional.softmax(self.att(code_vec).squeeze(), dim=1), dim=-1)
        code_vec = att_weights*code_vec


        code_vec = torch.unsqueeze(torch.sum(code_vec, dim=1), dim=0)
        static_out = self.relu(self.linear_1(code_vec.squeeze()))
        #static_out = self.relu(self.linear_2(static_out))
        static_out = self.static_out(static_out)

        out = torch.zeros(batch, seq_len, in_channels).to(self.device)
        h = code_vec
        input = torch.zeros((batch, 1, in_channels)).to(self.device)
#         print(h.size())
        for step in range(seq_len):
            input, h = self.decoder(input, h)
            output = self.out(torch.mul(input.squeeze(), code_vec.squeeze()))
            out[:,step] = output
            input = output.unsqueeze(1)
        
        out = out.view(batch, seq_len, in_channels)
        return code_vec.squeeze().view(batch, -1), out, static_out    
    

    
@variational_estimator
class ATT_NL_6(torch.nn.Module):
    def __init__(self, in_channels, stat_channels, code_dim, device):
        super(ATT_NL_6,self).__init__()
        
        self.code_dim = code_dim
        self.device = device
        self.encoder = torch.nn.GRU(in_channels, code_dim, batch_first=True, bidirectional=True)
        self.att = torch.nn.Linear(code_dim, 1)
        self.decoder = torch.nn.GRU(in_channels, code_dim, batch_first=True)
#         self.decoder = BayesianGRU(in_channels, code_dim)#, batch_first=True)
        self.out = torch.nn.Linear(code_dim, in_channels)
        self.linear_1 = torch.nn.Linear(code_dim,code_dim)
        #self.linear_2 = torch.nn.Linear(code_dim,code_dim)
#         self.static_out = torch.nn.Linear(code_dim, stat_channels)
        self.static_out = BayesianLinear(code_dim, stat_channels)
        self.relu = torch.nn.ReLU()

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
    
    def forward(self, x):
        batch, seq_len, in_channels = x.shape
        
        code_vec, _ = self.encoder(x)
        code_vec = torch.add(code_vec[:,:,:self.code_dim], code_vec[:,:,self.code_dim:])
        att_weights = torch.unsqueeze(torch.nn.functional.softmax(self.att(code_vec).squeeze(), dim=1), dim=-1)
        code_vec = att_weights*code_vec


        code_vec = torch.unsqueeze(torch.sum(code_vec, dim=1), dim=0)
        static_out = self.relu(self.linear_1(code_vec.squeeze()))
        #static_out = self.relu(self.linear_2(static_out))
        static_out = self.static_out(static_out)

        out = torch.zeros(batch, seq_len, in_channels).to(self.device)
        h = code_vec
        input = torch.zeros((batch, 1, in_channels)).to(self.device)
#         print(h.size())
        for step in range(seq_len):
            input, h = self.decoder(input, h)
            output = self.out(torch.mul(input.squeeze(), code_vec.squeeze()))
            out[:,step] = output
            input = output.unsqueeze(1)
        
        out = out.view(batch, seq_len, in_channels)
        return code_vec.squeeze().view(batch, -1), out, static_out    
    

    
class LAST(torch.nn.Module):
    def __init__(self, in_channels, stat_channels, code_dim, device):
        super(LAST,self).__init__()
        
        self.code_dim = code_dim
        self.device = device
        self.encoder = torch.nn.GRU(in_channels, code_dim, batch_first=True, bidirectional=True)
        self.decoder = torch.nn.GRU(in_channels, code_dim, batch_first=True)
        self.out = torch.nn.Linear(code_dim, in_channels)
        self.static_out = torch.nn.Linear(code_dim, stat_channels)

        self.encoder = torch.nn.GRU(in_channels, code_dim, batch_first=True, bidirectional=True)
        self.decoder = torch.nn.GRU(in_channels, code_dim, batch_first=True)
        self.out = torch.nn.Linear(code_dim, in_channels)

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
    
    def forward(self, x):
        batch, seq_len, in_channels = x.shape
        
        _, code_vec = self.encoder(x)
        code_vec = torch.unsqueeze(torch.sum(code_vec, dim=0), dim=0)
        static_out = self.static_out(code_vec.squeeze())

        out = torch.zeros(batch, seq_len, in_channels).to(self.device)
        h = code_vec
        input = torch.zeros((batch, 1, in_channels)).to(self.device)
        for step in range(seq_len):
            input, h = self.decoder(input, h)
            output = self.out(torch.mul(input.squeeze(), code_vec.squeeze()))
            out[:,step] = output
            input = output.unsqueeze(1)
        
        out = out.view(batch, seq_len, in_channels)
        return code_vec.squeeze().view(batch, -1), out, static_out


@variational_estimator
class ATT_NL_CD(torch.nn.Module):
    def __init__(self, in_channels, stat_channels, code_dim, device):
        super(ATT_NL_CD,self).__init__()
        
        self.code_dim = code_dim
        self.device = device
        self.encoder = torch.nn.GRU(in_channels, code_dim, batch_first=True, bidirectional=True)
        self.att = torch.nn.Linear(code_dim, 1)
#         self.att = BayesianLinear(code_dim, 1)
        self.cd = CD(
            weight_regulariser=1e-6,
            dropout_regulariser=1e-3
                    )
        self.cd_layer = torch.nn.Linear(code_dim,code_dim)
        self.decoder = torch.nn.GRU(in_channels, code_dim, batch_first=True)
#         self.decoder = BayesianGRU(in_channels, code_dim)#, batch_first=True)
        self.out = torch.nn.Linear(code_dim, in_channels)
        self.linear_1 = torch.nn.Linear(code_dim,code_dim)
        #self.linear_2 = torch.nn.Linear(code_dim,code_dim)
        self.static_out = torch.nn.Linear(code_dim, stat_channels)
        self.relu = torch.nn.ReLU()

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
    
    def forward(self, x):
        batch, seq_len, in_channels = x.shape

        code_vec, _ = self.encoder(x)
        code_vec = torch.add(code_vec[:,:,:self.code_dim], code_vec[:,:,self.code_dim:])
        att_weights = torch.unsqueeze(torch.nn.functional.softmax(self.att(code_vec).squeeze(), dim=1), dim=-1)
        code_vec = att_weights*code_vec

        code_vec = self.cd(code_vec, self.cd_layer)

        code_vec = torch.unsqueeze(torch.sum(code_vec, dim=1), dim=0)
        static_out = self.relu(self.linear_1(code_vec.squeeze()))
        #static_out = self.relu(self.linear_2(static_out))
        static_out = self.static_out(static_out)

        out = torch.zeros(batch, seq_len, in_channels).to(self.device)
        h = code_vec
        input = torch.zeros((batch, 1, in_channels)).to(self.device)
#         print(h.size())
        for step in range(seq_len):
            input, h = self.decoder(input, h)
            output = self.out(torch.mul(input.squeeze(), code_vec.squeeze()))
            out[:,step] = output
            input = output.unsqueeze(1)
        
        out = out.view(batch, seq_len, in_channels)
        return code_vec.squeeze().view(batch, -1), out, static_out   

    
@variational_estimator
class VAE(torch.nn.Module):
    def __init__(self, in_channels, stat_channels, code_dim, device):
        super(VAE,self).__init__()
        
        self.code_dim = code_dim
        self.device = device
        self.encoder = torch.nn.GRU(in_channels, code_dim, batch_first=True, bidirectional=True)
        self.att = torch.nn.Linear(code_dim, 1)
#         self.att = BayesianLinear(code_dim, 1)
        self.mu = torch.nn.Linear(code_dim,code_dim)
        self.logvar = torch.nn.Linear(code_dim,code_dim)
        self.decoder = torch.nn.GRU(in_channels, code_dim, batch_first=True)
#         self.decoder = BayesianGRU(in_channels, code_dim)#, batch_first=True)
        self.out = torch.nn.Linear(code_dim, in_channels)
        self.linear_1 = torch.nn.Linear(code_dim,code_dim)
        #self.linear_2 = torch.nn.Linear(code_dim,code_dim)
        self.static_out = torch.nn.Linear(code_dim, stat_channels)
        self.relu = torch.nn.ReLU()

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
    
    def forward(self, x):
        batch, seq_len, in_channels = x.shape

        code_vec, _ = self.encoder(x)
        code_vec = torch.add(code_vec[:,:,:self.code_dim], code_vec[:,:,self.code_dim:])
        att_weights = torch.unsqueeze(torch.nn.functional.softmax(self.att(code_vec).squeeze(), dim=1), dim=-1)
        code_vec = att_weights*code_vec

        mu, logvar = self.mu(code_vec), self.logvar(code_vec)
        var_val = logvar#torch.exp(logvar/2)
        code_vec = mu+var_val * torch.randn_like(var_val)
        code_vec = torch.unsqueeze(torch.sum(code_vec, dim=1), dim=0)
        static_out = self.relu(self.linear_1(code_vec.squeeze()))
        #static_out = self.relu(self.linear_2(static_out))
        static_out = self.static_out(static_out)

        out = torch.zeros(batch, seq_len, in_channels).to(self.device)
        h = code_vec
        input = torch.zeros((batch, 1, in_channels)).to(self.device)
#         print(h.size())
        for step in range(seq_len):
            input, h = self.decoder(input, h)
            output = self.out(torch.mul(input.squeeze(), code_vec.squeeze()))
            out[:,step] = output
            input = output.unsqueeze(1)
        
        out = out.view(batch, seq_len, in_channels)
        return code_vec.squeeze().view(batch, -1), out, static_out        

model = globals()[architecture](in_channels=len(input_channels), stat_channels=len(static_channels), code_dim=code_dim, device=device)
model = model.to(device)
criterion = torch.nn.MSELoss(reduction="none")
triplet_criterion = torch.nn.TripletMarginLoss(margin=alpha, p=2.0, eps=1e-06, reduction="none")
# triplet_criterion = torch.nn.TripletMarginWithDistanceLoss(distance_function=torch.nn.PairwiseDistance(p=2.0, eps=1e-06), margin=alpha, reduction="none")
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total trainable parameters:{}".format(pytorch_total_params))

Total trainable parameters:29089


# LOAD PRETRAINED MODEL

In [6]:
if pretrain:
    model.load_state_dict(torch.load(os.path.join(DIR, 'MODEL', "{}_{}.pt".format(pretrain, code_dim))), strict=False)
    model.eval()

# TRAIN MODEL

In [8]:
def run_model(model_object, train_data, validation_data, test_data, model_save_path=None):
    model = model_object
    train_loss = []
    test_loss = []
    validation_loss = []
    min_val = 10000

    for epoch in range(epochs):
        start = time.time()

        model.train()

        #############################################################
        # RUN ON TRAIN DATA
        dataset = train_data

        """Generate positive years"""
        positive_years = np.zeros((dataset.shape[0], dataset.shape[1]))
        for node in range(dataset.shape[0]):
            positive_years[node] = random.sample(range(dataset.shape[1]), dataset.shape[1])
        positive_years = positive_years.astype(np.int64)
        """Generate negative basins"""
        negative_basins = np.zeros((dataset.shape[0], dataset.shape[1]))
        for year in range(dataset.shape[1]):
            negative_basins[:,year] = random.sample(range(dataset.shape[0]), dataset.shape[0])
        negative_basins = negative_basins.astype(np.int64)
        """Generate negative years"""
        negative_years = np.zeros((dataset.shape[0], dataset.shape[1]))
        for node in range(dataset.shape[0]):
            negative_years[node] = random.sample(range(dataset.shape[1]), dataset.shape[1])
        negative_years = negative_years.astype(np.int64)

        total_loss = 0
        total_recon_loss = 0
        total_triplet_loss = 0
        total_static_loss = 0
        for year in range(dataset.shape[1]):
            anchor_data = dataset[:,year]
            positive_data = dataset[np.arange(len(dataset)), positive_years[:,year]]
            negative_data = dataset[negative_basins[:,year], negative_years[:,year]]

            # Remove triplets where (anchor,positive,negative) basins are same
            keep_idx = np.arange(len(dataset)) != negative_basins[:,year]
            anchor_data = anchor_data[keep_idx]
            positive_data = positive_data[keep_idx]
            negative_data = negative_data[keep_idx]

            # Remove triplets where (anchor,positive,negative) basins have unknown in streamflow
            keep_idx = np.zeros((anchor_data.shape[0], 3)).astype(bool)
            keep_idx[:,0] = (anchor_data[:,:,-1]!=config.unknown).all(axis=1)
            keep_idx[:,1] = (positive_data[:,:,-1]!=config.unknown).all(axis=1)
            keep_idx[:,2] = (negative_data[:,:,-1]!=config.unknown).all(axis=1)
            keep_idx = keep_idx.all(axis=1)
            anchor_data = anchor_data[keep_idx]
            positive_data = positive_data[keep_idx]
            negative_data = negative_data[keep_idx]

            for batch in range(math.ceil(anchor_data.shape[0]/batch_size)):
                optimizer.zero_grad()

                batch_anchor_data = torch.from_numpy(anchor_data[batch*batch_size:(batch+1)*batch_size]).to(device)
                batch_positive_data = torch.from_numpy(positive_data[batch*batch_size:(batch+1)*batch_size]).to(device)
                batch_negative_data = torch.from_numpy(negative_data[batch*batch_size:(batch+1)*batch_size]).to(device)

                batch_anchor_data_input = batch_anchor_data[:,:,input_channels]
                batch_positive_data_input = batch_positive_data[:,:,input_channels]
                batch_negative_data_input = batch_negative_data[:,:,input_channels]

                batch_anchor_data_static = batch_anchor_data[:,0,static_channels]
                batch_positive_data_static = batch_positive_data[:,0,static_channels]
                batch_negative_data_static = batch_negative_data[:,0,static_channels]

                input_data = torch.cat((batch_anchor_data_input, batch_positive_data_input, batch_negative_data_input), axis=0)
                static_data = torch.cat((batch_anchor_data_static, batch_positive_data_static, batch_negative_data_static), axis=0)
    #             print(input_data.size())
                code, reconstruction, static_reconstruction = model(input_data.to(device))

                # Calculate reconstruction loss
                recon_loss = torch.sum(criterion(reconstruction, input_data), axis=2)
                recon_loss = torch.masked_select(recon_loss, (input_data[:,:,-1]!=config.unknown))
                recon_loss = torch.mean(recon_loss)

                # Calculate contrastive loss
                anchor_code = code[:batch_anchor_data.shape[0]]
                positive_code = code[batch_anchor_data.shape[0]:batch_anchor_data.shape[0]+batch_positive_data.shape[0]]
                negative_code = code[batch_anchor_data.shape[0]+batch_positive_data.shape[0]:]
                triplet_loss = torch.mean(triplet_criterion(anchor_code, positive_code, negative_code))            


                # Calculate static loss
                static_loss = torch.mean(criterion(static_reconstruction, static_data), axis=1)
                static_loss = torch.mean(static_loss)

                loss = (recon_weight*recon_loss + triplet_weight*triplet_loss + static_weight*static_loss)/sum_weight
                total_loss += loss.item()
                total_recon_loss += recon_loss.item()
                total_triplet_loss += triplet_loss.item()
                total_static_loss += static_loss.item()
                loss.backward()
                optimizer.step()
        print('Epoch:{}\tTrain Loss:{:.4f}\tRecon Loss:{:.4f}\tTriplet Loss:{:.4f}\tStatic Loss:{:.4f}'.format(epoch, total_loss/((batch+1)*(year+1)), total_recon_loss/((batch+1)*(year+1)), total_triplet_loss/((batch+1)*(year+1)), total_static_loss/((batch+1)*(year+1))), end="\t")
        train_loss.append(total_loss/((batch+1)*(year+1)))

        model.eval()

        #############################################################
        # RUN ON VALIDATION DATA
        dataset = validation_data
        """Generate positive years"""
        positive_years = np.zeros((dataset.shape[0], dataset.shape[1]))
        for node in range(dataset.shape[0]):
            positive_years[node] = random.sample(range(dataset.shape[1]), dataset.shape[1])
        positive_years = positive_years.astype(np.int64)
        """Generate negative basins"""
        negative_basins = np.zeros((dataset.shape[0], dataset.shape[1]))
        for year in range(dataset.shape[1]):
            negative_basins[:,year] = random.sample(range(dataset.shape[0]), dataset.shape[0])
        negative_basins = negative_basins.astype(np.int64)
        """Generate negative years"""
        negative_years = np.zeros((dataset.shape[0], dataset.shape[1]))
        for node in range(dataset.shape[0]):
            negative_years[node] = random.sample(range(dataset.shape[1]), dataset.shape[1])
        negative_years = negative_years.astype(np.int64)

        total_loss = 0
        total_recon_loss = 0
        total_triplet_loss = 0
        total_static_loss = 0
        for year in range(dataset.shape[1]):
            anchor_data = dataset[:,year]
            positive_data = dataset[np.arange(len(dataset)), positive_years[:,year]]
            negative_data = dataset[negative_basins[:,year], negative_years[:,year]]

            # Remove triplets where (anchor,positive,negative) basins are same
            keep_idx = np.arange(len(dataset)) != negative_basins[:,year]
            anchor_data = anchor_data[keep_idx]
            positive_data = positive_data[keep_idx]
            negative_data = negative_data[keep_idx]

            # Remove triplets where (anchor,positive,negative) basins have unknown in streamflow
            keep_idx = np.zeros((anchor_data.shape[0], 3)).astype(bool)
            keep_idx[:,0] = (anchor_data[:,:,-1]!=config.unknown).all(axis=1)
            keep_idx[:,1] = (positive_data[:,:,-1]!=config.unknown).all(axis=1)
            keep_idx[:,2] = (negative_data[:,:,-1]!=config.unknown).all(axis=1)
            keep_idx = keep_idx.all(axis=1)
            anchor_data = anchor_data[keep_idx]
            positive_data = positive_data[keep_idx]
            negative_data = negative_data[keep_idx]

            for batch in range(math.ceil(anchor_data.shape[0]/batch_size)):

                batch_anchor_data = torch.from_numpy(anchor_data[batch*batch_size:(batch+1)*batch_size]).to(device)
                batch_positive_data = torch.from_numpy(positive_data[batch*batch_size:(batch+1)*batch_size]).to(device)
                batch_negative_data = torch.from_numpy(negative_data[batch*batch_size:(batch+1)*batch_size]).to(device)

                batch_anchor_data_input = batch_anchor_data[:,:,input_channels]
                batch_positive_data_input = batch_positive_data[:,:,input_channels]
                batch_negative_data_input = batch_negative_data[:,:,input_channels]

                batch_anchor_data_static = batch_anchor_data[:,0,static_channels]
                batch_positive_data_static = batch_positive_data[:,0,static_channels]
                batch_negative_data_static = batch_negative_data[:,0,static_channels]

                input_data = torch.cat((batch_anchor_data_input, batch_positive_data_input, batch_negative_data_input), axis=0)
                static_data = torch.cat((batch_anchor_data_static, batch_positive_data_static, batch_negative_data_static), axis=0)
    #             print(input_data.size())
                code, reconstruction, static_reconstruction = model(input_data.to(device))

                # Calculate reconstruction loss
                recon_loss = torch.sum(criterion(reconstruction, input_data), axis=2)
                recon_loss = torch.masked_select(recon_loss, (input_data[:,:,-1]!=config.unknown))
                recon_loss = torch.mean(recon_loss)

                # Calculate triplet loss
                anchor_code = code[:batch_anchor_data.shape[0]]
                positive_code = code[batch_anchor_data.shape[0]:batch_anchor_data.shape[0]+batch_positive_data.shape[0]]
                negative_code = code[batch_anchor_data.shape[0]+batch_positive_data.shape[0]:]
                triplet_loss = torch.mean(triplet_criterion(anchor_code, positive_code, negative_code))            
    #             pos_dist = ((anchor_code-positive_code)**2).sum(axis=1)
    #             neg_dist = ((anchor_code-negative_code)**2).sum(axis=1)
    #             triplet_loss = torch.mean(torch.nn.functional.relu(pos_dist - neg_dist + alpha))

                # Calculate static loss
                static_loss = torch.mean(criterion(static_reconstruction, static_data), axis=1)
                static_loss = torch.mean(static_loss)

                loss = (recon_weight*recon_loss + triplet_weight*triplet_loss + static_weight*static_loss)/sum_weight
                total_loss += loss.item()
                total_recon_loss += recon_loss.item()
                total_triplet_loss += triplet_loss.item()
                total_static_loss += static_loss.item()
        print('\nVal Loss:{:.4f}\tRecon Loss:{:.4f}\tTriplet Loss:{:.4f}\tStatic Loss:{:.4f}\n\n'.format(total_loss/((batch+1)*(year+1)), total_recon_loss/((batch+1)*(year+1)), total_triplet_loss/((batch+1)*(year+1)), total_static_loss/((batch+1)*(year+1))), end="\t")
        validation_loss.append(total_loss/((batch+1)*(year+1)))
        if min_val>validation_loss[-1] and validation_loss[-1]>0:
            min_val = validation_loss[-1]
            torch.save(model.state_dict(), os.path.join(MODEL_DIR, "{}.pt".format(model_name)))

        end = time.time()
        print("Time:{:.4f}".format(end-start))

        #############################################################
        # RUN ON TEST DATA
        dataset = test_data
        """Generate positive years"""
        positive_years = np.zeros((dataset.shape[0], dataset.shape[1]))
        for node in range(dataset.shape[0]):
            positive_years[node] = random.sample(range(dataset.shape[1]), dataset.shape[1])
        positive_years = positive_years.astype(np.int64)
        """Generate negative basins"""
        negative_basins = np.zeros((dataset.shape[0], dataset.shape[1]))
        for year in range(dataset.shape[1]):
            negative_basins[:,year] = random.sample(range(dataset.shape[0]), dataset.shape[0])
        negative_basins = negative_basins.astype(np.int64)
        """Generate negative years"""
        negative_years = np.zeros((dataset.shape[0], dataset.shape[1]))
        for node in range(dataset.shape[0]):
            negative_years[node] = random.sample(range(dataset.shape[1]), dataset.shape[1])
        negative_years = negative_years.astype(np.int64)

        total_loss = 0
        total_recon_loss = 0
        total_triplet_loss = 0
        total_static_loss = 0
        for year in range(dataset.shape[1]):
            anchor_data = dataset[:,year]
            positive_data = dataset[np.arange(len(dataset)), positive_years[:,year]]
            negative_data = dataset[negative_basins[:,year], negative_years[:,year]]

            # Remove triplets where (anchor,positive,negative) basins are same
            keep_idx = np.arange(len(dataset)) != negative_basins[:,year]
            anchor_data = anchor_data[keep_idx]
            positive_data = positive_data[keep_idx]
            negative_data = negative_data[keep_idx]

            # Remove triplets where (anchor,positive,negative) basins have unknown in streamflow
            keep_idx = np.zeros((anchor_data.shape[0], 3)).astype(bool)
            keep_idx[:,0] = (anchor_data[:,:,-1]!=config.unknown).all(axis=1)
            keep_idx[:,1] = (positive_data[:,:,-1]!=config.unknown).all(axis=1)
            keep_idx[:,2] = (negative_data[:,:,-1]!=config.unknown).all(axis=1)
            keep_idx = keep_idx.all(axis=1)
            anchor_data = anchor_data[keep_idx]
            positive_data = positive_data[keep_idx]
            negative_data = negative_data[keep_idx]

            for batch in range(math.ceil(anchor_data.shape[0]/batch_size)):

                batch_anchor_data = torch.from_numpy(anchor_data[batch*batch_size:(batch+1)*batch_size]).to(device)
                batch_positive_data = torch.from_numpy(positive_data[batch*batch_size:(batch+1)*batch_size]).to(device)
                batch_negative_data = torch.from_numpy(negative_data[batch*batch_size:(batch+1)*batch_size]).to(device)

                batch_anchor_data_input = batch_anchor_data[:,:,input_channels]
                batch_positive_data_input = batch_positive_data[:,:,input_channels]
                batch_negative_data_input = batch_negative_data[:,:,input_channels]

                batch_anchor_data_static = batch_anchor_data[:,0,static_channels]
                batch_positive_data_static = batch_positive_data[:,0,static_channels]
                batch_negative_data_static = batch_negative_data[:,0,static_channels]

                input_data = torch.cat((batch_anchor_data_input, batch_positive_data_input, batch_negative_data_input), axis=0)
                static_data = torch.cat((batch_anchor_data_static, batch_positive_data_static, batch_negative_data_static), axis=0)
    #             print(input_data.size())
                code, reconstruction, static_reconstruction = model(input_data.to(device))

                # Calculate reconstruction loss
                recon_loss = torch.sum(criterion(reconstruction, input_data), axis=2)
                recon_loss = torch.masked_select(recon_loss, (input_data[:,:,-1]!=config.unknown))
                recon_loss = torch.mean(recon_loss)

                # Calculate triplet loss
                anchor_code = code[:batch_anchor_data.shape[0]]
                positive_code = code[batch_anchor_data.shape[0]:batch_anchor_data.shape[0]+batch_positive_data.shape[0]]
                negative_code = code[batch_anchor_data.shape[0]+batch_positive_data.shape[0]:]
                triplet_loss = torch.mean(triplet_criterion(anchor_code, positive_code, negative_code))            
    #             pos_dist = ((anchor_code-positive_code)**2).sum(axis=1)
    #             neg_dist = ((anchor_code-negative_code)**2).sum(axis=1)
    #             triplet_loss = torch.mean(torch.nn.functional.relu(pos_dist - neg_dist + alpha))

                # Calculate static loss
                static_loss = torch.mean(criterion(static_reconstruction, static_data), axis=1)
                static_loss = torch.mean(static_loss)

                loss = (recon_weight*recon_loss + triplet_weight*triplet_loss + static_weight*static_loss)/sum_weight
                total_loss += loss.item()
                total_recon_loss += recon_loss.item()
                total_triplet_loss += triplet_loss.item()
                total_static_loss += static_loss.item()
        print('\nTest Loss:{:.4f}\tRecon Loss:{:.4f}\tTriplet Loss:{:.4f}\tStatic Loss:{:.4f}\n\n'.format(total_loss/((batch+1)*(year+1)), total_recon_loss/((batch+1)*(year+1)), total_triplet_loss/((batch+1)*(year+1)), total_static_loss/((batch+1)*(year+1))), end="\t")
        test_loss.append(total_loss/((batch+1)*(year+1)))        
        
    plt.figure(figsize=(10,10))
    plt.xlabel("#Epoch", fontsize=50)
    plt.plot(train_loss, linewidth=4)
    plt.plot(test_loss, linewidth=4)
    plt.tight_layout(pad=0.0,h_pad=0.0,w_pad=0.0)
    plt.savefig(os.path.join(RESULT_DIR, "{}_LOSS.png".format(model_name)), format = "png")
    plt.close()
    
    preds_std = torch.stack([model(input_data.to(device))[2] for i in range(10)]).std(axis=0).mean().item()
    preds_std_reconstruction = torch.stack([model(input_data.to(device))[1] for i in range(10)]).std(axis=0).mean().item()
    print('\nUnc Est in Static characteristics:{:.4f}\tUnc Est in Streamflow, dynamic:{:.4f}\t\n\n'.format(preds_std, preds_std_reconstruction), end="\t")
    torch.save(model.state_dict(), model_save_path)
    return min_val, total_recon_loss/((batch+1)*(year+1)), total_triplet_loss/((batch+1)*(year+1)), total_static_loss/((batch+1)*(year+1)), preds_std, preds_std_reconstruction

In [9]:
models = ["ATT_NL_0", "ATT_NL_3"] #KGSSL : ATT_NL, BIM: ATT_NL_3
results = []

for a_model in models:
    for i in range(5):
        print(a_model)
        model = globals()[a_model](in_channels=len(input_channels), stat_channels=len(static_channels), code_dim=code_dim, device=device)
        model = model.to(device)
        criterion = torch.nn.MSELoss(reduction="none")
        triplet_criterion = torch.nn.TripletMarginLoss(margin=alpha, p=2.0, eps=1e-06, reduction="none")
        # triplet_criterion = torch.nn.TripletMarginWithDistanceLoss(distance_function=torch.nn.PairwiseDistance(p=2.0, eps=1e-06), margin=alpha, reduction="none")
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        results.append([a_model, run_model(model, train_data, validation_data, test_data,os.path.join(RESULT_DIR, a_model+'_'+str(i)+'.pt') )])
        
print(results)



ATT_NL_0
Epoch:0	Train Loss:0.5471	Recon Loss:4.0932	Triplet Loss:0.3954	Static Loss:0.5268	
Val Loss:0.4180	Recon Loss:4.0309	Triplet Loss:0.2686	Static Loss:0.3968

	Time:51.6790

Test Loss:1.2908	Recon Loss:11.9336	Triplet Loss:0.8622	Static Loss:1.2272

	Epoch:1	Train Loss:0.3804	Recon Loss:3.9588	Triplet Loss:0.2418	Static Loss:0.3584	
Val Loss:0.3762	Recon Loss:4.0082	Triplet Loss:0.2043	Static Loss:0.3570

	Time:51.6685

Test Loss:1.1712	Recon Loss:11.8830	Triplet Loss:0.7004	Static Loss:1.1112

	Epoch:2	Train Loss:0.3490	Recon Loss:3.9496	Triplet Loss:0.2091	Static Loss:0.3270	
Val Loss:0.3498	Recon Loss:3.9913	Triplet Loss:0.2013	Static Loss:0.3283

	Time:50.2824

Test Loss:1.1034	Recon Loss:11.8446	Triplet Loss:0.6396	Static Loss:1.0423

	Epoch:3	Train Loss:0.3311	Recon Loss:3.9086	Triplet Loss:0.1858	Static Loss:0.3098	
Val Loss:0.3390	Recon Loss:3.8896	Triplet Loss:0.1677	Static Loss:0.3206

	Time:49.5650

Test Loss:1.0613	Recon Loss:11.5713	Triplet Loss:0.6305	Static Loss:

OSError: [Errno 122] Disk quota exceeded: 'DATA/RESULT/ATT_NL_3_1.pt'

In [25]:
pd.DataFrame(results).to_csv('results1.csv', index=False, sep=',')

In [None]:
# a_model = "ATT_NL_0"

# model.load_state_dict(torch.load(os.path.join(RESULT_DIR, a_model+'.pt')), strict=False)
# model.eval()

In [11]:
# different noise, missingness, loss weights
models = ["ATT_NL_0", "ATT_NL_3"] #, "ATT_NL_4", "ATT_NL_5", "ATT_NL_6", "ATT_NL" # "ATT_NL_1",
results = []
# corruption_percentage = [ 1, 5, 10, 20, 50]
# corruption_stderr = {1:[10],5:[10], 10:[0.1, 0.5, 1, 5], 20:[0.1, 0.5, 1, 5], 50:[0.1, 0.5, 1, 5]}
# reconstruction_weight_list = [0, 0.1]
# static_weight_list = [0, 10]
# triplet_weight_list = [0, 1]

corruption_percentage = [ 50, 5]
corruption_stderr = {1:[10],5:[10], 10:[0.1, 0.5, 1, 5], 20:[0.1, 0.5, 1, 5], 50:[1]}


reconstruction_weight_list = [0.1]
static_weight_list = [10]
triplet_weight_list = [1]

for recon_weight in reconstruction_weight_list:
    for static_weight in static_weight_list:
        for triplet_weight in triplet_weight_list:
            sum_weight = recon_weight+static_weight+triplet_weight
            if sum_weight>0:
                for corr_per in corruption_percentage:
                    for corr_std in corruption_stderr[corr_per]:
                        train_data = np.load(os.path.join(DIR,'NUMPY', "train_data_basin_corrupted_{}_{}.npy").format(corr_per,corr_std))[:,:,:,:-1]
                        validation_data = np.load(os.path.join(DIR,'NUMPY', "validation_data_basin_corrupted_{}_{}.npy").format(corr_per,corr_std))[:,:,:,:-1]
                        test_data = np.load(os.path.join(DIR,'NUMPY', "test_data_basin_clean_{}_{}.npy").format(corr_per,corr_std))[:,:,:,:-1]
                        # hidden_train_data = np.load(os.path.join(config.NUMPY_DIR, "train_hidden_data_basin.npy"))[:,:,:,:-1]
                        print("Train Data:{}\tValidation Data:{}\tTest Data:{}".format(train_data.shape, validation_data.shape, test_data.shape))            

                        for a_model in models:

                            print("recon weight:{}\tstatic weight:{}\ttriplet weight:{}\tcorr_per:{}\tcorr_std:{}\tmodel:{}".format(recon_weight, static_weight, triplet_weight, corr_per, corr_std, a_model)) 
                            model = globals()[a_model](in_channels=len(input_channels), stat_channels=len(static_channels), code_dim=code_dim, device=device)
                            model = model.to(device)
                            criterion = torch.nn.MSELoss(reduction="none")
                            triplet_criterion = torch.nn.TripletMarginLoss(margin=alpha, p=2.0, eps=1e-06, reduction="none")
                            # triplet_criterion = torch.nn.TripletMarginWithDistanceLoss(distance_function=torch.nn.PairwiseDistance(p=2.0, eps=1e-06), margin=alpha, reduction="none")
                            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
                            results.append([corr_per, corr_std, a_model, run_model(model,train_data, validation_data,test_data, os.path.join(RESULT_DIR, 'corrupt_50_1',a_model+'.pt' )) ])
    

for i in range(len(results)):
    print(results[i])

Train Data:(531, 39, 365, 33)	Validation Data:(531, 9, 365, 33)	Test Data:(531, 19, 365, 33)
recon weight:0.1	static weight:10	triplet weight:1	corr_per:50	corr_std:1	model:ATT_NL_0
Epoch:0	Train Loss:1.0409	Recon Loss:4.1464	Triplet Loss:0.3989	Static Loss:1.0741	
Val Loss:0.9000	Recon Loss:3.9999	Triplet Loss:0.2564	Static Loss:0.9334

	Time:54.4322

Test Loss:1.3625	Recon Loss:11.9434	Triplet Loss:0.8200	Static Loss:1.3109

	Epoch:1	Train Loss:0.8597	Recon Loss:3.9762	Triplet Loss:0.2519	Static Loss:0.8893	
Val Loss:0.8385	Recon Loss:3.9724	Triplet Loss:0.2183	Static Loss:0.8691

	Time:59.8261

Test Loss:1.1840	Recon Loss:11.8857	Triplet Loss:0.7027	Static Loss:1.1251

	Epoch:2	Train Loss:0.8114	Recon Loss:3.9711	Triplet Loss:0.2122	Static Loss:0.8397	
Val Loss:0.8113	Recon Loss:3.9589	Triplet Loss:0.2113	Static Loss:0.8398

	Time:53.3018

Test Loss:1.1772	Recon Loss:11.8554	Triplet Loss:0.6546	Static Loss:1.1227

	Epoch:3	Train Loss:0.7857	Recon Loss:3.9569	Triplet Loss:0.2000	Stat

In [12]:
# different noise, missingness, loss weights
models = ["ATT_NL_0", "ATT_NL_3"] #, "ATT_NL_4", "ATT_NL_5", "ATT_NL_6", "ATT_NL" # "ATT_NL_1",
results = []
# corruption_percentage = [ 1, 5, 10, 20, 50]
# corruption_stderr = {1:[10],5:[10], 10:[0.1, 0.5, 1, 5], 20:[0.1, 0.5, 1, 5], 50:[0.1, 0.5, 1, 5]}
# reconstruction_weight_list = [0, 0.1]
# static_weight_list = [0, 10]
# triplet_weight_list = [0, 1]

corruption_percentage = [ 1 ]
corruption_stderr = {1:[10],5:[10], 10:[0.1, 0.5, 1, 5], 20:[0.1, 0.5, 1, 5], 50:[1]}


reconstruction_weight_list = [0.1]
static_weight_list = [10]
triplet_weight_list = [1]

for recon_weight in reconstruction_weight_list:
    for static_weight in static_weight_list:
        for triplet_weight in triplet_weight_list:
            sum_weight = recon_weight+static_weight+triplet_weight
            if sum_weight>0:
                for corr_per in corruption_percentage:
                    for corr_std in corruption_stderr[corr_per]:
                        train_data = np.load(os.path.join(DIR,'NUMPY', "train_data_basin_corrupted_{}_{}.npy").format(corr_per,corr_std))[:,:,:,:-1]
                        validation_data = np.load(os.path.join(DIR,'NUMPY', "validation_data_basin_corrupted_{}_{}.npy").format(corr_per,corr_std))[:,:,:,:-1]
                        test_data = np.load(os.path.join(DIR,'NUMPY', "test_data_basin_clean_{}_{}.npy").format(corr_per,corr_std))[:,:,:,:-1]
                        # hidden_train_data = np.load(os.path.join(config.NUMPY_DIR, "train_hidden_data_basin.npy"))[:,:,:,:-1]
                        print("Train Data:{}\tValidation Data:{}\tTest Data:{}".format(train_data.shape, validation_data.shape, test_data.shape))            

                        for a_model in models:

                            print("recon weight:{}\tstatic weight:{}\ttriplet weight:{}\tcorr_per:{}\tcorr_std:{}\tmodel:{}".format(recon_weight, static_weight, triplet_weight, corr_per, corr_std, a_model)) 
                            model = globals()[a_model](in_channels=len(input_channels), stat_channels=len(static_channels), code_dim=code_dim, device=device)
                            model = model.to(device)
                            criterion = torch.nn.MSELoss(reduction="none")
                            triplet_criterion = torch.nn.TripletMarginLoss(margin=alpha, p=2.0, eps=1e-06, reduction="none")
                            # triplet_criterion = torch.nn.TripletMarginWithDistanceLoss(distance_function=torch.nn.PairwiseDistance(p=2.0, eps=1e-06), margin=alpha, reduction="none")
                            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
                            results.append([corr_per, corr_std, a_model, run_model(model,train_data, validation_data,test_data, os.path.join(RESULT_DIR, 'corrupt_50_1',a_model+'.pt' )) ])
    

for i in range(len(results)):
    print(results[i])

Train Data:(531, 39, 365, 33)	Validation Data:(531, 9, 365, 33)	Test Data:(531, 19, 365, 33)
recon weight:0.1	static weight:10	triplet weight:1	corr_per:1	corr_std:10	model:ATT_NL_0
Epoch:0	Train Loss:0.6438	Recon Loss:4.2146	Triplet Loss:0.4526	Static Loss:0.6272	
Val Loss:0.4825	Recon Loss:4.0079	Triplet Loss:0.2940	Static Loss:0.4661

	Time:51.4242

Test Loss:1.4219	Recon Loss:11.9927	Triplet Loss:0.9341	Static Loss:1.3649

	Epoch:1	Train Loss:0.4325	Recon Loss:3.9984	Triplet Loss:0.2795	Static Loss:0.4121	
Val Loss:0.4044	Recon Loss:3.9739	Triplet Loss:0.2492	Static Loss:0.3843

	Time:52.1167

Test Loss:1.1923	Recon Loss:11.9042	Triplet Loss:0.7115	Static Loss:1.1333

	Epoch:2	Train Loss:0.3826	Recon Loss:3.9779	Triplet Loss:0.2358	Static Loss:0.3613	
Val Loss:0.3776	Recon Loss:3.9270	Triplet Loss:0.2156	Static Loss:0.3583

	Time:50.9636

Test Loss:1.1563	Recon Loss:11.7502	Triplet Loss:0.6973	Static Loss:1.0963

	Epoch:3	Train Loss:0.3610	Recon Loss:3.9354	Triplet Loss:0.2176	Stat

# FORWARD MODEL: STREAMFLOW PREDICTION

In [None]:
"""
This file is part of the accompanying code to our manuscript:
Kratzert, F., Klotz, D., Shalev, G., Klambauer, G., Hochreiter, S., Nearing, G., "Benchmarking
a Catchment-Aware Long Short-Term Memory Network (LSTM) for Large-Scale Hydrological Modeling".
submitted to Hydrol. Earth Syst. Sci. Discussions (2019)
You should have received a copy of the Apache-2.0 license along with the code. If not,
see <https://opensource.org/licenses/Apache-2.0>
"""

from typing import Tuple

import torch
import torch.nn as nn


class EALSTM(nn.Module):
    """Implementation of the Entity-Aware-LSTM (EA-LSTM)
    TODO: Include paper ref and latex equations
    Parameters
    ----------
    input_size_dyn : int
        Number of dynamic features, which are those, passed to the LSTM at each time step.
    input_size_stat : int
        Number of static features, which are those that are used to modulate the input gate.
    hidden_size : int
        Number of hidden/memory cells.
    batch_first : bool, optional
        If True, expects the batch inputs to be of shape [batch, seq, features] otherwise, the
        shape has to be [seq, batch, features], by default True.
    initial_forget_bias : int, optional
        Value of the initial forget gate bias, by default 0
    """

    def __init__(self,
                 input_size_dyn: int,
                 input_size_stat: int,
                 hidden_size: int,
                 batch_first: bool = True,
                 initial_forget_bias: int = 0):
        super(EALSTM, self).__init__()

        self.input_size_dyn = input_size_dyn
        self.input_size_stat = input_size_stat
        self.hidden_size = hidden_size
        self.batch_first = batch_first
        self.initial_forget_bias = initial_forget_bias

        # create tensors of learnable parameters
        self.weight_ih = nn.Parameter(torch.FloatTensor(input_size_dyn, 3 * hidden_size))
        self.weight_hh = nn.Parameter(torch.FloatTensor(hidden_size, 3 * hidden_size))
        self.weight_sh = nn.Parameter(torch.FloatTensor(input_size_stat, hidden_size))
        self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size))
        self.bias_s = nn.Parameter(torch.FloatTensor(hidden_size))

        # initialize parameters
        self.reset_parameters()

    def reset_parameters(self):
        """Initialize all learnable parameters of the LSTM"""
        nn.init.orthogonal_(self.weight_ih.data)
        nn.init.orthogonal_(self.weight_sh)

        weight_hh_data = torch.eye(self.hidden_size)
        weight_hh_data = weight_hh_data.repeat(1, 3)
        self.weight_hh.data = weight_hh_data

        nn.init.constant_(self.bias.data, val=0)
        nn.init.constant_(self.bias_s.data, val=0)

        if self.initial_forget_bias != 0:
            self.bias.data[:self.hidden_size] = self.initial_forget_bias

    def forward(self, x_d: torch.Tensor, x_s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """[summary]
        Parameters
        ----------
        x_d : torch.Tensor
            Tensor, containing a batch of sequences of the dynamic features. Shape has to match
            the format specified with batch_first.
        x_s : torch.Tensor
            Tensor, containing a batch of static features.
        Returns
        -------
        h_n : torch.Tensor
            The hidden states of each time step of each sample in the batch.
        c_n : torch.Tensor]
            The cell states of each time step of each sample in the batch.
        """
        if self.batch_first:
            x_d = x_d.transpose(0, 1)

        seq_len, batch_size, _ = x_d.size()

        h_0 = x_d.data.new(batch_size, self.hidden_size).zero_()
        c_0 = x_d.data.new(batch_size, self.hidden_size).zero_()
        h_x = (h_0, c_0)

        # empty lists to temporally store all intermediate hidden/cell states
        h_n, c_n = [], []

        # expand bias vectors to batch size
        bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))

        # calculate input gate only once because inputs are static
        bias_s_batch = (self.bias_s.unsqueeze(0).expand(batch_size, *self.bias_s.size()))
        i = torch.sigmoid(torch.addmm(bias_s_batch, x_s, self.weight_sh))

        # perform forward steps over input sequence
        for t in range(seq_len):
            h_0, c_0 = h_x

            # calculate gates
            gates = (torch.addmm(bias_batch, h_0, self.weight_hh) +
                     torch.mm(x_d[t], self.weight_ih))
            f, o, g = gates.chunk(3, 1)

            c_1 = torch.sigmoid(f) * c_0 + i * torch.tanh(g)
            h_1 = torch.sigmoid(o) * torch.tanh(c_1)

            # store intermediate hidden/cell state in list
            h_n.append(h_1)
            c_n.append(c_1)

            h_x = (h_1, c_1)

        h_n = torch.stack(h_n, 0)
        c_n = torch.stack(c_n, 0)

        if self.batch_first:
            h_n = h_n.transpose(0, 1)
            c_n = c_n.transpose(0, 1)

        return h_n, c_n

In [None]:
# parameters and data for EALSTM

train_data = np.load(os.path.join(DIR, 'NUMPY', "train_data_basin_ealstm.npy"))[:,:,:,:-1]#[:-num_hidden,:,:,:-1]
validation_data = np.load(os.path.join(DIR, 'NUMPY', "validation_data_basin_ealstm.npy"))[:,:,:,:-1]#[:-num_hidden,:,:,:-1]
test_data = np.load(os.path.join(DIR, 'NUMPY', "test_data_basin_ealstm.npy"))[:,:,:,:-1]#[-num_hidden:,:,:,:-1]
# hidden_train_data = np.load(os.path.join(config.NUMPY_DIR, "train_hidden_data_basin_ealstm.npy"))[:,:,:,:-1]#[-num_hidden:,:,:,:-1]
print("Train Data:{}\tValidation Data:{}\tTest Data:{}".format(train_data.shape, validation_data.shape, test_data.shape))
feature_names = np.load(os.path.join(DIR, "RAW_DATA", "feature_names.npy"), allow_pickle=True)
print("Static features:{}".format(feature_names[config.static_channels]))
print("Weather features:{}".format(feature_names[config.weather_channels]))
print("SF features:{}".format(feature_names[config.sf_channels]))

In [None]:
# CHANNELS INFO
channels = config.channels
input_channels = config.weather_channels#+config.sf_channels
sf_channels = config.sf_channels
static_channels = config.static_channels
# TIME SERIES INFO
window = config.window

# TRAIN INFO
device = config.device
code_dim = 320#config.code_dim
n_clusters = config.n_clusters
epochs = 10 #config.epochs
batch_size = config.batch_size
learning_rate = 3e-4#config.learning_rate
alpha = config.alpha
recon_weight = 0.1#config.recon_weight
static_weight = 10#config.static_weight
triplet_weight = 1#config.triplet_weight
sum_weight = recon_weight+static_weight+triplet_weight
num_layers = 1
run=5

In [None]:
class LSTM(nn.Module):

    def __init__(self, in_channels, stat_channels, code_dim, num_layers,device):
        super(LSTM,self).__init__()
        
        self.in_channels = in_channels
        self.stat_channels = stat_channels
        self.code_dim = code_dim
        self.device = device
        self.num_layers = num_layers
        #self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
        #                    num_layers=num_layers, batch_first=True)
        self.lstm = EALSTM(in_channels,stat_channels, code_dim, batch_first=True)
        self.out = BayesianLinear(code_dim, 1)
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
    def forward(self, xd,xs):
        batch, seq_len, in_channels = xd.shape
       # h = torch.tensor(torch.zeros(num_layers,batch,self.code_dim))
        #h = h.to(self.device)
        #h = torch.unsqueeze(h,0)
        # Propagate input through LSTM
        #output,_ = self.lstm(x,h)
        output,h = self.lstm(xd,xs)
        
        out = self.out(output)
        
        return out

In [None]:
model = globals()["LSTM"](in_channels=len(input_channels), stat_channels=len(static_channels), code_dim=code_dim,num_layers=num_layers, device=device)
model = model.to(device)
criterion = torch.nn.MSELoss(reduction="none")
triplet_criterion = torch.nn.TripletMarginLoss(margin=alpha, p=2.0, eps=1e-06, reduction="none")
# triplet_criterion = torch.nn.TripletMarginWithDistanceLoss(distance_function=torch.nn.PairwiseDistance(p=2.0, eps=1e-06), margin=alpha, reduction="none")
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total trainable parameters:{}".format(pytorch_total_params))

In [None]:
model_name = "{}_{}_{}_{}".format("Only_EALSTM",num_hidden,run,code_dim)

In [None]:
train_loss = []
validation_loss = []
min_val = 10000

for epoch in range(epochs):
    start = time.time()
    
    model.train()

    #############################################################
    # RUN ON TRAIN DATA
    dataset = train_data
    total_loss = 0
    total_recon_loss = 0
    total_triplet_loss = 0
    total_static_loss = 0
    for year in range(dataset.shape[1]):
        data = dataset[:,year]
        for batch in range(math.ceil(data.shape[0]/batch_size)):
            optimizer.zero_grad()

            batch_data = torch.from_numpy(data[batch*batch_size:(batch+1)*batch_size]).to(device)
            
            batch_data_input = batch_data[:,:,input_channels]
            
            batch_data_static = batch_data[:,0,static_channels]

            batch_sf_data = batch_data[:,:,sf_channels]
            input_data =  batch_data_input
            static_data = batch_data_static
            reconstruction = model(input_data.to(device),static_data.to(device))

#             # Calculate reconstruction loss
#             recon_loss = torch.sum(criterion(reconstruction, input_data), axis=2)
#             recon_loss = torch.masked_select(recon_loss, (input_data[:,:,-1]!=config.unknown))
#             recon_loss = torch.mean(recon_loss)
                        # Calculate reconstruction loss
            recon_loss = criterion(reconstruction, batch_sf_data)
            recon_loss = torch.masked_select(recon_loss, (batch_sf_data!=config.unknown))
            recon_loss = torch.mean(recon_loss)
            loss = (recon_weight*recon_loss)
            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            loss.backward()
            optimizer.step()
    print('Epoch:{}\tTrain Loss:{:.4f}\tRecon Loss:{:.4f}\tTriplet Loss:{:.4f}\tStatic Loss:{:.4f}'.format(epoch, total_loss/((batch+1)*(year+1)), total_recon_loss/((batch+1)*(year+1)), total_triplet_loss/((batch+1)*(year+1)), total_static_loss/((batch+1)*(year+1))), end="\t")
    train_loss.append(total_loss/((batch+1)*(year+1)))
    model.eval()
    
    #############################################################
    # RUN ON VALIDATION DATA
    dataset = validation_data
    total_loss = 0
    total_recon_loss = 0
    total_triplet_loss = 0
    total_static_loss = 0
    for year in range(dataset.shape[1]):
        data = dataset[:,year]
        for batch in range(math.ceil(data.shape[0]/batch_size)):
            optimizer.zero_grad()

            batch_data = torch.from_numpy(data[batch*batch_size:(batch+1)*batch_size]).to(device)
            
            batch_data_input = batch_data[:,:,input_channels]
            
            batch_data_static = batch_data[:,0,static_channels]

            batch_sf_data = batch_data[:,:,sf_channels]
            input_data =  batch_data_input
            static_data = batch_data_static
            reconstruction = model(input_data.to(device),static_data.to(device))

#             # Calculate reconstruction loss
#             recon_loss = torch.sum(criterion(reconstruction, input_data), axis=2)
#             recon_loss = torch.masked_select(recon_loss, (input_data[:,:,-1]!=config.unknown))
#             recon_loss = torch.mean(recon_loss)
                        # Calculate reconstruction loss
            recon_loss = criterion(reconstruction, batch_sf_data)
            recon_loss = torch.masked_select(recon_loss, (batch_sf_data!=config.unknown))
            recon_loss = torch.mean(recon_loss)
            loss = (recon_weight*recon_loss)
            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
    print('Val Loss:{:.4f}\tRecon Loss:{:.4f}\tTriplet Loss:{:.4f}\tStatic Loss:{:.4f}'.format(total_loss/((batch+1)*(year+1)), total_recon_loss/((batch+1)*(year+1)), total_triplet_loss/((batch+1)*(year+1)), total_static_loss/((batch+1)*(year+1))), end="\t")
    validation_loss.append(total_loss/((batch+1)*(year+1)))
    if min_val>validation_loss[-1] and validation_loss[-1]>0:
        min_val = validation_loss[-1]
        torch.save(model.state_dict(), os.path.join(MODEL_DIR, "{}.pt".format(model_name)))    
    end = time.time()
    print("Time:{:.4f}".format(end-start))
plt.figure(figsize=(10,10))
plt.xlabel("#Epoch", fontsize=50)
plt.plot(train_loss, linewidth=4)
plt.plot(validation_loss, linewidth=4)
plt.tight_layout(pad=0.0,h_pad=0.0,w_pad=0.0)
plt.savefig(os.path.join(RESULT_DIR, "{}_LOSS.png".format(model_name)), format = "png")
plt.close()