<a href="https://colab.research.google.com/github/xy2119/Causal_Knowledge_GNN/blob/main/notebooks/3_causal_without_memory.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## This notebook uses GCN and GAT to estimate average treatment effect based on BN Connectivity and Causally Weighted Feature Embeddings
Dataset created without saving to disk (avoid `InMemoryDataset` class)

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from gensim.models import word2vec
from causalml.metrics import *
import pandas as pd
import numpy as np 
import random
random.seed(2022)
np.random.seed(2022)

  from pandas import (to_datetime, Int64Index, DatetimeIndex, Period,
  from pandas import (to_datetime, Int64Index, DatetimeIndex, Period,
pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.


In [None]:
!pip install --upgrade gensim
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git



In [None]:
df_train=pd.read_csv("criteo_sampled/criteo_train.csv",index_col=0)
df_test=pd.read_csv("criteo_sampled/criteo_test.csv",index_col=0)

In [None]:
# load causal weighting
ate_list=[]
ate=pd.read_excel("criteo_sampled/feats_ate_x13.xlsx",index_col=0)
for i in [c for c in df_train.columns[:-4]]+["visit","treatment"]:
    ate_list.append(float(ate[ate['Feature']==i]["ATE"].values))
ate_list

[-0.01388679662724193,
 0.05071538371534517,
 0.001328853128405821,
 -0.009331114629769072,
 0.1933108932875426,
 0.0008912893595676433,
 -0.0002958249515911775,
 -0.009201954095841196,
 0.1034574925533489,
 0.0741312293943781,
 -0.01664079180275199,
 0.1331562899243786,
 0.9310119697738223,
 0.007447354046319239]

In [None]:
# load importance weighting
imp_list=[]
imp=pd.read_excel("criteo_sampled/feats_imp_criteo.xlsx",index_col=0)
for i in [c for c in df_train.columns[:-3]]+["visit" ]:
    imp_list.append(float(imp[imp['Feature']==i]["Importance"].values))
imp_list

[2.156023947463837e-05,
 0.0,
 0.0,
 0.0003188499540556222,
 0.0001497909397585317,
 0.0,
 8.849588630255312e-05,
 0.0,
 4.08560226787813e-05,
 0.0,
 2.923487772932276e-05,
 0.0,
 0.9985944032669067,
 0.0007568388828076422]

In [None]:
# load edge index
edge_index=pd.read_csv('criteo_sampled/edge_index_criteo.csv')
edge_index=torch.from_numpy(np.transpose(np.array(edge_index)))
edge_index

tensor([[ 8,  8,  8,  9,  9,  9,  9,  6,  6,  6,  6,  6,  6,  6,  6,  2,  2,  2,
          2,  1,  1,  1, 12,  1,  0,  0,  0,  0,  0,  0,  0,  1, 11,  1,  1,  5,
          5,  5,  5,  5,  5,  7,  7,  7,  3,  3,  3,  9,  5,  3,  4,  7, 10,  1,
          8,  9, 12,  4,  5, 10,  3,  1,  4,  7, 10,  1,  8,  9,  4, 10,  1,  9,
          6,  2, 12,  4, 10,  1,  4, 10,  3,  1,  2, 11,  4, 10,  3,  1,  4, 10,
          1,  4, 10,  1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13,
         13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13],
        [ 9,  5,  3,  4,  7, 10,  1,  8,  9, 12,  4,  5, 10,  3,  1,  4,  7, 10,
          1,  8,  9,  4, 10,  1,  9,  6,  2, 12,  4, 10,  1,  4, 10,  3,  1,  2,
         11,  4, 10,  3,  1,  4, 10,  1,  4, 10,  1,  8,  8,  8,  9,  9,  9,  9,
          6,  6,  6,  6,  6,  6,  6,  6,  2,  2,  2,  2,  1,  1,  1, 12,  1,  0,
          0,  0,  0,  0,  0,  0,  1, 11,  1,  1,  5,  5,  5,  5,  5,  5,  7,  7,
          7,  3,  3,  3, 13, 13, 13, 13, 13, 13, 13

In [None]:
# load node embedding
model_dw=word2vec.Word2Vec.load("criteo_sampled/deepwalk_10d_x13.model")
# model_n2v=word2vec.Word2Vec.load("criteo_sampled/Node2Vec_10d_x13.model")

lst_dw=[]
# lst_n2v=[]
for i in range(14):
    lst_dw.append(model_dw.wv[i])
    #lst_n2v.append(model_n2v.wv[i])
len(lst_dw)

14

In [None]:
from scipy.stats import entropy
def R2(y_predicted, y_actual):
    y_predicted = np.asarray(y_predicted, dtype=float)
    y_actual = np.asarray(y_actual, dtype=float)

    R2 = 1 - np.sum(np.square(y_actual-y_predicted)) / np.sum(np.square(y_actual-np.mean(y_actual)))
    return R2

def MAE(y_predicted, y_actual):
    y_predicted = np.asarray(y_predicted, dtype=float)
    y_actual = np.asarray(y_actual, dtype=float)

    MAE = np.mean(abs(y_actual-y_predicted))
    return MAE

def RMSE(y_predicted, y_actual):
    y_predicted = np.asarray(y_predicted, dtype=float)
    y_actual = np.asarray(y_actual, dtype=float)

    RMSE = np.sqrt(np.mean(np.square(y_actual-y_predicted)))
    return RMSE

def CVRMSE(y_predicted, y_actual):
    y_predicted = np.asarray(y_predicted, dtype=float)
    y_actual = np.asarray(y_actual, dtype=float)

    RMSE = np.sqrt(np.mean(np.square(y_actual-y_predicted)))
    CVRMSE = RMSE/np.mean(y_actual)
    return CVRMSE

def MAPE(y_predicted, y_actual):
    y_predicted = np.asarray(y_predicted, dtype=float)
    y_actual = np.asarray(y_actual, dtype=float)

    MAPE=np.mean(abs((y_actual-y_predicted)/y_actual))
    return MAPE

def MSE(y_predicted, y_actual):
    y_predicted = np.asarray(y_predicted, dtype=float)
    y_actual = np.asarray(y_actual, dtype=float)

    MSE = np.mean(abs(y_actual-y_predicted)**2)
    return MSE

def kl_divergence(p, q):

    p = np.asarray(p, dtype=float)
    q = np.asarray(q, dtype=float)
    return np.sum(np.where(p != 0, p * np.log(p / q), 0))

def kl_divergence(y_predicted, y_actual):

    stacked_values = np.hstack((y_predicted, y_actual))
    stacked_low = np.percentile(stacked_values, 0.1)
    stacked_high = np.percentile(stacked_values, 99.9)
    bins = np.linspace(stacked_low, stacked_high, 100)

    distr = np.histogram(y_predicted, bins=bins)[0]
    distr = np.clip(distr / distr.sum(), 0.001, 0.999)
    true_distr = np.histogram( y_actual, bins=bins)[0]
    true_distr = np.clip(true_distr / true_distr.sum(), 0.001, 0.999)

    kl = entropy(distr, true_distr)
    return kl 

In [None]:
# Original Table of Results
result=pd.DataFrame(columns=[['AUUC']])
result.loc['S Learner(LR)','AUUC']=0.497983
result.loc['S Learner(XGB)','AUUC']=0.875572
result.loc['S Learner(LGBM)','AUUC']=0.883033

result.loc['GCN (Struct)','AUUC']=0.501865
result.loc['GCN (Struct+Feature)','AUUC']=0.721959
result.loc['GCN (Struct+Causal Weighting)','AUUC']=0.732616


result.loc['GAT (Struct)','AUUC']=0.544286
result.loc['GAT (Struct+Feature)','AUUC']=0.847630
result.loc['GAT (Struct+Causal Weighting)','AUUC']=0.8807

result

Unnamed: 0,AUUC
S Learner(LR),0.497983
S Learner(XGB),0.875572
S Learner(LGBM),0.883033
GCN (Struct),0.501865
GCN (Struct+Feature),0.721959
GCN (Struct+Causal Weighting),0.732616
GAT (Struct),0.544286
GAT (Struct+Feature),0.84763
GAT (Struct+Causal Weighting),0.8807


In [None]:
# Input the feature mode and get corresponding dimensions

feats_mode='causal+imp'


if feats_mode=='causal+imp':
    in_dim=39  # 1+10+14+14

elif feats_mode =='noweighting':
    in_dim=11  # 1+10

elif feats_mode in ['causal*imp','causal','imp','equal']:
    in_dim=25  # 1+10+14
    
else:
    print('Error: Input Dimension Error')
    
# subset data for demo, delete this while training ob full dataset
df_train = df_train[:1000]
df_test = df_test[:1000]

col=df_train.columns


data_list=[]
weighted_feats=[]

# Set no. of folds on dataset
folds= 10  # 10 folds
batch=int(len(df_train)/folds) # batch size 

for f in range(folds):
    train=df_train[(batch*f):(batch*(f+1))]
    x_train=torch.from_numpy(np.array(train[[i for i in col[:-4]]+["visit","T"]])).to(torch.float32)
    y_train=torch.from_numpy(np.array(train['y'])).reshape(train.shape[0],1).to(torch.float32)

    for i in range(x_train.shape[0]):
        Edge_index = edge_index.type(torch.long)
        X =x_train[i]
        
        if feats_mode =='causal':
            t=torch.zeros(14,25) # 25 dimensions = 1+10+14 = node number, node embedding, features

            causal_weighted= np.multiply(np.array(X),np.array(ate_list))
            weighted_feats.append(causal_weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(causal_weighted) # feature (10d), causally weighted features        
                                                    ))) 
        elif feats_mode =='imp':
            t=torch.zeros(14,25) # 25 dimensions = 1+10+14 = node number, node embedding, features

            imp_weighted= np.multiply(np.array(X),np.array(imp_list))
            weighted_feats.append(imp_weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(imp_weighted) # feature (10d),importance weighted features        
                                                    ))) 
        elif feats_mode =='causal*imp':
            t=torch.zeros(14,25) # 25 dimensions = 1+10+14 = node number, node embedding, features

            weighted= np.multiply(np.multiply(np.array(X),np.array(imp_list)),np.array(ate_list))
            weighted_feats.append(weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(weighted) # feature (10d), importance weighted features        
                                                    ))) 
        elif feats_mode =='causal+imp':
            t=torch.zeros(14,39) # 39 dimensions = 1+10+14+14 = node number, node embedding, causal features, important features

            imp_weighted= np.multiply(np.array(X),np.array(imp_list))
            causal_weighted= np.multiply(np.array(X),np.array(ate_list))
            weighted_feats.append(imp_weighted+causal_weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(causal_weighted),# feature (10d), causality weighted features
                                                    list(imp_weighted)# feature (10d), importance weighted features        
                                                    ))) 
        elif feats_mode =='equal':
            t=torch.zeros(14,25) # 25 dimensions = 1+10+14 = node number, node embedding, features

            weighted= np.array(X)
            weighted_feats.append(weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(weighted) # feature (10d), equally weighted features        
                                                    ))) 
        else:
            t=torch.zeros(14,11) # 11 dimensions = 1+10 = node number, node embedding


            for j in range(14):
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    ))) 

        Y = y_train[i].reshape(-1,1).to(torch.float32)
        data = Data(x=t, edge_index=Edge_index, y=Y)
        data_list.append(data)
train_loader =DataLoader(data_list, batch_size=200, shuffle=False,num_workers=0)


y_test=torch.from_numpy(np.array(df_test['y'])).reshape(df_test.shape[0],1).to(torch.float32)
x_test=torch.from_numpy(np.array(df_test[[i for i in col[:-4]]+["visit","T"]])).to(torch.float32)

data_list=[]
weighted_feats=[]

for f in range(folds):

    test=df_test[(batch*f):(batch*(f+1))]
    x_test=torch.from_numpy(np.array(test[[i for i in col[:-4]]+["visit","T"]])).to(torch.float32)
    y_test=torch.from_numpy(np.array(test['y'])).reshape(test.shape[0],1).to(torch.float32)

    for i in range(x_test.shape[0]):
        Edge_index = edge_index.type(torch.long)
        X =x_test[i]

        if feats_mode =='causal':
            t=torch.zeros(14,25) # 25 dimensions = 1+10+14 = node number, node embedding, features

            causal_weighted= np.multiply(np.array(X),np.array(ate_list))
            weighted_feats.append(causal_weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(causal_weighted) # feature (10d), causally weighted features        
                                                    ))) 
        elif feats_mode =='imp':
            t=torch.zeros(14,25) # 25 dimensions = 1+10+14 = node number, node embedding, features

            imp_weighted= np.multiply(np.array(X),np.array(imp_list))
            weighted_feats.append(imp_weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(imp_weighted) # feature (10d),importance weighted features        
                                                    ))) 
        elif feats_mode =='causal*imp':
            t=torch.zeros(14,25) # 25 dimensions = 1+10+14 = node number, node embedding, features

            weighted= np.multiply(np.multiply(np.array(X),np.array(imp_list)),np.array(ate_list))
            weighted_feats.append(weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(weighted) # feature (10d), importance weighted features        
                                                    ))) 
        elif feats_mode =='causal+imp':
            t=torch.zeros(14,39) # 39 dimensions = 1+10+14+14 = node number, node embedding, causal features, important features

            imp_weighted= np.multiply(np.array(X),np.array(imp_list))
            causal_weighted= np.multiply(np.array(X),np.array(ate_list))
            weighted_feats.append(imp_weighted+causal_weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(causal_weighted),# feature (10d), causality weighted features
                                                    list(imp_weighted)# feature (10d), importance weighted features        
                                                    ))) 
        elif feats_mode =='equal':
            t=torch.zeros(14,25) # 25 dimensions = 1+10+14 = node number, node embedding, features

            weighted= np.array(X)
            weighted_feats.append(weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(weighted) # feature (10d), equally weighted features        
                                                    ))) 
        else:
            t=torch.zeros(14,11) # 11 dimensions = 1+10 = node number, node embedding


            for j in range(14):
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    ))) 

        Y = y_test[i].reshape(-1,1).to(torch.float32)
        data = Data(x=t, edge_index=Edge_index, y=Y)
        data_list.append(data)
test_loader =DataLoader(data_list, batch_size=200, shuffle=False,num_workers=0)

# GCN

In [None]:
import torch
from torch_scatter import scatter_add
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(in_dim, 64)
        self.conv2 = GCNConv(64, 10)
        self.f1 = torch.nn.Linear(140,32)
        self.f2 = torch.nn.Linear(32,1)
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = x.reshape(-1,140)
        x = self.f1(x)
        x = self.f2(x)
        return x

def train():
    model.train()
    loss_all = 0
    y_actual = []
    y_predicted = []
    loss_all=0
    for data in iter(train_loader):
        loss = 0
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        label = data.y.to(device)
        loss = crit(output, label)
        loss.backward()
        loss_all += data.num_graphs * loss.item()
        optimizer.step()
        y_actual +=(label).cpu().detach().ravel().tolist()
        y_predicted +=(output).cpu().detach().ravel().tolist()
    
    
    loss=loss_all/len(df_train)
    r2=R2(y_predicted, y_actual)
    mse = MSE(y_predicted, y_actual)
    kl=kl_divergence(y_predicted, y_actual)

    print("R2:%f" % (R2(y_predicted, y_actual)),end='  ')
    print("MSE:%f" % (MSE(y_predicted, y_actual)),end='  ')
    print("KL:%f" % (kl_divergence(y_predicted, y_actual)),end='  ')
    print("MAE:%f" % (MAE(y_predicted, y_actual)),end='  ')
    print("RMSE:%f" % (RMSE(y_predicted, y_actual)),end='  ')
    print("CVRMSE:%f" % (CVRMSE(y_predicted, y_actual)),end='  ')

    return loss,r2,mse,kl


def val():
    model.eval()
    y_actual = []
    y_predicted = []
    loss_all=0
    for data in iter(test_loader):
      loss = 0
      data = data.to(device)
      output = model(data)
      label = data.y.to(device)
      y_actual +=(label).cpu().detach().ravel().tolist()
      y_predicted +=(output).cpu().detach().ravel().tolist()
      loss = crit(output, label)
      loss_all += loss.item()

    
    loss = loss_all / len(df_test)
    r2=R2(y_predicted, y_actual)
    mse = MSE(y_predicted, y_actual)
    kl=kl_divergence(y_predicted, y_actual)

    print("R2:%f" % (R2(y_predicted, y_actual)),end='  ')
    print("MSE:%f" % (MSE(y_predicted, y_actual)),end='  ')
    print("KL:%f" % (kl_divergence(y_predicted, y_actual)),end='  ')
    print("MAE:%f" % (MAE(y_predicted, y_actual)),end='  ')
    print("RMSE:%f" % (RMSE(y_predicted, y_actual)),end='  ')
    print("CVRMSE:%f" % (CVRMSE(y_predicted, y_actual)),end='  ')

    return loss,r2, mse,kl

In [None]:
num_epochs = 2560
batch_size = 512
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.002, weight_decay=5e-4)
crit = F.mse_loss
for epoch in range(num_epochs):
    loss,r2,mse,kl=train()
    if epoch %5==0:      
        print('train_loss:')
        print(loss)

        loss,r2,mse,kl=val()
        print('test_loss:')
        print(loss)

y_predicted = []
for data in iter(test_loader):
    loss = 0
    data = data.to(device)
    output = model(data)
    label = data.y.to(device)
    y_predicted +=(output).cpu().detach().ravel().tolist()
df_test["y_hat"]=y_predicted

R2:-2.908061  MSE:0.015570  KL:3.440364  MAE:0.100714  RMSE:0.124779  CVRMSE:31.194666  train_loss:
0.01556971613317728
R2:-0.494224  MSE:0.004469  KL:5.698337  MAE:0.044478  RMSE:0.066852  CVRMSE:22.284086  test_loss:
2.2346122306771575e-05
R2:-0.374339  MSE:0.005475  KL:3.332566  MAE:0.035122  RMSE:0.073996  CVRMSE:18.498932  R2:-0.292356  MSE:0.005149  KL:3.348468  MAE:0.030706  RMSE:0.071755  CVRMSE:17.938691  R2:-0.122749  MSE:0.004473  KL:3.603329  MAE:0.026905  RMSE:0.066881  CVRMSE:16.720185  R2:-0.120190  MSE:0.004463  KL:4.105466  MAE:0.025523  RMSE:0.066804  CVRMSE:16.701117  R2:-0.062155  MSE:0.004232  KL:3.634430  MAE:0.019868  RMSE:0.065051  CVRMSE:16.262735  train_loss:
0.0042316249164287
R2:0.007847  MSE:0.002968  KL:0.603239  MAE:0.011034  RMSE:0.054475  CVRMSE:18.158345  test_loss:
1.4837646922387649e-05
R2:-0.026566  MSE:0.004090  KL:3.250882  MAE:0.015840  RMSE:0.063952  CVRMSE:15.987959  R2:-0.034696  MSE:0.004122  KL:2.936221  MAE:0.013471  RMSE:0.064205  CVRMSE:1

R2:0.117864  MSE:0.003514  KL:0.122063  MAE:0.008886  RMSE:0.059283  CVRMSE:14.820656  R2:0.069812  MSE:0.003706  KL:0.085802  MAE:0.008235  RMSE:0.060876  CVRMSE:15.218962  R2:0.084583  MSE:0.003647  KL:0.068984  MAE:0.008189  RMSE:0.060391  CVRMSE:15.097646  R2:0.118763  MSE:0.003511  KL:0.063959  MAE:0.007876  RMSE:0.059252  CVRMSE:14.813101  R2:0.165962  MSE:0.003323  KL:0.061132  MAE:0.007667  RMSE:0.057644  CVRMSE:14.410947  train_loss:
0.003322806270443834
R2:0.115118  MSE:0.002647  KL:0.034933  MAE:0.006709  RMSE:0.051446  CVRMSE:17.148634  test_loss:
1.3233405610662885e-05
R2:0.094332  MSE:0.003608  KL:0.088619  MAE:0.008576  RMSE:0.060068  CVRMSE:15.017035  R2:0.060457  MSE:0.003743  KL:0.082488  MAE:0.008996  RMSE:0.061181  CVRMSE:15.295299  R2:0.102251  MSE:0.003577  KL:0.068393  MAE:0.008354  RMSE:0.059805  CVRMSE:14.951236  R2:0.091780  MSE:0.003618  KL:0.056414  MAE:0.007735  RMSE:0.060153  CVRMSE:15.038178  R2:0.105252  MSE:0.003565  KL:0.094235  MAE:0.008858  RMSE:0.05

R2:0.066590  MSE:0.003719  KL:0.058717  MAE:0.008439  RMSE:0.060981  CVRMSE:15.245299  R2:0.082922  MSE:0.003654  KL:0.060904  MAE:0.008092  RMSE:0.060445  CVRMSE:15.111333  R2:0.100674  MSE:0.003583  KL:0.044700  MAE:0.007907  RMSE:0.059857  CVRMSE:14.964366  R2:0.088255  MSE:0.003632  KL:0.061769  MAE:0.007852  RMSE:0.060269  CVRMSE:15.067335  R2:0.102949  MSE:0.003574  KL:0.068961  MAE:0.008220  RMSE:0.059782  CVRMSE:14.945422  train_loss:
0.003573849791428074
R2:0.113245  MSE:0.002652  KL:0.038173  MAE:0.006402  RMSE:0.051500  CVRMSE:17.166781  test_loss:
1.3261425352538936e-05
R2:0.124950  MSE:0.003486  KL:0.065308  MAE:0.008355  RMSE:0.059044  CVRMSE:14.761013  R2:0.078797  MSE:0.003670  KL:0.058840  MAE:0.008186  RMSE:0.060581  CVRMSE:15.145283  R2:0.124342  MSE:0.003489  KL:0.064926  MAE:0.007956  RMSE:0.059065  CVRMSE:14.766142  R2:0.104044  MSE:0.003569  KL:0.060977  MAE:0.008289  RMSE:0.059745  CVRMSE:14.936301  R2:0.103424  MSE:0.003572  KL:0.047500  MAE:0.007715  RMSE:0.05

R2:0.064706  MSE:0.003726  KL:0.041890  MAE:0.008309  RMSE:0.061043  CVRMSE:15.260671  R2:0.107516  MSE:0.003556  KL:0.059740  MAE:0.007718  RMSE:0.059629  CVRMSE:14.907327  R2:0.117406  MSE:0.003516  KL:0.045557  MAE:0.007739  RMSE:0.059298  CVRMSE:14.824507  R2:0.095793  MSE:0.003602  KL:0.041743  MAE:0.007611  RMSE:0.060020  CVRMSE:15.004918  R2:0.084522  MSE:0.003647  KL:0.047863  MAE:0.007862  RMSE:0.060393  CVRMSE:15.098146  train_loss:
0.003647264403116424
R2:0.106845  MSE:0.002671  KL:0.041724  MAE:0.006103  RMSE:0.051686  CVRMSE:17.228619  test_loss:
1.3357140647713095e-05
R2:0.100948  MSE:0.003582  KL:0.067639  MAE:0.007492  RMSE:0.059848  CVRMSE:14.962082  R2:0.105369  MSE:0.003564  KL:0.065588  MAE:0.007955  RMSE:0.059701  CVRMSE:14.925249  R2:0.106319  MSE:0.003560  KL:0.067845  MAE:0.008027  RMSE:0.059669  CVRMSE:14.917326  R2:0.103050  MSE:0.003573  KL:0.085731  MAE:0.008133  RMSE:0.059778  CVRMSE:14.944581  R2:0.078020  MSE:0.003673  KL:0.082387  MAE:0.007982  RMSE:0.06

R2:0.111912  MSE:0.003538  KL:0.059287  MAE:0.007727  RMSE:0.059482  CVRMSE:14.870569  R2:0.129947  MSE:0.003466  KL:0.051091  MAE:0.007334  RMSE:0.058875  CVRMSE:14.718800  R2:0.128180  MSE:0.003473  KL:0.056740  MAE:0.008111  RMSE:0.058935  CVRMSE:14.733739  R2:0.109676  MSE:0.003547  KL:0.076655  MAE:0.008117  RMSE:0.059557  CVRMSE:14.889276  R2:0.106724  MSE:0.003559  KL:0.059415  MAE:0.008565  RMSE:0.059656  CVRMSE:14.913943  train_loss:
0.003558811113180127
R2:0.111988  MSE:0.002656  KL:0.040777  MAE:0.006494  RMSE:0.051537  CVRMSE:17.178937  test_loss:
1.3280213897814974e-05
R2:0.115774  MSE:0.003523  KL:0.060326  MAE:0.008127  RMSE:0.059353  CVRMSE:14.838206  R2:0.129016  MSE:0.003470  KL:0.058406  MAE:0.007467  RMSE:0.058907  CVRMSE:14.726678  R2:0.123156  MSE:0.003493  KL:0.064170  MAE:0.008344  RMSE:0.059105  CVRMSE:14.776137  R2:0.123680  MSE:0.003491  KL:0.070634  MAE:0.008322  RMSE:0.059087  CVRMSE:14.771722  R2:0.128452  MSE:0.003472  KL:0.059232  MAE:0.008458  RMSE:0.05

