# IMPORT LIBRARIES

# LOAD DATA

In [None]:
import sys
sys.path.append("../")
import config
import EALSTM
import os
import random
import math
import time
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE

# CHANNELS INFO
channels = config.channels
input_channels = config.weather_channels+config.sf_channels
sf_channels = config.sf_channels
static_channels = config.static_channels

# TIME SERIES INFO
window = config.window

# TRAIN INFO
device = config.device
code_dim = config.code_dim
n_clusters = config.n_clusters
epochs = 200#config.epochs
batch_size = config.batch_size
learning_rate = 0.003 #config.learning_rate
alpha = config.alpha

# MODEL INFO
recon_weight = 1#config.recon_weight
static_weight = 1#config.static_weight
triplet_weight = 1# config.triplet_weight
sum_weight = recon_weight+static_weight+triplet_weight 
architecture = "ATT_NL"
# architecture = "LAST"
run = 1# int(sys.argv[1])
temp = 1#float(sys.argv[2])
num_hidden =0
num_layer =1
Hidden="Hidden_{}_Serial_contrastive_extend".format(num_hidden)

# MODEL NAME
model_name = "{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}".format("ALL", architecture, code_dim, len(static_channels), run,batch_size,"{}_NL".format(num_layer),Hidden,recon_weight,static_weight,triplet_weight,temp)
pretrain = None

print("{} Hyperparameters".format(model_name))
print("Channels : {}".format(channels))
print("Input Channels : {}".format(input_channels))
print("Static Channels : {}".format(static_channels))
print("Code dim : {}".format(code_dim))
print("Epochs : {}".format(epochs))
print("Batch Size : {}".format(batch_size))
print("Learning rate : {}".format(learning_rate))
print("Reconstruction Weight : {}".format(recon_weight))
print("Static Weight : {}".format(static_weight))
print("Triplet Weight : {}".format(triplet_weight))
print("Temperature : {}".format(temp))
print("Pretrain : {}".format(pretrain))

if not os.path.exists(os.path.join(config.MODEL_DIR)):
    os.makedirs(os.path.join(config.MODEL_DIR))
if not os.path.exists(os.path.join(config.RESULT_DIR)):
    os.makedirs(os.path.join(config.RESULT_DIR))

In [None]:
train_data = np.load(os.path.join(config.NUMPY_DIR, "train_data_basin_ealstm.npy"))[:,:,:,:-1]
validation_data = np.load(os.path.join(config.NUMPY_DIR, "validation_data_basin_ealstm.npy"))[:,:,:,:-1]
test_data = np.load(os.path.join(config.NUMPY_DIR, "test_data_basin_ealstm.npy"))[:,:,:,:-1]

print("Train Data:{}\tValidation Data:{}\tTest Data:{}".format(train_data.shape, validation_data.shape, test_data.shape))

feature_names = np.load(os.path.join(config.DATA_DIR, "RAW_DATA", "feature_names.npy"), allow_pickle=True)
print("Static features:{}".format(feature_names[config.static_channels]))
print("Weather features:{}".format(feature_names[config.weather_channels]))
print("SF features:{}".format(feature_names[config.sf_channels]))

# BUILD MODEL

