In [1]:
import argparse
import glob
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import sys
import numpy as np
import pandas as pd
from d3mds import D3MDS

In [2]:
from sklearn.metrics import mean_absolute_error
from sklearn.preprocessing import MinMaxScaler

In [3]:
import matplotlib.pyplot as plt
%matplotlib inline

In [4]:
import keras
from keras.models import Model, Sequential, load_model
from keras.layers import Dense, GRU, InputLayer, Lambda, Masking
from keras import backend as K
from keras.callbacks import EarlyStopping

Using TensorFlow backend.


In [5]:
from keras.layers import LSTM

In [6]:
def build_lstm_model(batch_size=1, num_time_steps=1, num_feat=1, neurons=5):
    model = Sequential()
    model.add(LSTM(neurons, batch_input_shape=(batch_size, num_time_steps, num_feat), stateful=True))
    model.add(Dense(1))
    opt = keras.optimizers.Adam(lr=1e-4)
    model.compile(loss='mae', optimizer=opt)
    return model

def train_lstm(model, train, batch_size, nb_epoch, last_ind):
    X, y = train[0:last_ind+1, 0:1], train[304:(304+last_ind+1), 0:1]    
    print(X.shape, y.shape)
    X = X.reshape(X.shape[0], 1, X.shape[1])
    #print(X.shape)
    for i in range(nb_epoch):
        model.fit(X, y, epochs=1, batch_size=batch_size, verbose=1, shuffle=False)
        model.reset_states()
    return model

In [7]:
# Fix random seed for reproducibility
np.random.seed(42)

In [8]:
dataset_root = '../../'
model_dir = './models'

## Load train data

In [9]:
print('Load DATA')    
data_path = glob.glob(os.path.join(dataset_root, "*_dataset"))[0]
problem_path = glob.glob(os.path.join(dataset_root, "*_problem"))[0]
d3mds = D3MDS(data_path, problem_path)

Load DATA




In [10]:
print('Load train data')
df_train = d3mds.get_train_data()
targets_train = d3mds.get_train_targets()

Load train data


In [11]:
X_train = df_train.copy()
X_train['count'] = targets_train
X_train = X_train.pivot_table(index=['species', 'sector'], columns='day', values='count')

In [12]:
X_train

Unnamed: 0_level_0,day,2,3,4,5,6,7,8,9,10,11,...,296,297,298,299,300,301,302,303,304,305
species,sector,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
cas9_CAD,S_0102,,,4580.0,4730.0,4860.0,4890.0,4920.0,,,4850.0,...,,,4500.0,4550.0,4500.0,4530.0,4530.0,,,
cas9_CAD,S_1102,,3910.0,4019.0,3940.0,3900.0,3960.0,,,4100.0,4880.0,...,,3900.0,3620.0,3650.0,3770.0,3900.0,,,3540.0,
cas9_CAD,S_2102,,3410.0,3800.0,3810.0,3630.0,,,3400.0,3420.0,3500.0,...,2750.0,2700.0,2700.0,2750.0,2700.0,,,,,2750.0
cas9_CAD,S_3102,2810.0,2820.0,2980.0,,,3110.0,3210.0,3520.0,3360.0,3440.0,...,3930.0,3930.0,3850.0,,,3750.0,3910.0,4000.0,4000.0,
cas9_CAD,S_4102,5080.0,5490.0,,,6070.0,6280.0,6630.0,6500.0,6400.0,,...,4550.0,4490.0,,,4600.0,4640.0,5000.0,5200.0,5480.0,
cas9_CAD,S_5102,5570.0,,,5290.0,5280.0,5390.0,5230.0,5190.0,,,...,6390.0,,,6430.0,6390.0,6130.0,6020.0,6200.0,,
cas9_CAD,S_6102,,,5920.0,5360.0,5210.0,5140.0,5160.0,,,5230.0,...,,,2510.0,2600.0,2610.0,2550.0,2530.0,,,2550.0
cas9_CAD,S_7002,,23800.0,23900.0,23630.0,,,23890.0,23900.0,23700.0,23830.0,...,35750.0,35300.0,36290.0,37500.0,,,37390.0,36390.0,35910.0,
cas9_CAD,S_7102,,2750.0,2750.0,2700.0,2650.0,,,2700.0,2700.0,2750.0,...,1500.0,1500.0,1500.0,1500.0,1500.0,,,1500.0,,
cas9_CAD,S_8002,26460.0,26160.0,25630.0,,,26240.0,25350.0,24890.0,25040.0,25490.0,...,8010.0,7820.0,7460.0,,,6800.0,7130.0,7220.0,7490.0,8390.0


In [13]:
last_day_info = {}

for i, (idx, df) in enumerate(X_train.groupby(level=[0,1])):
    print(i, (idx, df))
    last_ind = np.where(X_train.notnull().iloc[i,:]==True)[0][-1]
    last_day =  X_train.columns[last_ind]
    
    last_day_info[idx] = {}
    last_day_info[idx]['last_ind'] = last_ind
    last_day_info[idx]['last_day'] = last_day