R2:0.091537  MSE:0.003619  KL:0.059777  MAE:0.008286  RMSE:0.060161  CVRMSE:15.040193  R2:0.098906  MSE:0.003590  KL:0.057110  MAE:0.007975  RMSE:0.059916  CVRMSE:14.979068  R2:0.129664  MSE:0.003467  KL:0.054495  MAE:0.007702  RMSE:0.058885  CVRMSE:14.721195  R2:0.135871  MSE:0.003443  KL:0.057974  MAE:0.007859  RMSE:0.058674  CVRMSE:14.668609  R2:0.103876  MSE:0.003570  KL:0.049235  MAE:0.008305  RMSE:0.059751  CVRMSE:14.937704  train_loss:
0.0035701600805623456
R2:0.110075  MSE:0.002662  KL:0.042375  MAE:0.006234  RMSE:0.051592  CVRMSE:17.197433  test_loss:
1.3308826179127209e-05
R2:0.121194  MSE:0.003501  KL:0.058440  MAE:0.007583  RMSE:0.059171  CVRMSE:14.792653  R2:0.104143  MSE:0.003569  KL:0.061142  MAE:0.008043  RMSE:0.059742  CVRMSE:14.935473  R2:0.111513  MSE:0.003540  KL:0.062028  MAE:0.008423  RMSE:0.059496  CVRMSE:14.873915  R2:0.107262  MSE:0.003557  KL:0.059863  MAE:0.008343  RMSE:0.059638  CVRMSE:14.909450  R2:0.147805  MSE:0.003395  KL:0.064529  MAE:0.007666  RMSE:0.0

R2:0.127500  MSE:0.003476  KL:0.052702  MAE:0.007912  RMSE:0.058958  CVRMSE:14.739492  R2:0.119070  MSE:0.003510  KL:0.058064  MAE:0.008195  RMSE:0.059242  CVRMSE:14.810526  R2:0.111349  MSE:0.003540  KL:0.064630  MAE:0.008024  RMSE:0.059501  CVRMSE:14.875283  R2:0.098202  MSE:0.003593  KL:0.045203  MAE:0.007681  RMSE:0.059940  CVRMSE:14.984913  R2:0.133554  MSE:0.003452  KL:0.057617  MAE:0.007773  RMSE:0.058753  CVRMSE:14.688259  train_loss:
0.0034519194086897187
R2:0.108934  MSE:0.002665  KL:0.040800  MAE:0.006335  RMSE:0.051625  CVRMSE:17.208460  test_loss:
1.332589915546123e-05
R2:0.111700  MSE:0.003539  KL:0.058449  MAE:0.007748  RMSE:0.059489  CVRMSE:14.872348  R2:0.126091  MSE:0.003482  KL:0.052303  MAE:0.008040  RMSE:0.059006  CVRMSE:14.751386  R2:0.107868  MSE:0.003554  KL:0.065734  MAE:0.008384  RMSE:0.059618  CVRMSE:14.904394  R2:0.126633  MSE:0.003479  KL:0.066438  MAE:0.008260  RMSE:0.058987  CVRMSE:14.746808  R2:0.098481  MSE:0.003592  KL:0.056723  MAE:0.008132  RMSE:0.05

R2:0.126353  MSE:0.003481  KL:0.052236  MAE:0.007598  RMSE:0.058997  CVRMSE:14.749170  R2:0.128072  MSE:0.003474  KL:0.058380  MAE:0.007689  RMSE:0.058939  CVRMSE:14.734654  R2:0.132630  MSE:0.003456  KL:0.049160  MAE:0.007490  RMSE:0.058784  CVRMSE:14.696090  R2:0.118656  MSE:0.003511  KL:0.059305  MAE:0.007566  RMSE:0.059256  CVRMSE:14.814006  R2:0.136925  MSE:0.003438  KL:0.060941  MAE:0.007730  RMSE:0.058639  CVRMSE:14.659661  train_loss:
0.0034384906568448058
R2:0.110231  MSE:0.002661  KL:0.048359  MAE:0.006368  RMSE:0.051588  CVRMSE:17.195926  test_loss:
1.3306494423886762e-05
R2:0.133971  MSE:0.003450  KL:0.058130  MAE:0.007850  RMSE:0.058739  CVRMSE:14.684725  R2:0.129437  MSE:0.003468  KL:0.061734  MAE:0.007847  RMSE:0.058892  CVRMSE:14.723117  R2:0.116306  MSE:0.003521  KL:0.049210  MAE:0.007671  RMSE:0.059335  CVRMSE:14.833740  R2:0.134868  MSE:0.003447  KL:0.041821  MAE:0.007374  RMSE:0.058708  CVRMSE:14.677124  R2:0.118564  MSE:0.003512  KL:0.054525  MAE:0.007808  RMSE:0.0

R2:0.099974  MSE:0.003586  KL:0.047803  MAE:0.007814  RMSE:0.059881  CVRMSE:14.970187  R2:0.125201  MSE:0.003485  KL:0.054030  MAE:0.007369  RMSE:0.059036  CVRMSE:14.758895  R2:0.113668  MSE:0.003531  KL:0.051331  MAE:0.008128  RMSE:0.059423  CVRMSE:14.855863  R2:0.107064  MSE:0.003557  KL:0.058499  MAE:0.008056  RMSE:0.059644  CVRMSE:14.911108  R2:0.127877  MSE:0.003475  KL:0.045219  MAE:0.007969  RMSE:0.058945  CVRMSE:14.736306  train_loss:
0.0034745395059871953
R2:0.115219  MSE:0.002646  KL:0.046930  MAE:0.006109  RMSE:0.051443  CVRMSE:17.147658  test_loss:
1.3231898177764378e-05
R2:0.118555  MSE:0.003512  KL:0.047653  MAE:0.007673  RMSE:0.059259  CVRMSE:14.814850  R2:0.110451  MSE:0.003544  KL:0.058141  MAE:0.008007  RMSE:0.059531  CVRMSE:14.882801  R2:0.118073  MSE:0.003514  KL:0.052062  MAE:0.008068  RMSE:0.059276  CVRMSE:14.818898  R2:0.114832  MSE:0.003527  KL:0.061381  MAE:0.007874  RMSE:0.059384  CVRMSE:14.846106  R2:0.131679  MSE:0.003459  KL:0.061975  MAE:0.007843  RMSE:0.0

R2:0.115406  MSE:0.003524  KL:0.055199  MAE:0.007779  RMSE:0.059365  CVRMSE:14.841287  R2:0.094982  MSE:0.003606  KL:0.070616  MAE:0.007722  RMSE:0.060047  CVRMSE:15.011643  R2:0.117848  MSE:0.003514  KL:0.067524  MAE:0.007742  RMSE:0.059283  CVRMSE:14.820794  R2:0.130409  MSE:0.003464  KL:0.057844  MAE:0.008098  RMSE:0.058860  CVRMSE:14.714897  R2:0.106459  MSE:0.003560  KL:0.063170  MAE:0.008253  RMSE:0.059665  CVRMSE:14.916160  train_loss:
0.0035598692498751915
R2:0.113950  MSE:0.002650  KL:0.053659  MAE:0.006499  RMSE:0.051480  CVRMSE:17.159950  test_loss:
1.3250875781523064e-05
R2:0.115138  MSE:0.003525  KL:0.056345  MAE:0.008239  RMSE:0.059374  CVRMSE:14.843536  R2:0.117901  MSE:0.003514  KL:0.053112  MAE:0.007817  RMSE:0.059281  CVRMSE:14.820348  R2:0.110078  MSE:0.003545  KL:0.057945  MAE:0.007729  RMSE:0.059544  CVRMSE:14.885923  R2:0.145407  MSE:0.003405  KL:0.054291  MAE:0.007620  RMSE:0.058350  CVRMSE:14.587446  R2:0.103578  MSE:0.003571  KL:0.050819  MAE:0.007974  RMSE:0.0

R2:0.121065  MSE:0.003502  KL:0.055827  MAE:0.007777  RMSE:0.059175  CVRMSE:14.793741  R2:0.119740  MSE:0.003507  KL:0.063266  MAE:0.008111  RMSE:0.059220  CVRMSE:14.804891  R2:0.105103  MSE:0.003565  KL:0.069581  MAE:0.008772  RMSE:0.059710  CVRMSE:14.927471  R2:0.146587  MSE:0.003400  KL:0.059486  MAE:0.007991  RMSE:0.058309  CVRMSE:14.577374  R2:0.141483  MSE:0.003420  KL:0.052191  MAE:0.007429  RMSE:0.058484  CVRMSE:14.620898  train_loss:
0.0034203300412627867
R2:0.119247  MSE:0.002634  KL:0.054868  MAE:0.007090  RMSE:0.051326  CVRMSE:17.108585  test_loss:
1.3171665981644764e-05
R2:0.139125  MSE:0.003430  KL:0.049622  MAE:0.007973  RMSE:0.058564  CVRMSE:14.640965  R2:0.111388  MSE:0.003540  KL:0.069201  MAE:0.008321  RMSE:0.059500  CVRMSE:14.874955  R2:0.122727  MSE:0.003495  KL:0.061395  MAE:0.008741  RMSE:0.059119  CVRMSE:14.779752  R2:0.119414  MSE:0.003508  KL:0.053580  MAE:0.008013  RMSE:0.059231  CVRMSE:14.807628  R2:0.125743  MSE:0.003483  KL:0.049648  MAE:0.007832  RMSE:0.0

R2:0.115357  MSE:0.003524  KL:0.061731  MAE:0.007615  RMSE:0.059367  CVRMSE:14.841699  R2:0.132789  MSE:0.003455  KL:0.070429  MAE:0.007772  RMSE:0.058779  CVRMSE:14.694745  R2:0.137221  MSE:0.003437  KL:0.060278  MAE:0.008203  RMSE:0.058629  CVRMSE:14.657151  R2:0.129318  MSE:0.003469  KL:0.068487  MAE:0.008371  RMSE:0.058896  CVRMSE:14.724120  R2:0.111508  MSE:0.003540  KL:0.052573  MAE:0.008317  RMSE:0.059496  CVRMSE:14.873955  train_loss:
0.003539752356300596
R2:0.113259  MSE:0.002652  KL:0.053986  MAE:0.006216  RMSE:0.051500  CVRMSE:17.166639  test_loss:
1.3261205778690055e-05
R2:0.126394  MSE:0.003480  KL:0.066055  MAE:0.007645  RMSE:0.058995  CVRMSE:14.748823  R2:0.105787  MSE:0.003563  KL:0.052999  MAE:0.007765  RMSE:0.059687  CVRMSE:14.921768  R2:0.142542  MSE:0.003416  KL:0.054670  MAE:0.007780  RMSE:0.058448  CVRMSE:14.611881  R2:0.133510  MSE:0.003452  KL:0.051423  MAE:0.007934  RMSE:0.058755  CVRMSE:14.688633  R2:0.129037  MSE:0.003470  KL:0.060402  MAE:0.008197  RMSE:0.05

R2:0.134133  MSE:0.003450  KL:0.035264  MAE:0.008334  RMSE:0.058733  CVRMSE:14.683351  R2:0.134923  MSE:0.003446  KL:0.062293  MAE:0.008023  RMSE:0.058707  CVRMSE:14.676655  R2:0.125643  MSE:0.003483  KL:0.042859  MAE:0.007927  RMSE:0.059021  CVRMSE:14.755167  R2:0.144488  MSE:0.003408  KL:0.050977  MAE:0.007709  RMSE:0.058381  CVRMSE:14.595288  R2:0.135453  MSE:0.003444  KL:0.044697  MAE:0.007864  RMSE:0.058689  CVRMSE:14.672160  train_loss:
0.0034443566284608094
R2:0.121237  MSE:0.002628  KL:0.047529  MAE:0.006262  RMSE:0.051268  CVRMSE:17.089239  test_loss:
1.3141892733983695e-05
R2:0.144535  MSE:0.003408  KL:0.045864  MAE:0.007758  RMSE:0.058380  CVRMSE:14.594889  R2:0.112397  MSE:0.003536  KL:0.049042  MAE:0.007908  RMSE:0.059466  CVRMSE:14.866514  R2:0.110667  MSE:0.003543  KL:0.040738  MAE:0.007799  RMSE:0.059524  CVRMSE:14.880990  R2:0.108495  MSE:0.003552  KL:0.040554  MAE:0.007366  RMSE:0.059597  CVRMSE:14.899151  R2:0.112091  MSE:0.003537  KL:0.048413  MAE:0.007893  RMSE:0.0

R2:0.127837  MSE:0.003475  KL:0.041554  MAE:0.007633  RMSE:0.058947  CVRMSE:14.736638  R2:0.135186  MSE:0.003445  KL:0.052724  MAE:0.007629  RMSE:0.058698  CVRMSE:14.674422  R2:0.116589  MSE:0.003520  KL:0.059172  MAE:0.008120  RMSE:0.059325  CVRMSE:14.831363  R2:0.102521  MSE:0.003576  KL:0.047972  MAE:0.008322  RMSE:0.059796  CVRMSE:14.948993  R2:0.140208  MSE:0.003425  KL:0.055017  MAE:0.007951  RMSE:0.058527  CVRMSE:14.631752  train_loss:
0.0034254106663865968
R2:0.120420  MSE:0.002631  KL:0.044439  MAE:0.006012  RMSE:0.051292  CVRMSE:17.097190  test_loss:
1.315412604890298e-05
R2:0.134719  MSE:0.003447  KL:0.050456  MAE:0.007544  RMSE:0.058714  CVRMSE:14.678383  R2:0.141822  MSE:0.003419  KL:0.038357  MAE:0.007540  RMSE:0.058472  CVRMSE:14.618015  R2:0.147638  MSE:0.003396  KL:0.049361  MAE:0.007631  RMSE:0.058274  CVRMSE:14.568396  R2:0.107456  MSE:0.003556  KL:0.055558  MAE:0.008181  RMSE:0.059631  CVRMSE:14.907836  R2:0.102437  MSE:0.003576  KL:0.059360  MAE:0.008338  RMSE:0.05

R2:0.138181  MSE:0.003433  KL:0.050037  MAE:0.007996  RMSE:0.058596  CVRMSE:14.648992  R2:0.139672  MSE:0.003428  KL:0.048994  MAE:0.007736  RMSE:0.058545  CVRMSE:14.636317  R2:0.147196  MSE:0.003398  KL:0.058974  MAE:0.008153  RMSE:0.058289  CVRMSE:14.572175  R2:0.145969  MSE:0.003402  KL:0.054567  MAE:0.008335  RMSE:0.058331  CVRMSE:14.582655  R2:0.093404  MSE:0.003612  KL:0.040899  MAE:0.008068  RMSE:0.060099  CVRMSE:15.024724  train_loss:
0.003611877288494725
R2:0.122341  MSE:0.002625  KL:0.046602  MAE:0.006417  RMSE:0.051236  CVRMSE:17.078506  test_loss:
1.3125391415087507e-05
R2:0.134311  MSE:0.003449  KL:0.038294  MAE:0.007871  RMSE:0.058727  CVRMSE:14.681846  R2:0.144721  MSE:0.003407  KL:0.041911  MAE:0.007347  RMSE:0.058373  CVRMSE:14.593303  R2:0.138457  MSE:0.003432  KL:0.059626  MAE:0.007499  RMSE:0.058587  CVRMSE:14.646645  R2:0.142111  MSE:0.003418  KL:0.059188  MAE:0.007508  RMSE:0.058462  CVRMSE:14.615550  R2:0.147324  MSE:0.003397  KL:0.061075  MAE:0.007832  RMSE:0.05

R2:0.115149  MSE:0.003525  KL:0.058611  MAE:0.007787  RMSE:0.059374  CVRMSE:14.843449  R2:0.115926  MSE:0.003522  KL:0.046678  MAE:0.007597  RMSE:0.059348  CVRMSE:14.836924  R2:0.132195  MSE:0.003457  KL:0.055472  MAE:0.007600  RMSE:0.058799  CVRMSE:14.699779  R2:0.130432  MSE:0.003464  KL:0.059295  MAE:0.007777  RMSE:0.058859  CVRMSE:14.714700  R2:0.132491  MSE:0.003456  KL:0.050828  MAE:0.008167  RMSE:0.058789  CVRMSE:14.697271  train_loss:
0.0034561565102194437
R2:0.119431  MSE:0.002634  KL:0.060294  MAE:0.006498  RMSE:0.051320  CVRMSE:17.106795  test_loss:
1.316890976158902e-05
R2:0.125934  MSE:0.003482  KL:0.057581  MAE:0.008120  RMSE:0.059011  CVRMSE:14.752707  R2:0.113582  MSE:0.003531  KL:0.053130  MAE:0.007827  RMSE:0.059426  CVRMSE:14.856585  R2:0.126424  MSE:0.003480  KL:0.045550  MAE:0.007491  RMSE:0.058994  CVRMSE:14.748577  R2:0.110315  MSE:0.003545  KL:0.047165  MAE:0.008003  RMSE:0.059536  CVRMSE:14.883940  R2:0.119112  MSE:0.003509  KL:0.048589  MAE:0.007735  RMSE:0.05

R2:0.142596  MSE:0.003416  KL:0.045199  MAE:0.007586  RMSE:0.058446  CVRMSE:14.611422  R2:0.143309  MSE:0.003413  KL:0.046297  MAE:0.008127  RMSE:0.058421  CVRMSE:14.605345  R2:0.124198  MSE:0.003489  KL:0.043981  MAE:0.008470  RMSE:0.059069  CVRMSE:14.767355  R2:0.108153  MSE:0.003553  KL:0.052645  MAE:0.008157  RMSE:0.059608  CVRMSE:14.902013  R2:0.144784  MSE:0.003407  KL:0.043895  MAE:0.007396  RMSE:0.058371  CVRMSE:14.592765  train_loss:
0.0034071807080181316
R2:0.127005  MSE:0.002611  KL:0.037724  MAE:0.006852  RMSE:0.051099  CVRMSE:17.033066  test_loss:
1.3055640301899985e-05
R2:0.109227  MSE:0.003549  KL:0.041480  MAE:0.008037  RMSE:0.059572  CVRMSE:14.893036  R2:0.121899  MSE:0.003498  KL:0.049334  MAE:0.007649  RMSE:0.059147  CVRMSE:14.786726  R2:0.121796  MSE:0.003499  KL:0.066297  MAE:0.007831  RMSE:0.059150  CVRMSE:14.787588  R2:0.149849  MSE:0.003387  KL:0.044154  MAE:0.007996  RMSE:0.058198  CVRMSE:14.549492  R2:0.109283  MSE:0.003549  KL:0.052085  MAE:0.007716  RMSE:0.0

R2:0.142558  MSE:0.003416  KL:0.045211  MAE:0.008131  RMSE:0.058447  CVRMSE:14.611744  R2:0.139981  MSE:0.003426  KL:0.049750  MAE:0.008129  RMSE:0.058535  CVRMSE:14.633683  R2:0.117098  MSE:0.003517  KL:0.039530  MAE:0.007381  RMSE:0.059308  CVRMSE:14.827090  R2:0.139256  MSE:0.003429  KL:0.055574  MAE:0.007993  RMSE:0.058559  CVRMSE:14.639849  R2:0.120323  MSE:0.003505  KL:0.052917  MAE:0.008227  RMSE:0.059200  CVRMSE:14.799983  train_loss:
0.0035046313496422954
R2:0.125132  MSE:0.002617  KL:0.047216  MAE:0.006362  RMSE:0.051154  CVRMSE:17.051328  test_loss:
1.3083649158943444e-05
R2:0.145535  MSE:0.003404  KL:0.036522  MAE:0.008086  RMSE:0.058345  CVRMSE:14.586360  R2:0.130368  MSE:0.003465  KL:0.044605  MAE:0.007749  RMSE:0.058861  CVRMSE:14.715244  R2:0.131247  MSE:0.003461  KL:0.048704  MAE:0.007706  RMSE:0.058831  CVRMSE:14.707803  R2:0.157015  MSE:0.003358  KL:0.041950  MAE:0.007658  RMSE:0.057952  CVRMSE:14.488035  R2:0.120229  MSE:0.003505  KL:0.043877  MAE:0.008284  RMSE:0.0

R2:0.120224  MSE:0.003505  KL:0.057558  MAE:0.007975  RMSE:0.059203  CVRMSE:14.800818  R2:0.132316  MSE:0.003457  KL:0.044911  MAE:0.007929  RMSE:0.058795  CVRMSE:14.698749  R2:0.142001  MSE:0.003418  KL:0.054257  MAE:0.007818  RMSE:0.058466  CVRMSE:14.616489  R2:0.137215  MSE:0.003437  KL:0.052692  MAE:0.008115  RMSE:0.058629  CVRMSE:14.657195  R2:0.138624  MSE:0.003432  KL:0.048021  MAE:0.007769  RMSE:0.058581  CVRMSE:14.645225  train_loss:
0.003431721934612142
R2:0.116262  MSE:0.002643  KL:0.045997  MAE:0.005796  RMSE:0.051413  CVRMSE:17.137546  test_loss:
1.3216296632890589e-05
R2:0.129358  MSE:0.003469  KL:0.047910  MAE:0.007260  RMSE:0.058895  CVRMSE:14.723782  R2:0.190663  MSE:0.003224  KL:0.048530  MAE:0.007855  RMSE:0.056784  CVRMSE:14.195947  R2:0.134381  MSE:0.003449  KL:0.054578  MAE:0.008475  RMSE:0.058725  CVRMSE:14.681250  R2:0.147167  MSE:0.003398  KL:0.050467  MAE:0.008749  RMSE:0.058290  CVRMSE:14.572423  R2:0.147738  MSE:0.003395  KL:0.046476  MAE:0.008102  RMSE:0.05

R2:0.147706  MSE:0.003396  KL:0.042044  MAE:0.007780  RMSE:0.058271  CVRMSE:14.567815  R2:0.155036  MSE:0.003366  KL:0.034970  MAE:0.007734  RMSE:0.058020  CVRMSE:14.505031  R2:0.134815  MSE:0.003447  KL:0.039252  MAE:0.007605  RMSE:0.058710  CVRMSE:14.677567  R2:0.133988  MSE:0.003450  KL:0.042040  MAE:0.007572  RMSE:0.058738  CVRMSE:14.684580  R2:0.141696  MSE:0.003419  KL:0.037007  MAE:0.007377  RMSE:0.058476  CVRMSE:14.619087  train_loss:
0.0034194833002402446
R2:0.125618  MSE:0.002615  KL:0.039970  MAE:0.006142  RMSE:0.051140  CVRMSE:17.046590  test_loss:
1.307638073922135e-05
R2:0.155729  MSE:0.003364  KL:0.044165  MAE:0.007751  RMSE:0.057996  CVRMSE:14.499084  R2:0.142712  MSE:0.003415  KL:0.054336  MAE:0.007782  RMSE:0.058442  CVRMSE:14.610432  R2:0.164609  MSE:0.003328  KL:0.062905  MAE:0.007674  RMSE:0.057691  CVRMSE:14.422634  R2:0.136506  MSE:0.003440  KL:0.051472  MAE:0.007993  RMSE:0.058653  CVRMSE:14.663216  R2:0.161007  MSE:0.003343  KL:0.047561  MAE:0.007621  RMSE:0.05

R2:0.119819  MSE:0.003507  KL:0.034380  MAE:0.007969  RMSE:0.059217  CVRMSE:14.804228  R2:0.143771  MSE:0.003411  KL:0.033917  MAE:0.007307  RMSE:0.058406  CVRMSE:14.601405  R2:0.128018  MSE:0.003474  KL:0.052293  MAE:0.007750  RMSE:0.058940  CVRMSE:14.735115  R2:0.150205  MSE:0.003386  KL:0.048091  MAE:0.008100  RMSE:0.058186  CVRMSE:14.546442  R2:0.150152  MSE:0.003386  KL:0.043150  MAE:0.008307  RMSE:0.058188  CVRMSE:14.546895  train_loss:
0.0033857949601951987
R2:0.123246  MSE:0.002622  KL:0.049946  MAE:0.006661  RMSE:0.051209  CVRMSE:17.069698  test_loss:
1.3111856154864654e-05
R2:0.136420  MSE:0.003441  KL:0.054888  MAE:0.008292  RMSE:0.058656  CVRMSE:14.663947  R2:0.128501  MSE:0.003472  KL:0.049466  MAE:0.007703  RMSE:0.058924  CVRMSE:14.731029  R2:0.133993  MSE:0.003450  KL:0.051225  MAE:0.007749  RMSE:0.058738  CVRMSE:14.684537  R2:0.133936  MSE:0.003450  KL:0.034456  MAE:0.007499  RMSE:0.058740  CVRMSE:14.685028  R2:0.124000  MSE:0.003490  KL:0.047907  MAE:0.007343  RMSE:0.0

R2:0.159801  MSE:0.003347  KL:0.036753  MAE:0.007422  RMSE:0.057856  CVRMSE:14.464077  R2:0.151192  MSE:0.003382  KL:0.043880  MAE:0.007602  RMSE:0.058152  CVRMSE:14.537994  R2:0.139150  MSE:0.003430  KL:0.041948  MAE:0.008084  RMSE:0.058563  CVRMSE:14.640757  R2:0.142555  MSE:0.003416  KL:0.043128  MAE:0.008256  RMSE:0.058447  CVRMSE:14.611770  R2:0.151014  MSE:0.003382  KL:0.041210  MAE:0.007688  RMSE:0.058158  CVRMSE:14.539513  train_loss:
0.0033823593956185507
R2:0.127215  MSE:0.002611  KL:0.040794  MAE:0.006176  RMSE:0.051093  CVRMSE:17.031020  test_loss:
1.3052504655206577e-05
R2:0.165695  MSE:0.003324  KL:0.044796  MAE:0.007764  RMSE:0.057653  CVRMSE:14.413255  R2:0.142439  MSE:0.003417  KL:0.050094  MAE:0.008107  RMSE:0.058451  CVRMSE:14.612756  R2:0.161697  MSE:0.003340  KL:0.038700  MAE:0.007941  RMSE:0.057791  CVRMSE:14.447750  R2:0.122810  MSE:0.003495  KL:0.050908  MAE:0.007459  RMSE:0.059116  CVRMSE:14.779050  R2:0.125455  MSE:0.003484  KL:0.040362  MAE:0.007619  RMSE:0.0

