In [1]:
!pip install wandb

Collecting wandb
  Using cached wandb-0.12.5-py2.py3-none-any.whl (1.7 MB)
  Downloading wandb-0.12.4-py2.py3-none-any.whl (1.7 MB)
[K     |████████████████████████████████| 1.7 MB 1.3 MB/s eta 0:00:01
You should consider upgrading via the '/home/user/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m


In [2]:
import pandas as pd
import numpy as np
from tqdm import tqdm

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable

from sklearn.decomposition import PCA

from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_percentage_error

import wandb 

In [3]:
graph = pd.read_csv("graph_abakan_F_encoded.csv", index_col = 0)
X = pd.read_csv("abakan_full_routes_final_weather_L_NaN_filtered_FIXED.csv", index_col = 0).dropna().drop(["real_dist", "speed", "time", "length", "route_type","start_timestamp"], axis = 1).reset_index(drop=True)

In [4]:
X = X[(X["rebuildCount"] <= 1) & (X["RTA"] > 30) & (X["RTA"] < 3000)]
X = X.drop(["rebuildCount"], axis = 1)

In [6]:
X.reset_index(drop = True, inplace = True)

In [8]:
route_edges = np.array(X["edges"].map(lambda x: x.replace("'", "")))

In [9]:
wind_dir_classes = pd.get_dummies(X.wind_dir, prefix='wind_dir_class')
day_classes = pd.get_dummies(X.day_period, prefix='day_class')
Y = X["RTA"]
X = X.join(wind_dir_classes).join(day_classes).drop(["edges", "wind_dir", "day_period", "RTA"], axis = 1)

In [10]:
dense = X.iloc[:, 0:7].join(X[["temperature", "pressure", "wind_speed", "clouds", "snow"]])
sparse = X.iloc[:, 7::].drop(["temperature", "pressure", "wind_speed", "clouds", "snow"], axis =1)

In [11]:
pca = PCA(n_components=5)
sparse_compressed = pca.fit_transform(sparse)
deep_input = np.array(dense.join(pd.DataFrame(sparse_compressed))).astype("float32")

In [12]:
joined = dense.join(sparse)
features = list(joined.columns)
cross_product_features = []

for i in range(len(features)):
    for j in range(i+1, len(features)):
            cross_product_features.append([features[i], features[j]])

for i in range(len(cross_product_features)):
    joined[cross_product_features[i][0] + "_AND_" + cross_product_features[i][1]] = \
    joined[cross_product_features[i][0]]*joined[cross_product_features[i][1]]
    
wide_input = np.array(joined).astype("float32")

  # This is added back by InteractiveShellApp.init_path()


In [13]:
graph_features = graph.drop(["edge_id", "adjacent"], axis = 1)
embs = np.array(graph_features)

In [14]:
embs_dict = {}
for i in tqdm(range(len(route_edges))):
    arr = route_edges[i].split(",")
    route_embeddings = []
    for j in range(len(arr)):
        ind = graph[graph["edge_id"] == int(arr[j])].index
        if (len(ind) != 0):
            ind = ind[0]
        else:
#                print(str(i) + " " + str(j) + " " + arr[j] + " "+ str(graph[graph["edge_id"] == int(arr[j])].index[0]))
            print("index not found")
        route_embeddings.append(embs[ind])
    embs_dict[i] = route_embeddings

100%|██████████| 82202/82202 [12:06<00:00, 113.10it/s]


In [15]:
lens = set([len(embs_dict[i]) for i in range(len(embs_dict))])
batch_dict = {}
for i in tqdm(range(len(embs_dict))):
    if (len(embs_dict[i])) in batch_dict:
        batch_dict[len(embs_dict[i])].append([embs_dict[i], Y[i], i])
    else:
        batch_dict[len(embs_dict[i])] = []
        batch_dict[len(embs_dict[i])].append([embs_dict[i], Y[i], i])

100%|██████████| 82202/82202 [00:00<00:00, 150135.57it/s]


In [16]:
batch_list = list(batch_dict.values())

In [16]:
print(np.array(wide_input).shape, np.array(deep_input).shape)

(82202, 325) (82202, 17)


In [110]:
import wandb
wandb.init(project='ETA_second_stage', entity='eighonet')
wandb.run.name = "wdr_begin_lstm_full_deep_wide_final_filtered_SAMPLED"#"gs_3_ffd_3_128_test_MAEloss_lr_10^{-4}_10000"
wandb.run.save()



CondaEnvException: Unable to determine environment

Please re-run this command with one of the following options:

* Provide an environment name via --name or -n
* Re-run this command inside an activated conda environment.





True

In [101]:
class WDR(nn.Module):
    def __init__(self, 
                 recurrent_input_size, wide_input_size, deep_input_size, 
                 lstm_hidden_size, lstm_num_layers, 
                 device):
        super(WDR, self).__init__()

        self.device = device
        
        # Recurrent part
        self.num_layers = lstm_num_layers
        self.hidden_size = lstm_hidden_size
        
        self.linear_preprocess_lstm = nn.Linear(recurrent_input_size, 256)
        self.lstm = nn.LSTM(input_size=256, hidden_size=lstm_hidden_size,
                    num_layers=lstm_num_layers, batch_first=True)
        
        # Deep part
        self.linear_1 = nn.Linear(deep_input_size, 256)
        self.linear_2 = nn.Linear(256, 256)
        self.linear_2_1 = nn.Linear(256, 256)
        
        # Wide part
        self.linear_3 = nn.Linear(wide_input_size, 256)
        
        # Final regression
        self.linear_4 = nn.Linear(512 + 256, 256)
        self.linear_5 = nn.Linear(256, 1)
        
        # Test layers
        self.linear_test = nn.Linear(256, 1)
        self.branch_test = nn.Linear(deep_input_size, 256)
        
        
    def forward(self, x, deep_input, wide_input, indices):
        def merge_deep_branch(h_out, indices):
            deep_output = Variable(torch.Tensor(deep_input[indices])).to(self.device)
            deep_output = F.relu(self.linear_1(deep_output))
            deep_output = F.relu(self.linear_2(deep_output))
            deep_output = F.relu(self.linear_2_1(deep_output))
            return torch.cat((h_out, deep_output), 1)
        
        def merge_wide_branch(final_input, indices):
            wide_output = Variable(torch.Tensor(wide_input[indices])).to(self.device)
            wide_output = F.relu(self.linear_3(wide_output))
            return torch.cat((final_input, wide_output), 1)
        
#        print(x.shape)
        preprocessed_lstm_data = F.relu(self.linear_preprocess_lstm(x))
        
        h_0 = Variable(torch.zeros(
            self.num_layers, x.size(0), self.hidden_size)).to(self.device)
        
        c_0 = Variable(torch.zeros(
            self.num_layers, x.size(0), self.hidden_size)).to(self.device)

        ula, (h_out, _) = self.lstm(preprocessed_lstm_data, (h_0, c_0))        
        h_out = h_out.view(-1, self.hidden_size)
        
        final_input = merge_deep_branch(h_out, indices)
        final_input = merge_wide_branch(final_input, indices)
        
        linear_out = F.relu(self.linear_4(final_input))
        linear_out = F.relu(self.linear_5(linear_out))
        return linear_out

In [106]:
num_epochs = 200
learning_rate = 0.000001

device = ""
gpu_ids = []
if torch.cuda.is_available():
    gpu_ids += [gpu_id for gpu_id in range(torch.cuda.device_count())]
    device = torch.device(f'cuda:{gpu_ids[0]}')
    torch.cuda.set_device(device)
else:
    device = torch.device('cpu')
print(device)
device = "cuda"

# Recurrent parameters
recurrent_input_size = np.array(batch_list[0][0][0]).shape[-1]
hidden_size = 256
num_layers = 1

# Deep parameters
deep_input_size = deep_input.shape[1]

# Wide parameters
wide_input_size = wide_input.shape[1]

wdr7 = WDR(recurrent_input_size, wide_input_size, deep_input_size,
          hidden_size, num_layers, 
          device)
wdr7 = wdr7.to(device)

criterion = torch.nn.L1Loss()
optimizer = torch.optim.Adam(wdr7.parameters(), lr=learning_rate)

cuda:0


In [107]:
ratio = 0.8
train_mask = {i for i in range(int(ratio*82200), 120000)} 
test_mask = {i for i in range(int(ratio*82200))}

In [113]:
wdr7.train()
for epoch in range(num_epochs):
    losses = []
    overall_loss, loss_test = 0, 0
    
    for i in tqdm(range(len(batch_list))):
        indices = np.array([batch_list[i][j][2] for j in range(len(batch_list[i]))])
        indices = np.array(list(set(indices) - train_mask))
        if len(indices) == 0:
            continue
        trainX = Variable(torch.Tensor(np.array([batch_list[i][j][0] for j in range(len(batch_list[i])) if batch_list[i][j][2] in indices])), requires_grad=True).to(device)
        trainY = Variable(torch.Tensor(np.array([batch_list[i][j][1] for j in range(len(batch_list[i])) if batch_list[i][j][2] in indices])), requires_grad=True).to(device)
        outputs = wdr7(x = trainX, deep_input = deep_input, wide_input = wide_input, indices = indices)
        optimizer.zero_grad()
#        print(torch.autograd.grad(outputs=outputs[0], inputs=wo, retain_graph=True)[0][0])
        # obtain the loss function
        loss = criterion(outputs, trainY.unsqueeze(-1))
        loss.backward(retain_graph=True)
        optimizer.step()
        losses.append(loss.item()/trainX.shape[0])
        overall_loss += loss.item()
    
    for i in tqdm(range(len(batch_list))):
        indices = np.array([batch_list[i][j][2] for j in range(len(batch_list[i]))])
        indices = np.array(list(set(indices) - test_mask))
        if len(indices) == 0:
            continue
        testX = Variable(torch.Tensor(np.array([batch_list[i][j][0] for j in range(len(batch_list[i])) if batch_list[i][j][2] in indices])), requires_grad=True).to(device)
        testY = Variable(torch.Tensor(np.array([batch_list[i][j][1] for j in range(len(batch_list[i])) if batch_list[i][j][2] in indices])), requires_grad=True).to(device)
        outputs = wdr7(x = testX, deep_input = deep_input, wide_input = wide_input, indices = indices)
        loss = criterion(outputs, testY.unsqueeze(-1))
        loss_test += loss.item()
        
    print(epoch, {"train_MAE": overall_loss/len(batch_list)}, {"test_MAE": loss_test/len(batch_list)})
    wandb.log({"train_MAE":  overall_loss/len(batch_list),
              "test_MAE": loss_test/len(batch_list)})

100%|██████████| 225/225 [00:25<00:00,  8.72it/s]
100%|██████████| 225/225 [00:13<00:00, 17.02it/s]


0 {'train_MAE': 489.5816154649523} {'test_MAE': 273.5765378994412}


100%|██████████| 225/225 [00:25<00:00,  8.93it/s]
100%|██████████| 225/225 [00:13<00:00, 16.42it/s] 


1 {'train_MAE': 328.7494549899631} {'test_MAE': 369.3745949130588}


100%|██████████| 225/225 [00:25<00:00,  8.93it/s]
100%|██████████| 225/225 [00:14<00:00, 15.19it/s] 


2 {'train_MAE': 461.64828269110785} {'test_MAE': 371.2703334384494}


100%|██████████| 225/225 [00:25<00:00,  8.97it/s]
100%|██████████| 225/225 [00:14<00:00, 15.94it/s] 


3 {'train_MAE': 511.4765318637424} {'test_MAE': 236.90262736002603}


100%|██████████| 225/225 [00:25<00:00,  8.71it/s]
100%|██████████| 225/225 [00:14<00:00, 15.26it/s] 


4 {'train_MAE': 340.24128941853843} {'test_MAE': 228.59674002753363}


100%|██████████| 225/225 [00:25<00:00,  8.92it/s]
100%|██████████| 225/225 [00:14<00:00, 15.86it/s]


5 {'train_MAE': 353.98450037638344} {'test_MAE': 263.56873708089194}


100%|██████████| 225/225 [00:25<00:00,  8.97it/s]
100%|██████████| 225/225 [00:13<00:00, 16.86it/s]


6 {'train_MAE': 361.34044445461694} {'test_MAE': 283.07451651679145}


100%|██████████| 225/225 [00:25<00:00,  8.99it/s]
100%|██████████| 225/225 [00:13<00:00, 16.51it/s] 


7 {'train_MAE': 335.2665064154731} {'test_MAE': 299.43805608113604}


100%|██████████| 225/225 [00:24<00:00,  9.12it/s]
100%|██████████| 225/225 [00:14<00:00, 15.83it/s] 


8 {'train_MAE': 340.5997260708279} {'test_MAE': 251.78134318033855}


100%|██████████| 225/225 [00:25<00:00,  8.69it/s]
100%|██████████| 225/225 [00:14<00:00, 15.93it/s] 


9 {'train_MAE': 328.959036187066} {'test_MAE': 324.33760209825306}


100%|██████████| 225/225 [00:25<00:00,  8.83it/s]
100%|██████████| 225/225 [00:14<00:00, 15.94it/s] 


10 {'train_MAE': 353.914564581977} {'test_MAE': 232.81061986287435}


100%|██████████| 225/225 [00:25<00:00,  8.84it/s]
100%|██████████| 225/225 [00:13<00:00, 16.26it/s]


11 {'train_MAE': 294.45272618611654} {'test_MAE': 219.3668361409505}


100%|██████████| 225/225 [00:25<00:00,  8.88it/s]
100%|██████████| 225/225 [00:13<00:00, 16.58it/s] 


12 {'train_MAE': 304.93553007337783} {'test_MAE': 228.75451799180772}


100%|██████████| 225/225 [00:26<00:00,  8.60it/s]
100%|██████████| 225/225 [00:14<00:00, 15.70it/s] 


13 {'train_MAE': 307.51138003879123} {'test_MAE': 297.37114045884874}


100%|██████████| 225/225 [00:24<00:00,  9.14it/s]
100%|██████████| 225/225 [00:13<00:00, 16.33it/s] 


14 {'train_MAE': 304.78695185343423} {'test_MAE': 222.19769912719727}


100%|██████████| 225/225 [00:24<00:00,  9.00it/s]
100%|██████████| 225/225 [00:13<00:00, 16.63it/s] 


15 {'train_MAE': 332.72744330512154} {'test_MAE': 246.77933568318684}


100%|██████████| 225/225 [00:24<00:00,  9.05it/s]
100%|██████████| 225/225 [00:14<00:00, 15.93it/s] 


16 {'train_MAE': 293.94775851779514} {'test_MAE': 302.81363018459746}


100%|██████████| 225/225 [00:24<00:00,  9.05it/s]
100%|██████████| 225/225 [00:13<00:00, 16.89it/s] 


17 {'train_MAE': 318.1692540147569} {'test_MAE': 239.70462290445963}


100%|██████████| 225/225 [00:25<00:00,  8.90it/s]
100%|██████████| 225/225 [00:13<00:00, 16.17it/s] 


18 {'train_MAE': 311.9480675930447} {'test_MAE': 244.06914877997505}


100%|██████████| 225/225 [00:25<00:00,  8.85it/s]
100%|██████████| 225/225 [00:14<00:00, 15.53it/s] 


19 {'train_MAE': 324.4009492323134} {'test_MAE': 225.89145670572915}


100%|██████████| 225/225 [00:25<00:00,  8.90it/s]
100%|██████████| 225/225 [00:13<00:00, 16.16it/s] 


20 {'train_MAE': 328.21813320583766} {'test_MAE': 320.77834262424045}


100%|██████████| 225/225 [00:25<00:00,  8.72it/s]
100%|██████████| 225/225 [00:13<00:00, 16.19it/s] 


21 {'train_MAE': 307.6688671535916} {'test_MAE': 213.2446683078342}


100%|██████████| 225/225 [00:25<00:00,  8.80it/s]
100%|██████████| 225/225 [00:14<00:00, 15.59it/s]


22 {'train_MAE': 323.4502533976237} {'test_MAE': 318.64207729763456}


100%|██████████| 225/225 [00:25<00:00,  8.87it/s]
100%|██████████| 225/225 [00:13<00:00, 16.29it/s] 


23 {'train_MAE': 309.2405639648438} {'test_MAE': 229.42875484890408}


100%|██████████| 225/225 [00:24<00:00,  9.01it/s]
100%|██████████| 225/225 [00:14<00:00, 15.19it/s] 


24 {'train_MAE': 327.16859144422745} {'test_MAE': 322.108306511773}


100%|██████████| 225/225 [00:25<00:00,  8.73it/s]
100%|██████████| 225/225 [00:14<00:00, 15.18it/s] 


25 {'train_MAE': 283.6801249186198} {'test_MAE': 240.96335042317708}


100%|██████████| 225/225 [00:25<00:00,  8.97it/s]
100%|██████████| 225/225 [00:15<00:00, 14.66it/s]


26 {'train_MAE': 278.49058064778643} {'test_MAE': 247.3914333258735}


100%|██████████| 225/225 [00:25<00:00,  8.83it/s]
100%|██████████| 225/225 [00:14<00:00, 15.97it/s] 


27 {'train_MAE': 326.8445289781358} {'test_MAE': 250.45934865315755}


100%|██████████| 225/225 [00:25<00:00,  8.86it/s]
100%|██████████| 225/225 [00:14<00:00, 15.62it/s] 


28 {'train_MAE': 283.78186313205293} {'test_MAE': 239.15640867445205}


100%|██████████| 225/225 [00:24<00:00,  9.11it/s]
100%|██████████| 225/225 [00:14<00:00, 15.73it/s]


29 {'train_MAE': 279.09865834554034} {'test_MAE': 264.404601304796}


100%|██████████| 225/225 [00:26<00:00,  8.63it/s]
100%|██████████| 225/225 [00:14<00:00, 15.60it/s] 


30 {'train_MAE': 294.33859649658206} {'test_MAE': 288.4171528116862}


100%|██████████| 225/225 [00:25<00:00,  8.72it/s]
100%|██████████| 225/225 [00:14<00:00, 15.43it/s] 


31 {'train_MAE': 308.43726145426433} {'test_MAE': 204.12507334391276}


100%|██████████| 225/225 [00:24<00:00,  9.12it/s]
100%|██████████| 225/225 [00:14<00:00, 15.71it/s] 


32 {'train_MAE': 285.44376692030164} {'test_MAE': 203.38324829101563}


100%|██████████| 225/225 [00:25<00:00,  8.87it/s]
100%|██████████| 225/225 [00:13<00:00, 16.41it/s] 


33 {'train_MAE': 318.2769142659505} {'test_MAE': 410.59354668511287}


100%|██████████| 225/225 [00:25<00:00,  8.93it/s]
100%|██████████| 225/225 [00:14<00:00, 15.83it/s] 


34 {'train_MAE': 278.4948094346788} {'test_MAE': 205.56681047227647}


100%|██████████| 225/225 [00:24<00:00,  9.16it/s]
100%|██████████| 225/225 [00:14<00:00, 15.56it/s]


35 {'train_MAE': 301.32797658284505} {'test_MAE': 221.19911031087238}


100%|██████████| 225/225 [00:25<00:00,  8.94it/s]
100%|██████████| 225/225 [00:14<00:00, 15.68it/s] 


36 {'train_MAE': 280.7252377319336} {'test_MAE': 243.88542805989584}


100%|██████████| 225/225 [00:25<00:00,  8.86it/s]
100%|██████████| 225/225 [00:14<00:00, 15.72it/s] 


37 {'train_MAE': 280.8797422281901} {'test_MAE': 233.7019489542643}


100%|██████████| 225/225 [00:25<00:00,  8.90it/s]
100%|██████████| 225/225 [00:13<00:00, 16.78it/s]


38 {'train_MAE': 317.8576113213433} {'test_MAE': 347.1805241902669}


100%|██████████| 225/225 [00:25<00:00,  8.95it/s]
100%|██████████| 225/225 [00:13<00:00, 16.11it/s]


39 {'train_MAE': 284.1077346462674} {'test_MAE': 210.06633541531033}


100%|██████████| 225/225 [00:25<00:00,  8.95it/s]
100%|██████████| 225/225 [00:14<00:00, 15.77it/s] 


40 {'train_MAE': 277.90894100613065} {'test_MAE': 200.48914608425565}


100%|██████████| 225/225 [00:24<00:00,  9.00it/s]
100%|██████████| 225/225 [00:14<00:00, 15.83it/s]


41 {'train_MAE': 284.7489400906033} {'test_MAE': 223.0327173529731}


100%|██████████| 225/225 [00:25<00:00,  8.83it/s]
100%|██████████| 225/225 [00:14<00:00, 15.58it/s] 


42 {'train_MAE': 290.5626956854926} {'test_MAE': 203.52756130642362}


100%|██████████| 225/225 [00:26<00:00,  8.64it/s]
100%|██████████| 225/225 [00:13<00:00, 16.15it/s] 


43 {'train_MAE': 307.18622331407335} {'test_MAE': 252.3339000108507}


100%|██████████| 225/225 [00:25<00:00,  8.66it/s]
100%|██████████| 225/225 [00:13<00:00, 16.28it/s] 


44 {'train_MAE': 285.8931240505642} {'test_MAE': 219.97893720838758}


100%|██████████| 225/225 [00:24<00:00,  9.08it/s]
100%|██████████| 225/225 [00:14<00:00, 16.02it/s]


45 {'train_MAE': 289.10981502956815} {'test_MAE': 207.63427117241753}


100%|██████████| 225/225 [00:25<00:00,  8.77it/s]
100%|██████████| 225/225 [00:13<00:00, 17.27it/s] 


46 {'train_MAE': 271.35394982231986} {'test_MAE': 203.21662055121527}


100%|██████████| 225/225 [00:25<00:00,  8.81it/s]
100%|██████████| 225/225 [00:13<00:00, 16.27it/s] 


47 {'train_MAE': 302.18707875569663} {'test_MAE': 259.4832622612847}


100%|██████████| 225/225 [00:25<00:00,  8.88it/s]
100%|██████████| 225/225 [00:13<00:00, 16.32it/s] 


48 {'train_MAE': 316.08882568359377} {'test_MAE': 340.8963451809353}


100%|██████████| 225/225 [00:24<00:00,  9.04it/s]
100%|██████████| 225/225 [00:14<00:00, 15.95it/s] 


49 {'train_MAE': 300.4181499226888} {'test_MAE': 221.2637011379666}


100%|██████████| 225/225 [00:25<00:00,  8.87it/s]
100%|██████████| 225/225 [00:13<00:00, 16.37it/s] 


50 {'train_MAE': 275.92505520290797} {'test_MAE': 222.49572330050998}


100%|██████████| 225/225 [00:25<00:00,  8.86it/s]
100%|██████████| 225/225 [00:13<00:00, 16.08it/s] 


51 {'train_MAE': 298.7178366088867} {'test_MAE': 214.60465671115452}


100%|██████████| 225/225 [00:25<00:00,  8.93it/s]
100%|██████████| 225/225 [00:13<00:00, 16.63it/s] 


52 {'train_MAE': 278.6909752400716} {'test_MAE': 196.6785625881619}


100%|██████████| 225/225 [00:25<00:00,  8.92it/s]
100%|██████████| 225/225 [00:13<00:00, 16.29it/s] 


53 {'train_MAE': 278.20135586208767} {'test_MAE': 210.77910512288412}


100%|██████████| 225/225 [00:25<00:00,  8.94it/s]
100%|██████████| 225/225 [00:14<00:00, 15.71it/s] 


54 {'train_MAE': 280.3000519137912} {'test_MAE': 197.22637068006728}


100%|██████████| 225/225 [00:24<00:00,  9.11it/s]
100%|██████████| 225/225 [00:14<00:00, 15.84it/s]


55 {'train_MAE': 271.76891462537975} {'test_MAE': 198.77767032199435}


100%|██████████| 225/225 [00:24<00:00,  9.04it/s]
100%|██████████| 225/225 [00:14<00:00, 15.72it/s] 


56 {'train_MAE': 317.22057067871094} {'test_MAE': 231.93814785427517}


100%|██████████| 225/225 [00:24<00:00,  9.03it/s]
100%|██████████| 225/225 [00:14<00:00, 15.88it/s] 


57 {'train_MAE': 278.30695288764105} {'test_MAE': 200.44143886990017}


100%|██████████| 225/225 [00:24<00:00,  9.02it/s]
100%|██████████| 225/225 [00:14<00:00, 15.89it/s] 


58 {'train_MAE': 294.02679609510636} {'test_MAE': 201.39539031982423}


100%|██████████| 225/225 [00:25<00:00,  8.96it/s]
100%|██████████| 225/225 [00:13<00:00, 16.08it/s] 


59 {'train_MAE': 306.518103773329} {'test_MAE': 225.73097042507595}


100%|██████████| 225/225 [00:25<00:00,  8.98it/s]
100%|██████████| 225/225 [00:14<00:00, 15.81it/s] 


60 {'train_MAE': 280.4470996432834} {'test_MAE': 199.00887647840713}


100%|██████████| 225/225 [00:25<00:00,  8.90it/s]
100%|██████████| 225/225 [00:13<00:00, 16.07it/s] 


61 {'train_MAE': 301.5739468383789} {'test_MAE': 208.29813303629558}


100%|██████████| 225/225 [00:24<00:00,  9.08it/s]
100%|██████████| 225/225 [00:13<00:00, 16.27it/s] 


62 {'train_MAE': 311.95487616644965} {'test_MAE': 383.32564710828996}


100%|██████████| 225/225 [00:25<00:00,  8.66it/s]
100%|██████████| 225/225 [00:14<00:00, 15.84it/s] 


63 {'train_MAE': 273.948157687717} {'test_MAE': 257.01770477294923}


100%|██████████| 225/225 [00:24<00:00,  9.04it/s]
100%|██████████| 225/225 [00:13<00:00, 16.17it/s]


64 {'train_MAE': 272.32744062635635} {'test_MAE': 211.94688327365452}


100%|██████████| 225/225 [00:25<00:00,  8.79it/s]
100%|██████████| 225/225 [00:13<00:00, 16.18it/s] 


65 {'train_MAE': 299.8479568481445} {'test_MAE': 292.09742970784504}


100%|██████████| 225/225 [00:25<00:00,  8.86it/s]
100%|██████████| 225/225 [00:14<00:00, 16.06it/s] 


66 {'train_MAE': 295.35138926188154} {'test_MAE': 216.56211703830294}


100%|██████████| 225/225 [00:26<00:00,  8.63it/s]
100%|██████████| 225/225 [00:14<00:00, 15.60it/s] 


67 {'train_MAE': 302.81642452663846} {'test_MAE': 364.262562628852}


100%|██████████| 225/225 [00:25<00:00,  8.84it/s]
100%|██████████| 225/225 [00:14<00:00, 15.67it/s] 


68 {'train_MAE': 294.0737998453776} {'test_MAE': 224.7222479248047}


100%|██████████| 225/225 [00:24<00:00,  9.06it/s]
100%|██████████| 225/225 [00:14<00:00, 15.73it/s] 


69 {'train_MAE': 277.58795369466145} {'test_MAE': 219.25431087917752}


100%|██████████| 225/225 [00:25<00:00,  8.87it/s]
100%|██████████| 225/225 [00:14<00:00, 16.03it/s]


70 {'train_MAE': 314.92255652533635} {'test_MAE': 373.90268669976126}


100%|██████████| 225/225 [00:24<00:00,  9.03it/s]
100%|██████████| 225/225 [00:13<00:00, 16.57it/s] 


71 {'train_MAE': 305.3325735812717} {'test_MAE': 284.4427411227756}


100%|██████████| 225/225 [00:24<00:00,  9.15it/s]
100%|██████████| 225/225 [00:13<00:00, 16.17it/s] 


72 {'train_MAE': 272.28545532226565} {'test_MAE': 205.32458567301433}


100%|██████████| 225/225 [00:25<00:00,  8.85it/s]
100%|██████████| 225/225 [00:14<00:00, 15.55it/s] 


73 {'train_MAE': 267.7525831434462} {'test_MAE': 242.12780398898656}


100%|██████████| 225/225 [00:25<00:00,  8.86it/s]
100%|██████████| 225/225 [00:14<00:00, 15.83it/s] 


74 {'train_MAE': 293.5881457519531} {'test_MAE': 198.48706678602431}


100%|██████████| 225/225 [00:25<00:00,  8.87it/s]
100%|██████████| 225/225 [00:13<00:00, 16.52it/s] 


75 {'train_MAE': 268.5878705851237} {'test_MAE': 217.67253865559897}


100%|██████████| 225/225 [00:24<00:00,  9.18it/s]
100%|██████████| 225/225 [00:14<00:00, 15.71it/s]


76 {'train_MAE': 273.1373442247179} {'test_MAE': 207.74800733778213}


100%|██████████| 225/225 [00:25<00:00,  8.87it/s]
100%|██████████| 225/225 [00:13<00:00, 16.25it/s] 


77 {'train_MAE': 265.30153493245444} {'test_MAE': 198.79931949191624}


100%|██████████| 225/225 [00:25<00:00,  8.83it/s]
100%|██████████| 225/225 [00:14<00:00, 15.39it/s] 


78 {'train_MAE': 312.8660141330295} {'test_MAE': 340.47065107557506}


100%|██████████| 225/225 [00:25<00:00,  8.86it/s]
100%|██████████| 225/225 [00:13<00:00, 16.77it/s] 


79 {'train_MAE': 266.16280554877386} {'test_MAE': 218.24446953667535}


100%|██████████| 225/225 [00:24<00:00,  9.06it/s]
100%|██████████| 225/225 [00:13<00:00, 16.21it/s] 


80 {'train_MAE': 270.5297294447157} {'test_MAE': 198.1953372870551}


100%|██████████| 225/225 [00:24<00:00,  9.08it/s]
100%|██████████| 225/225 [00:14<00:00, 15.56it/s] 


81 {'train_MAE': 274.6236801147461} {'test_MAE': 211.13214881049262}


100%|██████████| 225/225 [00:24<00:00,  9.05it/s]
100%|██████████| 225/225 [00:13<00:00, 16.15it/s] 


82 {'train_MAE': 288.1907590399848} {'test_MAE': 200.37100036621095}


100%|██████████| 225/225 [00:25<00:00,  8.93it/s]
100%|██████████| 225/225 [00:14<00:00, 15.73it/s]


83 {'train_MAE': 275.65390635172525} {'test_MAE': 200.9970127020942}


100%|██████████| 225/225 [00:25<00:00,  8.97it/s]
100%|██████████| 225/225 [00:13<00:00, 16.89it/s] 


84 {'train_MAE': 273.05305962456595} {'test_MAE': 201.54155673556858}


100%|██████████| 225/225 [00:25<00:00,  8.97it/s]
100%|██████████| 225/225 [00:14<00:00, 15.50it/s] 


85 {'train_MAE': 284.6399609035916} {'test_MAE': 269.3059218343099}


100%|██████████| 225/225 [00:25<00:00,  8.80it/s]
100%|██████████| 225/225 [00:13<00:00, 16.48it/s]


86 {'train_MAE': 309.5705756971571} {'test_MAE': 358.7368037584093}


100%|██████████| 225/225 [00:25<00:00,  8.86it/s]
100%|██████████| 225/225 [00:13<00:00, 16.66it/s]


87 {'train_MAE': 270.8382664320204} {'test_MAE': 198.44129977756077}


100%|██████████| 225/225 [00:25<00:00,  8.66it/s]
100%|██████████| 225/225 [00:14<00:00, 15.17it/s] 


88 {'train_MAE': 279.3189409722222} {'test_MAE': 199.3947605726454}


100%|██████████| 225/225 [00:25<00:00,  8.80it/s]
100%|██████████| 225/225 [00:14<00:00, 15.40it/s] 


89 {'train_MAE': 265.3221632215712} {'test_MAE': 208.30850419786242}


100%|██████████| 225/225 [00:25<00:00,  8.76it/s]
100%|██████████| 225/225 [00:13<00:00, 16.17it/s]


90 {'train_MAE': 267.59227261013456} {'test_MAE': 204.21823822021486}


100%|██████████| 225/225 [00:24<00:00,  9.08it/s]
100%|██████████| 225/225 [00:14<00:00, 15.82it/s] 


91 {'train_MAE': 280.38624793158635} {'test_MAE': 201.49207526312935}


100%|██████████| 225/225 [00:25<00:00,  8.90it/s]
100%|██████████| 225/225 [00:14<00:00, 16.06it/s] 


92 {'train_MAE': 278.2813579305013} {'test_MAE': 201.61972635904948}


100%|██████████| 225/225 [00:25<00:00,  8.73it/s]
100%|██████████| 225/225 [00:14<00:00, 16.05it/s] 


93 {'train_MAE': 283.43547471788196} {'test_MAE': 215.14662770589192}


100%|██████████| 225/225 [00:25<00:00,  8.80it/s]
100%|██████████| 225/225 [00:13<00:00, 16.17it/s] 


94 {'train_MAE': 276.39903934054905} {'test_MAE': 212.02626502143013}


100%|██████████| 225/225 [00:25<00:00,  8.90it/s]
100%|██████████| 225/225 [00:14<00:00, 15.83it/s] 


95 {'train_MAE': 290.053404168023} {'test_MAE': 253.7192600165473}


100%|██████████| 225/225 [00:25<00:00,  8.80it/s]
100%|██████████| 225/225 [00:13<00:00, 16.86it/s] 


96 {'train_MAE': 267.898236456977} {'test_MAE': 216.1929232449002}


100%|██████████| 225/225 [00:25<00:00,  8.83it/s]
100%|██████████| 225/225 [00:13<00:00, 16.52it/s]


97 {'train_MAE': 297.99747260199655} {'test_MAE': 221.4991197374132}


100%|██████████| 225/225 [00:24<00:00,  9.07it/s]
100%|██████████| 225/225 [00:13<00:00, 16.28it/s] 


98 {'train_MAE': 271.3532626003689} {'test_MAE': 203.8096442667643}


100%|██████████| 225/225 [00:24<00:00,  9.12it/s]
100%|██████████| 225/225 [00:14<00:00, 15.83it/s] 


99 {'train_MAE': 297.0006776936849} {'test_MAE': 200.14935248480901}


100%|██████████| 225/225 [00:25<00:00,  8.83it/s]
100%|██████████| 225/225 [00:14<00:00, 15.28it/s] 


100 {'train_MAE': 281.1434836154514} {'test_MAE': 197.58320675320095}


100%|██████████| 225/225 [00:25<00:00,  8.97it/s]
100%|██████████| 225/225 [00:14<00:00, 15.95it/s]


101 {'train_MAE': 264.76240898980035} {'test_MAE': 199.4275954522027}


100%|██████████| 225/225 [00:25<00:00,  8.87it/s]
100%|██████████| 225/225 [00:13<00:00, 16.28it/s] 


102 {'train_MAE': 261.57378121270074} {'test_MAE': 199.90829471164278}


100%|██████████| 225/225 [00:25<00:00,  8.83it/s]
100%|██████████| 225/225 [00:14<00:00, 15.79it/s]


103 {'train_MAE': 270.10451866997613} {'test_MAE': 211.57223276774087}


100%|██████████| 225/225 [00:25<00:00,  8.83it/s]
100%|██████████| 225/225 [00:14<00:00, 15.61it/s] 


104 {'train_MAE': 288.8960208808051} {'test_MAE': 206.41956427680122}


100%|██████████| 225/225 [00:24<00:00,  9.04it/s]
100%|██████████| 225/225 [00:13<00:00, 17.01it/s] 


105 {'train_MAE': 257.0113353474935} {'test_MAE': 229.71103068033855}


100%|██████████| 225/225 [00:25<00:00,  8.97it/s]
100%|██████████| 225/225 [00:14<00:00, 15.40it/s]


106 {'train_MAE': 261.8078715684679} {'test_MAE': 198.18929996066623}


100%|██████████| 225/225 [00:25<00:00,  8.81it/s]
100%|██████████| 225/225 [00:14<00:00, 15.79it/s]


107 {'train_MAE': 285.55171946207685} {'test_MAE': 200.40776340060765}


100%|██████████| 225/225 [00:25<00:00,  8.94it/s]
100%|██████████| 225/225 [00:13<00:00, 17.02it/s] 


108 {'train_MAE': 283.07551245795355} {'test_MAE': 200.93251705593534}


100%|██████████| 225/225 [00:25<00:00,  8.90it/s]
100%|██████████| 225/225 [00:14<00:00, 15.95it/s] 


109 {'train_MAE': 275.3559905158149} {'test_MAE': 207.41524149576824}


100%|██████████| 225/225 [00:25<00:00,  8.76it/s]
100%|██████████| 225/225 [00:14<00:00, 15.50it/s]


110 {'train_MAE': 288.44283718532984} {'test_MAE': 210.36558834499783}


100%|██████████| 225/225 [00:25<00:00,  8.97it/s]
100%|██████████| 225/225 [00:12<00:00, 17.51it/s] 


111 {'train_MAE': 273.3517238023546} {'test_MAE': 200.84831420898436}


100%|██████████| 225/225 [00:25<00:00,  8.99it/s]
100%|██████████| 225/225 [00:14<00:00, 15.98it/s] 


112 {'train_MAE': 262.4740932210286} {'test_MAE': 196.5543127102322}


100%|██████████| 225/225 [00:25<00:00,  8.87it/s]
100%|██████████| 225/225 [00:13<00:00, 16.52it/s] 


113 {'train_MAE': 297.08361050075956} {'test_MAE': 228.79043884277343}


100%|██████████| 225/225 [00:25<00:00,  8.90it/s]
100%|██████████| 225/225 [00:14<00:00, 15.59it/s] 


114 {'train_MAE': 269.6635187784831} {'test_MAE': 202.42422878689237}


100%|██████████| 225/225 [00:25<00:00,  8.76it/s]
100%|██████████| 225/225 [00:13<00:00, 16.89it/s] 


115 {'train_MAE': 263.49725321451825} {'test_MAE': 200.72567643907334}


100%|██████████| 225/225 [00:25<00:00,  8.76it/s]
100%|██████████| 225/225 [00:13<00:00, 16.91it/s] 


116 {'train_MAE': 302.38351186116535} {'test_MAE': 223.45718132866753}


100%|██████████| 225/225 [00:25<00:00,  8.94it/s]
100%|██████████| 225/225 [00:14<00:00, 15.94it/s] 


117 {'train_MAE': 297.86043528238935} {'test_MAE': 217.15735802544486}


100%|██████████| 225/225 [00:25<00:00,  8.94it/s]
100%|██████████| 225/225 [00:14<00:00, 15.92it/s]


118 {'train_MAE': 268.01881629096135} {'test_MAE': 199.44996673583984}


100%|██████████| 225/225 [00:25<00:00,  8.86it/s]
100%|██████████| 225/225 [00:14<00:00, 15.51it/s] 


119 {'train_MAE': 259.7237420654297} {'test_MAE': 219.48214982774522}


100%|██████████| 225/225 [00:25<00:00,  8.97it/s]
100%|██████████| 225/225 [00:13<00:00, 16.49it/s] 


120 {'train_MAE': 282.8946702745226} {'test_MAE': 201.30245012071398}


100%|██████████| 225/225 [00:24<00:00,  9.02it/s]
100%|██████████| 225/225 [00:13<00:00, 16.39it/s] 


121 {'train_MAE': 250.00180909898546} {'test_MAE': 219.03086917453342}


100%|██████████| 225/225 [00:25<00:00,  8.69it/s]
100%|██████████| 225/225 [00:14<00:00, 15.58it/s] 


122 {'train_MAE': 257.81803127712675} {'test_MAE': 206.2573614501953}


100%|██████████| 225/225 [00:25<00:00,  8.97it/s]
100%|██████████| 225/225 [00:13<00:00, 16.52it/s] 


123 {'train_MAE': 297.33383880615236} {'test_MAE': 262.9289310370551}


100%|██████████| 225/225 [00:24<00:00,  9.01it/s]
100%|██████████| 225/225 [00:14<00:00, 15.39it/s] 


124 {'train_MAE': 265.57005654229056} {'test_MAE': 198.26342702229817}


100%|██████████| 225/225 [00:26<00:00,  8.63it/s]
100%|██████████| 225/225 [00:13<00:00, 16.50it/s] 


125 {'train_MAE': 288.09320844862197} {'test_MAE': 212.8216801622179}


100%|██████████| 225/225 [00:25<00:00,  8.83it/s]
100%|██████████| 225/225 [00:13<00:00, 16.25it/s] 


126 {'train_MAE': 263.6280859375} {'test_MAE': 205.17627065022788}


100%|██████████| 225/225 [00:26<00:00,  8.62it/s]
100%|██████████| 225/225 [00:14<00:00, 15.95it/s]


127 {'train_MAE': 273.38031124538844} {'test_MAE': 220.86758783976236}


100%|██████████| 225/225 [00:25<00:00,  8.90it/s]
100%|██████████| 225/225 [00:14<00:00, 15.81it/s] 


128 {'train_MAE': 256.2297753567166} {'test_MAE': 211.90069458007812}


100%|██████████| 225/225 [00:25<00:00,  8.90it/s]
100%|██████████| 225/225 [00:14<00:00, 15.19it/s] 


129 {'train_MAE': 294.10393897162544} {'test_MAE': 216.53144409179689}


100%|██████████| 225/225 [00:25<00:00,  8.97it/s]
100%|██████████| 225/225 [00:13<00:00, 16.90it/s] 


130 {'train_MAE': 276.2338230726454} {'test_MAE': 198.23090308295355}


100%|██████████| 225/225 [00:25<00:00,  8.95it/s]
100%|██████████| 225/225 [00:13<00:00, 16.84it/s] 


131 {'train_MAE': 258.7536350165473} {'test_MAE': 219.45304173787434}


100%|██████████| 225/225 [00:25<00:00,  8.90it/s]
100%|██████████| 225/225 [00:14<00:00, 16.06it/s]


132 {'train_MAE': 302.7723286268446} {'test_MAE': 313.3778260464139}


100%|██████████| 225/225 [00:25<00:00,  8.86it/s]
100%|██████████| 225/225 [00:13<00:00, 16.17it/s] 


133 {'train_MAE': 281.19148667229547} {'test_MAE': 252.51919521755642}


100%|██████████| 225/225 [00:25<00:00,  8.86it/s]
100%|██████████| 225/225 [00:14<00:00, 15.84it/s]


134 {'train_MAE': 276.8020366753472} {'test_MAE': 197.79822987874348}


100%|██████████| 225/225 [00:25<00:00,  8.97it/s]
100%|██████████| 225/225 [00:13<00:00, 16.89it/s]


135 {'train_MAE': 269.3367320421007} {'test_MAE': 204.3418617078993}


100%|██████████| 225/225 [00:25<00:00,  8.87it/s]
100%|██████████| 225/225 [00:13<00:00, 16.51it/s] 


136 {'train_MAE': 295.78884575737845} {'test_MAE': 294.7175379096137}


100%|██████████| 225/225 [00:25<00:00,  9.00it/s]
100%|██████████| 225/225 [00:13<00:00, 16.19it/s] 


137 {'train_MAE': 272.31314980400936} {'test_MAE': 203.75600697835287}


100%|██████████| 225/225 [00:24<00:00,  9.00it/s]
100%|██████████| 225/225 [00:13<00:00, 16.66it/s] 


138 {'train_MAE': 276.6668099127875} {'test_MAE': 283.1207022942437}


100%|██████████| 225/225 [00:24<00:00,  9.07it/s]
100%|██████████| 225/225 [00:13<00:00, 16.27it/s] 


139 {'train_MAE': 291.5360624525282} {'test_MAE': 201.60333965725368}


100%|██████████| 225/225 [00:25<00:00,  8.73it/s]
100%|██████████| 225/225 [00:14<00:00, 15.83it/s]


140 {'train_MAE': 280.4484900580512} {'test_MAE': 197.06269066704644}


100%|██████████| 225/225 [00:25<00:00,  8.76it/s]
100%|██████████| 225/225 [00:13<00:00, 16.78it/s] 


141 {'train_MAE': 264.72023885091147} {'test_MAE': 197.26806611802843}


 11%|█         | 24/225 [00:04<00:34,  5.82it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 225/225 [00:24<00:00,  9.08it/s]
100%|██████████| 225/225 [00:14<00:00, 16.04it/s] 


186 {'train_MAE': 261.2826957872179} {'test_MAE': 226.17927100287542}


100%|██████████| 225/225 [00:25<00:00,  8.83it/s]
100%|██████████| 225/225 [00:14<00:00, 15.19it/s] 


187 {'train_MAE': 286.1049081759983} {'test_MAE': 284.28769460042315}


100%|██████████| 225/225 [00:25<00:00,  8.86it/s]
100%|██████████| 225/225 [00:13<00:00, 16.18it/s] 


188 {'train_MAE': 270.9381395975749} {'test_MAE': 195.79010711669923}


100%|██████████| 225/225 [00:25<00:00,  8.83it/s]
100%|██████████| 225/225 [00:14<00:00, 15.82it/s] 


189 {'train_MAE': 272.5763627794054} {'test_MAE': 200.40728680080838}


100%|██████████| 225/225 [00:24<00:00,  9.05it/s]
100%|██████████| 225/225 [00:13<00:00, 16.13it/s] 


190 {'train_MAE': 263.10090547349716} {'test_MAE': 202.66024944729276}


100%|██████████| 225/225 [00:25<00:00,  8.73it/s]
100%|██████████| 225/225 [00:13<00:00, 17.27it/s] 


191 {'train_MAE': 288.93268091837564} {'test_MAE': 299.4064279683431}


100%|██████████| 225/225 [00:24<00:00,  9.12it/s]
100%|██████████| 225/225 [00:14<00:00, 15.82it/s] 


192 {'train_MAE': 257.4378934902615} {'test_MAE': 206.2794294060601}


100%|██████████| 225/225 [00:25<00:00,  8.76it/s]
100%|██████████| 225/225 [00:13<00:00, 16.77it/s] 


193 {'train_MAE': 264.70812494913736} {'test_MAE': 205.23478931003146}


100%|██████████| 225/225 [00:25<00:00,  8.87it/s]
100%|██████████| 225/225 [00:13<00:00, 16.17it/s] 


194 {'train_MAE': 279.39620532565647} {'test_MAE': 204.4604758199056}


100%|██████████| 225/225 [00:25<00:00,  8.97it/s]
100%|██████████| 225/225 [00:14<00:00, 15.61it/s]


195 {'train_MAE': 260.983843299018} {'test_MAE': 195.6820169236925}


100%|██████████| 225/225 [00:25<00:00,  8.82it/s]
100%|██████████| 225/225 [00:12<00:00, 17.46it/s]


196 {'train_MAE': 323.66776809692385} {'test_MAE': 216.79281533135307}


100%|██████████| 225/225 [00:25<00:00,  8.70it/s]
100%|██████████| 225/225 [00:13<00:00, 16.17it/s] 


197 {'train_MAE': 306.85791870117185} {'test_MAE': 199.36454530504014}


100%|██████████| 225/225 [00:24<00:00,  9.16it/s]
100%|██████████| 225/225 [00:13<00:00, 16.40it/s] 


198 {'train_MAE': 261.31643498738606} {'test_MAE': 197.58060406155056}


 76%|███████▌  | 170/225 [00:23<00:05, 10.48it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