0 (('cas9_CAD', 'S_0102'), day              2    3       4       5       6       7       8    9    10   \
species  sector                                                               
cas9_CAD S_0102  NaN  NaN  4580.0  4730.0  4860.0  4890.0  4920.0  NaN  NaN   

day                 11  ...   296  297     298     299     300     301  \
species  sector         ...                                              
cas9_CAD S_0102  4850.0 ...   NaN  NaN  4500.0  4550.0  4500.0  4530.0   

day                 302  303  304  305  
species  sector                         
cas9_CAD S_0102  4530.0  NaN  NaN  NaN  

[1 rows x 304 columns])
1 (('cas9_CAD', 'S_1102'), day              2       3       4       5       6       7    8    9    \
species  sector                                                          
cas9_CAD S_1102  NaN  3910.0  4019.0  3940.0  3900.0  3960.0  NaN  NaN   

day                 10      11  ...   296     297     298     299     300  \
species  sector                 ...  

25 (('cas9_HNAF', 'S_1102'), day               2        3        4        5        6        7    8    9    \
species   sector                                                               
cas9_HNAF S_1102  NaN  16842.0  17078.0  17562.0  17562.0  17907.0  NaN  NaN   

day                   10       11  ...   296     297     298     299     300  \
species   sector                   ...                                         
cas9_HNAF S_1102  18213.0  18558.0 ...   NaN  7675.0  8129.0  7893.0  7902.0   

day                  301  302  303     304  305  
species   sector                                 
cas9_HNAF S_1102  7784.0  NaN  NaN  7449.0  NaN  

[1 rows x 304 columns])
26 (('cas9_HNAF', 'S_2102'), day               2       3       4       5       6    7    8       9    \
species   sector                                                          
cas9_HNAF S_2102  NaN  7182.0  7123.0  7123.0  6807.0  NaN  NaN  6659.0   

day                  10      11    ...       296     297   

40 (('cas9_JAC', 'S_5102'), day                  2    3    4        5        6        7        8    \
species  sector                                                          
cas9_JAC S_5102  30581.0  NaN  NaN  30109.0  29935.0  30013.0  30360.0   

day                  9    10   11  ...       296  297  298      299      300  \
species  sector                    ...                                         
cas9_JAC S_5102  29993.0  NaN  NaN ...   30640.0  NaN  NaN  30254.0  28092.0   

day                  301      302      303  304  305  
species  sector                                       
cas9_JAC S_5102  29191.0  29077.0  28748.0  NaN  NaN  

[1 rows x 304 columns])
41 (('cas9_JAC', 'S_6002'), day              2        3        4        5        6    7    8        9    \
species  sector                                                               
cas9_JAC S_6002  NaN  36614.0  37380.0  37461.0  38545.0  NaN  NaN  38672.0   

day                  10       11  ...       296     

56 (('cas9_MBI', 'S_1991'), day                  2        3        4    5    6        7        8    \
species  sector                                                          
cas9_MBI S_1991  23320.0  23397.0  23320.0  NaN  NaN  22928.0  22670.0   

day                  9        10       11  ...       296      297      298  \
species  sector                            ...                               
cas9_MBI S_1991  22230.0  22537.0  22488.0 ...   20483.0  20430.0  20383.0   

day              299  300      301      302      303      304  305  
species  sector                                                     
cas9_MBI S_1991  NaN  NaN  20430.0  20566.0  20851.0  20430.0  NaN  

[1 rows x 304 columns])
57 (('cas9_MBI', 'S_2002'), day                   2         3         4    5    6         7         8    \
species  sector                                                               
cas9_MBI S_2002  101440.0  103280.0  104870.0  NaN  NaN  103570.0  104140.0   

day             

71 (('cas9_MBI', 'S_4691'), day                 2       3    4    5       6       7       8       9    \
species  sector                                                             
cas9_MBI S_4691  5728.0  5652.0  NaN  NaN  5669.0  5627.0  5686.0  5861.0   

day                 10   11  ...      296     297  298  299     300     301  \
species  sector              ...                                              
cas9_MBI S_4691  5911.0  NaN ...   5861.0  5836.0  NaN  NaN  5753.0  5769.0   

day                 302     303     304  305  
species  sector                               
cas9_MBI S_4691  5803.0  5828.0  5820.0  NaN  

[1 rows x 304 columns])
72 (('cas9_MBI', 'S_4791'), day                  2        3        4    5    6        7        8    \
species  sector                                                          
cas9_MBI S_4791  12615.0  12433.0  11989.0  NaN  NaN  11807.0  11849.0   

day                  9        10       11  ...      296     297     298  299  \
speci

86 (('cas9_MBI', 'S_6991'), day                  2        3        4        5    6    7        8    \
species  sector                                                          
cas9_MBI S_6991  18905.0  18563.0  18069.0  18437.0  NaN  NaN  18535.0   

day                  9        10       11    ...         296      297  \
species  sector                              ...                        
cas9_MBI S_6991  18044.0  18145.0  18145.0   ...     26421.0  26888.0   

day                  298      299  300  301      302      303      304  \
species  sector                                                          
cas9_MBI S_6991  26494.0  26311.0  NaN  NaN  26494.0  26127.0  26285.0   

day                  305  
species  sector           
cas9_MBI S_6991  26837.0  

[1 rows x 304 columns])
87 (('cas9_MBI', 'S_7002'), day              2        3        4        5    6    7        8        9    \
species  sector                                                               
cas9_MBI S_700

101 (('cas9_MBI', 'S_9891'), day              2        3        4        5        6    7    8        9    \
species  sector                                                               
cas9_MBI S_9891  NaN  25166.0  25474.0  25427.0  25344.0  NaN  NaN  25365.0   

day                  10       11  ...       296      297      298      299  \
species  sector                   ...                                        
cas9_MBI S_9891  25193.0  25427.0 ...   21525.0  21415.0  21186.0  20949.0   

day                  300  301  302      303      304  305  
species  sector                                            
cas9_MBI S_9891  20767.0  NaN  NaN  20691.0  20851.0  NaN  

[1 rows x 304 columns])
102 (('cas9_MBI', 'S_9991'), day              2    3        4        5        6        7        8    9    \
species  sector                                                               
cas9_MBI S_9991  NaN  NaN  76138.0  78892.0  78525.0  79126.0  78033.0  NaN   

day              10       

117 (('cas9_VBBA', 'S_4102'), day                   2        3    4    5        6        7        8    \
species   sector                                                          
cas9_VBBA S_4102  45285.0  45563.0  NaN  NaN  43900.0  43988.0  43873.0   

day                   9        10   11  ...       296      297  298  299  \
species   sector                        ...                                
cas9_VBBA S_4102  44623.0  44343.0  NaN ...   53528.0  54222.0  NaN  NaN   

day                   300      301      302      303      304  305  
species   sector                                                    
cas9_VBBA S_4102  54398.0  54798.0  53943.0  55020.0  57071.0  NaN  

[1 rows x 304 columns])
118 (('cas9_VBBA', 'S_5102'), day                   2    3    4        5        6        7        8    \
species   sector                                                          
cas9_VBBA S_5102  59258.0  NaN  NaN  58140.0  57854.0  60193.0  60821.0   

day                   9    

132 (('cas9_YABE', 'S_3102'), day                   2        3        4    5    6        7        8    \
species   sector                                                          
cas9_YABE S_3102  22555.0  22075.0  22214.0  NaN  NaN  22521.0  22171.0   

day                   9        10       11  ...       296      297      298  \
species   sector                            ...                               
cas9_YABE S_3102  22205.0  22305.0  22601.0 ...   21772.0  22029.0  21730.0   

day               299  300      301      302      303      304  305  
species   sector                                                     
cas9_YABE S_3102  NaN  NaN  21873.0  22449.0  22189.0  22184.0  NaN  

[1 rows x 304 columns])
133 (('cas9_YABE', 'S_4002'), day                   2    3    4        5        6        7        8    \
species   sector                                                          
cas9_YABE S_4002  13258.0  NaN  NaN  13597.0  13712.0  13720.0  13561.0   

day            

In [14]:
last_day_info.keys()

dict_keys([('cas9_CAD', 'S_0102'), ('cas9_CAD', 'S_1102'), ('cas9_CAD', 'S_2102'), ('cas9_CAD', 'S_3102'), ('cas9_CAD', 'S_4102'), ('cas9_CAD', 'S_5102'), ('cas9_CAD', 'S_6102'), ('cas9_CAD', 'S_7002'), ('cas9_CAD', 'S_7102'), ('cas9_CAD', 'S_8002'), ('cas9_CAD', 'S_9002'), ('cas9_FAB', 'S_0102'), ('cas9_FAB', 'S_1102'), ('cas9_FAB', 'S_2102'), ('cas9_FAB', 'S_3102'), ('cas9_FAB', 'S_4102'), ('cas9_FAB', 'S_5002'), ('cas9_FAB', 'S_5102'), ('cas9_FAB', 'S_6002'), ('cas9_FAB', 'S_6102'), ('cas9_FAB', 'S_7002'), ('cas9_FAB', 'S_7102'), ('cas9_FAB', 'S_8002'), ('cas9_FAB', 'S_9002'), ('cas9_HNAF', 'S_0102'), ('cas9_HNAF', 'S_1102'), ('cas9_HNAF', 'S_2102'), ('cas9_HNAF', 'S_3102'), ('cas9_HNAF', 'S_4102'), ('cas9_HNAF', 'S_5102'), ('cas9_HNAF', 'S_6102'), ('cas9_HNAF', 'S_7102'), ('cas9_HNAF', 'S_8002'), ('cas9_HNAF', 'S_9002'), ('cas9_JAC', 'S_0102'), ('cas9_JAC', 'S_1102'), ('cas9_JAC', 'S_2102'), ('cas9_JAC', 'S_3102'), ('cas9_JAC', 'S_4102'), ('cas9_JAC', 'S_5002'), ('cas9_JAC', 'S_510

In [15]:
last_day_info[('cas9_CAD', 'S_0102')]

{'last_day': 302, 'last_ind': 300}

In [16]:
last_day_info[('cas9_CAD', 'S_1102')]

{'last_day': 304, 'last_ind': 302}

In [17]:
last_day_info[('cas9_CAD', 'S_2102')]

{'last_day': 305, 'last_ind': 303}

## Interpolate missing values for each time series

In [18]:
X_train.interpolate('index', axis=1, limit_direction='both', inplace=True)

In [19]:
X_train

Unnamed: 0_level_0,day,2,3,4,5,6,7,8,9,10,11,...,296,297,298,299,300,301,302,303,304,305
species,sector,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
cas9_CAD,S_0102,4580.0,4580.000000,4580.000000,4730.000000,4860.000000,4890.000000,4920.000000,4896.666667,4873.333333,4850.000000,...,4633.333333,4566.666667,4500.000000,4550.000000,4500.000000,4530.000000,4530.000000,4530.000000,4530.000000,4530.0
cas9_CAD,S_1102,3910.0,3910.000000,4019.000000,3940.000000,3900.000000,3960.000000,4006.666667,4053.333333,4100.000000,4880.000000,...,3920.000000,3900.000000,3620.000000,3650.000000,3770.000000,3900.000000,3780.000000,3660.000000,3540.000000,3540.0
cas9_CAD,S_2102,3410.0,3410.000000,3800.000000,3810.000000,3630.000000,3553.333333,3476.666667,3400.000000,3420.000000,3500.000000,...,2750.000000,2700.000000,2700.000000,2750.000000,2700.000000,2710.000000,2720.000000,2730.000000,2740.000000,2750.0
cas9_CAD,S_3102,2810.0,2820.000000,2980.000000,3023.333333,3066.666667,3110.000000,3210.000000,3520.000000,3360.000000,3440.000000,...,3930.000000,3930.000000,3850.000000,3816.666667,3783.333333,3750.000000,3910.000000,4000.000000,4000.000000,4000.0
cas9_CAD,S_4102,5080.0,5490.000000,5683.333333,5876.666667,6070.000000,6280.000000,6630.000000,6500.000000,6400.000000,6476.666667,...,4550.000000,4490.000000,4526.666667,4563.333333,4600.000000,4640.000000,5000.000000,5200.000000,5480.000000,5480.0
cas9_CAD,S_5102,5570.0,5476.666667,5383.333333,5290.000000,5280.000000,5390.000000,5230.000000,5190.000000,5186.666667,5183.333333,...,6390.000000,6403.333333,6416.666667,6430.000000,6390.000000,6130.000000,6020.000000,6200.000000,6200.000000,6200.0
cas9_CAD,S_6102,5920.0,5920.000000,5920.000000,5360.000000,5210.000000,5140.000000,5160.000000,5183.333333,5206.666667,5230.000000,...,2516.666667,2513.333333,2510.000000,2600.000000,2610.000000,2550.000000,2530.000000,2536.666667,2543.333333,2550.0
cas9_CAD,S_7002,23800.0,23800.000000,23900.000000,23630.000000,23716.666667,23803.333333,23890.000000,23900.000000,23700.000000,23830.000000,...,35750.000000,35300.000000,36290.000000,37500.000000,37463.333333,37426.666667,37390.000000,36390.000000,35910.000000,35910.0
cas9_CAD,S_7102,2750.0,2750.000000,2750.000000,2700.000000,2650.000000,2666.666667,2683.333333,2700.000000,2700.000000,2750.000000,...,1500.000000,1500.000000,1500.000000,1500.000000,1500.000000,1500.000000,1500.000000,1500.000000,1500.000000,1500.0
cas9_CAD,S_8002,26460.0,26160.000000,25630.000000,25833.333333,26036.666667,26240.000000,25350.000000,24890.000000,25040.000000,25490.000000,...,8010.000000,7820.000000,7460.000000,7240.000000,7020.000000,6800.000000,7130.000000,7220.000000,7490.000000,8390.0


## Create supervised version of dataset (predict one time point at a time)

### Take differences

In [21]:
# Create shifted form of each timeseries - (creates supervised version of timeseries problem)
def shift_df(df, lag=1):    
    shifted_df = df.shift(lag, axis=1)
    shifted_df.fillna(0, inplace=True)
    return shifted_df

In [22]:
def diff_df(df):
    df_diff = df.diff(periods=-1, axis=1)
    df_diff.fillna(0, inplace=True)
    return df_diff

In [23]:
X_train_diff = diff_df(X_train)

In [24]:
X_train_diff

Unnamed: 0_level_0,day,2,3,4,5,6,7,8,9,10,11,...,296,297,298,299,300,301,302,303,304,305
species,sector,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
cas9_CAD,S_0102,0.000000,0.000000,-150.000000,-130.000000,-30.000000,-30.000000,23.333333,23.333333,23.333333,80.000000,...,66.666667,66.666667,-50.000000,50.000000,-30.000000,0.000000,0.000000,0.000000,0.000000,0.0
cas9_CAD,S_1102,0.000000,-109.000000,79.000000,40.000000,-60.000000,-46.666667,-46.666667,-46.666667,-780.000000,240.000000,...,20.000000,280.000000,-30.000000,-120.000000,-130.000000,120.000000,120.000000,120.000000,0.000000,0.0
cas9_CAD,S_2102,0.000000,-390.000000,-10.000000,180.000000,76.666667,76.666667,76.666667,-20.000000,-80.000000,60.000000,...,50.000000,0.000000,-50.000000,50.000000,-10.000000,-10.000000,-10.000000,-10.000000,-10.000000,0.0
cas9_CAD,S_3102,-10.000000,-160.000000,-43.333333,-43.333333,-43.333333,-100.000000,-310.000000,160.000000,-80.000000,10.000000,...,0.000000,80.000000,33.333333,33.333333,33.333333,-160.000000,-90.000000,0.000000,0.000000,0.0
cas9_CAD,S_4102,-410.000000,-193.333333,-193.333333,-193.333333,-210.000000,-350.000000,130.000000,100.000000,-76.666667,-76.666667,...,60.000000,-36.666667,-36.666667,-36.666667,-40.000000,-360.000000,-200.000000,-280.000000,0.000000,0.0
cas9_CAD,S_5102,93.333333,93.333333,93.333333,10.000000,-110.000000,160.000000,40.000000,3.333333,3.333333,3.333333,...,-13.333333,-13.333333,-13.333333,40.000000,260.000000,110.000000,-180.000000,0.000000,0.000000,0.0
cas9_CAD,S_6102,0.000000,0.000000,560.000000,150.000000,70.000000,-20.000000,-23.333333,-23.333333,-23.333333,40.000000,...,3.333333,3.333333,-90.000000,-10.000000,60.000000,20.000000,-6.666667,-6.666667,-6.666667,0.0
cas9_CAD,S_7002,0.000000,-100.000000,270.000000,-86.666667,-86.666667,-86.666667,-10.000000,200.000000,-130.000000,-280.000000,...,450.000000,-990.000000,-1210.000000,36.666667,36.666667,36.666667,1000.000000,480.000000,0.000000,0.0
cas9_CAD,S_7102,0.000000,0.000000,50.000000,50.000000,-16.666667,-16.666667,-16.666667,0.000000,-50.000000,100.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0
cas9_CAD,S_8002,300.000000,530.000000,-203.333333,-203.333333,-203.333333,890.000000,460.000000,-150.000000,-450.000000,163.333333,...,190.000000,360.000000,220.000000,220.000000,220.000000,-330.000000,-90.000000,-270.000000,-900.000000,0.0


In [25]:
X_train_diff_shift = shift_df(X_train_diff)

In [26]:
X_train_diff_shift

Unnamed: 0_level_0,day,2,3,4,5,6,7,8,9,10,11,...,296,297,298,299,300,301,302,303,304,305
species,sector,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
cas9_CAD,S_0102,0.0,0.000000,0.000000,-150.000000,-130.000000,-30.000000,-30.000000,23.333333,23.333333,23.333333,...,66.666667,66.666667,66.666667,-50.000000,50.000000,-30.000000,0.000000,0.000000,0.000000,0.000000
cas9_CAD,S_1102,0.0,0.000000,-109.000000,79.000000,40.000000,-60.000000,-46.666667,-46.666667,-46.666667,-780.000000,...,20.000000,20.000000,280.000000,-30.000000,-120.000000,-130.000000,120.000000,120.000000,120.000000,0.000000
cas9_CAD,S_2102,0.0,0.000000,-390.000000,-10.000000,180.000000,76.666667,76.666667,76.666667,-20.000000,-80.000000,...,0.000000,50.000000,0.000000,-50.000000,50.000000,-10.000000,-10.000000,-10.000000,-10.000000,-10.000000
cas9_CAD,S_3102,0.0,-10.000000,-160.000000,-43.333333,-43.333333,-43.333333,-100.000000,-310.000000,160.000000,-80.000000,...,-30.000000,0.000000,80.000000,33.333333,33.333333,33.333333,-160.000000,-90.000000,0.000000,0.000000
cas9_CAD,S_4102,0.0,-410.000000,-193.333333,-193.333333,-193.333333,-210.000000,-350.000000,130.000000,100.000000,-76.666667,...,-60.000000,60.000000,-36.666667,-36.666667,-36.666667,-40.000000,-360.000000,-200.000000,-280.000000,0.000000
cas9_CAD,S_5102,0.0,93.333333,93.333333,93.333333,10.000000,-110.000000,160.000000,40.000000,3.333333,3.333333,...,10.000000,-13.333333,-13.333333,-13.333333,40.000000,260.000000,110.000000,-180.000000,0.000000,0.000000
cas9_CAD,S_6102,0.0,0.000000,0.000000,560.000000,150.000000,70.000000,-20.000000,-23.333333,-23.333333,-23.333333,...,3.333333,3.333333,3.333333,-90.000000,-10.000000,60.000000,20.000000,-6.666667,-6.666667,-6.666667
cas9_CAD,S_7002,0.0,0.000000,-100.000000,270.000000,-86.666667,-86.666667,-86.666667,-10.000000,200.000000,-130.000000,...,-610.000000,450.000000,-990.000000,-1210.000000,36.666667,36.666667,36.666667,1000.000000,480.000000,0.000000
cas9_CAD,S_7102,0.0,0.000000,0.000000,50.000000,50.000000,-16.666667,-16.666667,-16.666667,0.000000,-50.000000,...,-16.666667,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
cas9_CAD,S_8002,0.0,300.000000,530.000000,-203.333333,-203.333333,-203.333333,890.000000,460.000000,-150.000000,-450.000000,...,880.000000,190.000000,360.000000,220.000000,220.000000,220.000000,-330.000000,-90.000000,-270.000000,-900.000000


In [27]:
X_train_diff_full = pd.concat([X_train_diff_shift, X_train_diff], axis=1)
print(X_train_diff_full.shape)

(144, 608)


In [28]:
X_train_diff_full

Unnamed: 0_level_0,day,2,3,4,5,6,7,8,9,10,11,...,296,297,298,299,300,301,302,303,304,305
species,sector,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
cas9_CAD,S_0102,0.0,0.000000,0.000000,-150.000000,-130.000000,-30.000000,-30.000000,23.333333,23.333333,23.333333,...,66.666667,66.666667,-50.000000,50.000000,-30.000000,0.000000,0.000000,0.000000,0.000000,0.0
cas9_CAD,S_1102,0.0,0.000000,-109.000000,79.000000,40.000000,-60.000000,-46.666667,-46.666667,-46.666667,-780.000000,...,20.000000,280.000000,-30.000000,-120.000000,-130.000000,120.000000,120.000000,120.000000,0.000000,0.0
cas9_CAD,S_2102,0.0,0.000000,-390.000000,-10.000000,180.000000,76.666667,76.666667,76.666667,-20.000000,-80.000000,...,50.000000,0.000000,-50.000000,50.000000,-10.000000,-10.000000,-10.000000,-10.000000,-10.000000,0.0
cas9_CAD,S_3102,0.0,-10.000000,-160.000000,-43.333333,-43.333333,-43.333333,-100.000000,-310.000000,160.000000,-80.000000,...,0.000000,80.000000,33.333333,33.333333,33.333333,-160.000000,-90.000000,0.000000,0.000000,0.0
cas9_CAD,S_4102,0.0,-410.000000,-193.333333,-193.333333,-193.333333,-210.000000,-350.000000,130.000000,100.000000,-76.666667,...,60.000000,-36.666667,-36.666667,-36.666667,-40.000000,-360.000000,-200.000000,-280.000000,0.000000,0.0
cas9_CAD,S_5102,0.0,93.333333,93.333333,93.333333,10.000000,-110.000000,160.000000,40.000000,3.333333,3.333333,...,-13.333333,-13.333333,-13.333333,40.000000,260.000000,110.000000,-180.000000,0.000000,0.000000,0.0
cas9_CAD,S_6102,0.0,0.000000,0.000000,560.000000,150.000000,70.000000,-20.000000,-23.333333,-23.333333,-23.333333,...,3.333333,3.333333,-90.000000,-10.000000,60.000000,20.000000,-6.666667,-6.666667,-6.666667,0.0
cas9_CAD,S_7002,0.0,0.000000,-100.000000,270.000000,-86.666667,-86.666667,-86.666667,-10.000000,200.000000,-130.000000,...,450.000000,-990.000000,-1210.000000,36.666667,36.666667,36.666667,1000.000000,480.000000,0.000000,0.0
cas9_CAD,S_7102,0.0,0.000000,0.000000,50.000000,50.000000,-16.666667,-16.666667,-16.666667,0.000000,-50.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0
cas9_CAD,S_8002,0.0,300.000000,530.000000,-203.333333,-203.333333,-203.333333,890.000000,460.000000,-150.000000,-450.000000,...,190.000000,360.000000,220.000000,220.000000,220.000000,-330.000000,-90.000000,-270.000000,-900.000000,0.0


## Scale values to be in range [-1, 1]

In [29]:
def build_scalers(train_df):
    scalers = {}
    for i, (idx, df) in enumerate(train_df.groupby(level=[0,1])):
        #print(i, (idx, df))
        scaler = MinMaxScaler(feature_range=(-1, 1))
        scaler = scaler.fit(df.T)
        scalers[idx] = scaler

    return scalers

In [30]:
def apply_scalers(df_in, scalers):
    df_out = df_in.copy()
    for i, (idx, df) in enumerate(df_in.groupby(level=[0,1])):
        #print('before scale', df_in.loc[idx])
        df_out.loc[idx] = scalers[idx].transform(df_in.loc[idx].reshape(1,-1))
        #print('after scale', df_out.loc[idx])
        
    return df_out

In [31]:
scalers = build_scalers(X_train_diff_full)

In [32]:
scalers

{('cas9_CAD', 'S_0102'): MinMaxScaler(copy=True, feature_range=(-1, 1)),
 ('cas9_CAD', 'S_1102'): MinMaxScaler(copy=True, feature_range=(-1, 1)),
 ('cas9_CAD', 'S_2102'): MinMaxScaler(copy=True, feature_range=(-1, 1)),
 ('cas9_CAD', 'S_3102'): MinMaxScaler(copy=True, feature_range=(-1, 1)),
 ('cas9_CAD', 'S_4102'): MinMaxScaler(copy=True, feature_range=(-1, 1)),
 ('cas9_CAD', 'S_5102'): MinMaxScaler(copy=True, feature_range=(-1, 1)),
 ('cas9_CAD', 'S_6102'): MinMaxScaler(copy=True, feature_range=(-1, 1)),
 ('cas9_CAD', 'S_7002'): MinMaxScaler(copy=True, feature_range=(-1, 1)),
 ('cas9_CAD', 'S_7102'): MinMaxScaler(copy=True, feature_range=(-1, 1)),
 ('cas9_CAD', 'S_8002'): MinMaxScaler(copy=True, feature_range=(-1, 1)),
 ('cas9_CAD', 'S_9002'): MinMaxScaler(copy=True, feature_range=(-1, 1)),
 ('cas9_FAB', 'S_0102'): MinMaxScaler(copy=True, feature_range=(-1, 1)),
 ('cas9_FAB', 'S_1102'): MinMaxScaler(copy=True, feature_range=(-1, 1)),
 ('cas9_FAB', 'S_2102'): MinMaxScaler(copy=True, fe

In [33]:
X_train_diff_full

Unnamed: 0_level_0,day,2,3,4,5,6,7,8,9,10,11,...,296,297,298,299,300,301,302,303,304,305
species,sector,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
cas9_CAD,S_0102,0.0,0.000000,0.000000,-150.000000,-130.000000,-30.000000,-30.000000,23.333333,23.333333,23.333333,...,66.666667,66.666667,-50.000000,50.000000,-30.000000,0.000000,0.000000,0.000000,0.000000,0.0
cas9_CAD,S_1102,0.0,0.000000,-109.000000,79.000000,40.000000,-60.000000,-46.666667,-46.666667,-46.666667,-780.000000,...,20.000000,280.000000,-30.000000,-120.000000,-130.000000,120.000000,120.000000,120.000000,0.000000,0.0
cas9_CAD,S_2102,0.0,0.000000,-390.000000,-10.000000,180.000000,76.666667,76.666667,76.666667,-20.000000,-80.000000,...,50.000000,0.000000,-50.000000,50.000000,-10.000000,-10.000000,-10.000000,-10.000000,-10.000000,0.0
cas9_CAD,S_3102,0.0,-10.000000,-160.000000,-43.333333,-43.333333,-43.333333,-100.000000,-310.000000,160.000000,-80.000000,...,0.000000,80.000000,33.333333,33.333333,33.333333,-160.000000,-90.000000,0.000000,0.000000,0.0
cas9_CAD,S_4102,0.0,-410.000000,-193.333333,-193.333333,-193.333333,-210.000000,-350.000000,130.000000,100.000000,-76.666667,...,60.000000,-36.666667,-36.666667,-36.666667,-40.000000,-360.000000,-200.000000,-280.000000,0.000000,0.0
cas9_CAD,S_5102,0.0,93.333333,93.333333,93.333333,10.000000,-110.000000,160.000000,40.000000,3.333333,3.333333,...,-13.333333,-13.333333,-13.333333,40.000000,260.000000,110.000000,-180.000000,0.000000,0.000000,0.0
cas9_CAD,S_6102,0.0,0.000000,0.000000,560.000000,150.000000,70.000000,-20.000000,-23.333333,-23.333333,-23.333333,...,3.333333,3.333333,-90.000000,-10.000000,60.000000,20.000000,-6.666667,-6.666667,-6.666667,0.0
cas9_CAD,S_7002,0.0,0.000000,-100.000000,270.000000,-86.666667,-86.666667,-86.666667,-10.000000,200.000000,-130.000000,...,450.000000,-990.000000,-1210.000000,36.666667,36.666667,36.666667,1000.000000,480.000000,0.000000,0.0
cas9_CAD,S_7102,0.0,0.000000,0.000000,50.000000,50.000000,-16.666667,-16.666667,-16.666667,0.000000,-50.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0
cas9_CAD,S_8002,0.0,300.000000,530.000000,-203.333333,-203.333333,-203.333333,890.000000,460.000000,-150.000000,-450.000000,...,190.000000,360.000000,220.000000,220.000000,220.000000,-330.000000,-90.000000,-270.000000,-900.000000,0.0


In [34]:
X_train_scaled = apply_scalers(X_train_diff_full, scalers)

  """


In [60]:
X_train_scaled

Unnamed: 0_level_0,day,2,3,4,5,6,7,8,9,10,11,...,296,297,298,299,300,301,302,303,304,305
species,sector,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
cas9_CAD,S_0102,0.014925,0.014925,0.014925,-4.328358e-01,-0.373134,-0.074627,-0.074627,0.084577,0.084577,0.084577,...,0.213930,0.213930,-0.134328,0.164179,-0.074627,0.014925,0.014925,0.014925,0.014925,0.014925
cas9_CAD,S_1102,0.132867,0.132867,-0.019580,2.433566e-01,0.188811,0.048951,0.067599,0.067599,0.067599,-0.958042,...,0.160839,0.524476,0.090909,-0.034965,-0.048951,0.300699,0.300699,0.300699,0.132867,0.132867
cas9_CAD,S_2102,0.382716,0.382716,-0.580247,3.580247e-01,0.827160,0.572016,0.572016,0.572016,0.333333,0.185185,...,0.506173,0.382716,0.259259,0.506173,0.358025,0.358025,0.358025,0.358025,0.358025,0.382716
cas9_CAD,S_3102,0.247059,0.223529,-0.129412,1.450980e-01,0.145098,0.145098,0.011765,-0.482353,0.623529,0.058824,...,0.247059,0.435294,0.325490,0.325490,0.325490,-0.129412,0.035294,0.247059,0.247059,0.247059
cas9_CAD,S_4102,0.047619,-0.603175,-0.259259,-2.592593e-01,-0.259259,-0.285714,-0.507937,0.253968,0.206349,-0.074074,...,0.142857,-0.010582,-0.010582,-0.010582,-0.015873,-0.523810,-0.269841,-0.396825,0.047619,0.047619
cas9_CAD,S_5102,0.192982,0.302144,0.302144,3.021442e-01,0.204678,0.064327,0.380117,0.239766,0.196881,0.196881,...,0.177388,0.177388,0.177388,0.239766,0.497076,0.321637,-0.017544,0.192982,0.192982,0.192982
cas9_CAD,S_6102,0.151515,0.151515,0.151515,1.000000e+00,0.378788,0.257576,0.121212,0.116162,0.116162,0.116162,...,0.156566,0.156566,0.015152,0.136364,0.242424,0.181818,0.141414,0.141414,0.141414,0.151515
cas9_CAD,S_7002,-0.067460,-0.067460,-0.107143,3.968254e-02,-0.101852,-0.101852,-0.101852,-0.071429,0.011905,-0.119048,...,0.111111,-0.460317,-0.547619,-0.052910,-0.052910,-0.052910,0.329365,0.123016,-0.067460,-0.067460
cas9_CAD,S_7102,-0.222222,-0.222222,-0.222222,-1.111111e-01,-0.111111,-0.259259,-0.259259,-0.259259,-0.222222,-0.333333,...,-0.222222,-0.222222,-0.222222,-0.222222,-0.222222,-0.222222,-0.222222,-0.222222,-0.222222,-0.222222
cas9_CAD,S_8002,0.004386,0.135965,0.236842,-8.479532e-02,-0.084795,-0.084795,0.394737,0.206140,-0.061404,-0.192982,...,0.087719,0.162281,0.100877,0.100877,0.100877,-0.140351,-0.035088,-0.114035,-0.390351,0.004386


## Train LSTM model

In [37]:
lstm_model = build_lstm_model(neurons=5)

In [38]:
lstm_model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm_1 (LSTM)                (1, 5)                    140       
_________________________________________________________________
dense_1 (Dense)              (1, 1)                    6         
Total params: 146
Trainable params: 146
Non-trainable params: 0
_________________________________________________________________


In [76]:
if os.path.exists('lstm_model_20_epochs.h5'):
    lstm_model.load_weights('lstm_model_20_epochs.h5')
else:
    num_epochs = 20
    for e in range(num_epochs):
        print('On Epoch %d' % e)
        for i, (idx, df) in enumerate(X_train_scaled.groupby(level=[0,1])):
            print('Processing series: ', idx)
            print(last_day_info[idx])
            last_ind = last_day_info[idx]['last_ind']
            train_lstm(lstm_model, X_train_scaled.loc[idx][:, None], batch_size=1, nb_epoch=1, last_ind=last_ind)
            
    lstm_model.save('lstm_model_20_epochs.h5')

## Forecast future points

In [77]:
def forecast_on_test_data(X_train_orig, X_train_scaled, idx, lstm_model, scalers, target_day, last_day_info):
    # forecast the entire training dataset to build up state for forecasting
    #T = 304 # length of timeseries
    last_ind, last_day = last_day_info[idx]['last_ind'], last_day_info[idx]['last_day']
    
    df = X_train.loc[idx][0:last_ind+1]
    df_scaled = X_train_scaled.loc[idx][0:last_ind+1]
    #print(df_scaled)
    scaler = scalers[idx]
    
    lstm_model.reset_states()
    res = lstm_model.predict(df_scaled[:, None, None], batch_size=1)   
    #res_inv_scale = scaler.inverse_transform(res)
    #pred_train = df[:, None] + res_inv_scale
    
    #print(X_train_scaled.loc[idx][0:last_ind+1])
    #print(res)
    #plt.plot(X_train_scaled.loc[idx][0:last_ind+1])
    #plt.plot(res)
    
    if target_day <= last_day:        
        print('target_day less than last_day')
        target_ind = np.where(X_train.columns == target_day)[0]
        return res[target_ind]
    
    target_ind = last_ind+(target_day-last_day)
    #print(target_ind)
    for i in range(last_ind+1, target_ind+1):
        #print(res[-1])
        val = lstm_model.predict(res[-1][:, None, None], batch_size=1)
        #print(val)
        res = np.append(res, val[0])[:, None]
    
    res_inv_scale = scaler.inverse_transform(res)
    pred_train = df[:, None] + res_inv_scale[0:last_ind+1]
    remainder = (df.values[-1] + np.cumsum(res_inv_scale[last_ind+1:]))[:,None]

    #print(res_inv_scale.shape)
    #print(df_scaled.shape)
    #print(remainder.shape)
    #plt.plot(res)
    #plt.plot(np.concatenate((pred_train, remainder)))
    
    #print('target_val: ', target_val)
    #print('pred_val: ', remainder[-1])
    
    #return remainder[-3] # becuase the time series all start at day 2?
    return remainder[-1]

In [78]:
forecast_on_test_data(X_train, X_train_scaled, ('cas9_CAD', 'S_0102'), lstm_model, scalers, 330, last_day_info)

array([ 5524.18554688], dtype=float32)

## Look at test data

In [79]:
print('Load test data')
df_test = d3mds.get_test_data()
targets_test = d3mds.get_test_targets()

Load test data


In [80]:
df_test

Unnamed: 0_level_0,species,sector,day
d3mIndex,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
209,cas9_VBBA,S_3102,330
210,cas9_VBBA,S_3102,360
422,cas9_VBBA,S_4102,328
423,cas9_VBBA,S_4102,360
634,cas9_VBBA,S_5102,329
635,cas9_VBBA,S_5102,362
846,cas9_VBBA,S_6102,333
847,cas9_VBBA,S_6102,363
1057,cas9_VBBA,S_7102,304
1058,cas9_VBBA,S_7102,306


In [81]:
targets_test

array([[ 42341],
       [ 46166],
       [ 60957],
       [ 60236],
       [ 56001],
       [ 54571],
       [ 57367],
       [ 60078],
       [ 90250],
       [ 91020],
       [  9074],
       [  9025],
       [  9612],
       [  9498],
       [  9040],
       [  8737],
       [  6321],
       [  6375],
       [  9047],
       [  9386],
       [  9753],
       [  8827],
       [ 10878],
       [ 11105],
       [ 13155],
       [ 12554],
       [ 10112],
       [ 10693],
       [ 12128],
       [ 12314],
       [ 13020],
       [ 13300],
       [ 13288],
       [ 13478],
       [ 14990],
       [ 15100],
       [ 34530],
       [ 36260],
       [ 46849],
       [ 51966],
       [ 48680],
       [ 42770],
       [ 27480],
       [ 27208],
       [ 34902],
       [ 39265],
       [ 43453],
       [ 46895],
       [ 39156],
       [ 39519],
       [ 32073],
       [ 36765],
       [ 30650],
       [ 29969],
       [ 29734],
       [ 31693],
       [ 29183],
       [ 30004],
       [ 28136

In [82]:
preds_test = []
for i in range(df_test.shape[0]):
    print('On test sample %d' % i)
    row_test = df_test.iloc[i,:]
    target_test = targets_test[i]
    species_test, sector_test, day_test = row_test['species'], row_test['sector'], row_test['day']
    print(species_test, sector_test, day_test)
    #print(target_test)
    
    pred_test = forecast_on_test_data(X_train, X_train_scaled, (species_test, sector_test), lstm_model, scalers, day_test, last_day_info)
    preds_test.append(pred_test)
    print(target_test, pred_test)

On test sample 0
cas9_VBBA S_3102 330
[42341] [ 44553.11328125]
On test sample 1
cas9_VBBA S_3102 360
[46166] [ 54817.2734375]
On test sample 2
cas9_VBBA S_4102 328
[60957] [ 46734.90625]
On test sample 3
cas9_VBBA S_4102 360
[60236] [ 45368.828125]
On test sample 4
cas9_VBBA S_5102 329
[56001] [ 45066.13671875]
On test sample 5
cas9_VBBA S_5102 362
[54571] [ 60473.81640625]
On test sample 6
cas9_VBBA S_6102 333
[57367] [ 71017.8125]
On test sample 7
cas9_VBBA S_6102 363
[60078] [ 102680.921875]
On test sample 8
cas9_VBBA S_7102 304
[90250] [ 90699.234375]
On test sample 9
cas9_VBBA S_7102 306
[91020] [ 90012.96875]
On test sample 10
cas9_FAB S_5002 332
[9074] [ 10632.08203125]
On test sample 11
cas9_FAB S_5002 362
[9025] [ 13476.20507812]
On test sample 12
cas9_FAB S_6002 332
[9612] [ 9844.49902344]
On test sample 13
cas9_FAB S_6002 361
[9498] [ 10563.65820312]
On test sample 14
cas9_FAB S_7002 332
[9040] [ 10302.68554688]
On test sample 15
cas9_FAB S_7002 358
[8737] [ 11970.24804688]

[6640] [ 28323.19921875]
On test sample 130
cas9_HNAF S_2102 333
[7330] [ 8385.94335938]
On test sample 131
cas9_HNAF S_2102 359
[6541] [ 11916.8046875]
On test sample 132
cas9_HNAF S_3102 330
[5021] [ 6097.54980469]
On test sample 133
cas9_HNAF S_3102 360
[5673] [ 8375.18945312]
On test sample 134
cas9_HNAF S_4102 328
[6156] [ 5858.89599609]
On test sample 135
cas9_HNAF S_4102 360
[6758] [ 6912.77197266]
On test sample 136
cas9_HNAF S_5102 329
[9086] [ 11783.98828125]
On test sample 137
cas9_HNAF S_5102 362
[9116] [ 23254.77539062]
On test sample 138
cas9_HNAF S_6102 333
[9047] [ 10193.00585938]
On test sample 139
cas9_HNAF S_6102 363
[8238] [ 14249.26367188]
On test sample 140
cas9_HNAF S_7102 304
[14880] [ 14901.97949219]
On test sample 141
cas9_HNAF S_7102 306
[14960] [ 15124.29199219]
On test sample 142
cas9_NIAG S_5002 332
[7580] [ 7885.34228516]
On test sample 143
cas9_NIAG S_5002 362
[7546] [ 9106.0546875]
On test sample 144
cas9_NIAG S_6002 332
[7928] [ 7551.79492188]
On test 

[69390] [ 112093.2109375]
On test sample 258
cas9_MBI S_6002 332
[76277] [ 74583.421875]
On test sample 259
cas9_MBI S_6002 361
[81081] [ 79301.8046875]
On test sample 260
cas9_MBI S_7002 332
[89664] [ 101110.15625]
On test sample 261
cas9_MBI S_7002 358
[93238] [ 114179.96875]
On test sample 262
cas9_MBI S_8002 329
[66709] [ 66467.1796875]
On test sample 263
cas9_MBI S_8002 361
[67911] [ 81773.46875]
On test sample 264
cas9_MBI S_9002 329
[106300] [ 111343.8671875]
On test sample 265
cas9_MBI S_9002 362
[110480] [ 134316.65625]
On test sample 266
cas9_MBI S_0102 330
[120140] [ 123039.375]
On test sample 267
cas9_MBI S_0102 362
[121660] [ 145367.921875]
On test sample 268
cas9_MBI S_1102 332
[154110] [ 169232.015625]
On test sample 269
cas9_MBI S_1102 362
[155600] [ 205557.078125]
On test sample 270
cas9_MBI S_2102 333
[165120] [ 195283.765625]
On test sample 271
cas9_MBI S_2102 359
[165470] [ 252767.46875]
On test sample 272
cas9_MBI S_3102 330
[155450] [ 174554.046875]
On test sample

In [83]:
mae_test = mean_absolute_error(targets_test, preds_test)
print('MAE TEST: ', mae_test)

MAE TEST:  8178.32995198


In [74]:
def write_predictions_csv_file(inds, preds, prediction_filename):
    df = pd.DataFrame(preds, index=inds, columns=['count'])
    df.to_csv(prediction_filename, index_label='d3mIndex')

def write_scores_csv_file(metric_dict, score_filename):
    metric_names = []
    metric_values = []
    for metric_name, metric_value in metric_dict.items():
        metric_names.append(metric_name)
        metric_values.append(metric_value)
    metric_names = np.array(metric_names)
    metric_values = np.array(metric_values)
    
    df = pd.DataFrame(np.concatenate((metric_names[:, None], metric_values[:, None]), axis=1), columns=['metric', 'value'])
    df.to_csv(score_filename, index=None)

In [67]:
print('Writing predictions to .csv file.')
output_dir = '../'
predictions_file = os.path.join(output_dir, 'predictions.csv')
write_predictions_csv_file(df_test.index, preds_test, predictions_file)

Writing predictions to .csv file.


In [75]:
print('Writing scores to .csv file.')
metric_dict = {'meanAbsoluteError': mae_test}
scores_file = os.path.join(output_dir, 'scores.csv')
write_scores_csv_file(metric_dict, scores_file)

Writing scores to .csv file.