R2:0.142812  MSE:0.003415  KL:0.050499  MAE:0.007617  RMSE:0.058438  CVRMSE:14.609583  R2:0.136906  MSE:0.003439  KL:0.046489  MAE:0.007814  RMSE:0.058639  CVRMSE:14.659825  R2:0.147290  MSE:0.003397  KL:0.045963  MAE:0.007840  RMSE:0.058285  CVRMSE:14.571365  R2:0.167976  MSE:0.003315  KL:0.052181  MAE:0.007552  RMSE:0.057574  CVRMSE:14.393537  R2:0.125991  MSE:0.003482  KL:0.041247  MAE:0.007562  RMSE:0.059009  CVRMSE:14.752226  train_loss:
0.003482050531601999
R2:0.133050  MSE:0.002593  KL:0.045818  MAE:0.006044  RMSE:0.050922  CVRMSE:16.973996  test_loss:
1.2965243498911149e-05
R2:0.141991  MSE:0.003418  KL:0.040742  MAE:0.007740  RMSE:0.058466  CVRMSE:14.616575  R2:0.117926  MSE:0.003514  KL:0.049616  MAE:0.007612  RMSE:0.059281  CVRMSE:14.820133  R2:0.142846  MSE:0.003415  KL:0.047662  MAE:0.007549  RMSE:0.058437  CVRMSE:14.609293  R2:0.135270  MSE:0.003445  KL:0.049791  MAE:0.007814  RMSE:0.058695  CVRMSE:14.673706  R2:0.145685  MSE:0.003404  KL:0.049790  MAE:0.008043  RMSE:0.05

R2:0.146985  MSE:0.003398  KL:0.048607  MAE:0.007520  RMSE:0.058296  CVRMSE:14.573975  R2:0.109651  MSE:0.003547  KL:0.044397  MAE:0.007883  RMSE:0.059558  CVRMSE:14.889488  R2:0.168233  MSE:0.003314  KL:0.038683  MAE:0.007556  RMSE:0.057565  CVRMSE:14.391314  R2:0.135788  MSE:0.003443  KL:0.045608  MAE:0.007205  RMSE:0.058677  CVRMSE:14.669313  R2:0.152416  MSE:0.003377  KL:0.044492  MAE:0.007668  RMSE:0.058110  CVRMSE:14.527507  train_loss:
0.0033767756453016774
R2:0.126819  MSE:0.002612  KL:0.051163  MAE:0.006415  RMSE:0.051105  CVRMSE:17.034883  test_loss:
1.3058426309726201e-05
R2:0.131322  MSE:0.003461  KL:0.043789  MAE:0.007863  RMSE:0.058829  CVRMSE:14.707165  R2:0.155142  MSE:0.003366  KL:0.044935  MAE:0.008027  RMSE:0.058017  CVRMSE:14.504127  R2:0.153748  MSE:0.003371  KL:0.042084  MAE:0.007421  RMSE:0.058064  CVRMSE:14.516086  R2:0.142560  MSE:0.003416  KL:0.038909  MAE:0.007593  RMSE:0.058447  CVRMSE:14.611729  R2:0.124473  MSE:0.003488  KL:0.046850  MAE:0.007783  RMSE:0.0

R2:0.138705  MSE:0.003431  KL:0.044433  MAE:0.007691  RMSE:0.058578  CVRMSE:14.644540  R2:0.127856  MSE:0.003475  KL:0.050767  MAE:0.007372  RMSE:0.058946  CVRMSE:14.736481  R2:0.157713  MSE:0.003356  KL:0.053395  MAE:0.007319  RMSE:0.057928  CVRMSE:14.482040  R2:0.162874  MSE:0.003335  KL:0.044967  MAE:0.007359  RMSE:0.057750  CVRMSE:14.437603  R2:0.140877  MSE:0.003423  KL:0.058282  MAE:0.007965  RMSE:0.058504  CVRMSE:14.626061  train_loss:
0.0034227468218887224
R2:0.128953  MSE:0.002605  KL:0.048462  MAE:0.006128  RMSE:0.051042  CVRMSE:17.014049  test_loss:
1.3026504399022087e-05
R2:0.125609  MSE:0.003484  KL:0.038985  MAE:0.007901  RMSE:0.059022  CVRMSE:14.755451  R2:0.177646  MSE:0.003276  KL:0.045005  MAE:0.007360  RMSE:0.057239  CVRMSE:14.309650  R2:0.153698  MSE:0.003372  KL:0.050032  MAE:0.007760  RMSE:0.058066  CVRMSE:14.516515  R2:0.148003  MSE:0.003394  KL:0.041478  MAE:0.008187  RMSE:0.058261  CVRMSE:14.565279  R2:0.147257  MSE:0.003397  KL:0.045587  MAE:0.008152  RMSE:0.0

R2:0.170207  MSE:0.003306  KL:0.047753  MAE:0.007642  RMSE:0.057497  CVRMSE:14.374233  R2:0.120757  MSE:0.003503  KL:0.037131  MAE:0.008250  RMSE:0.059185  CVRMSE:14.796338  R2:0.157996  MSE:0.003355  KL:0.041384  MAE:0.007597  RMSE:0.057918  CVRMSE:14.479610  R2:0.158592  MSE:0.003352  KL:0.031110  MAE:0.007034  RMSE:0.057898  CVRMSE:14.474479  R2:0.118091  MSE:0.003514  KL:0.031975  MAE:0.007782  RMSE:0.059275  CVRMSE:14.818746  train_loss:
0.0035135235113557426
R2:0.136429  MSE:0.002583  KL:0.039748  MAE:0.006197  RMSE:0.050823  CVRMSE:16.940876  test_loss:
1.2914696199004539e-05
R2:0.170486  MSE:0.003305  KL:0.042813  MAE:0.007509  RMSE:0.057487  CVRMSE:14.371811  R2:0.128985  MSE:0.003470  KL:0.029804  MAE:0.007579  RMSE:0.058908  CVRMSE:14.726938  R2:0.142961  MSE:0.003414  KL:0.051116  MAE:0.008043  RMSE:0.058433  CVRMSE:14.608306  R2:0.161812  MSE:0.003339  KL:0.045338  MAE:0.007802  RMSE:0.057787  CVRMSE:14.446759  R2:0.140981  MSE:0.003422  KL:0.051981  MAE:0.007508  RMSE:0.0

R2:0.142337  MSE:0.003417  KL:0.030521  MAE:0.007661  RMSE:0.058454  CVRMSE:14.613623  R2:0.170386  MSE:0.003305  KL:0.046601  MAE:0.007682  RMSE:0.057491  CVRMSE:14.372676  R2:0.202007  MSE:0.003179  KL:0.038242  MAE:0.007617  RMSE:0.056384  CVRMSE:14.096112  R2:0.137862  MSE:0.003435  KL:0.035926  MAE:0.007654  RMSE:0.058607  CVRMSE:14.651701  R2:0.148482  MSE:0.003392  KL:0.032073  MAE:0.007427  RMSE:0.058245  CVRMSE:14.561182  train_loss:
0.003392448023078032
R2:0.135926  MSE:0.002584  KL:0.046182  MAE:0.006139  RMSE:0.050837  CVRMSE:16.945813  test_loss:
1.2922224457724952e-05
R2:0.110328  MSE:0.003544  KL:0.039722  MAE:0.007685  RMSE:0.059535  CVRMSE:14.883831  R2:0.171899  MSE:0.003299  KL:0.033095  MAE:0.007586  RMSE:0.057438  CVRMSE:14.359568  R2:0.130651  MSE:0.003463  KL:0.042476  MAE:0.007784  RMSE:0.058851  CVRMSE:14.712846  R2:0.118161  MSE:0.003513  KL:0.045150  MAE:0.007635  RMSE:0.059273  CVRMSE:14.818159  R2:0.143773  MSE:0.003411  KL:0.039028  MAE:0.007154  RMSE:0.05

R2:0.179372  MSE:0.003269  KL:0.039051  MAE:0.007800  RMSE:0.057178  CVRMSE:14.294623  R2:0.136853  MSE:0.003439  KL:0.043327  MAE:0.008536  RMSE:0.058641  CVRMSE:14.660271  R2:0.116093  MSE:0.003521  KL:0.028634  MAE:0.007751  RMSE:0.059342  CVRMSE:14.835528  R2:0.156970  MSE:0.003359  KL:0.036599  MAE:0.007403  RMSE:0.057954  CVRMSE:14.488425  R2:0.206996  MSE:0.003159  KL:0.030131  MAE:0.007655  RMSE:0.056208  CVRMSE:14.051974  train_loss:
0.0031593279163644183
R2:0.134536  MSE:0.002589  KL:0.044585  MAE:0.006426  RMSE:0.050878  CVRMSE:16.959437  test_loss:
1.2943011533934623e-05
R2:0.154821  MSE:0.003367  KL:0.034027  MAE:0.008603  RMSE:0.058028  CVRMSE:14.506877  R2:0.148255  MSE:0.003393  KL:0.027152  MAE:0.009059  RMSE:0.058252  CVRMSE:14.563121  R2:0.167488  MSE:0.003317  KL:0.041394  MAE:0.007655  RMSE:0.057591  CVRMSE:14.397758  R2:0.158587  MSE:0.003352  KL:0.038432  MAE:0.007700  RMSE:0.057898  CVRMSE:14.474525  R2:0.166059  MSE:0.003322  KL:0.031851  MAE:0.007571  RMSE:0.0

R2:0.168293  MSE:0.003314  KL:0.034224  MAE:0.007623  RMSE:0.057563  CVRMSE:14.390801  R2:0.174527  MSE:0.003289  KL:0.038189  MAE:0.007750  RMSE:0.057347  CVRMSE:14.336761  R2:0.127201  MSE:0.003477  KL:0.042916  MAE:0.008542  RMSE:0.058968  CVRMSE:14.742015  R2:0.127344  MSE:0.003477  KL:0.038737  MAE:0.008374  RMSE:0.058963  CVRMSE:14.740809  R2:0.158483  MSE:0.003353  KL:0.034772  MAE:0.007264  RMSE:0.057902  CVRMSE:14.475417  train_loss:
0.003352603541861754
R2:0.135602  MSE:0.002585  KL:0.051648  MAE:0.006519  RMSE:0.050847  CVRMSE:16.948986  test_loss:
1.292706557433121e-05
R2:0.141418  MSE:0.003421  KL:0.033272  MAE:0.007638  RMSE:0.058486  CVRMSE:14.621450  R2:0.168357  MSE:0.003313  KL:0.045098  MAE:0.007186  RMSE:0.057561  CVRMSE:14.390242  R2:0.140773  MSE:0.003423  KL:0.034693  MAE:0.007661  RMSE:0.058508  CVRMSE:14.626943  R2:0.153581  MSE:0.003372  KL:0.039983  MAE:0.007475  RMSE:0.058070  CVRMSE:14.517519  R2:0.182115  MSE:0.003258  KL:0.036827  MAE:0.007218  RMSE:0.057

R2:0.139650  MSE:0.003428  KL:0.038593  MAE:0.007902  RMSE:0.058546  CVRMSE:14.636499  R2:0.140605  MSE:0.003424  KL:0.039759  MAE:0.008062  RMSE:0.058513  CVRMSE:14.628372  R2:0.174440  MSE:0.003289  KL:0.039842  MAE:0.007618  RMSE:0.057350  CVRMSE:14.337518  R2:0.188980  MSE:0.003231  KL:0.034064  MAE:0.007383  RMSE:0.056843  CVRMSE:14.210703  R2:0.152702  MSE:0.003376  KL:0.032525  MAE:0.007884  RMSE:0.058100  CVRMSE:14.525052  train_loss:
0.003375634195981547
R2:0.132101  MSE:0.002596  KL:0.053993  MAE:0.006317  RMSE:0.050950  CVRMSE:16.983277  test_loss:
1.2979427468962968e-05
R2:0.146851  MSE:0.003399  KL:0.049273  MAE:0.008219  RMSE:0.058300  CVRMSE:14.575123  R2:0.164176  MSE:0.003330  KL:0.039833  MAE:0.008045  RMSE:0.057705  CVRMSE:14.426369  R2:0.122633  MSE:0.003495  KL:0.040135  MAE:0.007330  RMSE:0.059122  CVRMSE:14.780542  R2:0.158305  MSE:0.003353  KL:0.037073  MAE:0.007511  RMSE:0.057908  CVRMSE:14.476950  R2:0.152373  MSE:0.003377  KL:0.041963  MAE:0.007646  RMSE:0.05

R2:0.133649  MSE:0.003452  KL:0.028721  MAE:0.007981  RMSE:0.058750  CVRMSE:14.687457  R2:0.186448  MSE:0.003241  KL:0.030451  MAE:0.007676  RMSE:0.056931  CVRMSE:14.232869  R2:0.177169  MSE:0.003278  KL:0.032962  MAE:0.007350  RMSE:0.057255  CVRMSE:14.313804  R2:0.166624  MSE:0.003320  KL:0.032901  MAE:0.007471  RMSE:0.057621  CVRMSE:14.405226  R2:0.165294  MSE:0.003325  KL:0.035336  MAE:0.007413  RMSE:0.057667  CVRMSE:14.416718  train_loss:
0.003325468364346307
R2:0.141308  MSE:0.002568  KL:0.026789  MAE:0.005474  RMSE:0.050679  CVRMSE:16.892959  test_loss:
1.2841743271565065e-05
R2:0.170621  MSE:0.003304  KL:0.032118  MAE:0.007338  RMSE:0.057483  CVRMSE:14.370640  R2:0.119625  MSE:0.003507  KL:0.030370  MAE:0.007885  RMSE:0.059223  CVRMSE:14.805853  R2:0.181710  MSE:0.003260  KL:0.043413  MAE:0.007981  RMSE:0.057097  CVRMSE:14.274253  R2:0.122487  MSE:0.003496  KL:0.037794  MAE:0.008031  RMSE:0.059127  CVRMSE:14.781772  R2:0.161079  MSE:0.003342  KL:0.036711  MAE:0.007145  RMSE:0.05

R2:0.171316  MSE:0.003301  KL:0.044312  MAE:0.007707  RMSE:0.057458  CVRMSE:14.364624  R2:0.178112  MSE:0.003274  KL:0.043002  MAE:0.007727  RMSE:0.057222  CVRMSE:14.305598  R2:0.166158  MSE:0.003322  KL:0.037402  MAE:0.007660  RMSE:0.057637  CVRMSE:14.409256  R2:0.184988  MSE:0.003247  KL:0.044055  MAE:0.007522  RMSE:0.056983  CVRMSE:14.245633  R2:0.171780  MSE:0.003300  KL:0.039879  MAE:0.007433  RMSE:0.057442  CVRMSE:14.360594  train_loss:
0.003299626688385615
R2:0.136870  MSE:0.002582  KL:0.043826  MAE:0.006191  RMSE:0.050810  CVRMSE:16.936559  test_loss:
1.2908115662867204e-05
R2:0.179990  MSE:0.003267  KL:0.030093  MAE:0.008085  RMSE:0.057157  CVRMSE:14.289242  R2:0.197523  MSE:0.003197  KL:0.042688  MAE:0.007946  RMSE:0.056543  CVRMSE:14.135655  R2:0.102107  MSE:0.003577  KL:0.042148  MAE:0.008289  RMSE:0.059810  CVRMSE:14.952439  R2:0.121618  MSE:0.003499  KL:0.026803  MAE:0.007633  RMSE:0.059156  CVRMSE:14.789084  R2:0.171434  MSE:0.003301  KL:0.033571  MAE:0.007125  RMSE:0.05

R2:0.123148  MSE:0.003493  KL:0.033669  MAE:0.008776  RMSE:0.059105  CVRMSE:14.776200  R2:0.159783  MSE:0.003347  KL:0.036737  MAE:0.008102  RMSE:0.057857  CVRMSE:14.464235  R2:0.162269  MSE:0.003338  KL:0.024605  MAE:0.007042  RMSE:0.057771  CVRMSE:14.442817  R2:0.159920  MSE:0.003347  KL:0.030078  MAE:0.007445  RMSE:0.057852  CVRMSE:14.463054  R2:0.135095  MSE:0.003446  KL:0.044134  MAE:0.007685  RMSE:0.058701  CVRMSE:14.675195  train_loss:
0.0034457819696399384
R2:0.132516  MSE:0.002595  KL:0.040371  MAE:0.005828  RMSE:0.050938  CVRMSE:16.979215  test_loss:
1.2973218035767786e-05
R2:0.159410  MSE:0.003349  KL:0.041924  MAE:0.007681  RMSE:0.057870  CVRMSE:14.467442  R2:0.150779  MSE:0.003383  KL:0.035140  MAE:0.007977  RMSE:0.058166  CVRMSE:14.541529  R2:0.143912  MSE:0.003411  KL:0.035135  MAE:0.007432  RMSE:0.058401  CVRMSE:14.600200  R2:0.160759  MSE:0.003344  KL:0.036747  MAE:0.007431  RMSE:0.057823  CVRMSE:14.455829  R2:0.167259  MSE:0.003318  KL:0.029912  MAE:0.007426  RMSE:0.0

R2:0.198980  MSE:0.003191  KL:0.035662  MAE:0.007607  RMSE:0.056491  CVRMSE:14.122820  R2:0.177233  MSE:0.003278  KL:0.036862  MAE:0.007782  RMSE:0.057253  CVRMSE:14.313244  R2:0.099864  MSE:0.003586  KL:0.026117  MAE:0.007905  RMSE:0.059884  CVRMSE:14.971100  R2:0.161458  MSE:0.003341  KL:0.034653  MAE:0.007295  RMSE:0.057799  CVRMSE:14.449812  R2:0.184859  MSE:0.003248  KL:0.035673  MAE:0.007403  RMSE:0.056987  CVRMSE:14.246755  train_loss:
0.0032475206855451686
R2:0.140427  MSE:0.002571  KL:0.039351  MAE:0.006387  RMSE:0.050705  CVRMSE:16.901625  test_loss:
1.2854923144914209e-05
R2:0.165121  MSE:0.003326  KL:0.039801  MAE:0.007840  RMSE:0.057673  CVRMSE:14.418213  R2:0.186939  MSE:0.003239  KL:0.028024  MAE:0.008082  RMSE:0.056914  CVRMSE:14.228569  R2:0.136065  MSE:0.003442  KL:0.027854  MAE:0.007760  RMSE:0.058668  CVRMSE:14.666960  R2:0.150327  MSE:0.003385  KL:0.035412  MAE:0.007525  RMSE:0.058182  CVRMSE:14.545399  R2:0.167053  MSE:0.003318  KL:0.031776  MAE:0.007621  RMSE:0.0

R2:0.168167  MSE:0.003314  KL:0.036545  MAE:0.007709  RMSE:0.057568  CVRMSE:14.391887  R2:0.160380  MSE:0.003345  KL:0.042083  MAE:0.007700  RMSE:0.057836  CVRMSE:14.459093  R2:0.176550  MSE:0.003281  KL:0.038445  MAE:0.007503  RMSE:0.057277  CVRMSE:14.319180  R2:0.182287  MSE:0.003258  KL:0.041028  MAE:0.007698  RMSE:0.057077  CVRMSE:14.269218  R2:0.145694  MSE:0.003404  KL:0.046482  MAE:0.008150  RMSE:0.058340  CVRMSE:14.584995  train_loss:
0.0034035534867143725
R2:0.134119  MSE:0.002590  KL:0.039224  MAE:0.006269  RMSE:0.050891  CVRMSE:16.963525  test_loss:
1.294925413094461e-05
R2:0.141634  MSE:0.003420  KL:0.041563  MAE:0.007975  RMSE:0.058478  CVRMSE:14.619617  R2:0.171227  MSE:0.003302  KL:0.029245  MAE:0.007131  RMSE:0.057462  CVRMSE:14.365394  R2:0.172758  MSE:0.003296  KL:0.038352  MAE:0.007511  RMSE:0.057408  CVRMSE:14.352121  R2:0.128168  MSE:0.003473  KL:0.037534  MAE:0.008102  RMSE:0.058935  CVRMSE:14.733848  R2:0.191470  MSE:0.003221  KL:0.031963  MAE:0.007748  RMSE:0.05

R2:0.135252  MSE:0.003445  KL:0.021190  MAE:0.007741  RMSE:0.058695  CVRMSE:14.673866  R2:0.167141  MSE:0.003318  KL:0.043435  MAE:0.007940  RMSE:0.057603  CVRMSE:14.400761  R2:0.174676  MSE:0.003288  KL:0.026814  MAE:0.007738  RMSE:0.057342  CVRMSE:14.335465  R2:0.207086  MSE:0.003159  KL:0.029621  MAE:0.007006  RMSE:0.056205  CVRMSE:14.051177  R2:0.167096  MSE:0.003318  KL:0.025504  MAE:0.007749  RMSE:0.057605  CVRMSE:14.401148  train_loss:
0.0033182886822032743
R2:0.138513  MSE:0.002577  KL:0.032726  MAE:0.006199  RMSE:0.050761  CVRMSE:16.920428  test_loss:
1.2883538147434592e-05
R2:0.211321  MSE:0.003142  KL:0.033357  MAE:0.007580  RMSE:0.056054  CVRMSE:14.013603  R2:0.205729  MSE:0.003164  KL:0.029133  MAE:0.007770  RMSE:0.056253  CVRMSE:14.063196  R2:0.195714  MSE:0.003204  KL:0.027363  MAE:0.007844  RMSE:0.056606  CVRMSE:14.151582  R2:0.120288  MSE:0.003505  KL:0.034126  MAE:0.008070  RMSE:0.059201  CVRMSE:14.800281  R2:0.176738  MSE:0.003280  KL:0.028369  MAE:0.007088  RMSE:0.0

In [None]:
# estimate area under the uplift curve (AUUC)
uplift=df_test.copy()
uplift = uplift.loc[:,~uplift.columns.duplicated()]

auuc=auuc_score(uplift, outcome_col='y', treatment_col='T')
gcn_auuc=pd.DataFrame(auuc[["y_hat","Random"]],columns=['auuc'])
gcn_auuc

Unnamed: 0,auuc
y_hat,0.777582
Random,0.489027


In [None]:
print('Feature mode on GCN:', feats_mode)
print('MSE:',MSE(uplift['y'],uplift['y_hat']))

Feature mode on GCN: causal+imp
MSE: 0.00259513338732548


In [None]:

if feats_mode == 'imp':
    result.loc['GCN (Struct+Important Weighting','AUUC']=auuc["y_hat"]
if feats_mode == 'causal+imp':
    result.loc['GCN (Struct+Causal+Important Weighting)','AUUC']=auuc["y_hat"]
if feats_mode == 'causal*imp':
    result.loc['GCN (Struct+(Causal*Important) Weighting','AUUC']=auuc["y_hat"]
result

Unnamed: 0,AUUC
S Learner(LR),0.497983
S Learner(XGB),0.875572
S Learner(LGBM),0.883033
GCN (Struct),0.501865
GCN (Struct+Feature),0.721959
GCN (Struct+Causal Weighting),0.732616
GAT (Struct),0.544286
GAT (Struct+Feature),0.84763
GAT (Struct+Causal Weighting),0.8807
GCN (Struct+Causal+Important Weighting),0.777582


# GAT 
GCN + Attention Layer

In [None]:
from uuid import RFC_4122
import torch
import math
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import GATConv

from torch_geometric.utils import add_self_loops,degree
from torch_geometric.datasets import Planetoid
import ssl
import torch.nn.functional as F


class Net(torch.nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.gat1=GATConv(in_channels=in_dim,out_channels=8,heads=8,dropout=0.6)
        self.gat2=GATConv(in_channels=64,out_channels=10,heads=1,dropout=0.6)
        self.f1 = torch.nn.Linear(140,32)
        self.f2 = torch.nn.Linear(32,1)
    def forward(self,data):
        x,edge_index=data.x, data.edge_index
        x=self.gat1(x,edge_index)
        x=self.gat2(x,edge_index)
        x=x.reshape(-1,140)
        x = self.f1(x)
        x = self.f2(x)
        return x

ssl._create_default_https_context = ssl._create_unverified_context
def train():
    model.train()
    loss_all = 0
    y_actual = []
    y_predicted = []
    loss_all=0
    for data in iter(train_loader):
        loss = 0
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        label = data.y.to(device)
        loss = crit(output, label)
        loss.backward()
        loss_all += data.num_graphs * loss.item()
        optimizer.step()
        y_actual +=(label).cpu().detach().ravel().tolist()
        y_predicted +=(output).cpu().detach().ravel().tolist()
    
    
    loss=loss_all/len(df_train)
    r2=R2(y_predicted, y_actual)
    mse = MSE(y_predicted, y_actual)
    kl=kl_divergence(y_predicted, y_actual)

    print("R2:%f" % (R2(y_predicted, y_actual)),end='  ')
    print("MSE:%f" % (MSE(y_predicted, y_actual)),end='  ')
    print("KL:%f" % (kl_divergence(y_predicted, y_actual)),end='  ')
    print("MAE:%f" % (MAE(y_predicted, y_actual)),end='  ')
    print("RMSE:%f" % (RMSE(y_predicted, y_actual)),end='  ')
    print("CVRMSE:%f" % (CVRMSE(y_predicted, y_actual)),end='  ')

    return loss,r2,mse,kl


def val():
    model.eval()
    y_actual = []
    y_predicted = []
    loss_all=0
    for data in iter(test_loader):
      loss = 0
      data = data.to(device)
      output = model(data)
      label = data.y.to(device)
      y_actual +=(label).cpu().detach().ravel().tolist()
      y_predicted +=(output).cpu().detach().ravel().tolist()
      loss = crit(output, label)
      loss_all += loss.item()

    
    loss = loss_all / len(df_test)
    r2=R2(y_predicted, y_actual)
    mse = MSE(y_predicted, y_actual)
    kl=kl_divergence(y_predicted, y_actual)

    print("R2:%f" % (R2(y_predicted, y_actual)),end='  ')
    print("MSE:%f" % (MSE(y_predicted, y_actual)),end='  ')
    print("KL:%f" % (kl_divergence(y_predicted, y_actual)),end='  ')
    print("MAE:%f" % (MAE(y_predicted, y_actual)),end='  ')
    print("RMSE:%f" % (RMSE(y_predicted, y_actual)),end='  ')
    print("CVRMSE:%f" % (CVRMSE(y_predicted, y_actual)),end='  ')

    return loss,r2, mse,kl



num_epochs = 2560
batch_size = 512
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.002, weight_decay=5e-4)
crit = F.mse_loss

