## 分车型预测-使用LSTM建模预测

In [2]:
import os
import warnings
import numpy as np
import pandas as pd
from keras.layers import *
from keras.models import *
from keras.callbacks import *
from keras.optimizers import *
import matplotlib.pyplot as plt
from keras import backend as K
from keras.utils import plot_model
from keras.models import Sequential
from keras.engine.topology import Layer
from keras.layers import Dense, Dropout
from keras.initializers import glorot_uniform
from keras.layers import LSTM, RNN, GRU, SimpleRNN, Bidirectional, BatchNormalization
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras import initializers, regularizers, constraints, optimizers, layers
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

from IPython.core.interactiveshell import InteractiveShell 
InteractiveShell.ast_node_interactivity = 'all'

np.random.seed(1904)
warnings.filterwarnings('ignore')

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


### 导入数据

In [3]:
data_path = './data/'
train_user_reply_data = pd.read_csv(data_path + 'train_user_reply_data.csv', encoding='utf-8')
train_search_data = pd.read_csv(data_path + 'train_search_data.csv', encoding='utf-8')
train_sales_data = pd.read_csv(data_path + 'train_sales_data.csv', encoding='utf-8')
evaluation_public = pd.read_csv(data_path + 'evaluation_public.csv', encoding='utf-8')

### 提取车型和省份数据

In [4]:
# cars = list(set(train_sales_data['model']))
# provinces = list(set(train_sales_data['province']))
# print(len(cars))
# print(len(provinces))
provinces = ['北京', '湖南', '江西', '山西', '河南', '四川', '山东', '湖北', '云南', '福建', '安徽', '河北', '浙江', '内蒙古', '陕西', '江苏', '重庆', '上海', '广东', '黑龙江', '辽宁', '广西']
cars = ['63065128401bb3ff', '02aab221aabc03b9', 'fde95ea242abd896', '3c974920a76ac9c1', '346393c2c6305fb1', '54fc07138d70374c', 'af6f4f548684e14d', 'bb9fbec9a2833839', 'feabbf46658382b9', 'a9a43d1a7ecbe75d', '3d7554f1f56dd664', 'd4efbebb087fd03f', 'a28bb927b6fcb33c', 'c6cd4e0e073f5ac2', '7023efdab9cedc03', 'b25c4e2e3856af22', '6858d6dfe680bdf7', '61e73e32ad101892', 'a432c483b5beb856', '3e21824be728cbec', 'fc32b1a017b34efe', 'ea489c253676aafc', '79de4e4b24c35b04', '7245e0ee27b195cd', 'd0f245b8781e3631', '8c915fe4632fb9fa', '0797526c057dcf5b', '5b1c11c3efed5312', '12f8b7e14947c34d', '7cf283430b3b5e38', '7aab7fca2470987e', 'a207df29ec9583f0', 'dff803b4024d261d', '936168bd4850913d', '2d0d2c3403909fdb', '212083a9246d2fd3', '4f79773e600518a6', '2a2ab41f8f6ff1cb', '37aa9169b575ef79', 'c6833cb891626c17', 'c06a2a387c0ee510', 'f270f6a489c6a9d7', '17bc272c93f19d56', '7a7885e2d7c00bcf', '04e66e578f653ab9', 'ef76a85c4b39f693', '17363f08d683d52b', '6155b214590c66e6', '5d7fb682edd0f937', 'f8a6975573af1b33', '4a103c30d593fbbe', 'b4be3a4917289c82', '28e29f2c03dcd84c', 'da457d15788fe8ee', '97f15de12cfabbd5', 'cd5841d44fd7625e', '9c1c7ee8ebdda299', 'cc21c7e91a3b5a0c', 'f5d69960089c3614', '06880909932890ca']

# Normalization
train_sales_data['salesVolume'] = train_sales_data['salesVolume'].map(lambda index: np.log2(index) + 1)
scaler = MinMaxScaler(feature_range=(0, 1))
train_sales_data[['salesVolume']] = scaler.fit_transform(train_sales_data[['salesVolume']])
print(train_sales_data[:3])

  province  adcode             model bodyType  regYear  regMonth  salesVolume
0       上海  310000  3c974920a76ac9c1      SUV     2016         1     0.557228
1       云南  530000  3c974920a76ac9c1      SUV     2016         1     0.609492
2      内蒙古  150000  3c974920a76ac9c1      SUV     2016         1     0.542952


### 构造数据

In [5]:
def generate_batch(data_set, look_back, gap_days):
    data_x, data_y = [], []
    for i in range(len(data_set) - look_back - gap_days + 1):
        a = data_set[i:(i + look_back), 0]
        data_x.append(a)
        data_y.append([data_set[i + look_back + gap_days - 1, 0]])
    return np.array(data_x), np.array(data_y), data_set[-look_back:, 0].reshape(1, look_back)

def generate_first_order_batch(data_set, look_back, gap_days):
    data_x = []
    for i in range(len(data_set) - look_back - gap_days + 1):
        a = data_set[i:(i + look_back), 0]
        data_x.append(a)
    return np.array(data_x), data_set[-look_back:, 0].reshape(1, look_back)

### 自定义评测指标

In [6]:
def metrics(data):
    data = scaler.inverse_transform(data)
    data = 2 ** (data - 1)
    print(data)
    
    a = data[:, 0]
    b = data[:, 1]
    res = np.sqrt(np.sum((a - b) ** 2) / len(a)) / np.mean(a)
    print('validate rmse:', 1 - res)
    return (1 - res)

### Adamw优化器&循环学习率

In [7]:
class AdamW(Optimizer):
    def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, weight_decay=1e-4,  # decoupled weight decay (1/4)
                 epsilon=1e-8, decay=0., **kwargs):
        super(AdamW, self).__init__(**kwargs)
        with K.name_scope(self.__class__.__name__):
            self.iterations = K.variable(0, dtype='int64', name='iterations')
            self.lr = K.variable(lr, name='lr')
            self.beta_1 = K.variable(beta_1, name='beta_1')
            self.beta_2 = K.variable(beta_2, name='beta_2')
            self.decay = K.variable(decay, name='decay')
            # decoupled weight decay (2/4)
            self.wd = K.variable(weight_decay, name='weight_decay')
        self.epsilon = epsilon
        self.initial_decay = decay

    @interfaces.legacy_get_updates_support
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]
        wd = self.wd  # decoupled weight decay (3/4)

        lr = self.lr
        if self.initial_decay > 0:
            lr *= (1. / (1. + self.decay * K.cast(self.iterations,
                                                  K.dtype(self.decay))))

        t = K.cast(self.iterations, K.floatx()) + 1
        lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
                     (1. - K.pow(self.beta_1, t)))

        ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        self.weights = [self.iterations] + ms + vs

        for p, g, m, v in zip(params, grads, ms, vs):
            m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
            v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
            # decoupled weight decay (4/4)
            p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) - lr * wd * p

            self.updates.append(K.update(m, m_t))
            self.updates.append(K.update(v, v_t))
            new_p = p_t

            # Apply constraints.
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))
        return self.updates