In [None]:
class SimCLR_Loss(torch.nn.Module):
    def __init__(self, temperature):
        super(SimCLR_Loss, self).__init__()
        self.temperature = temperature
        self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
        self.similarity = torch.nn.CosineSimilarity(dim=2)

    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)
        
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask

    def forward(self, z):
        
        N = z.shape[0]
        batch_size = N//2

        sim = self.similarity(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature

        sim_i_j = torch.diag(sim, batch_size)
        sim_j_i = torch.diag(sim, -batch_size)
        
        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
        mask = self.mask_correlated_samples(batch_size)
        negative_samples = sim[mask].reshape(N, -1)
        
        #SIMCLR
        labels = torch.from_numpy(np.array([0]*N)).reshape(-1).to(positive_samples.device).long()
        
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N
        
        return loss

In [None]:
class ATT_NL(torch.nn.Module):
    def __init__(self, in_channels, stat_channels, code_dim, device):
        super(ATT_NL,self).__init__()
        
        self.code_dim = code_dim
        self.device = device
        self.encoder = torch.nn.GRU(in_channels, code_dim,num_layers=num_layer, batch_first=True, bidirectional=True)
        self.att = torch.nn.Linear(code_dim, 1)
        self.decoder = torch.nn.GRU(in_channels, code_dim, batch_first=True)
        self.out = torch.nn.Linear(code_dim, in_channels)
        self.linear_1 = torch.nn.Linear(code_dim,code_dim)
        #self.linear_2 = torch.nn.Linear(code_dim,code_dim)
        self.static_out = torch.nn.Linear(code_dim, stat_channels)
        self.relu = torch.nn.ReLU()

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
    
    def forward(self, x):
        batch, seq_len, in_channels = x.shape
        
        code_vec, _ = self.encoder(x)
        code_vec = torch.add(code_vec[:,:,:self.code_dim], code_vec[:,:,self.code_dim:])
        att_weights = torch.unsqueeze(torch.nn.functional.softmax(self.att(code_vec).squeeze(), dim=1), dim=-1)
        code_vec = att_weights*code_vec
        code_vec = torch.unsqueeze(torch.sum(code_vec, dim=1), dim=0)
        static_out = self.relu(self.linear_1(code_vec.squeeze()))
        #static_out = self.relu(self.linear_2(static_out))
        static_out = self.static_out(static_out)

        out = torch.zeros(batch, seq_len, in_channels).to(self.device)
        h = code_vec
        input = torch.zeros((batch, 1, in_channels)).to(self.device)
        for step in range(seq_len):
            input, h = self.decoder(input, h)
            output = self.out(torch.mul(input.squeeze(), code_vec.squeeze()))
            out[:,step] = output
            input = output.unsqueeze(1)
        
        out = out.view(batch, seq_len, in_channels)
        return code_vec.squeeze().view(batch, -1), out, static_out     

    
class LAST(torch.nn.Module):
    def __init__(self, in_channels, stat_channels, code_dim, device):
        super(LAST,self).__init__()
        
        self.code_dim = code_dim
        self.device = device
        self.encoder = torch.nn.GRU(in_channels, code_dim, batch_first=True, bidirectional=True)
        self.decoder = torch.nn.GRU(in_channels, code_dim, batch_first=True)
        self.out = torch.nn.Linear(code_dim, in_channels)
        self.static_out = torch.nn.Linear(code_dim, stat_channels)

        self.encoder = torch.nn.GRU(in_channels, code_dim, batch_first=True, bidirectional=True)
        self.decoder = torch.nn.GRU(in_channels, code_dim, batch_first=True)
        self.out = torch.nn.Linear(code_dim, in_channels)

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
    
    def forward(self, x):
        batch, seq_len, in_channels = x.shape
        
        _, code_vec = self.encoder(x)
        code_vec = torch.unsqueeze(torch.sum(code_vec, dim=0), dim=0)
        static_out = self.static_out(code_vec.squeeze())

        out = torch.zeros(batch, seq_len, in_channels).to(self.device)
        h = code_vec
        input = torch.zeros((batch, 1, in_channels)).to(self.device)
        for step in range(seq_len):
            input, h = self.decoder(input, h)
            output = self.out(torch.mul(input.squeeze(), code_vec.squeeze()))
            out[:,step] = output
            input = output.unsqueeze(1)
        
        out = out.view(batch, seq_len, in_channels)
        return code_vec.squeeze().view(batch, -1), out, static_out

model = globals()[architecture](in_channels=len(input_channels), stat_channels=len(static_channels), code_dim=code_dim, device=device)
model = model.to(device)
criterion = torch.nn.MSELoss(reduction="none")
contrastive_criterion = SimCLR_Loss(temperature=temp)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total trainable parameters:{}".format(pytorch_total_params))

# LOAD PRETRAINED MODEL

In [None]:
if pretrain:
    model.load_state_dict(torch.load(os.path.join(config.MODEL_DIR, "{}_{}.pt".format(pretrain, code_dim))), strict=False)
    model.eval()

# TRAIN MODEL

In [None]:
train_loss = []
validation_loss = []
min_val = 10000
nan_batch =[]
for epoch in range(epochs):
    start = time.time()
    
    model.train()

    #############################################################
    # RUN ON TRAIN DATA
    dataset = train_data
    basins = dataset.shape[0]
    years = dataset.shape[1]
    window = dataset.shape[2]
    channels = dataset.shape[3]
    """Generate random years"""
    random_years_1 = np.zeros((basins, years))
    for node in range(basins):
        random_years_1[node] = random.sample(range(years), years)
    random_years_1 = random_years_1.astype(np.int64)
    random_years_2 = np.zeros((basins, years))
    for node in range(basins):
        random_years_2[node] = random.sample(range(years), years)
    random_years_2 = random_years_2.astype(np.int64)

    total_loss = 0
    total_recon_loss = 0
    total_contrastive_loss = 0
    total_static_loss = 0
    flag =1
    for year in range(years):
        anchor_data = dataset[np.arange(dataset.shape[0]), random_years_1[:, year]]
        positive_data = dataset[np.arange(dataset.shape[0]), random_years_2[:, year]]

        # Remove pairs where (anchor,positive) years are same
        keep_idx = random_years_1[:, year] != random_years_2[:, year]
        anchor_data = anchor_data[keep_idx]
        positive_data = positive_data[keep_idx]

        # Remove pairs where (anchor,positive) basins have unknown in streamflow
        keep_idx = np.zeros((anchor_data.shape[0], 2)).astype(bool)
        keep_idx[:,0] = (anchor_data[:,:,-1]!=config.unknown).all(axis=1)
        keep_idx[:,1] = (positive_data[:,:,-1]!=config.unknown).all(axis=1)
        keep_idx = keep_idx.all(axis=1)
        anchor_data = anchor_data[keep_idx]
        positive_data = positive_data[keep_idx]    
        random_batches = random.sample(range(anchor_data.shape[0]),anchor_data.shape[0])
        for batch in range(math.ceil(anchor_data.shape[0]/batch_size)):
            optimizer.zero_grad()
            
            random_batch = random_batches[batch*batch_size:(batch+1)*batch_size]
            batch_anchor_data = torch.from_numpy(anchor_data[random_batch]).to(device)
            batch_positive_data = torch.from_numpy(positive_data[random_batch]).to(device)

            batch_anchor_data_input = batch_anchor_data[:,:,input_channels]
            batch_positive_data_input = batch_positive_data[:,:,input_channels]
            
            batch_anchor_data_static = batch_anchor_data[:,0,static_channels]
            batch_positive_data_static = batch_positive_data[:,0,static_channels]

            input_data = torch.cat((batch_anchor_data_input, batch_positive_data_input), dim=0)
            static_data = torch.cat((batch_anchor_data_static, batch_positive_data_static), axis=0)
            if(input_data.isnan().any()):
                print("True")
                print(torch.isnan(input_data))
            code, reconstruction, static_reconstruction = model(input_data.to(device))

            if(code.isnan().any()):
                flag=0
                #print(input_data)
                print("NAN_TEST_CODE",code)
                print("NAN_TEST_RECONSTRUCTION",reconstruction.isnan().any())
                print("NAN_TEST_STATIC_RECON",static_reconstruction.isnan().any())
                break    
            # Calculate reconstruction loss
            recon_loss = torch.sum(criterion(reconstruction, input_data), axis=2)
            recon_loss = torch.masked_select(recon_loss, (input_data[:,:,-1]!=config.unknown))
            recon_loss = torch.mean(recon_loss)

            # Calculate NT-Xent loss
            filtered_code = code[~torch.any(code.isnan(),dim=1)]
            contrastive_loss = contrastive_criterion(filtered_code)

            # Calculate static loss
            static_loss = torch.mean(criterion(static_reconstruction, static_data), axis=1)
            static_loss = torch.mean(static_loss)

            # Calculate total loss
            loss = (recon_weight*recon_loss + triplet_weight*contrastive_loss + static_weight*static_loss)/sum_weight
            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            total_contrastive_loss += contrastive_loss.item()
            total_static_loss += static_loss.item()

            # Backpropogate loss
            loss.backward()
            optimizer.step()
#         if(flag==0):
#             break
    print('Epoch:{}\tTrain Loss:{:.4f}\tRecon Loss:{:.4f}\tTriplet Loss:{:.4f}\tStatic Loss:{:.4f}'.format(epoch, total_loss/((batch+1)*(year+1)), total_recon_loss/((batch+1)*(year+1)), total_contrastive_loss/((batch+1)*(year+1)), total_static_loss/((batch+1)*(year+1))), end="\t")
    train_loss.append(total_loss/((batch+1)*(year+1)))
    model.eval()
        #############################################################
    # RUN ON validation DATA
    dataset = validation_data
    basins = dataset.shape[0]
    years = dataset.shape[1]
    window = dataset.shape[2]
    channels = dataset.shape[3]
    """Generate random years"""
    random_years_1 = np.zeros((basins, years))
    for node in range(basins):
        random_years_1[node] = random.sample(range(years), years)
    random_years_1 = random_years_1.astype(np.int64)
    random_years_2 = np.zeros((basins, years))
    for node in range(basins):
        random_years_2[node] = random.sample(range(years), years)
    random_years_2 = random_years_2.astype(np.int64)

    total_loss = 0
    total_recon_loss = 0
    total_contrastive_loss = 0
    total_static_loss = 0
    for year in range(years):
        anchor_data = dataset[np.arange(dataset.shape[0]), random_years_1[:, year]]
        positive_data = dataset[np.arange(dataset.shape[0]), random_years_2[:, year]]

        # Remove pairs where (anchor,positive) years are same
        keep_idx = random_years_1[:, year] != random_years_2[:, year]
        anchor_data = anchor_data[keep_idx]
        positive_data = positive_data[keep_idx]

        # Remove pairs where (anchor,positive) basins have unknown in streamflow
        keep_idx = np.zeros((anchor_data.shape[0], 2)).astype(bool)
        keep_idx[:,0] = (anchor_data[:,:,-1]!=config.unknown).all(axis=1)
        keep_idx[:,1] = (positive_data[:,:,-1]!=config.unknown).all(axis=1)
        keep_idx = keep_idx.all(axis=1)
        anchor_data = anchor_data[keep_idx]
        positive_data = positive_data[keep_idx]    
        random_batches = random.sample(range(anchor_data.shape[0]),anchor_data.shape[0])
        for batch in range(math.ceil(anchor_data.shape[0]/batch_size)):
            optimizer.zero_grad()
            
            random_batch = random_batches[batch*batch_size:(batch+1)*batch_size]
            batch_anchor_data = torch.from_numpy(anchor_data[random_batch]).to(device)
            batch_positive_data = torch.from_numpy(positive_data[random_batch]).to(device)

            batch_anchor_data_input = batch_anchor_data[:,:,input_channels]
            batch_positive_data_input = batch_positive_data[:,:,input_channels]
            
            batch_anchor_data_static = batch_anchor_data[:,0,static_channels]
            batch_positive_data_static = batch_positive_data[:,0,static_channels]

            input_data = torch.cat((batch_anchor_data_input, batch_positive_data_input), dim=0)
            static_data = torch.cat((batch_anchor_data_static, batch_positive_data_static), axis=0)
            code, reconstruction, static_reconstruction = model(input_data.to(device))

#             if(code.isnan().any()):
#                 flag=0
#                 #print(input_data)
#                 #print("NAN_TEST",code.isnan().any())
#                 break
            # Calculate reconstruction loss
            recon_loss = torch.sum(criterion(reconstruction, input_data), axis=2)
            recon_loss = torch.masked_select(recon_loss, (input_data[:,:,-1]!=config.unknown))
            recon_loss = torch.mean(recon_loss)

            # Calculate NT-Xent loss
            contrastive_loss = contrastive_criterion(code)

            # Calculate static loss
            static_loss = torch.mean(criterion(static_reconstruction, static_data), axis=1)
            static_loss = torch.mean(static_loss)

            # Calculate total loss
            loss = (recon_weight*recon_loss + triplet_weight*contrastive_loss + static_weight*static_loss)/sum_weight
            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            total_contrastive_loss += contrastive_loss.item()
            total_static_loss += static_loss.item()
    print('Val Loss:{:.4f}\tRecon Loss:{:.4f}\tTriplet Loss:{:.4f}\tStatic Loss:{:.4f}'.format(total_loss/((batch+1)*(year+1)), total_recon_loss/((batch+1)*(year+1)), total_contrastive_loss/((batch+1)*(year+1)), total_static_loss/((batch+1)*(year+1))), end="\t")
    validation_loss.append(total_loss/((batch+1)*(year+1)))
    if min_val>validation_loss[-1] and validation_loss[-1]>0:
        min_val = validation_loss[-1]
        torch.save(model.state_dict(), os.path.join(config.MODEL_DIR, "{}.pt".format(model_name)))
    
    end = time.time()
    print("Time:{:.4f}".format(end-start))
    
plt.figure(figsize=(10,10))
plt.xlabel("#Epoch", fontsize=50)
plt.plot(train_loss, linewidth=4)
plt.plot(validation_loss, linewidth=4)
plt.tight_layout(pad=0.0,h_pad=0.0,w_pad=0.0)
plt.savefig(os.path.join(config.RESULT_DIR, "{}_LOSS.png".format(model_name)), format = "png")
plt.close()