for epoch in range(num_epochs):
    loss,r2,mse,kl=train()
    if epoch %5==0:      
        print('train_loss:')
        print(loss)

        loss,r2,mse,kl=val()
        print('test_loss:')
        print(loss)

y_predicted = []
for data in iter(test_loader):
    loss = 0
    data = data.to(device)
    output = model(data)
    label = data.y.to(device)
    y_predicted +=(output).cpu().detach().ravel().tolist()
df_test["y_hat"]=y_predicted

R2:-25.335778  MSE:0.104922  KL:2.934289  MAE:0.231009  RMSE:0.323916  CVRMSE:80.979063  train_loss:
0.10492173954844475
R2:-5.547298  MSE:0.019583  KL:5.300995  MAE:0.128627  RMSE:0.139939  CVRMSE:46.646385  test_loss:
9.791483543813229e-05
R2:-4.704450  MSE:0.022727  KL:2.975788  MAE:0.112224  RMSE:0.150753  CVRMSE:37.688300  R2:-2.593554  MSE:0.014317  KL:3.096289  MAE:0.077948  RMSE:0.119652  CVRMSE:29.913121  R2:-2.125688  MSE:0.012453  KL:3.174985  MAE:0.072824  RMSE:0.111592  CVRMSE:27.897962  R2:-1.425888  MSE:0.009665  KL:3.316360  MAE:0.064801  RMSE:0.098309  CVRMSE:24.577348  R2:-1.203343  MSE:0.008778  KL:3.245919  MAE:0.051748  RMSE:0.093692  CVRMSE:23.422904  train_loss:
0.00877811862155795
R2:-0.280573  MSE:0.003830  KL:4.146775  MAE:0.029076  RMSE:0.061889  CVRMSE:20.629518  test_loss:
1.915096596349031e-05
R2:-0.535084  MSE:0.006116  KL:3.096771  MAE:0.042126  RMSE:0.078203  CVRMSE:19.550851  R2:-0.519482  MSE:0.006054  KL:3.467679  MAE:0.042579  RMSE:0.077805  CVRMSE:

R2:0.073197  MSE:0.003692  KL:1.906742  MAE:0.012750  RMSE:0.060765  CVRMSE:15.191247  R2:0.055547  MSE:0.003763  KL:2.069550  MAE:0.013575  RMSE:0.061341  CVRMSE:15.335211  R2:0.064832  MSE:0.003726  KL:2.190748  MAE:0.013246  RMSE:0.061039  CVRMSE:15.259647  R2:0.079990  MSE:0.003665  KL:1.953211  MAE:0.013357  RMSE:0.060542  CVRMSE:15.135473  R2:0.030448  MSE:0.003863  KL:2.637734  MAE:0.013480  RMSE:0.062151  CVRMSE:15.537648  train_loss:
0.003862696239957586
R2:0.060736  MSE:0.002809  KL:1.772939  MAE:0.008871  RMSE:0.053003  CVRMSE:17.667730  test_loss:
1.4046691547264344e-05
R2:0.078295  MSE:0.003672  KL:1.731978  MAE:0.012444  RMSE:0.060598  CVRMSE:15.149407  R2:0.053030  MSE:0.003773  KL:2.020650  MAE:0.012861  RMSE:0.061423  CVRMSE:15.355638  R2:0.073174  MSE:0.003692  KL:1.819657  MAE:0.011708  RMSE:0.060766  CVRMSE:15.191431  R2:0.074954  MSE:0.003685  KL:2.143875  MAE:0.011791  RMSE:0.060707  CVRMSE:15.176835  R2:0.074131  MSE:0.003689  KL:1.791869  MAE:0.012531  RMSE:0.06

R2:0.063399  MSE:0.003731  KL:1.745479  MAE:0.011737  RMSE:0.061085  CVRMSE:15.271333  R2:0.051365  MSE:0.003779  KL:2.214662  MAE:0.012121  RMSE:0.061477  CVRMSE:15.369132  R2:0.085828  MSE:0.003642  KL:2.053928  MAE:0.011799  RMSE:0.060349  CVRMSE:15.087372  R2:0.045529  MSE:0.003803  KL:2.555929  MAE:0.012466  RMSE:0.061665  CVRMSE:15.416329  R2:0.081976  MSE:0.003657  KL:2.604995  MAE:0.012314  RMSE:0.060476  CVRMSE:15.119123  train_loss:
0.0036574056051904336
R2:0.063152  MSE:0.002802  KL:1.359929  MAE:0.010644  RMSE:0.052935  CVRMSE:17.644990  test_loss:
1.4010555823915637e-05
R2:0.076547  MSE:0.003679  KL:1.738804  MAE:0.012550  RMSE:0.060655  CVRMSE:15.163766  R2:0.091242  MSE:0.003620  KL:2.227534  MAE:0.012932  RMSE:0.060171  CVRMSE:15.042633  R2:0.075062  MSE:0.003685  KL:1.859187  MAE:0.013302  RMSE:0.060704  CVRMSE:15.175952  R2:0.048309  MSE:0.003792  KL:2.102163  MAE:0.013757  RMSE:0.061575  CVRMSE:15.393868  R2:0.056083  MSE:0.003761  KL:1.879141  MAE:0.012742  RMSE:0.0

R2:0.061536  MSE:0.003739  KL:2.009654  MAE:0.011585  RMSE:0.061146  CVRMSE:15.286513  R2:0.065561  MSE:0.003723  KL:2.253682  MAE:0.011012  RMSE:0.061015  CVRMSE:15.253694  R2:0.074196  MSE:0.003688  KL:1.271068  MAE:0.010982  RMSE:0.060732  CVRMSE:15.183054  R2:0.088513  MSE:0.003631  KL:1.318648  MAE:0.011258  RMSE:0.060261  CVRMSE:15.065198  R2:0.076339  MSE:0.003680  KL:1.663432  MAE:0.011487  RMSE:0.060662  CVRMSE:15.165474  train_loss:
0.0036798652072320692
R2:0.071642  MSE:0.002777  KL:2.256644  MAE:0.008974  RMSE:0.052695  CVRMSE:17.564864  test_loss:
1.3883600535336881e-05
R2:0.071019  MSE:0.003701  KL:2.397695  MAE:0.011505  RMSE:0.060836  CVRMSE:15.209084  R2:0.082415  MSE:0.003656  KL:2.140418  MAE:0.011850  RMSE:0.060462  CVRMSE:15.115508  R2:0.075781  MSE:0.003682  KL:1.902803  MAE:0.011849  RMSE:0.060680  CVRMSE:15.170053  R2:0.074117  MSE:0.003689  KL:1.740508  MAE:0.011571  RMSE:0.060735  CVRMSE:15.183704  R2:0.082990  MSE:0.003653  KL:1.750887  MAE:0.011736  RMSE:0.0

R2:0.058339  MSE:0.003752  KL:1.417276  MAE:0.011152  RMSE:0.061250  CVRMSE:15.312531  R2:0.051720  MSE:0.003778  KL:1.253885  MAE:0.010617  RMSE:0.061465  CVRMSE:15.366249  R2:0.075091  MSE:0.003685  KL:1.510083  MAE:0.010690  RMSE:0.060703  CVRMSE:15.175716  R2:0.052969  MSE:0.003773  KL:1.330129  MAE:0.011195  RMSE:0.061425  CVRMSE:15.356132  R2:0.079825  MSE:0.003666  KL:1.072021  MAE:0.010448  RMSE:0.060547  CVRMSE:15.136832  train_loss:
0.0036659790755948054
R2:0.071316  MSE:0.002778  KL:0.376055  MAE:0.010049  RMSE:0.052704  CVRMSE:17.567946  test_loss:
1.3888472414691933e-05
R2:0.065618  MSE:0.003723  KL:1.472733  MAE:0.011170  RMSE:0.061013  CVRMSE:15.253237  R2:0.059306  MSE:0.003748  KL:1.575106  MAE:0.011285  RMSE:0.061219  CVRMSE:15.304669  R2:0.066564  MSE:0.003719  KL:0.757120  MAE:0.010005  RMSE:0.060982  CVRMSE:15.245511  R2:0.077082  MSE:0.003677  KL:1.437039  MAE:0.011099  RMSE:0.060637  CVRMSE:15.159369  R2:0.056082  MSE:0.003761  KL:1.452230  MAE:0.011353  RMSE:0.0

R2:0.073286  MSE:0.003692  KL:1.600850  MAE:0.010918  RMSE:0.060762  CVRMSE:15.190520  R2:0.086032  MSE:0.003641  KL:1.341743  MAE:0.010864  RMSE:0.060343  CVRMSE:15.085691  R2:0.064917  MSE:0.003725  KL:0.965657  MAE:0.010357  RMSE:0.061036  CVRMSE:15.258956  R2:0.061425  MSE:0.003739  KL:1.960768  MAE:0.011408  RMSE:0.061150  CVRMSE:15.287417  R2:0.068915  MSE:0.003709  KL:1.110657  MAE:0.010695  RMSE:0.060905  CVRMSE:15.226296  train_loss:
0.00370944227615837
R2:0.069861  MSE:0.002782  KL:1.010422  MAE:0.010387  RMSE:0.052745  CVRMSE:17.581704  test_loss:
1.3910234323702752e-05
R2:0.082556  MSE:0.003655  KL:1.441905  MAE:0.010823  RMSE:0.060457  CVRMSE:15.114350  R2:0.070104  MSE:0.003705  KL:2.005222  MAE:0.010891  RMSE:0.060866  CVRMSE:15.216573  R2:0.060571  MSE:0.003743  KL:1.030157  MAE:0.010547  RMSE:0.061177  CVRMSE:15.294370  R2:0.073104  MSE:0.003693  KL:1.487385  MAE:0.011112  RMSE:0.060768  CVRMSE:15.192005  R2:0.079269  MSE:0.003668  KL:1.853018  MAE:0.010442  RMSE:0.060

R2:0.064197  MSE:0.003728  KL:0.900390  MAE:0.010466  RMSE:0.061059  CVRMSE:15.264826  R2:0.080421  MSE:0.003664  KL:1.208596  MAE:0.010434  RMSE:0.060528  CVRMSE:15.131925  R2:0.074193  MSE:0.003688  KL:1.212466  MAE:0.010714  RMSE:0.060732  CVRMSE:15.183084  R2:0.066289  MSE:0.003720  KL:1.383042  MAE:0.010274  RMSE:0.060991  CVRMSE:15.247756  R2:0.067404  MSE:0.003715  KL:1.106068  MAE:0.010541  RMSE:0.060955  CVRMSE:15.238648  train_loss:
0.003715462374384515
R2:0.071339  MSE:0.002778  KL:0.558796  MAE:0.010088  RMSE:0.052703  CVRMSE:17.567729  test_loss:
1.38881287566619e-05
R2:0.060517  MSE:0.003743  KL:1.380997  MAE:0.010724  RMSE:0.061179  CVRMSE:15.294814  R2:0.071157  MSE:0.003701  KL:1.749723  MAE:0.009962  RMSE:0.060832  CVRMSE:15.207956  R2:0.052818  MSE:0.003774  KL:1.325961  MAE:0.010681  RMSE:0.061429  CVRMSE:15.357352  R2:0.062819  MSE:0.003734  KL:1.196113  MAE:0.010269  RMSE:0.061104  CVRMSE:15.276060  R2:0.081276  MSE:0.003660  KL:0.891600  MAE:0.010112  RMSE:0.0605

R2:0.068921  MSE:0.003709  KL:1.213961  MAE:0.010926  RMSE:0.060905  CVRMSE:15.226247  R2:0.061900  MSE:0.003737  KL:0.947767  MAE:0.010375  RMSE:0.061134  CVRMSE:15.283551  R2:0.074222  MSE:0.003688  KL:0.992033  MAE:0.010408  RMSE:0.060731  CVRMSE:15.182844  R2:0.060598  MSE:0.003743  KL:1.112001  MAE:0.010167  RMSE:0.061177  CVRMSE:15.294156  R2:0.067994  MSE:0.003713  KL:1.238153  MAE:0.010475  RMSE:0.060935  CVRMSE:15.233827  train_loss:
0.0037131118384422734
R2:0.070826  MSE:0.002779  KL:1.847758  MAE:0.009618  RMSE:0.052718  CVRMSE:17.572577  test_loss:
1.3895796524593607e-05
R2:0.082176  MSE:0.003657  KL:1.677656  MAE:0.010320  RMSE:0.060470  CVRMSE:15.117477  R2:0.072321  MSE:0.003696  KL:1.000589  MAE:0.010289  RMSE:0.060794  CVRMSE:15.198421  R2:0.061377  MSE:0.003739  KL:1.135105  MAE:0.010285  RMSE:0.061151  CVRMSE:15.287806  R2:0.094347  MSE:0.003608  KL:1.300014  MAE:0.010019  RMSE:0.060068  CVRMSE:15.016909  R2:0.065846  MSE:0.003722  KL:1.902910  MAE:0.011332  RMSE:0.0

R2:0.064261  MSE:0.003728  KL:0.931239  MAE:0.010087  RMSE:0.061057  CVRMSE:15.264306  R2:0.091149  MSE:0.003621  KL:0.776253  MAE:0.009804  RMSE:0.060174  CVRMSE:15.043404  R2:0.066512  MSE:0.003719  KL:0.946012  MAE:0.010174  RMSE:0.060984  CVRMSE:15.245935  R2:0.091200  MSE:0.003621  KL:0.649132  MAE:0.009751  RMSE:0.060172  CVRMSE:15.042982  R2:0.070150  MSE:0.003705  KL:1.553845  MAE:0.011197  RMSE:0.060865  CVRMSE:15.216200  train_loss:
0.003704523667693138
R2:0.076839  MSE:0.002761  KL:1.883140  MAE:0.007740  RMSE:0.052547  CVRMSE:17.515629  test_loss:
1.3805875059915707e-05
R2:0.085045  MSE:0.003645  KL:0.745116  MAE:0.009810  RMSE:0.060375  CVRMSE:15.093829  R2:0.066406  MSE:0.003719  KL:1.629925  MAE:0.011274  RMSE:0.060987  CVRMSE:15.246803  R2:0.052783  MSE:0.003774  KL:1.499640  MAE:0.010083  RMSE:0.061431  CVRMSE:15.357639  R2:0.062727  MSE:0.003734  KL:1.532112  MAE:0.010497  RMSE:0.061107  CVRMSE:15.276810  R2:0.073525  MSE:0.003691  KL:1.346048  MAE:0.010248  RMSE:0.06

R2:0.071212  MSE:0.003700  KL:1.331899  MAE:0.009867  RMSE:0.060830  CVRMSE:15.207506  R2:0.055258  MSE:0.003764  KL:1.145540  MAE:0.010554  RMSE:0.061350  CVRMSE:15.337561  R2:0.066350  MSE:0.003720  KL:0.908672  MAE:0.009856  RMSE:0.060989  CVRMSE:15.247260  R2:0.097173  MSE:0.003597  KL:1.230511  MAE:0.010507  RMSE:0.059974  CVRMSE:14.993459  R2:0.069323  MSE:0.003708  KL:1.199846  MAE:0.010410  RMSE:0.060892  CVRMSE:15.222961  train_loss:
0.003707816798123531
R2:0.072446  MSE:0.002774  KL:1.802711  MAE:0.009290  RMSE:0.052672  CVRMSE:17.557249  test_loss:
1.3871565868612379e-05
R2:0.064019  MSE:0.003729  KL:1.421394  MAE:0.010337  RMSE:0.061065  CVRMSE:15.266283  R2:0.081833  MSE:0.003658  KL:1.860595  MAE:0.009936  RMSE:0.060481  CVRMSE:15.120302  R2:0.073161  MSE:0.003693  KL:1.334106  MAE:0.010291  RMSE:0.060766  CVRMSE:15.191540  R2:0.073309  MSE:0.003692  KL:0.786284  MAE:0.010129  RMSE:0.060761  CVRMSE:15.190326  R2:0.087443  MSE:0.003636  KL:1.368239  MAE:0.009957  RMSE:0.06

R2:0.080106  MSE:0.003665  KL:1.364617  MAE:0.010062  RMSE:0.060538  CVRMSE:15.134516  R2:0.066716  MSE:0.003718  KL:1.071723  MAE:0.010070  RMSE:0.060977  CVRMSE:15.244265  R2:0.059987  MSE:0.003745  KL:1.701927  MAE:0.010008  RMSE:0.061196  CVRMSE:15.299122  R2:0.064801  MSE:0.003726  KL:0.912280  MAE:0.009900  RMSE:0.061040  CVRMSE:15.259897  R2:0.065074  MSE:0.003725  KL:0.712110  MAE:0.009998  RMSE:0.061031  CVRMSE:15.257676  train_loss:
0.0037247472049784847
R2:0.073232  MSE:0.002772  KL:1.945290  MAE:0.008558  RMSE:0.052649  CVRMSE:17.549815  test_loss:
1.3859820828656666e-05
R2:0.075747  MSE:0.003682  KL:1.285083  MAE:0.009933  RMSE:0.060681  CVRMSE:15.170336  R2:0.064837  MSE:0.003726  KL:1.364091  MAE:0.009878  RMSE:0.061038  CVRMSE:15.259605  R2:0.090701  MSE:0.003623  KL:1.158952  MAE:0.009747  RMSE:0.060188  CVRMSE:15.047109  R2:0.065661  MSE:0.003722  KL:1.160685  MAE:0.010601  RMSE:0.061012  CVRMSE:15.252879  R2:0.075402  MSE:0.003684  KL:1.064503  MAE:0.009548  RMSE:0.0

R2:0.080503  MSE:0.003663  KL:0.908751  MAE:0.010119  RMSE:0.060525  CVRMSE:15.131248  R2:0.069009  MSE:0.003709  KL:0.800130  MAE:0.010101  RMSE:0.060902  CVRMSE:15.225532  R2:0.073812  MSE:0.003690  KL:1.231965  MAE:0.010066  RMSE:0.060745  CVRMSE:15.186206  R2:0.079193  MSE:0.003668  KL:0.987022  MAE:0.010342  RMSE:0.060568  CVRMSE:15.142025  R2:0.076303  MSE:0.003680  KL:0.912272  MAE:0.009921  RMSE:0.060663  CVRMSE:15.165771  train_loss:
0.003680009082017932
R2:0.073968  MSE:0.002770  KL:1.893397  MAE:0.008685  RMSE:0.052629  CVRMSE:17.542842  test_loss:
1.3848809379851446e-05
R2:0.062033  MSE:0.003737  KL:0.987296  MAE:0.010155  RMSE:0.061130  CVRMSE:15.282465  R2:0.069354  MSE:0.003708  KL:1.279499  MAE:0.010197  RMSE:0.060891  CVRMSE:15.222710  R2:0.089830  MSE:0.003626  KL:1.413243  MAE:0.009770  RMSE:0.060217  CVRMSE:15.054309  R2:0.070996  MSE:0.003701  KL:1.042860  MAE:0.010342  RMSE:0.060837  CVRMSE:15.209272  R2:0.080739  MSE:0.003662  KL:1.228366  MAE:0.009684  RMSE:0.06

R2:0.073167  MSE:0.003693  KL:0.893448  MAE:0.009947  RMSE:0.060766  CVRMSE:15.191493  R2:0.073958  MSE:0.003689  KL:0.763328  MAE:0.009596  RMSE:0.060740  CVRMSE:15.185010  R2:0.070683  MSE:0.003702  KL:0.776542  MAE:0.009553  RMSE:0.060847  CVRMSE:15.211838  R2:0.077795  MSE:0.003674  KL:0.756692  MAE:0.009539  RMSE:0.060614  CVRMSE:15.153514  R2:0.072808  MSE:0.003694  KL:1.246101  MAE:0.009980  RMSE:0.060778  CVRMSE:15.194434  train_loss:
0.003693933234899305
R2:0.074583  MSE:0.002768  KL:1.606246  MAE:0.008599  RMSE:0.052611  CVRMSE:17.537020  test_loss:
1.3839618593920022e-05
R2:0.076637  MSE:0.003679  KL:0.927367  MAE:0.009934  RMSE:0.060652  CVRMSE:15.163028  R2:0.082462  MSE:0.003655  KL:1.571876  MAE:0.009920  RMSE:0.060460  CVRMSE:15.115122  R2:0.069881  MSE:0.003706  KL:1.653509  MAE:0.009868  RMSE:0.060874  CVRMSE:15.218394  R2:0.074415  MSE:0.003688  KL:1.755649  MAE:0.010097  RMSE:0.060725  CVRMSE:15.181260  R2:0.079537  MSE:0.003667  KL:1.005279  MAE:0.009821  RMSE:0.06

R2:0.077906  MSE:0.003674  KL:1.006079  MAE:0.009817  RMSE:0.060610  CVRMSE:15.152601  R2:0.086925  MSE:0.003638  KL:1.573865  MAE:0.009804  RMSE:0.060313  CVRMSE:15.078314  R2:0.073426  MSE:0.003691  KL:1.500452  MAE:0.009926  RMSE:0.060757  CVRMSE:15.189371  R2:0.061525  MSE:0.003739  KL:0.979830  MAE:0.009967  RMSE:0.061146  CVRMSE:15.286603  R2:0.079322  MSE:0.003668  KL:0.843438  MAE:0.009831  RMSE:0.060564  CVRMSE:15.140963  train_loss:
0.0036679799799458125
R2:0.074301  MSE:0.002769  KL:1.156740  MAE:0.008683  RMSE:0.052619  CVRMSE:17.539688  test_loss:
1.3843829292454756e-05
R2:0.075278  MSE:0.003684  KL:1.430846  MAE:0.009737  RMSE:0.060697  CVRMSE:15.174182  R2:0.076763  MSE:0.003678  KL:0.903855  MAE:0.009760  RMSE:0.060648  CVRMSE:15.161993  R2:0.082808  MSE:0.003654  KL:1.069556  MAE:0.009831  RMSE:0.060449  CVRMSE:15.112275  R2:0.075294  MSE:0.003684  KL:1.395459  MAE:0.010100  RMSE:0.060696  CVRMSE:15.174052  R2:0.083287  MSE:0.003652  KL:1.216685  MAE:0.009766  RMSE:0.0

R2:0.063765  MSE:0.003730  KL:0.782298  MAE:0.010122  RMSE:0.061073  CVRMSE:15.268348  R2:0.062257  MSE:0.003736  KL:0.781200  MAE:0.009916  RMSE:0.061123  CVRMSE:15.280644  R2:0.064624  MSE:0.003727  KL:0.763163  MAE:0.010078  RMSE:0.061045  CVRMSE:15.261342  R2:0.068797  MSE:0.003710  KL:1.030970  MAE:0.009860  RMSE:0.060909  CVRMSE:15.227262  R2:0.068598  MSE:0.003711  KL:0.672212  MAE:0.009821  RMSE:0.060916  CVRMSE:15.228894  train_loss:
0.0037107076961547135
R2:0.074564  MSE:0.002768  KL:1.886772  MAE:0.008044  RMSE:0.052612  CVRMSE:17.537191  test_loss:
1.3839887615176849e-05
R2:0.083913  MSE:0.003650  KL:0.849532  MAE:0.009625  RMSE:0.060413  CVRMSE:15.103170  R2:0.073168  MSE:0.003692  KL:0.783470  MAE:0.009882  RMSE:0.060766  CVRMSE:15.191485  R2:0.075576  MSE:0.003683  KL:0.978354  MAE:0.009575  RMSE:0.060687  CVRMSE:15.171738  R2:0.063811  MSE:0.003730  KL:0.889150  MAE:0.009993  RMSE:0.061072  CVRMSE:15.267978  R2:0.080105  MSE:0.003665  KL:0.677076  MAE:0.009740  RMSE:0.0

R2:0.090240  MSE:0.003624  KL:1.253924  MAE:0.010237  RMSE:0.060204  CVRMSE:15.050925  R2:0.079972  MSE:0.003665  KL:0.917311  MAE:0.010182  RMSE:0.060542  CVRMSE:15.135619  R2:0.085001  MSE:0.003645  KL:0.760556  MAE:0.010113  RMSE:0.060377  CVRMSE:15.094199  R2:0.079469  MSE:0.003667  KL:0.786179  MAE:0.010263  RMSE:0.060559  CVRMSE:15.139757  R2:0.061392  MSE:0.003739  KL:1.086327  MAE:0.010401  RMSE:0.061151  CVRMSE:15.287690  train_loss:
0.0037394155981019138
R2:0.076317  MSE:0.002763  KL:2.257170  MAE:0.008091  RMSE:0.052562  CVRMSE:17.520579  test_loss:
1.381368181318976e-05
R2:0.079839  MSE:0.003666  KL:0.544097  MAE:0.009804  RMSE:0.060547  CVRMSE:15.136715  R2:0.063830  MSE:0.003730  KL:0.948446  MAE:0.009937  RMSE:0.061071  CVRMSE:15.267817  R2:0.082673  MSE:0.003655  KL:0.719740  MAE:0.009522  RMSE:0.060454  CVRMSE:15.113386  R2:0.076546  MSE:0.003679  KL:0.717191  MAE:0.009822  RMSE:0.060655  CVRMSE:15.163773  R2:0.083471  MSE:0.003651  KL:1.051766  MAE:0.009593  RMSE:0.06