class CyclicLR(Callback):
    def __init__(self, base_lr=0.001, max_lr=0.006, step_size=2000., mode='triangular',
                 gamma=1., scale_fn=None, scale_mode='cycle'):
        super(CyclicLR, self).__init__()

        self.base_lr = base_lr
        self.max_lr = max_lr
        self.step_size = step_size
        self.mode = mode
        self.gamma = gamma
        if scale_fn == None:
            if self.mode == 'triangular':
                self.scale_fn = lambda x: 1.
                self.scale_mode = 'cycle'
            elif self.mode == 'triangular2':
                self.scale_fn = lambda x: 1/(2.**(x-1))
                self.scale_mode = 'cycle'
            elif self.mode == 'exp_range':
                self.scale_fn = lambda x: gamma**(x)
                self.scale_mode = 'iterations'
        else:
            self.scale_fn = scale_fn
            self.scale_mode = scale_mode
        self.clr_iterations = 0.
        self.trn_iterations = 0.
        self.history = {}

        self._reset()

    def _reset(self, new_base_lr=None, new_max_lr=None,
               new_step_size=None):
        """Resets cycle iterations.
        Optional boundary/step size adjustment.
        """
        if new_base_lr != None:
            self.base_lr = new_base_lr
        if new_max_lr != None:
            self.max_lr = new_max_lr
        if new_step_size != None:
            self.step_size = new_step_size
        self.clr_iterations = 0.
        
    def clr(self):
        cycle = np.floor(1+self.clr_iterations/(2*self.step_size))
        x = np.abs(self.clr_iterations/self.step_size - 2*cycle + 1)
        if self.scale_mode == 'cycle':
            return self.base_lr + (self.max_lr-self.base_lr)*np.maximum(0, (1-x))*self.scale_fn(cycle)
        else:
            return self.base_lr + (self.max_lr-self.base_lr)*np.maximum(0, (1-x))*self.scale_fn(self.clr_iterations)
        
    def on_train_begin(self, logs={}):
        logs = logs or {}

        if self.clr_iterations == 0:
            K.set_value(self.model.optimizer.lr, self.base_lr)
        else:
            K.set_value(self.model.optimizer.lr, self.clr())        

    def on_batch_end(self, epoch, logs=None):
        
        logs = logs or {}
        self.trn_iterations += 1
        self.clr_iterations += 1

        self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr))
        self.history.setdefault('iterations', []).append(self.trn_iterations)

        for k, v in logs.items():
            self.history.setdefault(k, []).append(v)
        
        K.set_value(self.model.optimizer.lr, self.clr())

### 双输入-LSTM时序模型

In [8]:
from keras.layers.advanced_activations import PReLU

def construct_lstm_model(x_train, x_test, y_train, y_test, test, car_name, batch_size):
    print('batch_size:', batch_size)

    input1_ = Input(shape=(1, look_back), name='input1', dtype='float32')
    input2_ = Input(shape=(1, look_back), name='input2', dtype='float32')

    seq1 = LSTM(32, return_sequences=True)(input1_)    # , input_shape=(1, look_back)   GRU   LSTM
    seq1 = LSTM(32, return_sequences=False)(seq1)

    seq2 = LSTM(32, return_sequences=True)(input2_)
    seq2 = LSTM(32, return_sequences=False)(seq2)

    merged = concatenate([seq1, seq2])
    merged = BatchNormalization()(merged)
    merged = Dropout(0.5)(merged)

    output_ = Dense(1)(merged)
    output_ = PReLU()(output_)
    model = Model(inputs=[input1_, input2_], outputs=[output_])
    model.compile(loss='mean_squared_error', optimizer='adam')              # AdamW(weight_decay=0.08)   'adam'
#     model.summary()
    early_stopping = EarlyStopping('val_loss', patience=20, verbose=1)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=0.00001, verbose=0)
    best_model_path = './model/' + car_name + '.h5'
    checkpoint = ModelCheckpoint(best_model_path, monitor='val_loss', mode='min', save_best_only=True, verbose=0, save_weights_only=False)
    
#     clr = CyclicLR(base_lr=0.001, max_lr=0.01, step_size=300., mode='exp_range', gamma=0.99994)
    clr_fn = lambda x: 0.5*(1+np.sin(x*np.pi/2.))
    clr = CyclicLR(base_lr=0.0001, max_lr=0.001, step_size=300., scale_fn=clr_fn, scale_mode='cycle')
    
    callbacks = [early_stopping, reduce_lr]         # checkpoint   # reduce_lr
    model.fit([x_train[:, :, :look_back], x_train[:, :, look_back:]], y_train, epochs=500, batch_size=batch_size, verbose=0, shuffle=False, callbacks=callbacks, validation_data=([x_test[:, :, :look_back], x_test[:, :, look_back:]], y_test))
#     model.load_weights(best_model_path)     # load best model parameters   shuffle=False
    predict = model.predict([test[:, :, :look_back], test[:, :, look_back:]])
    predict = scaler.inverse_transform(predict)
    predict = 2 ** (predict - 1)
#     print(np.hstack((y_test, model.predict(x_test))))
    rmse = metrics(np.hstack((y_test, model.predict([x_test[:, :, :look_back], x_test[:, :, look_back:]]))))
    return predict, rmse

### 分车型&分月建模预测

In [9]:
file_name = os.listdir('./model/')

look_back = 9     # 10  9
label_length = 4
batch_size = 16
submit = pd.DataFrame()
flag = 0
total = []

for car in cars:
    labels = np.zeros([len(provinces), label_length])
    rmse_total = 0
    for k in range(1, 5, 1):
        print('{} car:{}'.format(flag+1, car))
        train = np.array([[0 for _ in range(look_back)]])
        train_feature = np.array([[0 for _ in range(look_back)]])
        label = np.array([[0]])
        test = np.array([[0 for _ in range(look_back)]])
        test_feature = np.array([[0 for _ in range(look_back)]])

        for province in provinces: 
            temp = train_sales_data[(train_sales_data['model'] == car) & (train_sales_data['province'] == province)]
    #         print(temp[['province', 'model', 'regYear', 'regMonth', 'salesVolume']])
            temp['first_order'] = temp['salesVolume'].diff(1)

            data = temp['salesVolume'].values.reshape(-1, 1).astype('float32')[1:]
            data_b = temp['first_order'].values.reshape(-1, 1).astype('float32')[1:]

            train_data, train_label, test_data = generate_batch(data, look_back, k)
            train_data_b, test_data_b = generate_first_order_batch(data_b, look_back, k)

            train = np.vstack((train, train_data))
            label = np.vstack((label, train_label))
            test = np.vstack((test, test_data))
            train_feature = np.vstack((train_feature, train_data_b))
            test_feature = np.vstack((test_feature, test_data_b))

        train = np.hstack((train, train_feature))
        test = np.hstack((test, test_feature))
        train = train[1:]
        label = label[1:]
        test = test[1:]
        train = np.reshape(train, (-1, 1, look_back * 2))
        test = np.reshape(test, (-1, 1, look_back * 2))
        x_train,x_test,y_train,y_test = train_test_split(train, label, test_size=0.1, random_state=1904)

        print('train shape:', train.shape)
        print('label shape:', label.shape)
        print('test shape:', test.shape)
        res, rmse = construct_lstm_model(x_train, x_test, y_train, y_test, test, car, int((train.shape[0] * 0.9) / 22))
        
        K.clear_session()    #  清除tf session

        labels[:, k-1] = res.reshape(1, -1)
        rmse_total += rmse
    print(labels)
    print('average rmse:', rmse_total / 4)
    flag += 1
    total.append(rmse_total / 4)

    index = 0
    submit_partial = pd.DataFrame()
    for province_2 in provinces: 
        temp = evaluation_public[(evaluation_public['model'] == car) & (evaluation_public['province'] == province_2)]
        temp['forecastVolum'] = labels[index]
        submit_partial = pd.concat([submit_partial, temp], axis=0)
        index += 1
    submit = pd.concat([submit, submit_partial], axis=0)

print('all validation rmse:', np.mean(total))  # 0.6555167883490066
car_rmse = pd.DataFrame({'cars': cars, 'rmse': total})
print(car_rmse)
# car_rmse.to_csv('car_rmse.csv', encoding='utf-8', index=None)

