# train__embedding_similar_binned_distribution

In [1]:
import datetime
import logging
import os
import numpy as np
import torch
from importlib import reload
import matplotlib.pyplot as plt

from models import dapm
from scripts.data_loader import *
from scripts.train_dapm import train
from utils.metrics import normalize_mat
from params import Param
from utils.logging_utils import *

import warnings
warnings.filterwarnings('ignore')


In [2]:

def dapm_main(param, **kwargs):
    
    """ define model name """ 
    model_name = param.generate_model_name()
#     print(model_name)
    ae_model_name = param.generate_ae_model_name()
#     print(ae_model_name)

    """ define prevoius model name for fine tuning """ 
    model_files = os.listdir(kwargs['model_dir'])
    previous_model_name = ''
    for f in os.listdir(kwargs['model_dir']):
        if f'{param.last_year}___#{param.last_month}#' in f and '.pkl' in f:
            previous_model_name = f
    if len(previous_model_name) == 0:
        return

    print(model_name)
    print(previous_model_name)
    model = torch.load(os.path.join(kwargs['model_dir'], previous_model_name)).to(kwargs['device'])

    kwargs['model_name'] = model_name
    kwargs['model_file'] = os.path.join(kwargs['model_dir'], model_name + '.pkl')
    kwargs['log_file'] = os.path.join(kwargs['log_dir'], model_name + '.log')
    kwargs['run_file'] = os.path.join(kwargs['run_dir'], model_name + '_run_{}'.format(datetime.datetime.now().strftime('%d%H%m')))
    kwargs['ae_model_file'] = os.path.join('./data/ae_models_2/models/', ae_model_name + '.pkl')

    """ load data """
    data_dir = f'/home/yijun/notebooks/training_data/'
    data_obj = load_data(data_dir, param)
    train_loc, val_loc, test_loc = load_locations(kwargs['train_val_test'], param)
    
    data_obj.train_loc = train_loc
    data_obj.train_y = data_obj.gen_train_val_test_label(data_obj.label_mat, data_obj.train_loc)
    data_obj.val_loc = val_loc
    data_obj.val_y = data_obj.gen_train_val_test_label(data_obj.label_mat, data_obj.val_loc)
    data_obj.test_loc = test_loc
    data_obj.test_y = data_obj.gen_train_val_test_label(data_obj.label_mat, data_obj.test_loc)
    
    """ logging starts """
    start_logging(kwargs['log_file'], model_name)
    data_logging(data_obj)

    """ load ae model """
#     ae = torch.load(kwargs['ae_model_file'])
    
    """ define DeepAP model
    in_dim, ae_en_h_dims, ae_de_h_dims
    conv_lstm_in_size, conv_lstm_in_dim, conv_lstm_h_dim, conv_lstm_kernel_sizes, conv_lstm_n_layers
    fc_in_dim, fc_h_dims, fc_out_dim  """
#     model = dapm.DeepAPM(in_dim=data_obj.n_features,
#                          ae_en_h_dims=param.ae_en_h_dims,
#                          ae_de_h_dims=param.ae_de_h_dims,
                               
#                          conv_lstm_in_size=(data_obj.n_rows, data_obj.n_cols),
#                          conv_lstm_in_dim=param.ae_en_h_dims[-1],  
#                          conv_lstm_h_dim=[param.dapm_h_dim],  # dap_h_dim
#                          conv_lstm_kernel_sizes=param.kernel_sizes,  # kernel_sizes
#                          conv_lstm_n_layers=1,
                               
#                          fc_in_dim=param.dapm_h_dim * len(param.kernel_sizes),
#                          fc_h_dims=param.fc_h_dims,  # fc_h_dims
#                          fc_out_dim=1,
                                    
#                          ae_pretrain_weight=ae.state_dict(),
#                          mask_thre=param.mask_thre,
#                          fc_p_dropout=0.1,
#                          device=kwargs['device'])
#    
#     model = model.to(kwargs['device'])
    
#     model = torch.load(f'data/los_angeles_500m_separate_1234_tp1/models/dapm___sp_ae_sc___{param.area}_{param.resolution}m_{param.year}___#{param.months[0]}#___6_00001_1___05_01_5_0___16_13.pkl')
#     model = torch.load(f'data/dapm_models/models/{param.area}_{param.resolution}m_{param.year}___#{param.months[0]}#.pkl')
#     model = model.to(kwargs['device'])
#     for p in model.parameters():
#         p.requires_grad = True
    
    train(model, data_obj, param, **kwargs)

    """ logging ends """
    end_logging(model_name)
    