R2:0.083574  MSE:0.003651  KL:0.964940  MAE:0.009667  RMSE:0.060424  CVRMSE:15.105961  R2:0.071642  MSE:0.003699  KL:1.630153  MAE:0.010010  RMSE:0.060816  CVRMSE:15.203985  R2:0.075052  MSE:0.003685  KL:0.817280  MAE:0.009846  RMSE:0.060704  CVRMSE:15.176035  R2:0.064646  MSE:0.003726  KL:0.937925  MAE:0.009834  RMSE:0.061045  CVRMSE:15.261168  R2:0.070646  MSE:0.003703  KL:0.881351  MAE:0.009645  RMSE:0.060849  CVRMSE:15.212135  train_loss:
0.0037025452838861383
R2:0.074512  MSE:0.002768  KL:2.154639  MAE:0.008455  RMSE:0.052613  CVRMSE:17.537691  test_loss:
1.3840676270774565e-05
R2:0.061828  MSE:0.003738  KL:1.217753  MAE:0.009740  RMSE:0.061137  CVRMSE:15.284138  R2:0.088718  MSE:0.003631  KL:1.056764  MAE:0.009609  RMSE:0.060254  CVRMSE:15.063510  R2:0.075839  MSE:0.003682  KL:1.210618  MAE:0.010122  RMSE:0.060678  CVRMSE:15.169578  R2:0.065122  MSE:0.003725  KL:0.664975  MAE:0.009661  RMSE:0.061029  CVRMSE:15.257280  R2:0.063343  MSE:0.003732  KL:1.399820  MAE:0.010075  RMSE:0.0

R2:0.087055  MSE:0.003637  KL:0.651928  MAE:0.010087  RMSE:0.060309  CVRMSE:15.077242  R2:0.078245  MSE:0.003672  KL:1.252471  MAE:0.009987  RMSE:0.060599  CVRMSE:15.149822  R2:0.091070  MSE:0.003621  KL:0.642763  MAE:0.009979  RMSE:0.060176  CVRMSE:15.044055  R2:0.066389  MSE:0.003720  KL:0.714182  MAE:0.010299  RMSE:0.060988  CVRMSE:15.246936  R2:0.064967  MSE:0.003725  KL:2.255125  MAE:0.009982  RMSE:0.061034  CVRMSE:15.258549  train_loss:
0.0037251724395900966
R2:0.074926  MSE:0.002767  KL:2.094606  MAE:0.008584  RMSE:0.052601  CVRMSE:17.533767  test_loss:
1.3834484867402352e-05
R2:0.087584  MSE:0.003635  KL:0.930175  MAE:0.009851  RMSE:0.060292  CVRMSE:15.072875  R2:0.058017  MSE:0.003753  KL:0.964339  MAE:0.009654  RMSE:0.061261  CVRMSE:15.315150  R2:0.051575  MSE:0.003779  KL:1.615798  MAE:0.009618  RMSE:0.061470  CVRMSE:15.367426  R2:0.070902  MSE:0.003702  KL:1.446932  MAE:0.009420  RMSE:0.060840  CVRMSE:15.210046  R2:0.063398  MSE:0.003731  KL:0.538741  MAE:0.009481  RMSE:0.0

R2:0.102284  MSE:0.003577  KL:0.633623  MAE:0.009518  RMSE:0.059804  CVRMSE:14.950964  R2:0.066099  MSE:0.003721  KL:0.692666  MAE:0.010165  RMSE:0.060997  CVRMSE:15.249305  R2:0.066202  MSE:0.003720  KL:0.648844  MAE:0.009657  RMSE:0.060994  CVRMSE:15.248463  R2:0.074088  MSE:0.003689  KL:2.270722  MAE:0.009893  RMSE:0.060736  CVRMSE:15.183939  R2:0.091034  MSE:0.003621  KL:1.816439  MAE:0.009286  RMSE:0.060177  CVRMSE:15.044350  train_loss:
0.003621319212834351
R2:0.073455  MSE:0.002771  KL:0.893976  MAE:0.009159  RMSE:0.052643  CVRMSE:17.547702  test_loss:
1.3856482211849653e-05
R2:0.057487  MSE:0.003755  KL:1.445147  MAE:0.010042  RMSE:0.061278  CVRMSE:15.319455  R2:0.069372  MSE:0.003708  KL:1.454641  MAE:0.009678  RMSE:0.060890  CVRMSE:15.222562  R2:0.096137  MSE:0.003601  KL:2.370307  MAE:0.009534  RMSE:0.060008  CVRMSE:15.002066  R2:0.073071  MSE:0.003693  KL:1.886965  MAE:0.009623  RMSE:0.060769  CVRMSE:15.192276  R2:0.067894  MSE:0.003714  KL:1.771696  MAE:0.009524  RMSE:0.06

R2:0.082653  MSE:0.003655  KL:0.669695  MAE:0.009467  RMSE:0.060454  CVRMSE:15.113546  R2:0.090853  MSE:0.003622  KL:1.114014  MAE:0.009762  RMSE:0.060183  CVRMSE:15.045846  R2:0.069488  MSE:0.003707  KL:0.956960  MAE:0.009899  RMSE:0.060886  CVRMSE:15.221609  R2:0.063721  MSE:0.003730  KL:0.724877  MAE:0.010272  RMSE:0.061075  CVRMSE:15.268705  R2:0.061314  MSE:0.003740  KL:0.907361  MAE:0.009832  RMSE:0.061153  CVRMSE:15.288323  train_loss:
0.0037397254709503613
R2:0.075367  MSE:0.002766  KL:2.766876  MAE:0.008128  RMSE:0.052589  CVRMSE:17.529589  test_loss:
1.3827891569235362e-05
R2:0.075945  MSE:0.003681  KL:1.738683  MAE:0.009852  RMSE:0.060675  CVRMSE:15.168712  R2:0.075568  MSE:0.003683  KL:1.171304  MAE:0.009703  RMSE:0.060687  CVRMSE:15.171805  R2:0.073578  MSE:0.003691  KL:0.882466  MAE:0.009841  RMSE:0.060752  CVRMSE:15.188123  R2:0.078521  MSE:0.003671  KL:0.612998  MAE:0.009759  RMSE:0.060590  CVRMSE:15.147552  R2:0.089243  MSE:0.003628  KL:1.383563  MAE:0.009661  RMSE:0.0

R2:0.083272  MSE:0.003652  KL:0.745198  MAE:0.010186  RMSE:0.060434  CVRMSE:15.108449  R2:0.090609  MSE:0.003623  KL:0.634632  MAE:0.009695  RMSE:0.060191  CVRMSE:15.047866  R2:0.073846  MSE:0.003690  KL:1.642647  MAE:0.010275  RMSE:0.060744  CVRMSE:15.185924  R2:0.071273  MSE:0.003700  KL:1.050172  MAE:0.010049  RMSE:0.060828  CVRMSE:15.207009  R2:0.087498  MSE:0.003635  KL:1.357899  MAE:0.010134  RMSE:0.060294  CVRMSE:15.073589  train_loss:
0.0036354096024297177
R2:0.076738  MSE:0.002761  KL:2.497467  MAE:0.008055  RMSE:0.052550  CVRMSE:17.516589  test_loss:
1.3807389797875657e-05
R2:0.089658  MSE:0.003627  KL:2.293119  MAE:0.009741  RMSE:0.060223  CVRMSE:15.055734  R2:0.077506  MSE:0.003675  KL:0.898941  MAE:0.010042  RMSE:0.060624  CVRMSE:15.155889  R2:0.089202  MSE:0.003629  KL:1.325555  MAE:0.009934  RMSE:0.060238  CVRMSE:15.059507  R2:0.084824  MSE:0.003646  KL:1.959454  MAE:0.010048  RMSE:0.060383  CVRMSE:15.095654  R2:0.069160  MSE:0.003708  KL:1.652823  MAE:0.010247  RMSE:0.0

R2:0.067571  MSE:0.003715  KL:1.037940  MAE:0.010023  RMSE:0.060949  CVRMSE:15.237282  R2:0.074773  MSE:0.003686  KL:0.757290  MAE:0.009915  RMSE:0.060713  CVRMSE:15.178328  R2:0.065049  MSE:0.003725  KL:0.566725  MAE:0.009669  RMSE:0.061032  CVRMSE:15.257876  R2:0.078052  MSE:0.003673  KL:1.190416  MAE:0.009459  RMSE:0.060606  CVRMSE:15.151408  R2:0.088701  MSE:0.003631  KL:2.779590  MAE:0.009500  RMSE:0.060255  CVRMSE:15.063649  train_loss:
0.003630615698057227
R2:0.075991  MSE:0.002764  KL:2.981424  MAE:0.008001  RMSE:0.052571  CVRMSE:17.523667  test_loss:
1.3818549021380022e-05
R2:0.067390  MSE:0.003716  KL:1.091980  MAE:0.009751  RMSE:0.060955  CVRMSE:15.238761  R2:0.089537  MSE:0.003627  KL:0.524364  MAE:0.009669  RMSE:0.060227  CVRMSE:15.056735  R2:0.065764  MSE:0.003722  KL:0.884734  MAE:0.010143  RMSE:0.061008  CVRMSE:15.252039  R2:0.078845  MSE:0.003670  KL:0.892988  MAE:0.009975  RMSE:0.060580  CVRMSE:15.144890  R2:0.074785  MSE:0.003686  KL:0.886892  MAE:0.009877  RMSE:0.06

R2:0.070031  MSE:0.003705  KL:1.163930  MAE:0.009866  RMSE:0.060869  CVRMSE:15.217170  R2:0.093687  MSE:0.003611  KL:1.178048  MAE:0.009502  RMSE:0.060090  CVRMSE:15.022378  R2:0.063904  MSE:0.003729  KL:1.147671  MAE:0.009733  RMSE:0.061069  CVRMSE:15.267220  R2:0.092199  MSE:0.003617  KL:0.744735  MAE:0.009739  RMSE:0.060139  CVRMSE:15.034710  R2:0.060532  MSE:0.003743  KL:1.047695  MAE:0.010054  RMSE:0.061179  CVRMSE:15.294689  train_loss:
0.003742840196355246
R2:0.076903  MSE:0.002761  KL:2.743711  MAE:0.007976  RMSE:0.052545  CVRMSE:17.515016  test_loss:
1.3804909933242016e-05
R2:0.071831  MSE:0.003698  KL:0.846555  MAE:0.009802  RMSE:0.060810  CVRMSE:15.202433  R2:0.091431  MSE:0.003620  KL:0.681108  MAE:0.009980  RMSE:0.060164  CVRMSE:15.041063  R2:0.076464  MSE:0.003679  KL:0.657520  MAE:0.009659  RMSE:0.060658  CVRMSE:15.164449  R2:0.074761  MSE:0.003686  KL:1.171913  MAE:0.009964  RMSE:0.060714  CVRMSE:15.178424  R2:0.093386  MSE:0.003612  KL:0.659254  MAE:0.009832  RMSE:0.06

R2:0.082967  MSE:0.003653  KL:0.665124  MAE:0.009521  RMSE:0.060444  CVRMSE:15.110963  R2:0.076492  MSE:0.003679  KL:1.529619  MAE:0.009723  RMSE:0.060657  CVRMSE:15.164221  R2:0.074174  MSE:0.003688  KL:1.691725  MAE:0.009622  RMSE:0.060733  CVRMSE:15.183237  R2:0.066253  MSE:0.003720  KL:0.564355  MAE:0.009814  RMSE:0.060992  CVRMSE:15.248052  R2:0.076863  MSE:0.003678  KL:0.861786  MAE:0.009661  RMSE:0.060645  CVRMSE:15.161171  train_loss:
0.003677778132259846
R2:0.076147  MSE:0.002763  KL:1.907656  MAE:0.008057  RMSE:0.052567  CVRMSE:17.522194  test_loss:
1.3816229824442417e-05
R2:0.083809  MSE:0.003650  KL:0.627155  MAE:0.009696  RMSE:0.060416  CVRMSE:15.104021  R2:0.077081  MSE:0.003677  KL:0.978137  MAE:0.009605  RMSE:0.060638  CVRMSE:15.159380  R2:0.061651  MSE:0.003738  KL:1.363839  MAE:0.009778  RMSE:0.061142  CVRMSE:15.285578  R2:0.074078  MSE:0.003689  KL:0.889738  MAE:0.009732  RMSE:0.060736  CVRMSE:15.184024  R2:0.062914  MSE:0.003733  KL:0.887685  MAE:0.009568  RMSE:0.06

R2:0.078933  MSE:0.003670  KL:2.903198  MAE:0.009631  RMSE:0.060577  CVRMSE:15.144160  R2:0.073092  MSE:0.003693  KL:1.138508  MAE:0.009817  RMSE:0.060768  CVRMSE:15.192107  R2:0.076149  MSE:0.003681  KL:1.075215  MAE:0.009716  RMSE:0.060668  CVRMSE:15.167036  R2:0.068776  MSE:0.003710  KL:1.648552  MAE:0.009587  RMSE:0.060910  CVRMSE:15.227432  R2:0.062589  MSE:0.003735  KL:1.086396  MAE:0.009568  RMSE:0.061112  CVRMSE:15.277935  train_loss:
0.003734644426731393
R2:0.076065  MSE:0.002763  KL:2.174523  MAE:0.007767  RMSE:0.052569  CVRMSE:17.522964  test_loss:
1.3817441678838804e-05
R2:0.081567  MSE:0.003659  KL:0.650194  MAE:0.009393  RMSE:0.060490  CVRMSE:15.122490  R2:0.090003  MSE:0.003625  KL:0.581547  MAE:0.009487  RMSE:0.060212  CVRMSE:15.052880  R2:0.079876  MSE:0.003666  KL:1.648548  MAE:0.009544  RMSE:0.060546  CVRMSE:15.136405  R2:0.075038  MSE:0.003685  KL:1.173027  MAE:0.009637  RMSE:0.060705  CVRMSE:15.176147  R2:0.076619  MSE:0.003679  KL:0.496788  MAE:0.009803  RMSE:0.06

R2:0.082716  MSE:0.003654  KL:0.770019  MAE:0.010228  RMSE:0.060452  CVRMSE:15.113031  R2:0.089504  MSE:0.003627  KL:0.697995  MAE:0.009877  RMSE:0.060228  CVRMSE:15.057009  R2:0.083191  MSE:0.003653  KL:1.783602  MAE:0.010071  RMSE:0.060436  CVRMSE:15.109114  R2:0.085616  MSE:0.003643  KL:1.972415  MAE:0.009800  RMSE:0.060356  CVRMSE:15.089120  R2:0.073761  MSE:0.003690  KL:1.420272  MAE:0.009907  RMSE:0.060747  CVRMSE:15.186626  train_loss:
0.0036901377563481217
R2:0.076399  MSE:0.002762  KL:2.072607  MAE:0.008358  RMSE:0.052559  CVRMSE:17.519799  test_loss:
1.3812450240948238e-05
R2:0.096913  MSE:0.003598  KL:0.607697  MAE:0.009960  RMSE:0.059982  CVRMSE:14.995619  R2:0.073299  MSE:0.003692  KL:0.792506  MAE:0.009893  RMSE:0.060762  CVRMSE:15.190408  R2:0.079221  MSE:0.003668  KL:1.112987  MAE:0.009983  RMSE:0.060567  CVRMSE:15.141799  R2:0.089789  MSE:0.003626  KL:0.582091  MAE:0.009738  RMSE:0.060219  CVRMSE:15.054652  R2:0.090657  MSE:0.003623  KL:0.694912  MAE:0.009826  RMSE:0.0

R2:0.080025  MSE:0.003665  KL:0.659173  MAE:0.009711  RMSE:0.060541  CVRMSE:15.135181  R2:0.077722  MSE:0.003674  KL:1.344947  MAE:0.009985  RMSE:0.060616  CVRMSE:15.154112  R2:0.081024  MSE:0.003661  KL:0.600791  MAE:0.009759  RMSE:0.060508  CVRMSE:15.126964  R2:0.082855  MSE:0.003654  KL:1.140093  MAE:0.009972  RMSE:0.060448  CVRMSE:15.111889  R2:0.090630  MSE:0.003623  KL:0.671984  MAE:0.009923  RMSE:0.060191  CVRMSE:15.047694  train_loss:
0.0036229295379598624
R2:0.076631  MSE:0.002762  KL:2.664018  MAE:0.008238  RMSE:0.052553  CVRMSE:17.517595  test_loss:
1.3808976567815989e-05
R2:0.071602  MSE:0.003699  KL:0.694319  MAE:0.010112  RMSE:0.060817  CVRMSE:15.204308  R2:0.064345  MSE:0.003728  KL:0.788374  MAE:0.010094  RMSE:0.061054  CVRMSE:15.263620  R2:0.065742  MSE:0.003722  KL:2.053634  MAE:0.009920  RMSE:0.061009  CVRMSE:15.252219  R2:0.086315  MSE:0.003640  KL:0.896270  MAE:0.009580  RMSE:0.060333  CVRMSE:15.083350  R2:0.079647  MSE:0.003667  KL:0.782526  MAE:0.009803  RMSE:0.0

R2:0.080049  MSE:0.003665  KL:1.002675  MAE:0.009799  RMSE:0.060540  CVRMSE:15.134989  R2:0.059105  MSE:0.003749  KL:1.462302  MAE:0.009888  RMSE:0.061225  CVRMSE:15.306301  R2:0.068724  MSE:0.003710  KL:0.522185  MAE:0.009588  RMSE:0.060911  CVRMSE:15.227864  R2:0.083695  MSE:0.003651  KL:3.176628  MAE:0.009452  RMSE:0.060420  CVRMSE:15.104964  R2:0.074031  MSE:0.003689  KL:0.839740  MAE:0.009310  RMSE:0.060738  CVRMSE:15.184410  train_loss:
0.0036890609786496497
R2:0.074783  MSE:0.002767  KL:2.247864  MAE:0.008076  RMSE:0.052605  CVRMSE:17.535118  test_loss:
1.3836616271873936e-05
R2:0.083262  MSE:0.003652  KL:0.753076  MAE:0.009567  RMSE:0.060434  CVRMSE:15.108535  R2:0.079016  MSE:0.003669  KL:1.876314  MAE:0.009510  RMSE:0.060574  CVRMSE:15.143481  R2:0.081982  MSE:0.003657  KL:1.045006  MAE:0.009666  RMSE:0.060476  CVRMSE:15.119078  R2:0.069413  MSE:0.003707  KL:0.696048  MAE:0.009806  RMSE:0.060889  CVRMSE:15.222223  R2:0.073754  MSE:0.003690  KL:2.148534  MAE:0.009670  RMSE:0.0

R2:0.061842  MSE:0.003738  KL:1.140172  MAE:0.009714  RMSE:0.061136  CVRMSE:15.284026  R2:0.082887  MSE:0.003654  KL:1.205269  MAE:0.009366  RMSE:0.060446  CVRMSE:15.111620  R2:0.088973  MSE:0.003630  KL:1.089434  MAE:0.009587  RMSE:0.060246  CVRMSE:15.061399  R2:0.092718  MSE:0.003615  KL:0.646835  MAE:0.009577  RMSE:0.060122  CVRMSE:15.030410  R2:0.075737  MSE:0.003682  KL:2.468070  MAE:0.009962  RMSE:0.060682  CVRMSE:15.170416  train_loss:
0.0036822639929596336
R2:0.075611  MSE:0.002765  KL:2.298773  MAE:0.008540  RMSE:0.052582  CVRMSE:17.527276  test_loss:
1.3824242807459086e-05
R2:0.084431  MSE:0.003648  KL:2.786178  MAE:0.009974  RMSE:0.060396  CVRMSE:15.098900  R2:0.072960  MSE:0.003693  KL:1.371979  MAE:0.009794  RMSE:0.060773  CVRMSE:15.193189  R2:0.059737  MSE:0.003746  KL:1.325208  MAE:0.009739  RMSE:0.061205  CVRMSE:15.301163  R2:0.070514  MSE:0.003703  KL:1.501702  MAE:0.009453  RMSE:0.060853  CVRMSE:15.213219  R2:0.085772  MSE:0.003642  KL:1.114420  MAE:0.009228  RMSE:0.0

R2:0.074979  MSE:0.003685  KL:0.761835  MAE:0.009944  RMSE:0.060707  CVRMSE:15.176635  R2:0.083035  MSE:0.003653  KL:0.560898  MAE:0.009698  RMSE:0.060442  CVRMSE:15.110405  R2:0.077057  MSE:0.003677  KL:0.683354  MAE:0.009965  RMSE:0.060638  CVRMSE:15.159576  R2:0.066618  MSE:0.003719  KL:1.903902  MAE:0.010155  RMSE:0.060980  CVRMSE:15.245068  R2:0.099489  MSE:0.003588  KL:1.050065  MAE:0.009662  RMSE:0.059897  CVRMSE:14.974222  train_loss:
0.003587637083546724
R2:0.074671  MSE:0.002768  KL:2.315850  MAE:0.008775  RMSE:0.052609  CVRMSE:17.536179  test_loss:
1.3838291575666516e-05
R2:0.073695  MSE:0.003690  KL:0.777033  MAE:0.010214  RMSE:0.060749  CVRMSE:15.187162  R2:0.095672  MSE:0.003603  KL:0.956561  MAE:0.009802  RMSE:0.060024  CVRMSE:15.005923  R2:0.076078  MSE:0.003681  KL:0.955021  MAE:0.010288  RMSE:0.060670  CVRMSE:15.167616  R2:0.086296  MSE:0.003640  KL:1.192915  MAE:0.009973  RMSE:0.060334  CVRMSE:15.083507  R2:0.094603  MSE:0.003607  KL:0.922219  MAE:0.010150  RMSE:0.06

R2:0.078761  MSE:0.003670  KL:0.873030  MAE:0.009741  RMSE:0.060582  CVRMSE:15.145578  R2:0.076086  MSE:0.003681  KL:0.620403  MAE:0.009819  RMSE:0.060670  CVRMSE:15.167549  R2:0.069043  MSE:0.003709  KL:0.682612  MAE:0.009890  RMSE:0.060901  CVRMSE:15.225248  R2:0.084445  MSE:0.003648  KL:1.142852  MAE:0.009710  RMSE:0.060395  CVRMSE:15.098784  R2:0.085447  MSE:0.003644  KL:0.814258  MAE:0.009672  RMSE:0.060362  CVRMSE:15.090519  train_loss:
0.003643580300558824
R2:0.075644  MSE:0.002765  KL:2.847683  MAE:0.008106  RMSE:0.052581  CVRMSE:17.526960  test_loss:
1.3823744913679547e-05
R2:0.088439  MSE:0.003632  KL:1.785529  MAE:0.009810  RMSE:0.060263  CVRMSE:15.065815  R2:0.079159  MSE:0.003669  KL:1.017895  MAE:0.009817  RMSE:0.060569  CVRMSE:15.142308  R2:0.057632  MSE:0.003754  KL:0.654660  MAE:0.010094  RMSE:0.061273  CVRMSE:15.318277  R2:0.085834  MSE:0.003642  KL:0.844250  MAE:0.009828  RMSE:0.060349  CVRMSE:15.087320  R2:0.086695  MSE:0.003639  KL:1.672835  MAE:0.009901  RMSE:0.06

R2:0.070704  MSE:0.003702  KL:0.912556  MAE:0.009632  RMSE:0.060847  CVRMSE:15.211666  R2:0.079176  MSE:0.003669  KL:1.203029  MAE:0.009391  RMSE:0.060569  CVRMSE:15.142162  R2:0.079002  MSE:0.003669  KL:0.484180  MAE:0.009451  RMSE:0.060574  CVRMSE:15.143596  R2:0.080152  MSE:0.003665  KL:0.751964  MAE:0.009693  RMSE:0.060537  CVRMSE:15.134142  R2:0.081421  MSE:0.003660  KL:1.046709  MAE:0.009568  RMSE:0.060495  CVRMSE:15.123699  train_loss:
0.0036596204896341077
R2:0.075098  MSE:0.002766  KL:2.684044  MAE:0.008560  RMSE:0.052596  CVRMSE:17.532136  test_loss:
1.3831910255248658e-05
R2:0.086227  MSE:0.003640  KL:0.771580  MAE:0.010210  RMSE:0.060336  CVRMSE:15.084084  R2:0.073691  MSE:0.003690  KL:0.740266  MAE:0.009877  RMSE:0.060749  CVRMSE:15.187195  R2:0.088736  MSE:0.003630  KL:1.095699  MAE:0.010118  RMSE:0.060253  CVRMSE:15.063361  R2:0.070949  MSE:0.003701  KL:2.012112  MAE:0.009702  RMSE:0.060839  CVRMSE:15.209661  R2:0.073229  MSE:0.003692  KL:0.824736  MAE:0.009729  RMSE:0.0