[]
1 car:63065128401bb3ff
train shape: (308, 1, 18)
label shape: (308, 1)
test shape: (22, 1, 18)
batch_size: 12
Epoch 00031: early stopping
[[ 691.99995217 1933.38750864]
 [ 213.99998516  384.05053383]
 [2187.99985726 2094.61830788]
 [1218.00022926 1251.88118693]
 [1543.99992576 1804.32234767]
 [2400.99986102 2363.8903343 ]
 [ 861.00018704 1010.99140879]
 [1223.00025475 1638.30938908]
 [3813.99934853 2938.9302254 ]
 [1033.00008916 1033.06507027]
 [1030.99979211  752.03382252]
 [1307.00008176 1450.9150229 ]
 [4090.99959642 3164.60211109]
 [3307.00016327 3355.5433662 ]
 [1550.00023406  970.44302813]
 [4744.99898154 4027.92425944]
 [1398.99963257 1745.97256651]
 [2429.00015662 2313.69521389]
 [2449.99960778 2884.12198502]
 [1328.00003096 1601.07419807]
 [2268.99946119 2012.72733164]
 [ 972.99989301 2055.39241437]
 [ 821.99996833  648.09962745]
 [4376.99941314 3582.16366546]
 [ 923.00021537  406.66986251]
 [ 216.99998578 1357.69702151]
 [1474.00019371 1895.39881674]
 [1980.00015659 1432.7

Epoch 00066: early stopping
[[ 251.99997215  284.84079005]
 [ 669.99992742  744.88203936]
 [ 277.99995496  244.17446656]
 [ 547.99987753  523.56051206]
 [ 574.99993604  605.45222007]
 [ 346.00006825  267.26063032]
 [1428.00031926 1260.9128361 ]
 [ 137.00000586  144.77363356]
 [ 972.99989301  950.60647341]
 [ 335.00005295  293.99389454]
 [ 646.99983716  669.57789433]
 [ 254.99994193  247.46037559]
 [ 240.0000128   227.50913791]
 [ 275.00003008  265.72041638]
 [1360.99992584 1351.70453333]
 [ 544.00013778  543.75166906]
 [ 215.9999473   178.2691066 ]
 [1536.99996726 1339.04764118]
 [ 346.00006825  297.56627031]
 [ 203.99998059  226.43217985]
 [ 261.00001421  225.29557648]
 [ 627.00007424  532.00163685]
 [ 476.99993557  431.01870943]
 [ 164.99998246  133.68980997]
 [ 339.00003597  322.0977506 ]]
validate rmse: 0.8764571798971973
[[ 640.71533203  564.44598389  645.53973389  590.72338867]
 [ 282.45648193  273.07580566  295.69125366  278.11950684]
 [ 242.72335815  200.9198761   247.14892578 

Epoch 00077: early stopping
[[ 193.9999825   289.75193223]
 [ 798.99993297  676.48179251]
 [ 624.00011502  560.16536209]
 [ 175.99997209  174.58899632]
 [ 308.00004924  264.37560801]
 [ 209.00004177  281.66627623]
 [ 335.99999034  397.32101793]
 [ 449.0001188   357.06231618]
 [ 432.00000981  321.79260668]
 [ 280.00005231  288.94107012]
 [ 167.99999514  162.84769542]
 [ 297.00003849  272.63829735]
 [ 449.0001188   384.42168208]
 [ 145.99998818  190.4808422 ]
 [ 601.00014728  407.87198493]
 [ 297.00003849  226.94140712]
 [ 616.99999424  637.45984078]
 [2695.99958708 2351.05425293]
 [ 260.00003463  208.99837057]
 [ 275.00003008  324.38364045]
 [2958.00036127 1917.65309392]
 [ 235.99993931  135.22196533]
 [ 303.00000948  304.476846  ]
 [ 676.00017861  536.65387606]
 [ 308.99993282  243.20420351]
 [ 144.000015    156.20348807]
 [ 292.99992504  395.51949182]
 [ 310.99999849  230.60765462]
 [ 218.99996439  187.58298892]]
validate rmse: 0.5818793950825417
4 car:3c974920a76ac9c1
train shape: (2

train shape: (308, 1, 18)
label shape: (308, 1)
test shape: (22, 1, 18)
batch_size: 12
Epoch 00032: early stopping
[[ 213.00003262  171.45971402]
 [ 129.00000071  211.54110703]
 [1835.99952625 1117.61277495]
 [ 328.99999414  244.60438624]
 [1066.99973392  706.64612039]
 [ 623.00000063  437.01000811]
 [ 274.00001175  290.86292865]
 [ 139.00001451  141.69756732]
 [ 251.99997215  168.63350498]
 [ 485.00002132  314.7973096 ]
 [ 132.99998377  153.76160288]
 [ 299.00004546  311.72599857]
 [ 495.00003899  416.23291356]
 [ 545.00008749  401.36615646]
 [ 398.00000325  365.96608895]
 [ 439.99998925  459.46305504]
 [ 507.99999809  470.32096507]
 [ 614.99985624  500.60561197]
 [ 849.00009697  580.85310064]
 [ 341.0000874   279.81338782]
 [ 546.00012889  494.37373521]
 [ 297.00003849  221.02780161]
 [ 238.00000912  255.84917927]
 [ 420.99990696  447.99739673]
 [ 419.0000991   321.3446829 ]
 [ 533.99992926  486.46532542]
 [ 325.00000021  274.59703847]
 [ 185.00002976  122.72893082]
 [ 595.00010261  

Epoch 00037: early stopping
[[ 311.9999743   268.01522201]
 [ 503.00000204  460.12337003]
 [1084.00020992 1042.74638313]
 [1018.99977723 1231.2746993 ]
 [ 489.00005739  504.82692125]
 [1068.00014336  878.20301746]
 [1406.99999231 1429.93897997]
 [ 257.00000759  250.61436096]
 [ 537.00014281  566.68135715]
 [1019.99976803  970.37888229]
 [ 574.00001853  706.84466801]
 [ 413.00002531  304.3270707 ]
 [ 321.0000371   273.27824193]
 [ 893.00002009  826.56892342]
 [1023.99972849  897.19804811]
 [1018.00022413 1017.95626849]
 [ 745.99986478  658.4973865 ]
 [ 986.00020072  980.80508651]
 [ 923.00021537  748.12488701]
 [ 432.99993244  359.10173447]
 [ 311.9999743   302.47100021]
 [ 869.9997714   801.66344112]
 [ 398.00000325  409.51992858]
 [ 452.0000849   442.40980968]
 [1014.99986878  958.66529894]]
validate rmse: 0.8778918953094831
[[ 410.92861938  520.31829834  462.84655762  413.18505859]
 [ 974.03881836  965.16375732  875.09008789  842.24645996]
 [ 903.41699219  768.13220215  659.39788818 

train shape: (286, 1, 18)
label shape: (286, 1)
test shape: (22, 1, 18)
batch_size: 11
Epoch 00059: early stopping
[[ 417.99997219  574.28894296]
 [2478.99946444 1690.45569039]
 [1253.99981444 1078.99829757]
 [ 789.0000243   581.75330238]
 [1186.99996599 1465.51964288]
 [ 218.99996439  238.09543649]
 [1246.99971479 1060.11064774]
 [1788.00004729 1773.88565657]
 [ 212.00000596  598.77475182]
 [1042.99988864 1405.42878368]
 [ 474.99991823  557.4426862 ]
 [ 671.00001371  801.55148338]
 [1013.99991491  975.00511876]
 [ 197.99998904  390.72095724]
 [1610.00011814 1274.93565321]
 [ 607.99989456  641.9153764 ]
 [ 853.99997036  867.99392816]
 [1141.99976237  779.47381369]
 [ 969.99978424  667.420122  ]
 [ 361.99993072  492.8065804 ]
 [ 825.00002307  915.16314122]
 [ 624.00011502  388.81440176]
 [1737.00009576 1732.0606686 ]
 [1924.99979798 1610.6585353 ]
 [ 197.99998904  242.17414236]
 [ 356.00007678  427.87552854]
 [ 337.99999916  593.04851936]
 [ 766.99980157  629.71763726]
 [ 925.00002648  

train shape: (308, 1, 18)
label shape: (308, 1)
test shape: (22, 1, 18)
batch_size: 12
Epoch 00053: early stopping
[[  289.99994741   384.19448532]
 [  149.00001603   186.89537681]
 [13597.99903144  8338.26982556]
 [ 2205.00016687  1329.21184855]
 [ 9099.99805824  6876.33309156]
 [ 2708.99993565  1695.83686606]
 [ 2361.00008393  1915.58719893]
 [ 1699.00024579  1327.76289621]
 [  699.00013804   464.92908115]
 [  608.99995221   319.03870777]
 [  482.99993117   460.88633783]
 [  505.99999716   458.12719406]
 [ 1052.0001184    941.39814871]
 [ 1948.00034548  2051.83123873]
 [  368.9999325    422.62924222]
 [ 2159.9997634   2110.26412897]
 [  678.99997411   925.89987378]
 [  968.00020842   880.6295134 ]
 [ 1125.00026984  1167.79134167]
 [  706.99992815   708.29383103]
 [ 1228.99986986   984.94112924]
 [  435.00000158   515.6998531 ]
 [  699.00013804   694.41448353]
 [ 1018.00022413   968.73375064]
 [  344.00003006   342.57045217]
 [  368.00005663   449.16697832]
 [ 1847.0000175   1761.4031

Epoch 00045: early stopping
[[ 817.99985235  893.11142016]
 [ 968.00020842  488.63081987]
 [ 649.99982727 1014.97389753]
 [1335.00000223 1411.76630926]
 [ 952.00003677  839.19490783]
 [ 685.00012127  894.6290816 ]
 [1078.00022843 1200.34602967]
 [ 474.00002431  408.59818704]
 [ 639.00004583 1080.10380482]
 [ 691.99995217  711.43882801]
 [ 322.00006621  689.13115366]
 [ 318.99993356  266.65354015]
 [ 309.99992256  323.49655576]
 [ 658.99990312  894.97681469]
 [1078.00022843 1274.18555067]
 [1163.00028898 1140.67340283]
 [ 561.00000674  573.19155876]
 [2221.99989741 1811.08811666]
 [1447.99972332 1202.86210589]
 [1915.00012297  996.67872205]
 [ 485.00002132  287.22078723]
 [  98.9999945   145.94306704]
 [ 608.99995221  617.76353603]
 [1161.00012672  740.98657691]
 [ 410.00004689  537.50993309]]
validate rmse: 0.6489830946933528
[[ 324.53363037  448.61669922  554.07568359  789.59735107]
 [1415.48901367  855.24407959  735.43560791  773.79162598]
 [1059.16577148  687.44946289  584.14459229 

Epoch 00064: early stopping
[[  67.00000159  180.75288575]
 [1420.00014576  701.82733297]
 [ 337.00003805  357.14341117]
 [ 110.99999391  103.29027616]
 [ 466.00000584  390.46464263]
 [ 103.9999999   102.86316617]
 [ 225.99998218  270.60422072]
 [ 368.00005663  399.86654044]
 [  28.99999857  222.12499367]
 [ 328.99999414  329.73338426]
 [ 377.00008145  299.23922677]
 [ 231.99998868  318.5995426 ]
 [ 372.00002518  594.67888428]
 [ 125.99998606  234.19089215]
 [1684.99996742 1435.25958627]
 [ 292.99992504  353.03735627]
 [ 787.00009207  722.40244801]
 [ 480.99988398  182.1882774 ]
 [ 603.99996007  432.63446327]
 [  50.0000001    56.16826448]
 [ 373.99993532  410.79485264]
 [ 548.99986279  316.87932324]
 [ 297.00003849  353.17589614]
 [ 510.00001987  490.66704385]
 [ 222.99998655  245.81966186]
 [ 344.00003006  262.32527454]
 [  50.99999513  115.78228497]
 [ 361.99993072  335.54838862]
 [ 339.99995037  299.44491609]]
validate rmse: 0.5535037366958823
14 car:c6cd4e0e073f5ac2
train shape: (

train shape: (308, 1, 18)
label shape: (308, 1)
test shape: (22, 1, 18)
batch_size: 12
Epoch 00056: early stopping
[[ 229.00004815  272.21293572]
 [ 109.99999728  214.11592058]
 [1644.0003751  1295.73949224]
 [ 904.99994818  737.69994958]
 [ 947.00021953  705.52181947]
 [2703.0005379  1760.41657561]
 [ 271.99999635  237.98947892]
 [ 372.00002518  396.06257395]
 [ 398.99991883  423.13871466]
 [ 691.99995217  563.43030882]
 [ 503.00000204  372.4776521 ]
 [ 439.00004864  420.46140493]
 [ 667.99995115  493.15846735]
 [1068.00014336  974.66006307]
 [ 370.99993079  419.53292761]
 [ 952.00003677  897.85496267]
 [ 680.00008208  634.42201048]
 [ 699.9998515   553.8609758 ]
 [1774.9999469   949.64515254]
 [ 562.00004641  527.81357611]
 [ 813.00009098  734.4682264 ]
 [ 753.99996206  719.5008077 ]
 [ 593.00015171  433.02970961]
 [ 499.00008146  466.71847715]
 [ 197.99998904  211.48371591]
 [ 375.99997385  703.7271077 ]
 [1037.00017171  987.87474807]
 [ 159.99997894  239.99208072]
 [ 804.99984438  

Epoch 00050: early stopping
[[ 218.99996439  200.81925619]
 [ 667.99995115  588.57124025]
 [ 401.99997688  369.18153671]
 [ 498.00010356  408.35213036]
 [ 526.99993064  616.12637221]
 [ 379.99998495  321.7092487 ]
 [ 893.99978522  984.96160644]
 [ 157.00002043  147.84560698]
 [1003.99980564  966.73678465]
 [ 400.000001    337.67690986]
 [ 580.99988865  617.98191139]
 [ 228.00000256  228.31136217]
 [ 195.00001004  266.25894736]
 [ 372.00002518  327.12682189]
 [1448.99967552 1015.67264726]
 [ 510.00001987  529.24541682]
 [ 197.99998904  202.14075814]
 [1906.00012926 1513.93740013]
 [ 360.00008578  405.58456047]
 [ 185.99996299  200.54613729]
 [ 163.99999676  163.68879763]
 [ 503.00000204  449.76286354]
 [ 366.00003594  325.81165102]
 [ 185.00002976  196.51878664]
 [ 299.99997623  288.27074782]]
validate rmse: 0.7506618675449006
[[ 543.74456787  575.43664551  665.59729004  588.78234863]
 [ 365.62240601  374.46447754  400.82446289  372.16799927]
 [ 206.33703613  214.57106018  225.08279419 

train shape: (286, 1, 18)
label shape: (286, 1)
test shape: (22, 1, 18)
batch_size: 11
Epoch 00035: early stopping
[[100.00000022 146.39792662]
 [995.00014158 441.98432482]
 [432.99993244 310.28015511]
 [134.0000032  115.36629364]
 [358.99990848 408.16453011]
 [130.0000173  161.20502928]
 [197.00001339 171.32983604]
 [297.00003849 172.8462752 ]
 [ 44.00000474 146.46989864]
 [229.00004815 129.90275622]
 [192.00003569 114.95136797]
 [132.99998377 161.05454159]
 [187.00001748 221.06927995]
 [ 51.99999994 168.58941731]
 [599.00014511 446.02339245]
 [132.99998377 127.508875  ]
 [358.000029   385.40535801]
 [567.99998211 461.99843533]
 [225.99998218 202.15088754]
 [105.00001102  38.64432027]
 [466.99995201 379.48944723]
 [258.00000145 215.97461691]
 [267.99993502 249.39917824]
 [586.99986225 705.84082305]
 [ 60.99999469  67.76845284]
 [103.00001345  84.91464062]
 [ 41.99999877  76.71949592]
 [164.99998246 182.39642862]
 [195.00001004 169.23962916]]
validate rmse: 0.5249567606753738
19 car:a4

train shape: (308, 1, 18)
label shape: (308, 1)
test shape: (22, 1, 18)
batch_size: 12
Epoch 00033: early stopping
[[ 417.00000957 1018.80695708]
 [ 134.99998513  325.45336614]
 [ 768.000143    806.76592416]
 [ 253.00006598  267.60762948]
 [ 584.99998247  649.19817252]
 [ 556.00005821  761.66387387]
 [ 199.99994716  178.63021106]
 [ 197.99998904  213.46744471]
 [ 731.00005337  520.37373162]
 [ 434.00008731  547.6187845 ]
 [ 215.9999473   226.19608118]
 [ 379.00000058  449.08557569]
 [1136.99990914 1192.09830668]
 [ 557.00009262  614.09174618]
 [  79.00001047  198.02616705]
 [ 672.00015989  609.49133781]
 [ 814.00010927  823.41785945]
 [ 387.00007376  580.45596803]
 [2057.00041337 1635.8710729 ]
 [ 464.00010112  490.88759185]
 [ 472.99996588  648.5423486 ]
 [ 419.99993224  802.76077017]
 [ 195.00001004  211.85912604]
 [1508.99988323 1127.96134637]
 [ 159.99997894  122.91405706]
 [ 172.99998799  669.12794738]
 [ 394.99989484  402.88449717]
 [ 358.99990848  239.7783972 ]
 [ 411.99994411  

Epoch 00041: early stopping
[[ 614.00009249  648.88228427]
 [ 149.99998809  171.40712054]
 [ 409.00003513  698.55350717]
 [2557.99982451 2217.96323852]
 [ 100.00000022  399.84650388]
 [ 448.00002373  729.41421833]
 [ 992.99984555 1040.99909561]
 [ 221.99998786  241.37994681]
 [ 466.99995201  542.98928653]
 [ 812.00000266  677.7942705 ]
 [ 166.99998773  177.95056116]
 [1251.99998102  818.1415835 ]
 [ 328.99999414  451.01230093]
 [ 483.99997513  588.55304277]
 [ 432.99993244  506.85227594]
 [1386.00036504 1089.70765295]
 [ 235.99993931  525.58453142]
 [ 432.00000981  613.68104798]
 [5099.00050083 3380.89065502]
 [1141.99976237  619.58569116]
 [ 246.00004064  240.85051827]
 [ 238.00000912  240.19468666]
 [ 301.00004187  218.88290131]
 [ 581.99990012  236.40349369]
 [ 166.00000379  312.29017204]]
validate rmse: 0.4685497152231133
[[  99.35699463  133.28746033   90.24754333   85.14995575]
 [ 552.89208984  399.11514282  409.99227905  518.51397705]
 [ 374.09658813  350.62359619  305.1546936  

Epoch 00035: early stopping
[[ 246.00004064  356.28884305]
 [1012.99986616  727.53856119]
 [ 491.00012783  584.03840189]
 [ 291.9999764   281.83599708]
 [ 667.00006949  699.61374281]
 [ 195.00001004  222.89147874]
 [ 627.99991453  530.39888803]
 [ 912.00001051 1021.30231219]
 [  74.00000198  233.52232008]
 [ 512.99992205  758.90384586]
 [ 348.00004739  423.1454817 ]
 [ 478.99994766  747.72259671]
 [ 930.0001877  1096.79199745]
 [ 366.00003594  635.54480121]
 [2915.00000886 1962.7006102 ]
 [1023.0001789   937.82632791]
 [ 460.00000981  510.8102917 ]
 [ 867.00013626  517.73626143]
 [1091.99996689  739.35621688]
 [ 115.00000242   97.04472949]
 [ 821.00018808  741.73588298]
 [ 634.0000814   547.26800456]
 [ 880.99998307  681.34989529]
 [ 935.99984669  925.97045802]
 [  60.00000318   86.73166701]
 [ 315.99995776  296.85520707]
 [  79.99998946  207.51297476]
 [1069.99978428 1217.54775847]
 [ 676.00017861  490.98756439]]
validate rmse: 0.6345067948036089
24 car:7245e0ee27b195cd
train shape: (

train shape: (308, 1, 18)
label shape: (308, 1)
test shape: (22, 1, 18)
batch_size: 12
Epoch 00035: early stopping
[[ 772.00016853 1108.18411647]
 [ 270.00004226  518.31976536]
 [1467.00005272  909.20398054]
 [ 622.00016287  458.14868572]
 [ 853.99997036  898.5123582 ]
 [1502.000059   1252.57342105]
 [ 286.00003039  404.51250528]
 [ 306.00002747  446.05953426]
 [1487.00018598 1045.20954881]
 [ 756.99997368  664.00727168]
 [ 632.00008407  540.92217453]
 [ 573.00000713  763.66170394]
 [1714.00016605 1511.76477509]
 [ 930.0001877  1245.90269766]
 [ 126.99999948  152.90510615]
 [1065.00002226 1127.51587712]
 [1191.99981106 1097.52776669]
 [1413.00033037 1186.49639266]
 [2433.00064161 1756.4467648 ]
 [ 498.00010356  678.40915027]
 [1811.99973265 1614.42104713]
 [1014.99986878 1217.17331513]
 [ 757.99979923  616.81681892]
 [1627.99978487 1737.76232575]
 [  66.00000172   36.63034058]
 [ 101.00001138  431.38833278]
 [ 697.0001621   915.25827763]
 [ 627.99991453  494.83910213]
 [1503.99989588 1

Epoch 00043: early stopping
[[ 387.00007376  409.50530233]
 [ 411.00009365  556.01902759]
 [1502.99964483 1484.20304721]
 [2045.99981278 1911.65569539]
 [ 318.99993356  777.64351985]
 [1104.99983697 1431.9100418 ]
 [2915.00000886 2116.78192983]
 [1736.00034977 1230.30694006]
 [ 801.00004204  617.81721694]
 [ 742.99987713  593.66183437]
 [ 378.00002814  545.15844813]
 [2284.99996212 2280.98381808]
 [ 329.99996496  417.013792  ]
 [1785.99956441 1359.56270527]
 [8079.00062155 7506.86527931]
 [1082.99982332  978.9126435 ]
 [ 834.99982825  758.82091657]
 [1896.00009781 1520.35164211]
 [2518.00059452 2267.62824077]
 [ 389.99991618  390.89012119]
 [ 671.00001371  524.85657383]
 [ 287.99995327  302.90214249]
 [1475.99973044 1211.20689813]
 [2963.0000302  2542.96939881]
 [ 842.00003848  775.61159246]]
validate rmse: 0.7974282082532936
[[ 436.08224487  378.20922852  388.62216187  412.94012451]
 [1987.88928223 1555.34069824 1389.16882324 1563.11425781]
 [1629.54321289 1171.48461914 1087.29150391 

Epoch 00028: early stopping
[[ 159.99997894  298.3623676 ]
 [2274.99951385  861.41382789]
 [ 610.0000289   453.95552765]
 [ 369.99996096  342.02832675]
 [ 462.99991421  304.58496367]
 [ 175.99997209  198.5624153 ]
 [ 580.99988865  418.33657604]
 [ 830.99979565  494.63919119]
 [ 101.99999028  263.08654692]
 [ 398.00000325  312.89439469]
 [ 280.99994826  217.34857194]
 [ 254.00006671  250.62932429]
 [ 912.00001051  691.5791764 ]
 [  86.99998863  291.92791523]
 [1707.00005187 1471.04478292]
 [ 591.99985831  396.50514342]
 [1006.99985272  835.13960774]
 [ 246.00004064  139.98379483]
 [ 464.00010112  421.13817605]
 [ 126.99999948   82.71043463]
 [ 220.99999662  125.72794802]
 [ 250.99995133  200.96425748]
 [ 558.00014097  400.61586032]
 [ 880.99998307  612.4623542 ]
 [ 504.99999014  345.96318119]
 [ 257.00000759  208.93420681]
 [  86.00000749  123.45847621]
 [ 597.99993162  403.43489625]
 [ 789.0000243   647.73765723]]
validate rmse: 0.44276050766572617
29 car:12f8b7e14947c34d
train shape: 

train shape: (308, 1, 18)
label shape: (308, 1)
test shape: (22, 1, 18)
batch_size: 12
Epoch 00050: early stopping
[[105.00001102  63.33230625]
 [ 19.00000175  19.26488324]
 [971.99986347 620.6722987 ]
 [ 83.00000188  51.37472412]
 [864.99982555 652.91278427]
 [322.99993925 246.84298968]
 [162.99998896 127.97325517]
 [ 57.00000062  58.58147132]
 [ 57.00000062  39.71028671]
 [ 22.00000237  16.36574963]
 [ 25.00000005  28.58526007]
 [  8.99999973  17.1183105 ]
 [232.99994078 205.61472038]
 [304.99993311 273.49422311]
 [320.00004322 339.56614167]
 [284.00006671 255.82790368]
 [ 28.99999857  40.02906871]
 [155.00000257 166.31448866]
 [ 64.00000005  41.76239537]
 [ 30.00000159  28.51401905]
 [303.00000948 272.01493148]
 [ 60.99999469  30.9386239 ]
 [ 38.99999677  45.69808679]
 [218.00000576 193.61575493]
 [257.00000759 228.50763955]
 [423.99989897 359.03167816]
 [449.0001188  260.65185642]
 [ 57.99999715  48.67419078]
 [295.00004229 264.94121228]
 [271.0000524  217.21328511]
 [258.00000145 

Epoch 00043: early stopping
[[ 223.99995213  225.0810596 ]
 [ 727.00005764  724.48612714]
 [ 225.99998218  206.28563174]
 [ 451.00003943  407.10605477]
 [ 763.00005986  662.30192402]
 [ 246.00004064  204.80292051]
 [ 891.00004291  928.58239578]
 [ 134.0000032   117.75841718]
 [ 623.00000063  621.07708097]
 [ 254.99994193  248.24850561]
 [ 773.00002635  799.25425341]
 [ 247.0000225   266.34227735]
 [ 197.00001339  176.13747469]
 [ 246.00004064  193.0170187 ]
 [1100.99977689  848.55441914]
 [ 469.00003563  490.00285966]
 [ 213.99998516  130.62556584]
 [1063.00009171  826.15571951]
 [ 250.99995133  298.10052634]
 [ 229.99994357  238.84979323]
 [ 168.99999955  202.48619822]
 [ 523.99999792  533.24496565]
 [ 329.99996496  284.0274704 ]
 [ 144.000015    146.09232308]
 [ 281.99995734  325.07312014]]
validate rmse: 0.8174611559791098
[[ 663.25933838  589.69067383  760.64788818  725.25482178]
 [ 220.04231262  222.20739746  253.28773499  264.05923462]
 [ 174.016922    172.37060547  180.28884888 

Epoch 00050: early stopping
[[ 161.99999041  219.89792843]
 [1390.00033196  893.12855987]
 [ 191.00001793  491.10823899]
 [ 392.99996632  246.44959619]
 [ 815.9999226   603.73631655]
 [ 478.99994766  402.38804255]
 [ 108.00000242  323.60883955]
 [ 521.99988936  378.6801095 ]
 [ 254.99994193  447.16660467]
 [ 551.00002384  701.55427195]
 [ 316.99995616  338.26677152]
 [ 556.00005821  452.7389635 ]
 [ 272.99999164  669.22354918]
 [ 276.00001991  368.67177636]
 [1082.99982332  961.82106649]
 [ 842.00003848  493.49324261]
 [ 787.00009207  609.58491811]
 [1566.00037537 1784.20484805]
 [ 454.00005671  365.9668693 ]
 [ 676.00017861  430.15817402]
 [2734.99945903 1487.83036163]
 [ 397.00005391  245.75650814]
 [ 529.00002124  346.33296737]
 [1104.99983697  593.90778117]
 [ 223.99995213  211.89673758]
 [ 358.99990848  419.83585081]
 [ 411.00009365  309.22947465]
 [1066.00026157  526.11685838]
 [ 373.99993532  294.33653117]]
validate rmse: 0.49146996841647106
34 car:936168bd4850913d
train shape: 

train shape: (308, 1, 18)
label shape: (308, 1)
test shape: (22, 1, 18)
batch_size: 12
Epoch 00032: early stopping
[[ 312.99999516  596.36681993]
 [ 122.99998751  241.65044325]
 [1967.00026481 1541.58773887]
 [ 745.99986478  780.81626905]
 [ 614.00009249  780.64230167]
 [1204.99990653 1308.07279461]
 [ 122.99998751  192.76828046]
 [ 337.99999916  380.33903537]
 [2003.99969025 1260.7461501 ]
 [ 364.99992211  368.58079376]
 [ 283.00005537  246.81035835]
 [ 166.00000379  274.44806227]
 [ 544.00013778  603.7269833 ]
 [ 801.9998357   960.001075  ]
 [  60.00000318   84.56058464]
 [1121.99971461 1485.86786663]
 [ 379.99998495  449.74056651]
 [ 806.99991589  826.10639554]
 [1730.00011248 1338.05152282]
 [ 411.99994411  460.14225712]
 [ 551.00002384  656.17451046]
 [ 658.00016377 1160.38942906]
 [ 221.99998786  244.66320069]
 [ 756.99997368  879.21527483]
 [ 132.99998377  116.93052105]
 [ 105.00001102  782.2857154 ]
 [ 423.00001424  434.90145987]
 [ 654.99991058  598.03882414]
 [ 503.00000204  

Epoch 00037: early stopping
[[221.99998786 154.3168227 ]
 [ 19.00000175  42.34102136]
 [257.00000759 168.75917988]
 [242.00005203 211.53208578]
 [ 28.00000147  56.62385965]
 [262.00006876 188.87456929]
 [ 93.00000627 148.80451154]
 [ 22.99999739  12.45784147]
 [ 59.00000054  66.86071033]
 [137.00000586 132.54961503]
 [ 22.99999739  52.48211111]
 [ 76.00000704  48.01956194]
 [ 65.00000864  76.09525423]
 [242.99996579 153.47118157]
 [ 44.00000474  84.94846118]
 [ 88.00000949 132.92736715]
 [ 62.99999302  63.52019475]
 [ 53.00000147  64.6901978 ]
 [195.00001004 171.36153126]
 [276.00001991 194.87977663]
 [ 76.99999176  93.42407909]
 [ 32.00000002  31.13801735]
 [ 25.00000005  16.22368508]
 [ 19.00000175  20.90151664]
 [ 88.99999544 101.17827429]]
validate rmse: 0.6104625136622607
[[ 34.48209763  15.23377514   8.63243389   8.77971935]
 [442.43505859 236.99125671 144.32588196 129.77510071]
 [116.69728851  51.04667664  26.43666458  30.69802475]
 [ 79.80327606  36.75596237  23.58517265  20.90

test shape: (22, 1, 18)
batch_size: 11
Epoch 00045: early stopping
[[ 691.99995217 1142.29810128]
 [1047.00018148  629.96173036]
 [ 516.99988159  567.62222771]
 [1161.00012672  563.66733681]
 [1327.00010515 1032.56735318]
 [ 423.99989897  340.67304073]
 [2398.00032111 1927.37994374]
 [3201.00064747 1741.39932428]
 [ 271.0000524   714.63626518]
 [1284.00014882 1085.63388602]
 [ 355.00003633  347.43821663]
 [1232.00019733 1413.83893519]
 [2796.99935802 2087.22055624]
 [ 442.00011112 1137.05749106]
 [4023.99894562 3194.19600424]
 [2632.99982615 2065.04608784]
 [ 604.99988872  573.34374586]
 [1935.99990113 1383.30545854]
 [1616.00018315 1327.65177599]
 [ 469.99990502  296.41718609]
 [2248.99984105 1300.35694623]
 [ 353.99997444  334.6013398 ]
 [2400.99986102 1641.80913125]
 [ 836.99992008  648.89750427]
 [  86.00000749   66.20208598]
 [ 286.00003039  345.02462425]
 [ 341.0000874   317.73826877]
 [2655.9993544  1723.29748283]
 [ 981.99999406  950.50563595]]
validate rmse: 0.5993837235384591

train shape: (308, 1, 18)
label shape: (308, 1)
test shape: (22, 1, 18)
batch_size: 12
Epoch 00052: early stopping
[[ 164.99998246  130.89773364]
 [  64.00000005   85.13151568]
 [ 571.00003329  400.95834277]
 [  96.99999124  105.16372029]
 [1038.00012023  786.32325742]
 [ 532.99998864  492.23421714]
 [ 435.00000158  446.56201836]
 [ 248.99998537  253.60716364]
 [ 408.00006998  465.4863195 ]
 [  66.00000172   73.11151147]
 [ 153.00001371  141.66947075]
 [ 187.00001748  181.99210229]
 [ 655.99998724  561.06969135]
 [1275.99973464  981.77806236]
 [1132.00022183 1352.75119799]
 [1121.00011386  986.83417925]
 [ 308.00004924  307.20836795]
 [ 763.00005986  537.38816954]
 [ 321.0000371   250.75053339]
 [ 206.00002693  155.04475163]
 [1247.9998976   956.13845726]
 [ 118.00000109  104.91393149]
 [ 187.00001748  202.48144887]
 [ 579.00007907  612.6008018 ]
 [1424.99963864 1228.97169859]
 [1351.99999706  895.65502041]
 [ 807.99987609  569.69479063]
 [ 233.9999616   338.64548072]
 [1302.99999775  

Epoch 00073: early stopping
[[160.99999017 138.98815935]
 [ 69.9999944  127.11134864]
 [267.00003576 313.7990739 ]
 [374.99992058 433.50850291]
 [120.00000638 227.83610151]
 [374.99992058 346.91113361]
 [383.00007525 428.68327669]
 [ 19.99999736  60.03590075]
 [498.00010356 351.57528175]
 [225.99998218 135.98140352]
 [ 98.9999945   84.37381008]
 [172.000015   150.87020761]
 [438.0000456  581.39460403]
 [276.00001991 267.11634629]
 [756.99997368 625.53314153]
 [344.00003006 394.1131458 ]
 [ 89.99999743  99.88380338]
 [306.99996437 296.18799701]
 [516.99988159 419.64454078]
 [145.00001233 125.32608588]
 [461.99991302 393.13951795]
 [ 25.99999997  18.90804561]
 [101.00001138 115.29623649]
 [ 30.00000159  30.60109973]
 [403.00005947 406.09166965]]
validate rmse: 0.7497328524010565
[[131.3588562  101.04395294 115.27269745 110.7488327 ]
 [202.57595825 182.24134827 195.24249268 218.81504822]
 [ 57.71835709  52.71234512  53.57111359  53.240448  ]
 [ 98.74864197  96.73345184  84.74442291  83.95

Epoch 00034: early stopping
[[ 218.00000576  371.91634981]
 [2732.00060384 1161.07501653]
 [1883.99959018 1252.94138871]
 [ 334.00006454  246.93445936]
 [ 526.99993064  378.10341365]
 [ 262.00006876  304.27840553]
 [ 630.00001489  535.40659472]
 [1044.00005716  557.76668543]
 [ 523.00010272  785.36603264]
 [ 479.99989773  354.68773383]
 [ 404.00004566  275.80911169]
 [ 497.00001006  364.16421393]
 [1615.00034357 1227.76748968]
 [ 183.99997924  266.61004664]
 [2765.99967194 1481.60624082]
 [1533.0002394  1077.65146707]
 [1952.99955395 1401.02289425]
 [ 825.99983058  540.37285107]
 [ 483.99997513  271.31618   ]
 [ 739.00000483  457.21083003]
 [ 895.99980879  736.04896205]
 [ 213.00003262  200.78628676]
 [ 664.99983093  492.2919484 ]
 [2155.00020862 1618.55641623]
 [ 653.0001516   468.04296066]
 [ 605.99985753  496.27381079]
 [ 654.99991058  636.57659165]
 [1689.9997723   896.09295307]
 [ 417.99997219  318.80698295]]
validate rmse: 0.48945030755944663
44 car:7a7885e2d7c00bcf
train shape: 

train shape: (308, 1, 18)
label shape: (308, 1)
test shape: (22, 1, 18)
batch_size: 12
Epoch 00043: early stopping
[[  93.99999343   68.59958098]
 [  79.00001047   82.22746491]
 [ 593.00015171  482.67459008]
 [ 138.00000994  154.18464163]
 [ 764.00007196  520.33572921]
 [ 965.00008274  797.13186327]
 [ 648.00013456  411.17102349]
 [ 308.00004924  237.46191699]
 [ 434.00008731  509.33844591]
 [ 141.99999548  132.60297356]
 [ 175.00000947  177.69310348]
 [ 276.99995444  232.84311984]
 [ 438.0000456   388.81564538]
 [1084.99978541  835.75509376]
 [ 521.99988936  467.37725994]
 [1253.0001667  1312.76708969]
 [ 497.00001006  484.98890409]
 [ 292.99992504  201.58366988]
 [ 525.99991892  425.29042857]
 [ 267.00003576  242.203062  ]
 [ 478.99994766  387.65212221]
 [ 199.99994716  115.79725329]
 [ 304.99993311  315.80545472]
 [ 454.99996294  549.10669407]
 [ 639.00004583  592.64810351]
 [ 511.00012141  437.59233131]
 [1284.00014882  936.39261161]
 [ 277.99995496  343.42781393]
 [ 400.000001    

Epoch 00070: early stopping
[[ 86.00000749  33.38714936]
 [ 33.00000086  50.80874069]
 [114.00000126  92.35958875]
 [175.00000947 148.86953214]
 [ 47.99999611  74.39317366]
 [148.00000398  96.04624046]
 [201.00004198 100.65478249]
 [  8.          10.2197518 ]
 [129.00000071 218.6804355 ]
 [ 88.00000949  63.33360607]
 [ 25.99999997  64.13938316]
 [ 70.99999773  78.07129443]
 [ 16.00000001  22.03499638]
 [100.00000022  65.82579386]
 [ 62.99999302  35.41984107]
 [ 88.00000949  50.15616891]
 [ 91.00000462  67.31886476]
 [ 88.99999544 131.91260222]
 [209.00004177 119.93119503]
 [ 74.00000198 113.17946153]
 [  5.99999991  15.47272632]
 [  2.99999996   9.44978282]
 [ 28.00000147  20.63682603]
 [ 98.9999945   15.34691566]
 [158.00002097  96.15438487]]
validate rmse: 0.45914593091872036
[[ 50.17279053  66.76558685  45.07048416  20.45087624]
 [346.17529297 207.19100952 127.06642151  80.79761505]
 [218.69763184  82.32958221  76.83026123  53.54040146]
 [ 64.44799805   9.24969864  16.84269905  16.9

Epoch 00051: early stopping
[[108.00000242 154.02231475]
 [940.99977284 622.1401032 ]
 [584.00010854 561.39879138]
 [103.00001345  92.79775692]
 [187.00001748 178.59774266]
 [182.99996917 179.8546831 ]
 [315.99995776 268.44003438]
 [388.99995138 313.7545807 ]
 [190.00004309 335.19333364]
 [197.99998904 171.74594787]
 [126.99999948 132.46333274]
 [204.99996877 174.7015072 ]
 [253.00006598 257.37511622]
 [ 56.00000294  86.26528381]
 [253.00006598 266.88050368]
 [201.00004198 179.45713914]
 [677.99989133 565.47219206]
 [562.00004641 379.11255173]
 [ 91.00000462  83.85197818]
 [373.00003175 243.86799236]
 [531.00005987 460.71807189]
 [164.99998246 133.47593035]
 [297.99995267 258.77520637]
 [709.99988352 569.23640025]
 [297.00003849 257.8570055 ]
 [210.00002206 170.40613484]
 [289.00006896 256.67579354]
 [195.00001004 162.98039604]
 [ 86.00000749 100.98746723]]
validate rmse: 0.704390141438734
49 car:5d7fb682edd0f937
train shape: (264, 1, 18)
label shape: (264, 1)
test shape: (22, 1, 18)
b

train shape: (308, 1, 18)
label shape: (308, 1)
test shape: (22, 1, 18)
batch_size: 12
Epoch 00049: early stopping
[[ 401.00002467 1010.97470186]
 [ 238.99995737  361.76303577]
 [2170.99983785 1244.86437836]
 [ 946.00018406  585.91285773]
 [1234.99994918  972.84585569]
 [2458.00039525 1752.5615789 ]
 [1272.00027139  901.38505814]
 [ 690.00014234  707.5801943 ]
 [2187.00050178 1360.91504269]
 [ 713.99996921  581.18079241]
 [ 589.99992742  390.04648868]
 [ 556.00005821  599.2192346 ]
 [3344.00067035 2426.57353048]
 [2155.00020862 1968.47824838]
 [ 496.00007412  546.27554192]
 [3743.99938791 2690.44773964]
 [1536.99996726 1236.02608734]
 [1683.99962837 1214.61759122]
 [3589.99956751 2985.00813789]
 [ 563.000068    548.67042655]
 [2083.00008864 1693.9440139 ]
 [ 398.00000325  772.75859127]
 [ 655.99998724  456.05117719]
 [4164.00082321 3262.47061804]
 [ 615.99993438  457.22935383]
 [ 401.99997688 1376.48359257]
 [1416.99965775 1252.3310622 ]
 [1080.00016939  830.9652433 ]
 [1856.00040504 1

Epoch 00051: early stopping
[[ 89.99999743  64.15612199]
 [ 88.00000949 152.46891522]
 [ 74.99999403  57.25043928]
 [149.00001603 179.71945201]
 [137.00000586 202.2677356 ]
 [ 60.99999469  56.61903027]
 [487.00008622 298.9633879 ]
 [ 30.00000159  23.03893384]
 [289.99994741 259.53669104]
 [ 94.99999621  93.81312117]
 [148.00000398 146.8835636 ]
 [ 74.99999403  62.67724137]
 [ 33.99999953  31.64249544]
 [ 72.00000749  46.74409835]
 [231.00001805 196.69224081]
 [153.99998355 139.75358786]
 [ 39.99999472  38.82678559]
 [511.00012141 329.12945211]
 [139.99998882 114.87840892]
 [ 64.00000005  52.87262005]
 [ 32.00000002  28.94952165]
 [218.99996439 137.26529458]
 [ 64.00000005  72.27374631]
 [ 47.99999611  30.18853365]
 [ 66.00000172  78.94498769]]
validate rmse: 0.5589449200588954
[[188.48597717 103.19406128 109.27294159 111.9598465 ]
 [ 82.44002533  56.12828445  48.52261734  62.86833954]
 [ 55.86650467  36.91265869  32.88195419  36.71067429]
 [ 32.15263367  35.09800339  19.42809105  18.76

Epoch 00046: early stopping
[[ 166.00000379  283.59643382]
 [1308.00027686 1038.87863513]
 [1317.00003868  945.37960392]
 [ 193.9999825   196.20559479]
 [ 482.99993117  386.0384619 ]
 [ 193.9999825   275.54150383]
 [ 546.99996414  453.82462784]
 [ 959.00009105  528.68089526]
 [ 423.99989897  821.76556987]
 [ 420.99990696  416.22181947]
 [ 255.99993204  247.75960944]
 [ 342.00006709  377.33605294]
 [ 541.0000357   493.27809778]
 [  44.00000474  124.1144934 ]
 [ 646.00005078  530.32170448]
 [ 337.00003805  379.82439753]
 [ 932.99991951  937.52291648]
 [1202.99997605  782.51010413]
 [ 103.00001345  153.94810884]
 [1125.99983606  725.2183609 ]
 [1005.99973609 1089.97606162]
 [ 281.99995734  200.99982769]
 [ 435.00000158  495.10797515]
 [1678.00021347 1072.71773439]
 [ 455.99988364  497.37239015]
 [ 264.00000696  268.13469056]
 [ 919.99977455  690.53001249]
 [ 411.99994411  391.15693231]
 [ 199.99994716  198.26456719]]
validate rmse: 0.6299835182690063
54 car:da457d15788fe8ee
train shape: (

train shape: (308, 1, 18)
label shape: (308, 1)
test shape: (22, 1, 18)
batch_size: 12
Epoch 00044: early stopping
[[ 425.99995176  369.66855074]
 [ 318.99993356  420.38408413]
 [ 995.00014158  581.31309844]
 [ 325.00000021  278.19831163]
 [ 804.00016818  767.95347233]
 [ 746.99989529  528.26648823]
 [ 765.0001713   819.63976409]
 [ 314.99992343  350.93676966]
 [ 753.99996206  610.37377269]
 [ 442.00011112  349.67888004]
 [ 314.99992343  298.34725818]
 [ 417.99997219  425.10048539]
 [1805.99962308 1468.94937354]
 [1241.00021283 1094.4995428 ]
 [1047.00018148 1153.16228854]
 [1162.00008721 1100.14086239]
 [ 852.00013075  888.0671733 ]
 [1324.99968636 1045.14491808]
 [1082.99982332  678.61857519]
 [ 379.00000058  413.2435956 ]
 [1694.00000113 1365.9584317 ]
 [ 364.00001861  251.36000466]
 [ 469.99990502  487.33042114]
 [1624.99978615 1556.81289887]
 [ 995.9999418   998.34895808]
 [1263.00029131 1226.41016064]
 [ 646.00005078  679.19690902]
 [ 517.99990282  377.45635964]
 [1577.99962826 1

Epoch 00036: early stopping
[[ 122.99998751   76.8946346 ]
 [  10.99999971   59.5726177 ]
 [  50.0000001   125.21052686]
 [ 649.99982727  506.28276121]
 [  12.99999998  144.12941926]
 [  86.99998863  150.42899551]
 [ 351.00000732  177.827189  ]
 [ 134.0000032    83.29835334]
 [ 197.00001339   81.17601247]
 [ 153.00001371   54.11888949]
 [  41.99999877   90.1384937 ]
 [ 775.99993026  281.53687649]
 [  25.00000005  118.95833417]
 [  54.99999863   52.95075087]
 [ 470.9998974   337.47590048]
 [ 454.00005671  427.90061931]
 [  41.99999877  228.01203552]
 [ 222.99998655   62.63705976]
 [1481.99961955  784.21388435]
 [ 274.00001175  140.70365556]
 [  80.99999519   75.53658874]
 [  22.00000237   31.70029553]
 [  57.99999715   80.24746391]
 [ 461.00010567  184.53149103]
 [  38.99999677  110.56213125]]
validate rmse: 0.1915540208459544
[[ 140.8269043    54.72396088   19.65753937    8.38460255]
 [ 452.50942993  206.93449402   63.01054764   25.36350822]
 [ 320.26980591  165.10749817   64.34998322 

Epoch 00035: early stopping
[[ 47.99999611  43.01099642]
 [212.00000596 177.98499925]
 [109.99999728 108.23901954]
 [ 72.00000749  53.94599646]
 [157.00002043 124.44030873]
 [ 70.99999773  79.14442212]
 [231.00001805 167.21837295]
 [444.00009413 187.65219922]
 [232.99994078 377.76635616]
 [134.99998513 123.88759739]
 [ 67.00000159  63.75628976]
 [132.00000346 104.24664022]
 [193.00004207 181.86311667]
 [ 20.99999938  33.62499026]
 [356.00007678 154.47357207]
 [112.0000059   83.41268465]
 [177.00003437 153.22332743]
 [497.00001006 166.05762144]
 [ 50.99999513  30.6738759 ]
 [289.99994741 292.3487047 ]
 [309.99992256 233.66676815]
 [ 83.99999756  58.67327695]
 [241.00001318 185.54946442]
 [126.99999948 124.1049994 ]
 [251.99997215 166.29392113]
 [ 80.99999519  87.84175728]
 [577.99998394 341.16336472]
 [112.99999107  83.75804975]
 [ 64.00000005  45.75937055]]
validate rmse: 0.4413340360907759
59 car:f5d69960089c3614
train shape: (264, 1, 18)
label shape: (264, 1)
test shape: (22, 1, 18)


```
all validation rmse: 0.6182310141381996 

```

In [10]:
submit.sort_values(by=['id'], ascending=True)['forecastVolum'].mean()

568.2767967775012

In [11]:
submit = submit.sort_values(by=['id'], ascending=True)
submit['forecastVolum'] = submit['forecastVolum'].map(lambda index: int(np.round(index)))
print('res store over')

res store over


In [12]:
thresh = submit['forecastVolum'].mean() / 450
submit.groupby(['regMonth'])['forecastVolum'].mean()

regMonth
1    745.329545
2    534.420455
3    487.915152
4    505.452273
Name: forecastVolum, dtype: float64

In [13]:
sub = submit.copy()
sub['forecastVolum'] = sub['forecastVolum'] / thresh
sub['forecastVolum'] = sub['forecastVolum'].map(lambda index: int(np.round(index)))
sub.groupby(['regMonth'])['forecastVolum'].mean()
sub[['id', 'forecastVolum']].to_csv('./submit/chusai_lstm_model_2.csv', encoding='utf-8', index=None)

regMonth
1    590.193182
2    423.180303
3    386.362879
4    400.250758
Name: forecastVolum, dtype: float64