In [3]:
from scipy.optimize import curve_fit

def gaussian(h, r, s, n=0):
    return n + s * (1. - np.exp(- (h ** 2 / (r / 2.) ** 2)))

def exponential(h, r, s, n=0):
    return n + s * (1. - np.exp(-(h / (r / 3.))))

def get_fit_bounds(x, y):
    n = np.nanmin(y)
    r = np.nanmax(x)
    s = np.nanmax(y)
    return (0, [r, s, n])


def get_fit_func(x, y, model):
    try:
        bounds = get_fit_bounds(x, y)
        popt, _ = curve_fit(model, x, y, method='trf', p0=bounds[1], bounds=bounds)
        return popt
    except Exception as e:
        return [0, 0, 0]


def gen_semivariogram(distances, variances, bins, thr):
        
    valid_variances, valid_bins = [], []
    for b in range(len(bins) - 1):
        left, right = bins[b], bins[b + 1]
        mask = (distances >= left) & (distances < right)
        if np.count_nonzero(mask) > 10:
            v = np.nanmean(variances[mask])
            d = np.nanmean(distances[mask])
            valid_variances.append(v)
            valid_bins.append(d)
            
    x, y = np.array(valid_bins), np.array(valid_variances)
    popt = get_fit_func(x, y, model=gaussian)                        
    return popt
    

In [11]:
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as dat
from tensorboardX import SummaryWriter
from torch import autograd