R2:0.088873  MSE:0.003630  KL:1.107552  MAE:0.009534  RMSE:0.060249  CVRMSE:15.062222  R2:0.077192  MSE:0.003676  KL:1.638042  MAE:0.009855  RMSE:0.060634  CVRMSE:15.158467  R2:0.075959  MSE:0.003681  KL:0.557669  MAE:0.009772  RMSE:0.060674  CVRMSE:15.168595  R2:0.072636  MSE:0.003695  KL:1.406523  MAE:0.009726  RMSE:0.060783  CVRMSE:15.195847  R2:0.075185  MSE:0.003684  KL:0.968908  MAE:0.009707  RMSE:0.060700  CVRMSE:15.174945  train_loss:
0.003684463200625032
R2:0.076452  MSE:0.002762  KL:2.123804  MAE:0.008122  RMSE:0.052558  CVRMSE:17.519294  test_loss:
1.3811654614983126e-05
R2:0.059216  MSE:0.003748  KL:0.941630  MAE:0.009813  RMSE:0.061222  CVRMSE:15.305398  R2:0.065936  MSE:0.003721  KL:0.670094  MAE:0.009288  RMSE:0.061003  CVRMSE:15.250639  R2:0.058913  MSE:0.003749  KL:2.922605  MAE:0.009297  RMSE:0.061231  CVRMSE:15.307863  R2:0.076699  MSE:0.003678  KL:0.911395  MAE:0.009190  RMSE:0.060650  CVRMSE:15.162522  R2:0.076213  MSE:0.003680  KL:1.644114  MAE:0.009132  RMSE:0.06

R2:0.083182  MSE:0.003653  KL:1.807029  MAE:0.009771  RMSE:0.060437  CVRMSE:15.109191  R2:0.077840  MSE:0.003674  KL:1.888180  MAE:0.010038  RMSE:0.060613  CVRMSE:15.153143  R2:0.084574  MSE:0.003647  KL:0.792415  MAE:0.009857  RMSE:0.060391  CVRMSE:15.097715  R2:0.078948  MSE:0.003669  KL:1.509077  MAE:0.009886  RMSE:0.060576  CVRMSE:15.144041  R2:0.082988  MSE:0.003653  KL:0.662187  MAE:0.009974  RMSE:0.060443  CVRMSE:15.110789  train_loss:
0.003653375504654832
R2:0.076951  MSE:0.002761  KL:2.605620  MAE:0.007993  RMSE:0.052544  CVRMSE:17.514560  test_loss:
1.3804190413793548e-05
R2:0.068371  MSE:0.003712  KL:0.581270  MAE:0.009700  RMSE:0.060923  CVRMSE:15.230750  R2:0.066682  MSE:0.003718  KL:0.589206  MAE:0.009743  RMSE:0.060978  CVRMSE:15.244543  R2:0.094617  MSE:0.003607  KL:1.200696  MAE:0.009666  RMSE:0.060059  CVRMSE:15.014675  R2:0.074782  MSE:0.003686  KL:2.654540  MAE:0.009629  RMSE:0.060713  CVRMSE:15.178251  R2:0.094865  MSE:0.003606  KL:1.873257  MAE:0.009554  RMSE:0.06

R2:0.092775  MSE:0.003614  KL:0.547981  MAE:0.009162  RMSE:0.060120  CVRMSE:15.029935  R2:0.071439  MSE:0.003699  KL:1.351894  MAE:0.009594  RMSE:0.060823  CVRMSE:15.205649  R2:0.079680  MSE:0.003667  KL:1.490936  MAE:0.009556  RMSE:0.060552  CVRMSE:15.138020  R2:0.084971  MSE:0.003645  KL:0.716499  MAE:0.009765  RMSE:0.060378  CVRMSE:15.094447  R2:0.066989  MSE:0.003717  KL:0.656775  MAE:0.009742  RMSE:0.060968  CVRMSE:15.242038  train_loss:
0.003717115122708492
R2:0.077014  MSE:0.002761  KL:2.332803  MAE:0.007935  RMSE:0.052542  CVRMSE:17.513962  test_loss:
1.380324964702595e-05
R2:0.078677  MSE:0.003671  KL:1.944178  MAE:0.009549  RMSE:0.060585  CVRMSE:15.146270  R2:0.075598  MSE:0.003683  KL:0.794287  MAE:0.009492  RMSE:0.060686  CVRMSE:15.171554  R2:0.093655  MSE:0.003611  KL:1.970718  MAE:0.009592  RMSE:0.060091  CVRMSE:15.022650  R2:0.069386  MSE:0.003708  KL:1.234006  MAE:0.009708  RMSE:0.060890  CVRMSE:15.222444  R2:0.072010  MSE:0.003697  KL:0.586908  MAE:0.009769  RMSE:0.060

R2:0.077116  MSE:0.003677  KL:1.313480  MAE:0.009704  RMSE:0.060636  CVRMSE:15.159091  R2:0.076000  MSE:0.003681  KL:2.670909  MAE:0.009638  RMSE:0.060673  CVRMSE:15.168255  R2:0.069227  MSE:0.003708  KL:0.859866  MAE:0.009715  RMSE:0.060895  CVRMSE:15.223744  R2:0.070212  MSE:0.003704  KL:2.651074  MAE:0.009684  RMSE:0.060863  CVRMSE:15.215690  R2:0.073652  MSE:0.003691  KL:1.016216  MAE:0.009405  RMSE:0.060750  CVRMSE:15.187516  train_loss:
0.0036905699846101923
R2:0.076964  MSE:0.002761  KL:2.247942  MAE:0.007642  RMSE:0.052543  CVRMSE:17.514442  test_loss:
1.3804005851852708e-05
R2:0.082217  MSE:0.003656  KL:0.439661  MAE:0.009252  RMSE:0.060469  CVRMSE:15.117145  R2:0.089794  MSE:0.003626  KL:1.186621  MAE:0.009444  RMSE:0.060218  CVRMSE:15.054614  R2:0.080344  MSE:0.003664  KL:1.121370  MAE:0.009506  RMSE:0.060530  CVRMSE:15.132560  R2:0.085954  MSE:0.003642  KL:0.799144  MAE:0.009576  RMSE:0.060345  CVRMSE:15.086335  R2:0.076487  MSE:0.003679  KL:0.923383  MAE:0.009791  RMSE:0.0

In [None]:
# estimate area under the uplift curve (AUUC)
uplift=df_test.copy()
uplift = uplift.loc[:,~uplift.columns.duplicated()]

auuc=auuc_score(uplift, outcome_col='y', treatment_col='T', treatment_effect_col='tau')
gat_auuc=pd.DataFrame(auuc[["y_hat","Random"]],columns=['auuc'])
gat_auuc

Unnamed: 0,auuc
y_hat,0.830929
Random,0.489027


In [None]:
print('Feature mode on GAT:', feats_mode)
print('MSE:',MSE(uplift['y'],uplift['y_hat']))

MSE: 0.0027703335521529373


In [None]:
if feats_mode == 'imp':
    result.loc['GAT (Struct+Important Weighting)','AUUC']=auuc["y_hat"]
if feats_mode == 'causal+imp':
    result.loc['GAT (Struct+Causal+Important Weighting)','AUUC']=auuc["y_hat"]
if feats_mode == 'causal*imp':
    result.loc['GAT (Struct+(Causal*Important) Weighting','AUUC']=auuc["y_hat"]
result

Unnamed: 0,AUUC
S Learner(LR),0.497983
S Learner(XGB),0.875572
S Learner(LGBM),0.883033
GCN (Struct),0.501865
GCN (Struct+Feature),0.721959
GCN (Struct+Causal Weighting),0.732616
GAT (Struct),0.544286
GAT (Struct+Feature),0.84763
GAT (Struct+Causal Weighting),0.8807
GCN (Struct+Causal+Important Weighting),0.843556


# GAT
Built-in Pytorch Geometric with default attention matrix without Bayesian Network Connectivity

In [None]:
edge_index=pd.DataFrame(columns=['start','end'])
edge_index['start']=[i+1 for i in range(14)]
edge_index['end']=[i+1 for i in range(14)]
edge_index=torch.from_numpy(np.transpose(np.array(edge_index)))

edge_index

tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14]])

In [None]:
# Input the feature mode and get corresponding dimensions

feats_mode='causal+imp'


if feats_mode=='causal+imp':
    in_dim=39  # 1+10+14+14

elif feats_mode =='noweighting':
    in_dim=11  # 1+10

elif feats_mode in ['causal*imp','causal','imp','equal']:
    in_dim=25  # 1+10+14
    
else:
    print('Error: Input Dimension Error')
    
# subset data for demo, delete this while training ob full dataset
df_train = df_train[:1000]
df_test = df_test[:1000]

col=df_train.columns


data_list=[]
weighted_feats=[]

# Set no. of folds on dataset
folds= 10  # 10 folds
batch=int(len(df_train)/folds) # batch size 

for f in range(folds):
    train=df_train[(batch*f):(batch*(f+1))]
    x_train=torch.from_numpy(np.array(train[[i for i in col[:-4]]+["visit","T"]])).to(torch.float32)
    y_train=torch.from_numpy(np.array(train['y'])).reshape(train.shape[0],1).to(torch.float32)

    for i in range(x_train.shape[0]):
        Edge_index = edge_index.type(torch.long)
        X =x_train[i]
        
        if feats_mode =='causal':
            t=torch.zeros(14,25) # 25 dimensions = 1+10+14 = node number, node embedding, features

            causal_weighted= np.multiply(np.array(X),np.array(ate_list))
            weighted_feats.append(causal_weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(causal_weighted) # feature (10d), causally weighted features        
                                                    ))) 
        elif feats_mode =='imp':
            t=torch.zeros(14,25) # 25 dimensions = 1+10+14 = node number, node embedding, features

            imp_weighted= np.multiply(np.array(X),np.array(imp_list))
            weighted_feats.append(imp_weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(imp_weighted) # feature (10d),importance weighted features        
                                                    ))) 
        elif feats_mode =='causal*imp':
            t=torch.zeros(14,25) # 25 dimensions = 1+10+14 = node number, node embedding, features

            weighted= np.multiply(np.multiply(np.array(X),np.array(imp_list)),np.array(ate_list))
            weighted_feats.append(weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(weighted) # feature (10d), importance weighted features        
                                                    ))) 
        elif feats_mode =='causal+imp':
            t=torch.zeros(14,39) # 39 dimensions = 1+10+14+14 = node number, node embedding, causal features, important features

            imp_weighted= np.multiply(np.array(X),np.array(imp_list))
            causal_weighted= np.multiply(np.array(X),np.array(ate_list))
            weighted_feats.append(imp_weighted+causal_weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(causal_weighted),# feature (10d), causality weighted features
                                                    list(imp_weighted)# feature (10d), importance weighted features        
                                                    ))) 
        elif feats_mode =='equal':
            t=torch.zeros(14,25) # 25 dimensions = 1+10+14 = node number, node embedding, features

            weighted= np.array(X)
            weighted_feats.append(weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(weighted) # feature (10d), equally weighted features        
                                                    ))) 
        else:
            t=torch.zeros(14,11) # 11 dimensions = 1+10 = node number, node embedding


            for j in range(14):
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    ))) 

        Y = y_train[i].reshape(-1,1).to(torch.float32)
        data = Data(x=t, edge_index=Edge_index, y=Y)
        data_list.append(data)
train_loader =DataLoader(data_list, batch_size=200, shuffle=False,num_workers=0)


y_test=torch.from_numpy(np.array(df_test['y'])).reshape(df_test.shape[0],1).to(torch.float32)
x_test=torch.from_numpy(np.array(df_test[[i for i in col[:-4]]+["visit","T"]])).to(torch.float32)

data_list=[]
weighted_feats=[]

for f in range(folds):

    test=df_test[(batch*f):(batch*(f+1))]
    x_test=torch.from_numpy(np.array(test[[i for i in col[:-4]]+["visit","T"]])).to(torch.float32)
    y_test=torch.from_numpy(np.array(test['y'])).reshape(test.shape[0],1).to(torch.float32)

    for i in range(x_test.shape[0]):
        Edge_index = edge_index.type(torch.long)
        X =x_test[i]

        if feats_mode =='causal':
            t=torch.zeros(14,25) # 25 dimensions = 1+10+14 = node number, node embedding, features

            causal_weighted= np.multiply(np.array(X),np.array(ate_list))
            weighted_feats.append(causal_weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(causal_weighted) # feature (10d), causally weighted features        
                                                    ))) 
        elif feats_mode =='imp':
            t=torch.zeros(14,25) # 25 dimensions = 1+10+14 = node number, node embedding, features

            imp_weighted= np.multiply(np.array(X),np.array(imp_list))
            weighted_feats.append(imp_weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(imp_weighted) # feature (10d),importance weighted features        
                                                    ))) 
        elif feats_mode =='causal*imp':
            t=torch.zeros(14,25) # 25 dimensions = 1+10+14 = node number, node embedding, features

            weighted= np.multiply(np.multiply(np.array(X),np.array(imp_list)),np.array(ate_list))
            weighted_feats.append(weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(weighted) # feature (10d), importance weighted features        
                                                    ))) 
        elif feats_mode =='causal+imp':
            t=torch.zeros(14,39) # 39 dimensions = 1+10+14+14 = node number, node embedding, causal features, important features

            imp_weighted= np.multiply(np.array(X),np.array(imp_list))
            causal_weighted= np.multiply(np.array(X),np.array(ate_list))
            weighted_feats.append(imp_weighted+causal_weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(causal_weighted),# feature (10d), causality weighted features
                                                    list(imp_weighted)# feature (10d), importance weighted features        
                                                    ))) 
        elif feats_mode =='equal':
            t=torch.zeros(14,25) # 25 dimensions = 1+10+14 = node number, node embedding, features

            weighted= np.array(X)
            weighted_feats.append(weighted)

            for j in range(14): # for j-th feature
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    list(weighted) # feature (10d), equally weighted features        
                                                    ))) 
        else:
            t=torch.zeros(14,11) # 11 dimensions = 1+10 = node number, node embedding


            for j in range(14):
                t[j]=torch.from_numpy(np.concatenate(( [float(X[int(j)])], # node (1d)
                                                    list(lst_dw[int(j)]), # node embeddings (10d), prior embeddings from DeepWalk/Node2Vec
                                                    ))) 

        Y = y_test[i].reshape(-1,1).to(torch.float32)
        data = Data(x=t, edge_index=Edge_index, y=Y)
        data_list.append(data)
test_loader =DataLoader(data_list, batch_size=200, shuffle=False,num_workers=0)

In [None]:
from uuid import RFC_4122
import torch
import math
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import GATConv

from torch_geometric.utils import add_self_loops,degree
from torch_geometric.datasets import Planetoid
import ssl
import torch.nn.functional as F


class Net(torch.nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.gat1=GATConv(in_channels=in_dim,out_channels=8,heads=8,dropout=0.6)
        self.gat2=GATConv(in_channels=64,out_channels=10,heads=1,dropout=0.6)
        self.f1 = torch.nn.Linear(140,32)
        self.f2 = torch.nn.Linear(32,1)
    def forward(self,data):
        x,edge_index=data.x, data.edge_index
        x=self.gat1(x,edge_index)
        x=self.gat2(x,edge_index)
        x=x.reshape(-1,140)
        x = self.f1(x)
        x = self.f2(x)
        return x

ssl._create_default_https_context = ssl._create_unverified_context
def train():
    model.train()
    loss_all = 0
    y_actual = []
    y_predicted = []
    loss_all=0
    for data in iter(train_loader):
        loss = 0
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        label = data.y.to(device)
        loss = crit(output, label)
        loss.backward()
        loss_all += data.num_graphs * loss.item()
        optimizer.step()
        y_actual +=(label).cpu().detach().ravel().tolist()
        y_predicted +=(output).cpu().detach().ravel().tolist()
    
    
    loss=loss_all/len(df_train)
    r2=R2(y_predicted, y_actual)
    mse = MSE(y_predicted, y_actual)
    kl=kl_divergence(y_predicted, y_actual)

    print("R2:%f" % (R2(y_predicted, y_actual)),end='  ')
    print("MSE:%f" % (MSE(y_predicted, y_actual)),end='  ')
    print("KL:%f" % (kl_divergence(y_predicted, y_actual)),end='  ')
    print("MAE:%f" % (MAE(y_predicted, y_actual)),end='  ')
    print("RMSE:%f" % (RMSE(y_predicted, y_actual)),end='  ')
    print("CVRMSE:%f" % (CVRMSE(y_predicted, y_actual)),end='  ')

    return loss,r2,mse,kl


def val():
    model.eval()
    y_actual = []
    y_predicted = []
    loss_all=0
    for data in iter(test_loader):
      loss = 0
      data = data.to(device)
      output = model(data)
      label = data.y.to(device)
      y_actual +=(label).cpu().detach().ravel().tolist()
      y_predicted +=(output).cpu().detach().ravel().tolist()
      loss = crit(output, label)
      loss_all += loss.item()

    
    loss = loss_all / len(df_test)
    r2=R2(y_predicted, y_actual)
    mse = MSE(y_predicted, y_actual)
    kl=kl_divergence(y_predicted, y_actual)

    print("R2:%f" % (R2(y_predicted, y_actual)),end='  ')
    print("MSE:%f" % (MSE(y_predicted, y_actual)),end='  ')
    print("KL:%f" % (kl_divergence(y_predicted, y_actual)),end='  ')
    print("MAE:%f" % (MAE(y_predicted, y_actual)),end='  ')
    print("RMSE:%f" % (RMSE(y_predicted, y_actual)),end='  ')
    print("CVRMSE:%f" % (CVRMSE(y_predicted, y_actual)),end='  ')

    return loss,r2, mse,kl



num_epochs = 2560
batch_size = 512
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.002, weight_decay=5e-4)
crit = F.mse_loss

for epoch in range(num_epochs):
    loss,r2,mse,kl=train()
    if epoch %5==0:      
        print('train_loss:')
        print(loss)

        loss,r2,mse,kl=val()
        print('test_loss:')
        print(loss)

y_predicted = []
for data in iter(test_loader):
    loss = 0
    data = data.to(device)
    output = model(data)
    label = data.y.to(device)
    y_predicted +=(output).cpu().detach().ravel().tolist()
df_test["y_hat"]=y_predicted

R2:-85.316482  MSE:0.343885  KL:2.721755  MAE:0.426810  RMSE:0.586417  CVRMSE:146.604242  train_loss:
0.3438848674297333
R2:-0.116688  MSE:0.003340  KL:2.672313  MAE:0.016765  RMSE:0.057793  CVRMSE:19.264286  test_loss:
1.67000723304227e-05
R2:-25.708243  MSE:0.106406  KL:2.699752  MAE:0.243195  RMSE:0.326199  CVRMSE:81.549694  R2:-15.054149  MSE:0.063960  KL:2.782090  MAE:0.185219  RMSE:0.252903  CVRMSE:63.225652  R2:-8.684555  MSE:0.038583  KL:2.845542  MAE:0.142577  RMSE:0.196426  CVRMSE:49.106560  R2:-4.573907  MSE:0.022206  KL:2.933564  MAE:0.102578  RMSE:0.149018  CVRMSE:37.254567  R2:-3.834650  MSE:0.019261  KL:3.019548  MAE:0.095120  RMSE:0.138785  CVRMSE:34.696224  train_loss:
0.019261247664690017
R2:-0.096358  MSE:0.003279  KL:4.352970  MAE:0.018290  RMSE:0.057264  CVRMSE:19.088119  test_loss:
1.6396035527577625e-05
R2:-2.103998  MSE:0.012366  KL:3.013953  MAE:0.077458  RMSE:0.111204  CVRMSE:27.800999  R2:-1.334576  MSE:0.009301  KL:3.128081  MAE:0.063581  RMSE:0.096441  CVRM

R2:-0.110161  MSE:0.004423  KL:2.983040  MAE:0.022371  RMSE:0.066505  CVRMSE:16.626183  R2:-0.166780  MSE:0.004648  KL:3.085367  MAE:0.022277  RMSE:0.068180  CVRMSE:17.044887  R2:-0.052469  MSE:0.004193  KL:2.953662  MAE:0.022898  RMSE:0.064754  CVRMSE:16.188418  R2:-0.035050  MSE:0.004124  KL:3.093516  MAE:0.024888  RMSE:0.064216  CVRMSE:16.053893  R2:-0.063143  MSE:0.004236  KL:3.029905  MAE:0.022221  RMSE:0.065081  CVRMSE:16.270299  train_loss:
0.0042355618672445415
R2:0.062870  MSE:0.002803  KL:0.907885  MAE:0.007426  RMSE:0.052943  CVRMSE:17.647646  test_loss:
1.401477454783162e-05
R2:-0.083659  MSE:0.004317  KL:2.977087  MAE:0.021282  RMSE:0.065706  CVRMSE:16.426533  R2:-0.071701  MSE:0.004270  KL:2.834494  MAE:0.020797  RMSE:0.065343  CVRMSE:16.335653  R2:-0.014515  MSE:0.004042  KL:3.276501  MAE:0.022445  RMSE:0.063575  CVRMSE:15.893843  R2:-0.130863  MSE:0.004505  KL:3.072630  MAE:0.025306  RMSE:0.067122  CVRMSE:16.780495  R2:0.053323  MSE:0.003772  KL:2.838896  MAE:0.021799  

R2:-0.039364  MSE:0.004141  KL:2.728677  MAE:0.018998  RMSE:0.064349  CVRMSE:16.087313  R2:-0.079586  MSE:0.004301  KL:3.077791  MAE:0.020627  RMSE:0.065583  CVRMSE:16.395634  R2:0.060919  MSE:0.003741  KL:2.839260  MAE:0.018947  RMSE:0.061166  CVRMSE:15.291539  R2:-0.052275  MSE:0.004192  KL:2.959470  MAE:0.021894  RMSE:0.064748  CVRMSE:16.186926  R2:-0.015086  MSE:0.004044  KL:2.816052  MAE:0.019051  RMSE:0.063593  CVRMSE:15.898318  train_loss:
0.004044104128843174
R2:0.067353  MSE:0.002790  KL:1.067622  MAE:0.009588  RMSE:0.052816  CVRMSE:17.605388  test_loss:
1.394773632637225e-05
R2:-0.063957  MSE:0.004239  KL:2.943142  MAE:0.022124  RMSE:0.065106  CVRMSE:16.276529  R2:-0.067508  MSE:0.004253  KL:2.933718  MAE:0.021986  RMSE:0.065215  CVRMSE:16.303668  R2:-0.107601  MSE:0.004413  KL:3.096445  MAE:0.020844  RMSE:0.066428  CVRMSE:16.607004  R2:-0.075733  MSE:0.004286  KL:2.967427  MAE:0.024347  RMSE:0.065465  CVRMSE:16.366350  R2:-0.130242  MSE:0.004503  KL:2.953884  MAE:0.021570  R

R2:0.007444  MSE:0.003954  KL:2.687121  MAE:0.017704  RMSE:0.062884  CVRMSE:15.720892  R2:-0.017262  MSE:0.004053  KL:2.659385  MAE:0.018943  RMSE:0.063661  CVRMSE:15.915346  R2:0.021258  MSE:0.003899  KL:2.635626  MAE:0.017917  RMSE:0.062444  CVRMSE:15.611108  R2:-0.008295  MSE:0.004017  KL:2.963068  MAE:0.017388  RMSE:0.063380  CVRMSE:15.845047  R2:0.002502  MSE:0.003974  KL:2.858534  MAE:0.017322  RMSE:0.063040  CVRMSE:15.759980  train_loss:
0.003974031552206725
R2:0.064887  MSE:0.002797  KL:3.936264  MAE:0.010085  RMSE:0.052886  CVRMSE:17.628645  test_loss:
1.3984609613544308e-05
R2:-0.002705  MSE:0.003995  KL:2.654734  MAE:0.016967  RMSE:0.063204  CVRMSE:15.801060  R2:0.006568  MSE:0.003958  KL:2.728623  MAE:0.016722  RMSE:0.062911  CVRMSE:15.727831  R2:0.050084  MSE:0.003784  KL:2.728821  MAE:0.017949  RMSE:0.061518  CVRMSE:15.379502  R2:-0.024208  MSE:0.004080  KL:2.646087  MAE:0.016470  RMSE:0.063878  CVRMSE:15.969591  R2:0.048019  MSE:0.003793  KL:2.842224  MAE:0.015776  RMSE:

R2:0.042903  MSE:0.003813  KL:2.346539  MAE:0.014015  RMSE:0.061750  CVRMSE:15.437522  R2:0.063279  MSE:0.003732  KL:2.626891  MAE:0.014058  RMSE:0.061089  CVRMSE:15.272315  R2:0.092267  MSE:0.003616  KL:2.973653  MAE:0.016132  RMSE:0.060137  CVRMSE:15.034144  R2:0.039112  MSE:0.003828  KL:2.726526  MAE:0.016087  RMSE:0.061872  CVRMSE:15.468068  R2:0.078625  MSE:0.003671  KL:2.788879  MAE:0.015023  RMSE:0.060587  CVRMSE:15.146691  train_loss:
0.0036707561172079294
R2:0.071392  MSE:0.002777  KL:0.115838  MAE:0.008770  RMSE:0.052702  CVRMSE:17.567224  test_loss:
1.3887331209843978e-05
R2:0.036200  MSE:0.003840  KL:2.451581  MAE:0.014923  RMSE:0.061966  CVRMSE:15.491486  R2:-0.016997  MSE:0.004052  KL:2.397114  MAE:0.015132  RMSE:0.063653  CVRMSE:15.913273  R2:0.028873  MSE:0.003869  KL:2.390065  MAE:0.015050  RMSE:0.062201  CVRMSE:15.550258  R2:0.063515  MSE:0.003731  KL:2.785472  MAE:0.014534  RMSE:0.061082  CVRMSE:15.270387  R2:0.037103  MSE:0.003836  KL:2.619903  MAE:0.015928  RMSE:0.

R2:0.030137  MSE:0.003864  KL:1.954964  MAE:0.012547  RMSE:0.062161  CVRMSE:15.540134  R2:0.020588  MSE:0.003902  KL:1.954551  MAE:0.012265  RMSE:0.062466  CVRMSE:15.616449  R2:0.073862  MSE:0.003690  KL:1.757951  MAE:0.012969  RMSE:0.060743  CVRMSE:15.185792  R2:0.111331  MSE:0.003540  KL:2.249584  MAE:0.014085  RMSE:0.059502  CVRMSE:14.875440  R2:0.068594  MSE:0.003711  KL:3.103063  MAE:0.015224  RMSE:0.060916  CVRMSE:15.228920  train_loss:
0.003710719832452014
R2:0.075701  MSE:0.002765  KL:0.106575  MAE:0.008292  RMSE:0.052579  CVRMSE:17.526419  test_loss:
1.3822891589370556e-05
R2:0.051848  MSE:0.003777  KL:2.275493  MAE:0.014926  RMSE:0.061461  CVRMSE:15.365219  R2:0.042153  MSE:0.003816  KL:2.238380  MAE:0.015299  RMSE:0.061774  CVRMSE:15.443569  R2:0.022516  MSE:0.003894  KL:2.462049  MAE:0.014333  RMSE:0.062404  CVRMSE:15.601073  R2:0.007416  MSE:0.003954  KL:2.883078  MAE:0.013562  RMSE:0.062884  CVRMSE:15.721110  R2:0.022968  MSE:0.003892  KL:1.994694  MAE:0.011933  RMSE:0.06

R2:0.088244  MSE:0.003632  KL:2.695881  MAE:0.012058  RMSE:0.060270  CVRMSE:15.067419  R2:0.067701  MSE:0.003714  KL:2.204575  MAE:0.012073  RMSE:0.060945  CVRMSE:15.236222  R2:0.061852  MSE:0.003738  KL:2.005825  MAE:0.012984  RMSE:0.061136  CVRMSE:15.283943  R2:0.049001  MSE:0.003789  KL:2.659525  MAE:0.013234  RMSE:0.061553  CVRMSE:15.388265  R2:0.084290  MSE:0.003648  KL:2.108882  MAE:0.011940  RMSE:0.060400  CVRMSE:15.100063  train_loss:
0.0036481907867710107
R2:0.068310  MSE:0.002787  KL:0.085021  MAE:0.010139  RMSE:0.052789  CVRMSE:17.596353  test_loss:
1.3933422349509783e-05
R2:0.054364  MSE:0.003767  KL:2.241646  MAE:0.012350  RMSE:0.061379  CVRMSE:15.344814  R2:0.038804  MSE:0.003829  KL:2.177188  MAE:0.010916  RMSE:0.061882  CVRMSE:15.470543  R2:0.068812  MSE:0.003710  KL:2.327615  MAE:0.011365  RMSE:0.060909  CVRMSE:15.227141  R2:0.026036  MSE:0.003880  KL:1.974281  MAE:0.012357  RMSE:0.062292  CVRMSE:15.572956  R2:0.026279  MSE:0.003879  KL:1.893826  MAE:0.012402  RMSE:0.0

R2:0.055813  MSE:0.003762  KL:2.978070  MAE:0.010962  RMSE:0.061332  CVRMSE:15.333056  R2:0.060286  MSE:0.003744  KL:1.736100  MAE:0.011705  RMSE:0.061187  CVRMSE:15.296693  R2:0.032492  MSE:0.003855  KL:1.928597  MAE:0.011477  RMSE:0.062085  CVRMSE:15.521258  R2:0.064262  MSE:0.003728  KL:2.493428  MAE:0.011338  RMSE:0.061057  CVRMSE:15.264298  R2:0.048375  MSE:0.003791  KL:2.067672  MAE:0.010729  RMSE:0.061573  CVRMSE:15.393333  train_loss:
0.003791274703689851
R2:0.069630  MSE:0.002783  KL:0.117321  MAE:0.008011  RMSE:0.052752  CVRMSE:17.583880  test_loss:
1.39136779907858e-05
R2:0.076690  MSE:0.003678  KL:1.157966  MAE:0.010697  RMSE:0.060650  CVRMSE:15.162590  R2:0.004066  MSE:0.003968  KL:2.516351  MAE:0.010763  RMSE:0.062990  CVRMSE:15.747621  R2:0.072555  MSE:0.003695  KL:1.702976  MAE:0.010477  RMSE:0.060786  CVRMSE:15.196505  R2:0.071428  MSE:0.003699  KL:1.342515  MAE:0.011124  RMSE:0.060823  CVRMSE:15.205737  R2:0.050301  MSE:0.003784  KL:1.640759  MAE:0.010882  RMSE:0.0615

R2:0.054587  MSE:0.003767  KL:1.490843  MAE:0.010564  RMSE:0.061372  CVRMSE:15.343008  R2:0.030555  MSE:0.003862  KL:1.634562  MAE:0.010481  RMSE:0.062147  CVRMSE:15.536785  R2:0.031503  MSE:0.003858  KL:2.219246  MAE:0.010336  RMSE:0.062117  CVRMSE:15.529194  R2:0.058186  MSE:0.003752  KL:1.002634  MAE:0.009779  RMSE:0.061255  CVRMSE:15.313777  R2:0.080418  MSE:0.003664  KL:1.086383  MAE:0.010257  RMSE:0.060528  CVRMSE:15.131950  train_loss:
0.003663614571269136
R2:0.065222  MSE:0.002796  KL:0.093133  MAE:0.007699  RMSE:0.052876  CVRMSE:17.625492  test_loss:
1.3979608229419683e-05
R2:0.070490  MSE:0.003703  KL:3.190263  MAE:0.010210  RMSE:0.060854  CVRMSE:15.213418  R2:0.055459  MSE:0.003763  KL:0.987317  MAE:0.010726  RMSE:0.061344  CVRMSE:15.335932  R2:0.070432  MSE:0.003703  KL:0.860525  MAE:0.009613  RMSE:0.060856  CVRMSE:15.213889  R2:0.067075  MSE:0.003717  KL:1.090760  MAE:0.009413  RMSE:0.060965  CVRMSE:15.241334  R2:0.051618  MSE:0.003778  KL:1.010176  MAE:0.010105  RMSE:0.06

R2:0.091937  MSE:0.003618  KL:1.130230  MAE:0.010206  RMSE:0.060148  CVRMSE:15.036875  R2:0.073090  MSE:0.003693  KL:0.947770  MAE:0.009975  RMSE:0.060769  CVRMSE:15.192125  R2:0.025101  MSE:0.003884  KL:1.018970  MAE:0.010009  RMSE:0.062322  CVRMSE:15.580433  R2:0.033042  MSE:0.003852  KL:1.164924  MAE:0.009888  RMSE:0.062067  CVRMSE:15.516843  R2:-0.006728  MSE:0.004011  KL:0.816332  MAE:0.010164  RMSE:0.063331  CVRMSE:15.832727  train_loss:
0.004010803697747178
R2:0.065465  MSE:0.002795  KL:0.084895  MAE:0.006124  RMSE:0.052870  CVRMSE:17.623202  test_loss:
1.3975976420624648e-05
R2:0.043145  MSE:0.003812  KL:1.292535  MAE:0.010090  RMSE:0.061742  CVRMSE:15.435575  R2:0.039822  MSE:0.003825  KL:1.264857  MAE:0.011207  RMSE:0.061849  CVRMSE:15.462348  R2:0.042380  MSE:0.003815  KL:1.650451  MAE:0.010216  RMSE:0.061767  CVRMSE:15.441745  R2:0.077248  MSE:0.003676  KL:1.746867  MAE:0.010306  RMSE:0.060632  CVRMSE:15.158009  R2:0.074119  MSE:0.003689  KL:1.113549  MAE:0.010802  RMSE:0.0

R2:0.076910  MSE:0.003678  KL:0.927160  MAE:0.010128  RMSE:0.060643  CVRMSE:15.160783  R2:0.071628  MSE:0.003699  KL:0.903368  MAE:0.010192  RMSE:0.060816  CVRMSE:15.204102  R2:0.071554  MSE:0.003699  KL:0.798879  MAE:0.009288  RMSE:0.060819  CVRMSE:15.204706  R2:0.057367  MSE:0.003755  KL:0.711407  MAE:0.010528  RMSE:0.061282  CVRMSE:15.320432  R2:0.080366  MSE:0.003664  KL:1.045678  MAE:0.009395  RMSE:0.060530  CVRMSE:15.132381  train_loss:
0.003663823232636787
R2:0.070128  MSE:0.002781  KL:0.096133  MAE:0.006386  RMSE:0.052738  CVRMSE:17.579179  test_loss:
1.3906239058997016e-05
R2:0.082680  MSE:0.003655  KL:2.652028  MAE:0.009383  RMSE:0.060453  CVRMSE:15.113329  R2:0.029376  MSE:0.003867  KL:0.994849  MAE:0.010621  RMSE:0.062185  CVRMSE:15.546234  R2:0.039247  MSE:0.003828  KL:0.532001  MAE:0.010150  RMSE:0.061868  CVRMSE:15.466980  R2:0.062682  MSE:0.003734  KL:4.395259  MAE:0.009612  RMSE:0.061109  CVRMSE:15.277179  R2:0.053572  MSE:0.003771  KL:0.830901  MAE:0.009986  RMSE:0.06

R2:0.029658  MSE:0.003866  KL:1.035616  MAE:0.009821  RMSE:0.062176  CVRMSE:15.543972  R2:0.069941  MSE:0.003705  KL:0.925191  MAE:0.009339  RMSE:0.060872  CVRMSE:15.217908  R2:0.062031  MSE:0.003737  KL:1.410717  MAE:0.010015  RMSE:0.061130  CVRMSE:15.282481  R2:0.023629  MSE:0.003890  KL:1.402401  MAE:0.009667  RMSE:0.062369  CVRMSE:15.592188  R2:0.045321  MSE:0.003803  KL:1.015610  MAE:0.010197  RMSE:0.061672  CVRMSE:15.418012  train_loss:
0.003803441472700797
R2:0.061706  MSE:0.002806  KL:0.094678  MAE:0.006986  RMSE:0.052976  CVRMSE:17.658608  test_loss:
1.4032189137651584e-05
R2:0.063949  MSE:0.003729  KL:1.520266  MAE:0.009502  RMSE:0.061067  CVRMSE:15.266852  R2:0.069679  MSE:0.003706  KL:0.646205  MAE:0.009567  RMSE:0.060880  CVRMSE:15.220051  R2:0.048352  MSE:0.003791  KL:2.630955  MAE:0.008783  RMSE:0.061574  CVRMSE:15.393518  R2:0.041498  MSE:0.003819  KL:1.190138  MAE:0.009049  RMSE:0.061795  CVRMSE:15.448855  R2:0.057038  MSE:0.003757  KL:3.884857  MAE:0.010369  RMSE:0.06

R2:0.027006  MSE:0.003876  KL:1.949136  MAE:0.008879  RMSE:0.062261  CVRMSE:15.565205  R2:0.095669  MSE:0.003603  KL:1.079720  MAE:0.009295  RMSE:0.060024  CVRMSE:15.005948  R2:0.065305  MSE:0.003724  KL:3.681542  MAE:0.009946  RMSE:0.061023  CVRMSE:15.255787  R2:0.078100  MSE:0.003673  KL:3.827237  MAE:0.009634  RMSE:0.060604  CVRMSE:15.151011  R2:0.052164  MSE:0.003776  KL:2.423293  MAE:0.008770  RMSE:0.061451  CVRMSE:15.362656  train_loss:
0.003776179147826042
R2:0.068719  MSE:0.002785  KL:0.104625  MAE:0.007706  RMSE:0.052777  CVRMSE:17.592493  test_loss:
1.392731252417434e-05
R2:0.076958  MSE:0.003677  KL:1.543249  MAE:0.009019  RMSE:0.060642  CVRMSE:15.160388  R2:0.047735  MSE:0.003794  KL:0.623550  MAE:0.009305  RMSE:0.061594  CVRMSE:15.398504  R2:0.046324  MSE:0.003799  KL:2.591773  MAE:0.009290  RMSE:0.061640  CVRMSE:15.409909  R2:0.086317  MSE:0.003640  KL:0.733340  MAE:0.009170  RMSE:0.060333  CVRMSE:15.083334  R2:0.052754  MSE:0.003774  KL:4.016186  MAE:0.010541  RMSE:0.061

R2:0.034502  MSE:0.003847  KL:1.182963  MAE:0.010571  RMSE:0.062020  CVRMSE:15.505125  R2:0.014681  MSE:0.003926  KL:2.904872  MAE:0.009979  RMSE:0.062654  CVRMSE:15.663475  R2:0.046047  MSE:0.003801  KL:1.018340  MAE:0.009165  RMSE:0.061649  CVRMSE:15.412150  R2:0.068470  MSE:0.003711  KL:0.912199  MAE:0.008696  RMSE:0.060920  CVRMSE:15.229935  R2:0.092836  MSE:0.003614  KL:0.531964  MAE:0.009516  RMSE:0.060118  CVRMSE:15.029429  train_loss:
0.0036141402204521
R2:0.063500  MSE:0.002801  KL:0.093133  MAE:0.007328  RMSE:0.052925  CVRMSE:17.641715  test_loss:
1.4005353863467463e-05
R2:0.065551  MSE:0.003723  KL:1.029081  MAE:0.009444  RMSE:0.061015  CVRMSE:15.253779  R2:0.053231  MSE:0.003772  KL:0.557231  MAE:0.009654  RMSE:0.061416  CVRMSE:15.354005  R2:0.082160  MSE:0.003657  KL:1.490650  MAE:0.008977  RMSE:0.060470  CVRMSE:15.117609  R2:0.043573  MSE:0.003810  KL:2.057144  MAE:0.009603  RMSE:0.061728  CVRMSE:15.432122  R2:0.103343  MSE:0.003572  KL:0.958082  MAE:0.008634  RMSE:0.0597

R2:0.031833  MSE:0.003857  KL:0.775937  MAE:0.010280  RMSE:0.062106  CVRMSE:15.526541  R2:0.034129  MSE:0.003848  KL:0.998659  MAE:0.010270  RMSE:0.062032  CVRMSE:15.508122  R2:0.061404  MSE:0.003739  KL:0.816288  MAE:0.009532  RMSE:0.061150  CVRMSE:15.287593  R2:0.046274  MSE:0.003800  KL:3.565397  MAE:0.010191  RMSE:0.061641  CVRMSE:15.410310  R2:0.056341  MSE:0.003760  KL:1.007782  MAE:0.008997  RMSE:0.061315  CVRMSE:15.328770  train_loss:
0.0037595391404465772
R2:0.062164  MSE:0.002805  KL:0.099853  MAE:0.007934  RMSE:0.052963  CVRMSE:17.654294  test_loss:
1.4025334778125398e-05
R2:0.035983  MSE:0.003841  KL:0.757901  MAE:0.009834  RMSE:0.061973  CVRMSE:15.493232  R2:0.078274  MSE:0.003672  KL:2.073332  MAE:0.009469  RMSE:0.060598  CVRMSE:15.149577  R2:0.075630  MSE:0.003683  KL:4.008907  MAE:0.010068  RMSE:0.060685  CVRMSE:15.171294  R2:0.054310  MSE:0.003768  KL:0.973953  MAE:0.010400  RMSE:0.061381  CVRMSE:15.345254  R2:0.085911  MSE:0.003642  KL:2.648598  MAE:0.008706  RMSE:0.0

R2:0.064274  MSE:0.003728  KL:0.928981  MAE:0.009855  RMSE:0.061057  CVRMSE:15.264196  R2:0.076498  MSE:0.003679  KL:1.186211  MAE:0.009447  RMSE:0.060657  CVRMSE:15.164169  R2:0.055261  MSE:0.003764  KL:1.225326  MAE:0.009672  RMSE:0.061350  CVRMSE:15.337537  R2:0.044639  MSE:0.003806  KL:0.830429  MAE:0.009512  RMSE:0.061694  CVRMSE:15.423516  R2:0.043351  MSE:0.003811  KL:1.223984  MAE:0.009797  RMSE:0.061736  CVRMSE:15.433913  train_loss:
0.003811290602607187
R2:0.066850  MSE:0.002791  KL:0.087214  MAE:0.008373  RMSE:0.052830  CVRMSE:17.610136  test_loss:
1.3955261078081093e-05
R2:0.004810  MSE:0.003965  KL:0.851933  MAE:0.009704  RMSE:0.062967  CVRMSE:15.741740  R2:0.028396  MSE:0.003871  KL:0.861061  MAE:0.009088  RMSE:0.062216  CVRMSE:15.554078  R2:0.040575  MSE:0.003822  KL:0.824824  MAE:0.009130  RMSE:0.061825  CVRMSE:15.456284  R2:0.056150  MSE:0.003760  KL:1.306124  MAE:0.009284  RMSE:0.061321  CVRMSE:15.330318  R2:0.055562  MSE:0.003763  KL:1.188174  MAE:0.009551  RMSE:0.06

R2:0.063010  MSE:0.003733  KL:3.744343  MAE:0.009288  RMSE:0.061098  CVRMSE:15.274503  R2:0.043340  MSE:0.003811  KL:2.161951  MAE:0.009522  RMSE:0.061736  CVRMSE:15.433997  R2:0.078740  MSE:0.003670  KL:1.914874  MAE:0.008836  RMSE:0.060583  CVRMSE:15.145752  R2:0.060974  MSE:0.003741  KL:2.687342  MAE:0.009705  RMSE:0.061164  CVRMSE:15.291089  R2:0.067822  MSE:0.003714  KL:1.696858  MAE:0.009018  RMSE:0.060941  CVRMSE:15.235234  train_loss:
0.0037137980762054212
R2:0.064288  MSE:0.002799  KL:0.104482  MAE:0.009391  RMSE:0.052903  CVRMSE:17.634291  test_loss:
1.3993570246384478e-05
R2:0.076649  MSE:0.003679  KL:2.258666  MAE:0.010250  RMSE:0.060652  CVRMSE:15.162925  R2:0.054545  MSE:0.003767  KL:1.025996  MAE:0.009343  RMSE:0.061373  CVRMSE:15.343347  R2:0.058307  MSE:0.003752  KL:1.329397  MAE:0.009354  RMSE:0.061251  CVRMSE:15.312790  R2:0.119660  MSE:0.003507  KL:3.720418  MAE:0.009833  RMSE:0.059222  CVRMSE:14.805563  R2:0.029627  MSE:0.003866  KL:2.394256  MAE:0.009735  RMSE:0.0

R2:0.032463  MSE:0.003855  KL:3.908559  MAE:0.009875  RMSE:0.062086  CVRMSE:15.521490  R2:0.058164  MSE:0.003752  KL:0.650125  MAE:0.009621  RMSE:0.061256  CVRMSE:15.313952  R2:0.089381  MSE:0.003628  KL:0.340886  MAE:0.008658  RMSE:0.060232  CVRMSE:15.058023  R2:0.052129  MSE:0.003776  KL:3.702110  MAE:0.010203  RMSE:0.061452  CVRMSE:15.362938  R2:0.053097  MSE:0.003772  KL:0.404500  MAE:0.008759  RMSE:0.061420  CVRMSE:15.355089  train_loss:
0.003772459847095888
R2:0.064868  MSE:0.002797  KL:0.091851  MAE:0.006720  RMSE:0.052886  CVRMSE:17.628829  test_loss:
1.3984901597723365e-05
R2:0.071914  MSE:0.003697  KL:3.914706  MAE:0.008777  RMSE:0.060807  CVRMSE:15.201754  R2:0.036509  MSE:0.003839  KL:2.974262  MAE:0.009006  RMSE:0.061956  CVRMSE:15.489008  R2:0.027820  MSE:0.003873  KL:2.158834  MAE:0.008177  RMSE:0.062235  CVRMSE:15.558692  R2:0.050045  MSE:0.003785  KL:2.941955  MAE:0.008523  RMSE:0.061519  CVRMSE:15.379819  R2:0.060532  MSE:0.003743  KL:1.527281  MAE:0.008855  RMSE:0.06

R2:0.044018  MSE:0.003809  KL:1.277638  MAE:0.009886  RMSE:0.061714  CVRMSE:15.428527  R2:0.030763  MSE:0.003861  KL:2.602778  MAE:0.009929  RMSE:0.062140  CVRMSE:15.535121  R2:0.057713  MSE:0.003754  KL:1.525397  MAE:0.009208  RMSE:0.061270  CVRMSE:15.317620  R2:0.032690  MSE:0.003854  KL:1.020880  MAE:0.009947  RMSE:0.062079  CVRMSE:15.519673  R2:0.061812  MSE:0.003738  KL:0.871200  MAE:0.009717  RMSE:0.061137  CVRMSE:15.284270  train_loss:
0.0037377425760496408
R2:0.064809  MSE:0.002797  KL:0.084866  MAE:0.007737  RMSE:0.052888  CVRMSE:17.629382  test_loss:
1.398577988584293e-05
R2:0.035493  MSE:0.003843  KL:0.634977  MAE:0.010519  RMSE:0.061989  CVRMSE:15.497167  R2:0.075232  MSE:0.003684  KL:0.902850  MAE:0.009705  RMSE:0.060698  CVRMSE:15.174555  R2:0.059314  MSE:0.003748  KL:0.793772  MAE:0.009114  RMSE:0.061218  CVRMSE:15.304604  R2:0.079530  MSE:0.003667  KL:0.796577  MAE:0.009405  RMSE:0.060557  CVRMSE:15.139252  R2:0.024736  MSE:0.003885  KL:0.851340  MAE:0.009630  RMSE:0.06

R2:0.057524  MSE:0.003755  KL:1.610821  MAE:0.010838  RMSE:0.061277  CVRMSE:15.319155  R2:0.063494  MSE:0.003731  KL:0.791710  MAE:0.009172  RMSE:0.061082  CVRMSE:15.270557  R2:0.087660  MSE:0.003635  KL:0.878999  MAE:0.008731  RMSE:0.060289  CVRMSE:15.072249  R2:0.056994  MSE:0.003757  KL:3.661216  MAE:0.009849  RMSE:0.061294  CVRMSE:15.323463  R2:0.060007  MSE:0.003745  KL:1.512642  MAE:0.008701  RMSE:0.061196  CVRMSE:15.298965  train_loss:
0.003744932934932876
R2:0.069404  MSE:0.002783  KL:0.107678  MAE:0.007087  RMSE:0.052758  CVRMSE:17.586024  test_loss:
1.39170695329085e-05
R2:0.065333  MSE:0.003724  KL:0.671546  MAE:0.008713  RMSE:0.061022  CVRMSE:15.255557  R2:0.042924  MSE:0.003813  KL:1.900340  MAE:0.008937  RMSE:0.061749  CVRMSE:15.437358  R2:0.058661  MSE:0.003750  KL:1.925753  MAE:0.008191  RMSE:0.061240  CVRMSE:15.309912  R2:0.099368  MSE:0.003588  KL:0.518570  MAE:0.008445  RMSE:0.059901  CVRMSE:14.975222  R2:0.076640  MSE:0.003679  KL:2.740803  MAE:0.009241  RMSE:0.0606

R2:0.050021  MSE:0.003785  KL:1.081455  MAE:0.008895  RMSE:0.061520  CVRMSE:15.380013  R2:0.075069  MSE:0.003685  KL:0.542947  MAE:0.009259  RMSE:0.060704  CVRMSE:15.175898  R2:0.051035  MSE:0.003781  KL:0.666831  MAE:0.010037  RMSE:0.061487  CVRMSE:15.371800  R2:0.055696  MSE:0.003762  KL:0.958263  MAE:0.010162  RMSE:0.061336  CVRMSE:15.334006  R2:0.042984  MSE:0.003813  KL:0.815618  MAE:0.009939  RMSE:0.061747  CVRMSE:15.436872  train_loss:
0.0038127526539028624
R2:0.064783  MSE:0.002797  KL:0.107678  MAE:0.008809  RMSE:0.052889  CVRMSE:17.629625  test_loss:
1.398616496589966e-05
R2:0.073120  MSE:0.003693  KL:0.525403  MAE:0.009518  RMSE:0.060767  CVRMSE:15.191873  R2:0.019963  MSE:0.003904  KL:3.459775  MAE:0.009164  RMSE:0.062486  CVRMSE:15.621438  R2:0.089911  MSE:0.003626  KL:2.094700  MAE:0.008145  RMSE:0.060215  CVRMSE:15.053640  R2:0.049005  MSE:0.003789  KL:0.726266  MAE:0.009428  RMSE:0.061553  CVRMSE:15.388231  R2:0.029001  MSE:0.003868  KL:1.440374  MAE:0.008804  RMSE:0.06

R2:0.064263  MSE:0.003728  KL:0.743541  MAE:0.009432  RMSE:0.061057  CVRMSE:15.264293  R2:0.020223  MSE:0.003903  KL:0.597015  MAE:0.010107  RMSE:0.062477  CVRMSE:15.619365  R2:0.059336  MSE:0.003748  KL:0.820033  MAE:0.009304  RMSE:0.061218  CVRMSE:15.304419  R2:0.079954  MSE:0.003665  KL:4.857251  MAE:0.010815  RMSE:0.060543  CVRMSE:15.135766  R2:0.095384  MSE:0.003604  KL:3.512470  MAE:0.009535  RMSE:0.060033  CVRMSE:15.008306  train_loss:
0.0036039880200405606
R2:0.067033  MSE:0.002791  KL:0.104625  MAE:0.005740  RMSE:0.052825  CVRMSE:17.608405  test_loss:
1.3952516957942861e-05
R2:0.057578  MSE:0.003755  KL:2.315626  MAE:0.008379  RMSE:0.061275  CVRMSE:15.318714  R2:0.082977  MSE:0.003653  KL:2.192314  MAE:0.009442  RMSE:0.060444  CVRMSE:15.110877  R2:0.044458  MSE:0.003807  KL:2.205476  MAE:0.009195  RMSE:0.061700  CVRMSE:15.424979  R2:0.091177  MSE:0.003621  KL:1.721926  MAE:0.008733  RMSE:0.060173  CVRMSE:15.043167  R2:0.060529  MSE:0.003743  KL:1.422019  MAE:0.010222  RMSE:0.0