from utils.early_stopping import EarlyStopping
from utils.metrics import compute_error
from models.spatial_loss_func import SpatialLossFunc

        
def train(dapm, data_obj, args, **kwargs):
    
    """ construct index-based data loader """
    idx = np.array([i for i in range(args.seq_len + 1, data_obj.train_y.shape[0])])
    idx_dat = dat.TensorDataset(torch.tensor(idx, dtype=torch.int32))
    train_idx_data_loader = dat.DataLoader(dataset=idx_dat, batch_size=args.batch_size, shuffle=True)
    
    idx = np.array([i for i in range(args.seq_len + 1, data_obj.test_y.shape[0])])
    idx_dat = dat.TensorDataset(torch.tensor(idx, dtype=torch.int32))
    test_idx_data_loader = dat.DataLoader(dataset=idx_dat, batch_size=1, shuffle=False)

    """ set writer, loss function, and optimizer """
    writer = SummaryWriter(kwargs['run_file'])
    loss_func = nn.MSELoss()
    spatial_loss_func = SpatialLossFunc(sp_neighbor=args.sp_neighbor) 
    optimizer = optim.Adam(dapm.parameters(), lr=args.lr, weight_decay=1e-8)
    early_stopping = EarlyStopping(patience=5, verbose=True)
    
    def get_current_learning_rate(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']
        
    def adjust_learning_rate(optimizer, epoch):
        """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
        if epoch > 20:
            lr = args.lr * (0.1 ** (epoch // 10))
        else:
            lr = get_current_learning_rate(optimizer)
        print(lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
            
    for epoch in range(args.epochs):

        dapm.train()
        total_losses, train_losses, val_losses, l1_losses, ae_losses, ac_losses, tp_losses, sp_losses = [], [], [], [], [], [], [], []

#         adjust_learning_rate(optimizer, epoch)
        
        for _, idx in enumerate(train_idx_data_loader):
            batch_idx = idx[0]

            ############################
            # construct sequence input #
            ############################

            def construct_sequence_x(idx_list, dynamic_x, static_x):
                d_x = [dynamic_x[i - args.seq_len: i + 1, ...] for i in idx_list]
                d_x = np.stack(d_x, axis=0)
                s_x = np.expand_dims(static_x, axis=0)
                s_x = np.repeat(s_x, args.seq_len + 1, axis=1)  # (t, c, h, w)
                s_x = np.repeat(s_x, len(idx_list), axis=0)  # (b, t, c, h, w)
                x = np.concatenate([d_x, s_x], axis=2)
                return torch.tensor(x, dtype=torch.float).to(kwargs['device'])

            def construct_y(idx_list, output_y):
                y = [output_y[i] for i in idx_list]
                y = np.stack(y, axis=0)
                return torch.tensor(y, dtype=torch.float).to(kwargs['device'])

            batch_x = construct_sequence_x(batch_idx, data_obj.dynamic_x, data_obj.static_x)  # x = (b, t, c, h, w)
            batch_y = construct_y(batch_idx, data_obj.train_y)  # y = (b, c, h, w)
            batch_val_y = construct_y(batch_idx, data_obj.val_y)

            ###################
            # train the model #
            ###################

            out, masked_x, _, de_x, em = dapm(batch_x)
            train_loss = loss_func(batch_y[~torch.isnan(batch_y)], out[~torch.isnan(batch_y)])
            train_losses.append(train_loss.item())

            # add loss according to the model type
            total_loss = train_loss
            if 'sp' in args.model_type:
                mask_layer_params = torch.cat([x.view(-1) for x in dapm.mask_layer.parameters()])
                l1_regularization = torch.norm(mask_layer_params, 1)
                l1_losses.append(l1_regularization.item())
                total_loss += l1_regularization * args.alpha

            if 'ae' in args.model_type:
                ae_loss = loss_func(masked_x, de_x)
                ae_losses.append(ae_loss.item())
                total_loss += ae_loss * args.gamma
            
            if 'sc' in args.model_type:
                sp_loss = spatial_loss_func(out)
                sp_losses.append(sp_loss.item())
                total_loss += sp_loss * args.beta

            if 'esc' in args.model_type:
                # 1-step temporal neighboring loss
                pre_batch_idx = batch_idx - torch.ones_like(batch_idx)
                pre_batch_x = construct_sequence_x(pre_batch_idx, data_obj.dynamic_x, data_obj.static_x)  # x = (b, t, c, h, w)
                _, _, _, _, pre_em = dapm(pre_batch_x)
                tp_loss = torch.mean(torch.mean((em - pre_em) ** 2, axis=1))
                
                # 1-step spatial neighboring loss
                sp_loss = 0.
                sp_loss += torch.sum(torch.mean((em[..., 1:, 1:] - em[..., :-1, :-1]) ** 2, axis=1)) 
                sp_loss += torch.sum(torch.mean((em[..., :-1, 1:] - em[..., 1:, :-1]) ** 2, axis=1)) 
                sp_loss += torch.sum(torch.mean((em[..., 1:, :] - em[..., :-1, :]) ** 2, axis=1)) 
                sp_loss += torch.sum(torch.mean((em[..., :, 1:] - em[..., :, :-1]) ** 2, axis=1)) 
                sp_loss = sp_loss / args.batch_size / (em.shape[-1] - 1) / (em.shape[-2] - 1)
                
                tp_losses.append(tp_loss.item())
                sp_losses.append(sp_loss.item())
                total_loss += (tp_loss + sp_loss) * args.beta
                
            if 'acc' in args.model_type:
                
                for t in range(batch_x.shape[0]):
                    # flatten embeddings, labels, and predictions
#                     em_flatten = em.permute(1, 0, 2, 3).reshape(64, -1)  # [64, 27968]
                    em_flatten = em[t, ...].reshape(64, -1)  # [64, 27968]
                    y_flatten = torch.flatten(batch_y[t, ...])
                    out_flatten = torch.flatten(out[t, ...])

                    # pairs of labeled locations 
                    y_mask = (~torch.isnan(y_flatten)).nonzero().view(-1)
                    y_mask_pairs = torch.combinations(y_mask)
                    p1, p2 = y_mask_pairs[:, 0], y_mask_pairs[:, 1]
                    y_em_sim = torch.mean((em_flatten[:, p1] - em_flatten[:, p2]) ** 2, dim=0)
                    y_var = (y_flatten[p1] - y_flatten[p2]) ** 2 / 2

                    out_mask = torch.randint(out_flatten.shape[0], (100,)).to(kwargs['device'])
                    out_mask_pairs = torch.combinations(out_mask)
                    p1, p2 = out_mask_pairs[:, 0], out_mask_pairs[:, 1]
                    out_em_sim = torch.mean((em_flatten[:, p1] - em_flatten[:, p2]) ** 2, dim=0)
                    out_var = (out_flatten[p1] - out_flatten[p2]) ** 2 / 2

                    # rescale the distance
                    def max_min_rescale(t, min_t, max_t):
                        t -= min_t
                        t /= (max_t - min_t)
                        return t

                    max_sim = torch.max(torch.cat([out_em_sim, y_em_sim])).detach()
                    min_sim = torch.min(torch.cat([out_em_sim, y_em_sim])).detach()
                    y_em_sim = max_min_rescale(y_em_sim, min_sim, max_sim)
                    out_em_sim = max_min_rescale(out_em_sim, min_sim, max_sim)

                    # generate semivariograms for labeled data                
                    bins = [i / 10 for i in range(11)]
                    thr = y_mask_pairs.shape[0] / len(bins) * 0.01

                    dis_np = y_em_sim.detach().cpu().data.numpy() ** 0.5
                    var_np = y_var.detach().cpu().data.numpy()                        
                    popt = gen_semivariogram(dis_np, var_np, bins, thr)
                    r, s, n = popt

                    # generate semivariograms for unlabeled data
    #                 dis_np1 = out_em_sim.detach().cpu().data.numpy() ** 0.5
    #                 var_np1 = out_var.detach().cpu().data.numpy()
    #                 popt1 = gen_semivariogram(dis_np1, var_np1, bins, thr)
    #                 r1, s1, n1 = popt1    

    #                 if 100 in batch_idx.detach().cpu().data.numpy() and epoch % 5 == 0:
    #                     plt.figure(epoch, figsize=(10, 6))
    #                     plt.scatter(dis_np, var_np, s=0.8, label='Label')
    #                     plt.scatter(dis_np, gaussian(dis_np, *popt), s=0.8, label=f'Empirical Label : {r:.2f}, {s:.2f}, {n:.2f}')
    # #                     plt.scatter(dis_np1, var_np1, s=0.8, label='Prediction')
    #                     plt.scatter(dis_np1, gaussian(dis_np1, *popt1), s=0.8, label=f'Empirical Prediction : {r1:.2f}, {s1:.2f}, {n1:.2f}')
    #                     plt.legend()
    #                     plt.show()

                    sub_loss = 0.
                    valid_bins = [b for b in bins if b < r]                                      
                    if r > 0 and s > 0 and n > 0 and n < s:
                        for b in range(len(valid_bins) - 1):
                            left, right = valid_bins[b], valid_bins[b + 1]
                            mask1 = (y_em_sim >= left ** 2) & (y_em_sim < right ** 2)
                            mask2 = (out_em_sim >= left ** 2) & (out_em_sim < right ** 2)
                            if mask1.sum() > thr and mask2.sum() > thr:
                                mu1, var1 = torch.mean(y_var[mask1]), torch.std(y_var[mask1])
                                mu2, var2 = torch.mean(out_var[mask2]), torch.std(out_var[mask2])   
                                if var1 > 0 and var2 > 0:
                                    sub_loss += (torch.log(var2 ** 2 / var1 ** 2) - 1 + (var1 ** 2 + (mu1 - mu2) ** 2) / var2 ** 2) * 0.5   

                    sub_loss = sub_loss / len(valid_bins) if len(valid_bins) > 0 else 0.
                    total_loss += sub_loss * args.eta
                        
                    try:
                        ac_losses.append(sub_loss.item())
                    except:
                        ac_losses.append(0.0)

            total_losses.append(total_loss.item())

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            ######################
            # validate the model #
            ######################

            val_loss = loss_func(batch_val_y[~torch.isnan(batch_val_y)], out[~torch.isnan(batch_val_y)])
            val_losses.append(val_loss.item())

        avg_total_loss = np.average(total_losses)
        avg_train_loss = np.average(train_losses)
        avg_val_loss = np.average(val_losses)
        avg_l1_loss = np.average(l1_losses)
        avg_ae_loss = np.average(ae_losses)
        avg_ac_loss = np.average(ac_losses)
        avg_tp_loss = np.average(tp_losses)
        avg_sp_loss = np.average(sp_losses)
        

        # write for tensorboard visualization
        writer.add_scalar('data/train_loss', avg_total_loss, epoch)
        writer.add_scalar('data/val_loss', avg_val_loss, epoch)

        logging.info(f'Epoch [{epoch}/{args.epochs}] total_loss = {avg_total_loss:.4f}, train_loss = {avg_train_loss:.4f}, valid_loss = {avg_val_loss:.4f}.')
        logging.info(f'l1_loss = {avg_l1_loss:.4f}, ae_loss = {avg_ae_loss:.4f}, ac_loss = {avg_ac_loss:.4f}, tp_loss = {avg_tp_loss:.4f}, sp_loss = {avg_sp_loss:.4f}.')

        ##################
        # early_stopping #
        ##################

        early_stopping(avg_val_loss, dapm, kwargs['model_file'])

        #########################
        # evaluate testing data #
        #########################
        
        if early_stopping.counter < 2 and epoch % 2 == 0:
            
            dapm.eval()
            predictions = []

            with torch.no_grad():
                for i, data in enumerate(test_idx_data_loader):
                    batch_idx = data[0]
                    batch_x = construct_sequence_x(batch_idx, data_obj.dynamic_x, data_obj.static_x)  # x = (b, t, c, h, w)
                    out, _, _, _, _ = dapm(batch_x)
                    predictions.append(out.cpu().data.numpy())

            prediction = np.concatenate(predictions)
            rmse, mape, r2 = compute_error(data_obj.test_y[args.seq_len + 1:, ...], prediction)
            writer.add_scalar('data/test_rmse', rmse, epoch)
            logging.info(f'Testing: RMSE = {rmse:.4f}, MAPE = {mape:.4f}, R2 = {r2:.4f}.')

        if early_stopping.early_stop:
            logging.info(kwargs['model_name'] + f' val_loss = {early_stopping.val_loss_min:.4f}.')
            logging.info('Early stopping')
            break


In [12]:
"""
    define directory
"""

base_dir = f'data/los_angeles_500m_acc_1234_tp1_2/'
train_val_test_file = f'/home/yijun/notebooks/training_data/train_val_test_los_angeles_500m_fine_tune_1234.json'
device = torch.device("cuda:3" if torch.cuda.is_available() else 'cpu')  # the gpu device

""" load train, val, test locations """
f = open(train_val_test_file, 'r')
train_val_test = json.loads(f.read())

kwargs = {
    'model_dir': os.path.join(base_dir, 'models1/'),
    'log_dir': os.path.join(base_dir, 'logs1/'),
    'run_dir': os.path.join(base_dir, 'runs/'),
    'train_val_test': train_val_test,
    'device': device
}


In [13]:
for i in range(2, 13):

    param = Param([i], 2018, alpha=1, beta=10, gamma=5, eta=0.1, sp_neighbor=1, lr=0.0002, batch_size=8, model_type=['sp', 'ae', 'esc', 'acc'])
    dapm_main(param, **kwargs)     
    

dapm___sp_ae_esc_acc___los_angeles_500m_2018___#02#___6_00001_1___1_10_5_01___16_13
dapm___sp_ae_esc_acc___los_angeles_500m_2018___#01#___6_00001_1___1_20_5_01___16_13.pkl
dapm___sp_ae_esc_acc___los_angeles_500m_2018___#03#___6_00001_1___1_10_5_01___16_13
dapm___sp_ae_esc_acc___los_angeles_500m_2018___#02#___6_00001_1___1_10_5_01___16_13.pkl
dapm___sp_ae_esc_acc___los_angeles_500m_2018___#04#___6_00001_1___1_10_5_01___16_13
dapm___sp_ae_esc_acc___los_angeles_500m_2018___#03#___6_00001_1___1_10_5_01___16_13.pkl
dapm___sp_ae_esc_acc___los_angeles_500m_2018___#05#___6_00001_1___1_10_5_01___16_13
dapm___sp_ae_esc_acc___los_angeles_500m_2018___#04#___6_00001_1___1_10_5_01___16_13.pkl
dapm___sp_ae_esc_acc___los_angeles_500m_2018___#06#___6_00001_1___1_10_5_01___16_13
dapm___sp_ae_esc_acc___los_angeles_500m_2018___#05#___6_00001_1___1_10_5_01___16_13.pkl
dapm___sp_ae_esc_acc___los_angeles_500m_2018___#07#___6_00001_1___1_10_5_01___16_13
dapm___sp_ae_esc_acc___los_angeles_500m_2018___#06#___6_

KeyboardInterrupt: 