R2:0.064806  MSE:0.003726  KL:0.553049  MAE:0.010108  RMSE:0.061039  CVRMSE:15.259863  R2:0.049017  MSE:0.003789  KL:1.089180  MAE:0.009776  RMSE:0.061553  CVRMSE:15.388137  R2:0.036109  MSE:0.003840  KL:4.343658  MAE:0.010140  RMSE:0.061969  CVRMSE:15.492219  R2:0.064839  MSE:0.003726  KL:0.746884  MAE:0.009398  RMSE:0.061038  CVRMSE:15.259592  R2:0.048808  MSE:0.003790  KL:0.700708  MAE:0.009576  RMSE:0.061559  CVRMSE:15.389832  train_loss:
0.003789549946668558
R2:0.058781  MSE:0.002815  KL:0.093133  MAE:0.009793  RMSE:0.053058  CVRMSE:17.686107  test_loss:
1.4075927712838165e-05
R2:0.034694  MSE:0.003846  KL:1.084582  MAE:0.010242  RMSE:0.062014  CVRMSE:15.503583  R2:0.032270  MSE:0.003855  KL:2.335028  MAE:0.010082  RMSE:0.062092  CVRMSE:15.523037  R2:0.072862  MSE:0.003694  KL:3.706241  MAE:0.010473  RMSE:0.060776  CVRMSE:15.193988  R2:0.066055  MSE:0.003721  KL:3.434580  MAE:0.009376  RMSE:0.060999  CVRMSE:15.249666  R2:0.049153  MSE:0.003788  KL:0.974735  MAE:0.009014  RMSE:0.06

R2:0.097273  MSE:0.003596  KL:0.974448  MAE:0.008674  RMSE:0.059971  CVRMSE:14.992635  R2:0.032286  MSE:0.003855  KL:1.547697  MAE:0.009473  RMSE:0.062092  CVRMSE:15.522911  R2:0.029685  MSE:0.003866  KL:1.219859  MAE:0.008503  RMSE:0.062175  CVRMSE:15.543762  R2:0.042739  MSE:0.003814  KL:0.705162  MAE:0.009424  RMSE:0.061755  CVRMSE:15.438844  R2:0.030807  MSE:0.003861  KL:0.651774  MAE:0.009255  RMSE:0.062139  CVRMSE:15.534766  train_loss:
0.0038612629054114223
R2:0.063931  MSE:0.002800  KL:0.093133  MAE:0.006736  RMSE:0.052913  CVRMSE:17.637655  test_loss:
1.3998909351357724e-05
R2:0.063071  MSE:0.003733  KL:0.806371  MAE:0.009338  RMSE:0.061096  CVRMSE:15.274007  R2:0.044740  MSE:0.003806  KL:0.727914  MAE:0.010043  RMSE:0.061691  CVRMSE:15.422701  R2:0.044772  MSE:0.003806  KL:3.303666  MAE:0.009079  RMSE:0.061690  CVRMSE:15.422441  R2:0.062258  MSE:0.003736  KL:0.695954  MAE:0.008742  RMSE:0.061123  CVRMSE:15.280630  R2:0.034295  MSE:0.003847  KL:3.701260  MAE:0.009211  RMSE:0.0

R2:0.056072  MSE:0.003761  KL:3.337007  MAE:0.009039  RMSE:0.061324  CVRMSE:15.330952  R2:0.071902  MSE:0.003698  KL:1.460373  MAE:0.008926  RMSE:0.060807  CVRMSE:15.201856  R2:0.099521  MSE:0.003588  KL:4.077177  MAE:0.009922  RMSE:0.059896  CVRMSE:14.973954  R2:0.047311  MSE:0.003796  KL:0.465740  MAE:0.009234  RMSE:0.061608  CVRMSE:15.401936  R2:0.048542  MSE:0.003791  KL:0.926855  MAE:0.009023  RMSE:0.061568  CVRMSE:15.391982  train_loss:
0.0037906097451923413
R2:0.068907  MSE:0.002785  KL:0.086596  MAE:0.007576  RMSE:0.052772  CVRMSE:17.590715  test_loss:
1.3924495462561026e-05
R2:0.082504  MSE:0.003655  KL:1.239657  MAE:0.009397  RMSE:0.060459  CVRMSE:15.114777  R2:0.039392  MSE:0.003827  KL:0.816853  MAE:0.009851  RMSE:0.061863  CVRMSE:15.465816  R2:0.038119  MSE:0.003832  KL:0.646123  MAE:0.009365  RMSE:0.061904  CVRMSE:15.476056  R2:0.046046  MSE:0.003801  KL:1.151253  MAE:0.009506  RMSE:0.061649  CVRMSE:15.412155  R2:0.109617  MSE:0.003547  KL:0.884662  MAE:0.009301  RMSE:0.0

R2:0.042011  MSE:0.003817  KL:3.586949  MAE:0.009549  RMSE:0.061779  CVRMSE:15.444715  R2:0.084000  MSE:0.003649  KL:0.387979  MAE:0.008467  RMSE:0.060410  CVRMSE:15.102451  R2:0.070463  MSE:0.003703  KL:3.796819  MAE:0.008143  RMSE:0.060855  CVRMSE:15.213634  R2:0.039458  MSE:0.003827  KL:0.208642  MAE:0.009479  RMSE:0.061861  CVRMSE:15.465285  R2:0.092687  MSE:0.003615  KL:0.905351  MAE:0.008754  RMSE:0.060123  CVRMSE:15.030664  train_loss:
0.0036147334350971503
R2:0.065061  MSE:0.002796  KL:0.111006  MAE:0.008362  RMSE:0.052881  CVRMSE:17.627009  test_loss:
1.3982016564114019e-05
R2:0.108004  MSE:0.003554  KL:2.786633  MAE:0.009221  RMSE:0.059613  CVRMSE:14.903258  R2:0.083208  MSE:0.003653  KL:3.024462  MAE:0.009207  RMSE:0.060436  CVRMSE:15.108981  R2:0.096798  MSE:0.003598  KL:1.728555  MAE:0.008382  RMSE:0.059986  CVRMSE:14.996573  R2:0.066732  MSE:0.003718  KL:1.039256  MAE:0.009527  RMSE:0.060977  CVRMSE:15.244141  R2:0.024247  MSE:0.003887  KL:0.646127  MAE:0.010425  RMSE:0.0

R2:0.058471  MSE:0.003751  KL:3.288569  MAE:0.010437  RMSE:0.061246  CVRMSE:15.311453  R2:0.066325  MSE:0.003720  KL:0.370866  MAE:0.009604  RMSE:0.060990  CVRMSE:15.247457  R2:0.055450  MSE:0.003763  KL:3.660316  MAE:0.009104  RMSE:0.061344  CVRMSE:15.336004  R2:0.107224  MSE:0.003557  KL:0.396615  MAE:0.009554  RMSE:0.059639  CVRMSE:14.909771  R2:0.104484  MSE:0.003568  KL:0.670744  MAE:0.009609  RMSE:0.059731  CVRMSE:14.932631  train_loss:
0.0035677358304383234
R2:0.064925  MSE:0.002797  KL:0.084866  MAE:0.009621  RMSE:0.052885  CVRMSE:17.628292  test_loss:
1.3984049393911846e-05
R2:0.089570  MSE:0.003627  KL:0.985053  MAE:0.010052  RMSE:0.060226  CVRMSE:15.056464  R2:0.099172  MSE:0.003589  KL:0.902794  MAE:0.010087  RMSE:0.059907  CVRMSE:14.976855  R2:0.039580  MSE:0.003826  KL:2.570146  MAE:0.010135  RMSE:0.061857  CVRMSE:15.464302  R2:0.050302  MSE:0.003784  KL:0.842579  MAE:0.010217  RMSE:0.061511  CVRMSE:15.377733  R2:0.044814  MSE:0.003805  KL:1.152912  MAE:0.009233  RMSE:0.0

R2:0.063791  MSE:0.003730  KL:4.293331  MAE:0.008893  RMSE:0.061073  CVRMSE:15.268135  R2:0.057715  MSE:0.003754  KL:0.661197  MAE:0.008579  RMSE:0.061270  CVRMSE:15.317599  R2:0.037009  MSE:0.003837  KL:2.408324  MAE:0.008687  RMSE:0.061940  CVRMSE:15.484984  R2:0.056148  MSE:0.003760  KL:1.146096  MAE:0.008965  RMSE:0.061321  CVRMSE:15.330331  R2:0.053890  MSE:0.003769  KL:0.824632  MAE:0.008294  RMSE:0.061395  CVRMSE:15.348661  train_loss:
0.0037693022997700608
R2:0.064261  MSE:0.002799  KL:0.094555  MAE:0.008005  RMSE:0.052904  CVRMSE:17.634548  test_loss:
1.3993977510835976e-05
R2:0.058762  MSE:0.003750  KL:0.137671  MAE:0.009163  RMSE:0.061236  CVRMSE:15.309094  R2:0.048986  MSE:0.003789  KL:2.622394  MAE:0.008701  RMSE:0.061554  CVRMSE:15.388389  R2:0.055008  MSE:0.003765  KL:3.558079  MAE:0.008815  RMSE:0.061358  CVRMSE:15.339588  R2:0.042273  MSE:0.003816  KL:0.462296  MAE:0.008969  RMSE:0.061770  CVRMSE:15.442605  R2:0.059467  MSE:0.003747  KL:0.526279  MAE:0.008925  RMSE:0.0

R2:0.074990  MSE:0.003685  KL:1.189576  MAE:0.009374  RMSE:0.060706  CVRMSE:15.176547  R2:0.040698  MSE:0.003822  KL:0.727105  MAE:0.009718  RMSE:0.061821  CVRMSE:15.455300  R2:0.045169  MSE:0.003804  KL:0.531220  MAE:0.009264  RMSE:0.061677  CVRMSE:15.419238  R2:0.052713  MSE:0.003774  KL:0.510633  MAE:0.009756  RMSE:0.061433  CVRMSE:15.358206  R2:0.101574  MSE:0.003579  KL:0.542121  MAE:0.009164  RMSE:0.059828  CVRMSE:14.956877  train_loss:
0.0035793308561551385
R2:0.062504  MSE:0.002804  KL:0.097889  MAE:0.009320  RMSE:0.052953  CVRMSE:17.651096  test_loss:
1.4020253816852347e-05
R2:0.061018  MSE:0.003741  KL:0.590272  MAE:0.010193  RMSE:0.061163  CVRMSE:15.290729  R2:0.035768  MSE:0.003841  KL:1.310440  MAE:0.009292  RMSE:0.061980  CVRMSE:15.494958  R2:0.043973  MSE:0.003809  KL:0.523136  MAE:0.009129  RMSE:0.061716  CVRMSE:15.428894  R2:0.061432  MSE:0.003739  KL:3.490346  MAE:0.009536  RMSE:0.061149  CVRMSE:15.287362  R2:0.082763  MSE:0.003654  KL:0.732036  MAE:0.009445  RMSE:0.0

R2:0.050390  MSE:0.003783  KL:0.321444  MAE:0.009212  RMSE:0.061508  CVRMSE:15.377027  R2:0.043483  MSE:0.003811  KL:0.971786  MAE:0.008819  RMSE:0.061731  CVRMSE:15.432848  R2:0.073428  MSE:0.003691  KL:0.432930  MAE:0.009049  RMSE:0.060757  CVRMSE:15.189356  R2:0.034520  MSE:0.003846  KL:0.388746  MAE:0.009647  RMSE:0.062020  CVRMSE:15.504984  R2:0.072667  MSE:0.003694  KL:0.320571  MAE:0.008678  RMSE:0.060782  CVRMSE:15.195590  train_loss:
0.003694495327363256
R2:0.065412  MSE:0.002795  KL:0.090818  MAE:0.006922  RMSE:0.052871  CVRMSE:17.623701  test_loss:
1.3976767339045182e-05
R2:0.088743  MSE:0.003630  KL:4.652061  MAE:0.009056  RMSE:0.060253  CVRMSE:15.063297  R2:0.072198  MSE:0.003696  KL:2.947611  MAE:0.009326  RMSE:0.060798  CVRMSE:15.199435  R2:0.079772  MSE:0.003666  KL:0.245401  MAE:0.008662  RMSE:0.060549  CVRMSE:15.137261  R2:0.062466  MSE:0.003735  KL:4.305034  MAE:0.009475  RMSE:0.061116  CVRMSE:15.278937  R2:0.023162  MSE:0.003892  KL:0.526218  MAE:0.009618  RMSE:0.06

R2:0.067238  MSE:0.003716  KL:1.572669  MAE:0.008208  RMSE:0.060960  CVRMSE:15.240002  R2:0.052163  MSE:0.003776  KL:2.006263  MAE:0.008589  RMSE:0.061451  CVRMSE:15.362666  R2:0.104842  MSE:0.003566  KL:1.030424  MAE:0.008647  RMSE:0.059719  CVRMSE:14.929649  R2:0.112284  MSE:0.003537  KL:0.864460  MAE:0.008404  RMSE:0.059470  CVRMSE:14.867457  R2:0.076084  MSE:0.003681  KL:2.981603  MAE:0.009329  RMSE:0.060670  CVRMSE:15.167566  train_loss:
0.00368088114191778
R2:0.069681  MSE:0.002783  KL:0.107678  MAE:0.006660  RMSE:0.052750  CVRMSE:17.583401  test_loss:
1.391291910840664e-05
R2:0.106696  MSE:0.003559  KL:0.801660  MAE:0.008396  RMSE:0.059657  CVRMSE:14.914179  R2:0.033313  MSE:0.003851  KL:1.174768  MAE:0.009222  RMSE:0.062059  CVRMSE:15.514670  R2:0.057189  MSE:0.003756  KL:1.614929  MAE:0.009235  RMSE:0.061288  CVRMSE:15.321881  R2:0.038853  MSE:0.003829  KL:1.235683  MAE:0.009666  RMSE:0.061881  CVRMSE:15.470153  R2:0.027055  MSE:0.003876  KL:0.582865  MAE:0.009249  RMSE:0.0622

R2:0.038342  MSE:0.003831  KL:2.497895  MAE:0.009056  RMSE:0.061897  CVRMSE:15.474266  R2:0.063754  MSE:0.003730  KL:1.083698  MAE:0.009280  RMSE:0.061074  CVRMSE:15.268437  R2:0.058796  MSE:0.003750  KL:1.383202  MAE:0.008510  RMSE:0.061235  CVRMSE:15.308815  R2:0.094088  MSE:0.003609  KL:1.411249  MAE:0.008822  RMSE:0.060076  CVRMSE:15.019060  R2:0.106125  MSE:0.003561  KL:0.747395  MAE:0.009751  RMSE:0.059676  CVRMSE:14.918939  train_loss:
0.0035611961560789497
R2:0.064982  MSE:0.002797  KL:0.093133  MAE:0.009290  RMSE:0.052883  CVRMSE:17.627757  test_loss:
1.3983201337396167e-05
R2:0.066828  MSE:0.003718  KL:0.645140  MAE:0.010215  RMSE:0.060973  CVRMSE:15.243356  R2:0.125329  MSE:0.003485  KL:0.561853  MAE:0.009151  RMSE:0.059031  CVRMSE:14.757813  R2:0.082319  MSE:0.003656  KL:0.688139  MAE:0.009838  RMSE:0.060465  CVRMSE:15.116302  R2:0.053757  MSE:0.003770  KL:1.584186  MAE:0.009886  RMSE:0.061399  CVRMSE:15.349743  R2:0.101122  MSE:0.003581  KL:1.410532  MAE:0.009188  RMSE:0.0

R2:0.121402  MSE:0.003500  KL:1.379493  MAE:0.008308  RMSE:0.059164  CVRMSE:14.790909  R2:0.107329  MSE:0.003556  KL:0.320819  MAE:0.009181  RMSE:0.059636  CVRMSE:14.908890  R2:-0.011449  MSE:0.004030  KL:0.644449  MAE:0.010095  RMSE:0.063479  CVRMSE:15.869808  R2:0.067867  MSE:0.003714  KL:0.313115  MAE:0.009312  RMSE:0.060939  CVRMSE:15.234865  R2:0.090566  MSE:0.003623  KL:4.443797  MAE:0.010470  RMSE:0.060193  CVRMSE:15.048225  train_loss:
0.003623185437754728
R2:0.065100  MSE:0.002796  KL:0.085546  MAE:0.009078  RMSE:0.052880  CVRMSE:17.626638  test_loss:
1.3981425625388511e-05
R2:0.077449  MSE:0.003675  KL:0.590933  MAE:0.009072  RMSE:0.060625  CVRMSE:15.156354  R2:0.039980  MSE:0.003825  KL:2.575139  MAE:0.008535  RMSE:0.061844  CVRMSE:15.461077  R2:0.114567  MSE:0.003528  KL:0.193281  MAE:0.008291  RMSE:0.059393  CVRMSE:14.848330  R2:0.059333  MSE:0.003748  KL:3.624522  MAE:0.009537  RMSE:0.061218  CVRMSE:15.304446  R2:0.070041  MSE:0.003705  KL:1.557528  MAE:0.008991  RMSE:0.0

R2:0.071042  MSE:0.003701  KL:0.643650  MAE:0.009270  RMSE:0.060836  CVRMSE:15.208896  R2:0.043836  MSE:0.003809  KL:0.287170  MAE:0.008692  RMSE:0.061720  CVRMSE:15.430000  R2:0.045612  MSE:0.003802  KL:0.951391  MAE:0.009057  RMSE:0.061663  CVRMSE:15.415660  R2:0.050990  MSE:0.003781  KL:0.413661  MAE:0.008841  RMSE:0.061489  CVRMSE:15.372167  R2:0.104680  MSE:0.003567  KL:0.461448  MAE:0.009608  RMSE:0.059724  CVRMSE:14.931000  train_loss:
0.0035669565026182682
R2:0.061072  MSE:0.002808  KL:0.107678  MAE:0.010058  RMSE:0.052994  CVRMSE:17.664574  test_loss:
1.4041672882740385e-05
R2:0.088445  MSE:0.003632  KL:0.295389  MAE:0.009955  RMSE:0.060263  CVRMSE:15.065762  R2:0.070647  MSE:0.003703  KL:0.324241  MAE:0.008784  RMSE:0.060849  CVRMSE:15.212130  R2:0.035858  MSE:0.003841  KL:0.422809  MAE:0.009323  RMSE:0.061977  CVRMSE:15.494233  R2:0.025035  MSE:0.003884  KL:2.177576  MAE:0.008466  RMSE:0.062324  CVRMSE:15.580958  R2:0.109694  MSE:0.003547  KL:0.442515  MAE:0.008247  RMSE:0.0

R2:0.066790  MSE:0.003718  KL:5.197562  MAE:0.010228  RMSE:0.060975  CVRMSE:15.243660  R2:0.051152  MSE:0.003780  KL:0.821095  MAE:0.008477  RMSE:0.061483  CVRMSE:15.370855  R2:0.054480  MSE:0.003767  KL:3.440649  MAE:0.008449  RMSE:0.061376  CVRMSE:15.343875  R2:0.064914  MSE:0.003725  KL:0.486185  MAE:0.008968  RMSE:0.061036  CVRMSE:15.258975  R2:0.050743  MSE:0.003782  KL:0.953128  MAE:0.008709  RMSE:0.061497  CVRMSE:15.374163  train_loss:
0.003781838115537539
R2:0.066428  MSE:0.002792  KL:0.097889  MAE:0.006343  RMSE:0.052842  CVRMSE:17.614113  test_loss:
1.3961564174678643e-05
R2:0.045806  MSE:0.003802  KL:0.371011  MAE:0.008991  RMSE:0.061656  CVRMSE:15.414095  R2:0.071018  MSE:0.003701  KL:3.530188  MAE:0.009095  RMSE:0.060836  CVRMSE:15.209090  R2:0.057942  MSE:0.003753  KL:0.323252  MAE:0.009389  RMSE:0.061263  CVRMSE:15.315761  R2:0.063674  MSE:0.003730  KL:1.173493  MAE:0.008879  RMSE:0.061076  CVRMSE:15.269088  R2:0.060966  MSE:0.003741  KL:0.357154  MAE:0.009796  RMSE:0.06

R2:0.092128  MSE:0.003617  KL:0.373534  MAE:0.009382  RMSE:0.060141  CVRMSE:15.035295  R2:0.076413  MSE:0.003680  KL:2.905778  MAE:0.009445  RMSE:0.060659  CVRMSE:15.164863  R2:0.055442  MSE:0.003763  KL:0.677907  MAE:0.009352  RMSE:0.061344  CVRMSE:15.336065  R2:0.058172  MSE:0.003752  KL:0.963593  MAE:0.008686  RMSE:0.061256  CVRMSE:15.313889  R2:0.048891  MSE:0.003789  KL:0.575346  MAE:0.009367  RMSE:0.061557  CVRMSE:15.389153  train_loss:
0.003789216351287905
R2:0.067548  MSE:0.002789  KL:0.090699  MAE:0.007885  RMSE:0.052811  CVRMSE:17.603552  test_loss:
1.3944826627266593e-05
R2:0.045290  MSE:0.003804  KL:0.783272  MAE:0.008843  RMSE:0.061673  CVRMSE:15.418258  R2:0.083115  MSE:0.003653  KL:1.415706  MAE:0.007949  RMSE:0.060439  CVRMSE:15.109743  R2:0.050998  MSE:0.003781  KL:1.692590  MAE:0.008347  RMSE:0.061488  CVRMSE:15.372098  R2:0.029017  MSE:0.003868  KL:2.731953  MAE:0.008626  RMSE:0.062196  CVRMSE:15.549107  R2:0.051978  MSE:0.003777  KL:1.493638  MAE:0.008679  RMSE:0.06

R2:0.064275  MSE:0.003728  KL:3.723473  MAE:0.009685  RMSE:0.061057  CVRMSE:15.264194  R2:0.073252  MSE:0.003692  KL:2.702214  MAE:0.009207  RMSE:0.060763  CVRMSE:15.190792  R2:0.065320  MSE:0.003724  KL:0.752980  MAE:0.009071  RMSE:0.061023  CVRMSE:15.255666  R2:0.101009  MSE:0.003582  KL:1.283122  MAE:0.008819  RMSE:0.059846  CVRMSE:14.961580  R2:0.041530  MSE:0.003819  KL:3.572134  MAE:0.009917  RMSE:0.061794  CVRMSE:15.448594  train_loss:
0.0038185446901479736
R2:0.070825  MSE:0.002779  KL:0.087214  MAE:0.006225  RMSE:0.052718  CVRMSE:17.572590  test_loss:
1.3895817115553654e-05
R2:0.043856  MSE:0.003809  KL:0.628827  MAE:0.009361  RMSE:0.061719  CVRMSE:15.429838  R2:0.114098  MSE:0.003529  KL:2.659083  MAE:0.008843  RMSE:0.059409  CVRMSE:14.852260  R2:0.117797  MSE:0.003515  KL:1.122354  MAE:0.009723  RMSE:0.059285  CVRMSE:14.821221  R2:0.040036  MSE:0.003824  KL:2.206689  MAE:0.008856  RMSE:0.061843  CVRMSE:15.460628  R2:0.047042  MSE:0.003797  KL:2.004125  MAE:0.008642  RMSE:0.0

In [None]:
# estimate area under the uplift curve (AUUC)
uplift=df_test.copy()
uplift = uplift.loc[:,~uplift.columns.duplicated()]

auuc=auuc_score(uplift, outcome_col='y', treatment_col='T', treatment_effect_col='tau')
gat_auuc=pd.DataFrame(auuc[["y_hat","Random"]],columns=['auuc'])
gat_auuc

Unnamed: 0,auuc
y_hat,0.590529
Random,0.489027


In [None]:
print('Feature mode on GAT without BN connectivity:', feats_mode)
print('MSE:',MSE(uplift['y'],uplift['y_hat']))

Feature mode on GAT without BN connectivity: causal+imp
MSE: 0.002852495611763611


In [None]:
if feats_mode == 'causal':
    result.loc['GAT (Causal Weighting)','AUUC']=auuc["y_hat"]
if feats_mode == 'imp':
    result.loc['GAT (Important Weighting)','AUUC']=auuc["y_hat"]
if feats_mode == 'equal':
    result.loc['GAT (Feature)','AUUC']=auuc["y_hat"]
if feats_mode == 'noweighting':
    result.loc['GAT','AUUC']=auuc["y_hat"]

if feats_mode == 'causal+imp':
    result.loc['GAT (Causal+Important Weighting)','AUUC']=auuc["y_hat"]
if feats_mode == 'causal*imp':
    result.loc['GAT (Causal*Important) Weighting','AUUC']=auuc["y_hat"]
result

Unnamed: 0,AUUC
S Learner(LR),0.497983
S Learner(XGB),0.875572
S Learner(LGBM),0.883033
GCN (Struct),0.501865
GCN (Struct+Feature),0.721959
GCN (Struct+Causal Weighting),0.732616
GAT (Struct),0.544286
GAT (Struct+Feature),0.84763
GAT (Struct+Causal Weighting),0.8807
GCN (Struct+Causal+Important Weighting),0.777582
