In [1]:
import pandas as pd
from collections import Counter
import tensorflow as tf
from tffm import TFFMRegressor
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
import numpy as np

In [2]:
buys = open('yoochoose-buys.dat', 'r')
clicks = open('yoochoose-clicks.dat', 'r')

In [3]:
initial_buys_df = pd.read_csv(buys, names=['Session ID', 'Timestamp', 'Item ID', 'Category', 'Quantity'],
                              dtype={'Session ID': 'float32', 'Timestamp': 'str', 'Item ID': 'float32',
                                     'Category': 'str'})

In [4]:
initial_buys_df.set_index('Session ID', inplace=True)

In [5]:
len(initial_buys_df)

1150753

In [6]:
initial_buys_df.query("Quantity == 0").head()

Unnamed: 0_level_0,Timestamp,Item ID,Category,Quantity
Session ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
281963.0,2014-04-07T20:33:54.828Z,214563136.0,0,0
210097.0,2014-04-07T17:45:33.522Z,214508944.0,0,0
351577.0,2014-04-07T17:45:02.666Z,214827024.0,0,0
351577.0,2014-04-07T17:45:02.740Z,214829744.0,0,0
419982.0,2014-04-07T15:02:06.999Z,214748304.0,0,0


In [7]:
initial_buys_df.Quantity.value_counts()[:20]

0     610030
1     435065
2      75486
3      11050
4       7505
6       4477
5       2670
10      1887
12       648
8        588
7        397
24       227
30       224
9        145
18       144
20       126
15        26
16        16
13         9
14         8
Name: Quantity, dtype: int64

In [8]:
initial_clicks_df = pd.read_csv(clicks, names=['Session ID', 'Timestamp', 'Item ID', 'Category'],
                                dtype={'Category': 'str'})

initial_clicks_df.set_index('Session ID', inplace=True)

In [9]:
initial_clicks_df.head(10)

Unnamed: 0_level_0,Timestamp,Item ID,Category
Session ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,2014-04-07T10:51:09.277Z,214536502,0
1,2014-04-07T10:54:09.868Z,214536500,0
1,2014-04-07T10:54:46.998Z,214536506,0
1,2014-04-07T10:57:00.306Z,214577561,0
2,2014-04-07T13:56:37.614Z,214662742,0
2,2014-04-07T13:57:19.373Z,214662742,0
2,2014-04-07T13:58:37.446Z,214825110,0
2,2014-04-07T13:59:50.710Z,214757390,0
2,2014-04-07T14:00:38.247Z,214757407,0
2,2014-04-07T14:02:36.889Z,214551617,0


In [10]:
len(initial_clicks_df)

33003944

In [11]:
initial_clicks_df.Category.value_counts()[:10]

0    16337653
S    10769610
1     1671754
2     1292249
3      789713
4      480569
5      471923
6      414696
7      389910
9      105282
Name: Category, dtype: int64

In [12]:
a = pd.DataFrame(initial_clicks_df.index).drop_duplicates().astype(int)
b = pd.DataFrame(initial_buys_df.index).drop_duplicates().astype(int)

In [13]:
c = a.merge(b, on=['Session ID'], how='inner')

len(c)

509696

In [14]:
initial_buys_df = initial_buys_df.drop('Timestamp', 1)
initial_clicks_df = initial_clicks_df.drop('Timestamp', 1)

In [15]:
from collections import Counter

x = Counter(initial_buys_df.index).most_common(10000)
top_k = dict(x).keys()

In [16]:
x

[(5638444.0, 144),
 (10808253.0, 120),
 (9014734.0, 81),
 (428198.0, 72),
 (6832724.0, 72),
 (6149111.0, 64),
 (601904.0, 62),
 (2233614.0, 54),
 (7586548.0, 54),
 (10683806.0, 54),
 (2920884.0, 48),
 (4543178.0, 48),
 (8211827.0, 48),
 (9209529.0, 48),
 (953756.0, 45),
 (8301793.0, 44),
 (10449092.0, 42),
 (10966823.0, 42),
 (4630111.0, 41),
 (1753739.0, 40),
 (2561659.0, 40),
 (5015396.0, 40),
 (8749796.0, 40),
 (1655343.0, 39),
 (6958333.0, 39),
 (10129729.0, 39),
 (980883.0, 36),
 (3394936.0, 36),
 (4382259.0, 36),
 (5033466.0, 36),
 (6529877.0, 36),
 (7614712.0, 36),
 (8166894.0, 36),
 (9010472.0, 36),
 (9822698.0, 36),
 (10317602.0, 36),
 (1032524.0, 35),
 (2081373.0, 35),
 (6006624.0, 34),
 (10760937.0, 34),
 (216528.0, 33),
 (557899.0, 33),
 (5563564.0, 33),
 (8117182.0, 33),
 (8653737.0, 33),
 (10043431.0, 33),
 (10219991.0, 33),
 (3734474.0, 32),
 (4932653.0, 32),
 (5022813.0, 32),
 (8959164.0, 32),
 (10162424.0, 32),
 (10393571.0, 32),
 (3460796.0, 31),
 (312929.0, 30),
 (17

In [17]:
initial_buys_df = initial_buys_df[initial_buys_df.index.isin(top_k)]

In [18]:
initial_clicks_df = initial_clicks_df[initial_clicks_df.index.isin(top_k)]

In [19]:
initial_buys_df['_Session ID'] = initial_buys_df.index

In [20]:
initial_buys_df.head(10)

Unnamed: 0_level_0,Item ID,Category,Quantity,_Session ID
Session ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
420471.0,214717888.0,2092,1,420471.0
420471.0,214821024.0,1570,1,420471.0
420471.0,214829280.0,837,1,420471.0
420471.0,214819552.0,418,1,420471.0
420471.0,214746384.0,784,1,420471.0
420471.0,214821024.0,1570,1,420471.0
420471.0,214717888.0,2092,1,420471.0
420471.0,214819552.0,418,1,420471.0
420471.0,214829280.0,837,1,420471.0
420471.0,214573360.0,784,1,420471.0


In [21]:
len(initial_buys_df)

106956

In [22]:
initial_clicks_df.head(10)

Unnamed: 0_level_0,Item ID,Category
Session ID,Unnamed: 1_level_1,Unnamed: 2_level_1
932,214826906,0
932,214826906,0
932,214826906,0
932,214826955,0
932,214826955,0
932,214826627,0
932,214826627,0
932,214800262,0
932,214709741,0
932,214709741,0


In [23]:
transformed_buys = pd.get_dummies(initial_buys_df)
transformed_clicks = pd.get_dummies(initial_clicks_df)

In [24]:
transformed_buys.head(10)

Unnamed: 0_level_0,Item ID,Quantity,_Session ID,Category_0,Category_1024,Category_1036,Category_10367,Category_1037,Category_104,Category_1041,...,Category_931,Category_932,Category_936,Category_937,Category_941,Category_9424,Category_97,Category_99,Category_994,Category_9947
Session ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
420471.0,214717888.0,1,420471.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
420471.0,214821024.0,1,420471.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
420471.0,214829280.0,1,420471.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
420471.0,214819552.0,1,420471.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
420471.0,214746384.0,1,420471.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
420471.0,214821024.0,1,420471.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
420471.0,214717888.0,1,420471.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
420471.0,214819552.0,1,420471.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
420471.0,214829280.0,1,420471.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
420471.0,214573360.0,1,420471.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [25]:
transformed_clicks.head(10)

Unnamed: 0_level_0,Item ID,Category_0,Category_1,Category_10,Category_11,Category_12,Category_2,Category_2088904854,Category_2088919107,Category_2088937100,...,Category_2089677914,Category_2089796643,Category_3,Category_4,Category_5,Category_6,Category_7,Category_8,Category_9,Category_S
Session ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
932,214826906,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214826906,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214826906,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214826955,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214826955,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214826627,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214826627,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214800262,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214709741,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214709741,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [26]:
pd.Series(transformed_buys.columns)

0              Item ID
1             Quantity
2          _Session ID
3           Category_0
4        Category_1024
5        Category_1036
6       Category_10367
7        Category_1037
8         Category_104
9        Category_1041
10       Category_1042
11       Category_1046
12       Category_1047
13      Category_10471
14     Category_106709
15       Category_1089
16      Category_10996
17        Category_114
18       Category_1140
19       Category_1141
20      Category_11414
21       Category_1146
22       Category_1151
23      Category_11624
24       Category_1182
25        Category_120
26       Category_1203
27      Category_12043
28       Category_1245
29       Category_1246
            ...       
326       Category_833
327     Category_83671
328       Category_837
329      Category_8375
330      Category_8377
331       Category_838
332       Category_867
333       Category_879
334       Category_884
335       Category_889
336        Category_89
337       Category_890
338      Ca

In [27]:
pd.Series(transformed_clicks.columns)

0                 Item ID
1              Category_0
2              Category_1
3             Category_10
4             Category_11
5             Category_12
6              Category_2
7     Category_2088904854
8     Category_2088919107
9     Category_2088937100
10    Category_2088942073
11    Category_2088962571
12    Category_2088973177
13    Category_2088999608
14    Category_2089045199
15    Category_2089046251
16    Category_2089046367
17    Category_2089074648
18    Category_2089156185
19    Category_2089221555
20    Category_2089246594
21    Category_2089267025
22    Category_2089282248
23    Category_2089282437
24    Category_2089286907
25    Category_2089287221
26    Category_2089300095
27    Category_2089318476
28    Category_2089318666
29    Category_2089322935
30    Category_2089358732
31    Category_2089404239
32    Category_2089422131
33    Category_2089426426
34    Category_2089437536
35    Category_2089440547
36    Category_2089502248
37    Category_2089509324
38    Catego

In [28]:
filtered_buys = transformed_buys.filter(regex="Item.*|Category.*")
filtered_clicks = transformed_clicks.filter(regex="Item.*|Category.*")

In [29]:
filtered_buys.columns

Index(['Item ID', 'Category_0', 'Category_1024', 'Category_1036',
       'Category_10367', 'Category_1037', 'Category_104', 'Category_1041',
       'Category_1042', 'Category_1046',
       ...
       'Category_931', 'Category_932', 'Category_936', 'Category_937',
       'Category_941', 'Category_9424', 'Category_97', 'Category_99',
       'Category_994', 'Category_9947'],
      dtype='object', length=354)

In [30]:
filtered_clicks.columns

Index(['Item ID', 'Category_0', 'Category_1', 'Category_10', 'Category_11',
       'Category_12', 'Category_2', 'Category_2088904854',
       'Category_2088919107', 'Category_2088937100', 'Category_2088942073',
       'Category_2088962571', 'Category_2088973177', 'Category_2088999608',
       'Category_2089045199', 'Category_2089046251', 'Category_2089046367',
       'Category_2089074648', 'Category_2089156185', 'Category_2089221555',
       'Category_2089246594', 'Category_2089267025', 'Category_2089282248',
       'Category_2089282437', 'Category_2089286907', 'Category_2089287221',
       'Category_2089300095', 'Category_2089318476', 'Category_2089318666',
       'Category_2089322935', 'Category_2089358732', 'Category_2089404239',
       'Category_2089422131', 'Category_2089426426', 'Category_2089437536',
       'Category_2089440547', 'Category_2089502248', 'Category_2089509324',
       'Category_2089515459', 'Category_2089531793', 'Category_2089538467',
       'Category_2089538518',

In [31]:
filtered_clicks.head(10)

Unnamed: 0_level_0,Item ID,Category_0,Category_1,Category_10,Category_11,Category_12,Category_2,Category_2088904854,Category_2088919107,Category_2088937100,...,Category_2089677914,Category_2089796643,Category_3,Category_4,Category_5,Category_6,Category_7,Category_8,Category_9,Category_S
Session ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
932,214826906,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214826906,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214826906,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214826955,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214826955,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214826627,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214826627,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214800262,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214709741,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214709741,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [32]:
historical_buy_data = filtered_buys.groupby(filtered_buys.index).sum()
historical_buy_data = historical_buy_data.rename(columns=lambda column_name: 'buy history:' + column_name)

In [33]:
historical_buy_data.head(10)

Unnamed: 0_level_0,buy history:Item ID,buy history:Category_0,buy history:Category_1024,buy history:Category_1036,buy history:Category_10367,buy history:Category_1037,buy history:Category_104,buy history:Category_1041,buy history:Category_1042,buy history:Category_1046,...,buy history:Category_931,buy history:Category_932,buy history:Category_936,buy history:Category_937,buy history:Category_941,buy history:Category_9424,buy history:Category_97,buy history:Category_99,buy history:Category_994,buy history:Category_9947
Session ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
932.0,2148020000.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3302.0,1716957000.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3687.0,1717444000.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,1,0,0,0,0,0
3889.0,2148006000.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4451.0,1718581000.0,0,0,0,0,0,0,0,0,4,...,0,0,0,0,0,0,0,0,0,0
5274.0,1718390000.0,0,0,0,0,0,0,0,0,2,...,0,0,0,0,1,0,0,0,0,0
5582.0,1716857000.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
5942.0,1718652000.0,0,0,0,0,0,0,0,0,0,...,2,0,0,0,0,0,0,0,0,0
8102.0,1932980000.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
9702.0,1718105000.0,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,0


In [34]:
historical_click_data = filtered_clicks.groupby(filtered_clicks.index).sum()
historical_click_data = historical_click_data.rename(columns=lambda column_name: 'click history:' + column_name)

In [35]:
historical_click_data.head(10)

Unnamed: 0_level_0,click history:Item ID,click history:Category_0,click history:Category_1,click history:Category_10,click history:Category_11,click history:Category_12,click history:Category_2,click history:Category_2088904854,click history:Category_2088919107,click history:Category_2088937100,...,click history:Category_2089677914,click history:Category_2089796643,click history:Category_3,click history:Category_4,click history:Category_5,click history:Category_6,click history:Category_7,click history:Category_8,click history:Category_9,click history:Category_S
Session ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
932,2792466840,13,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3302,858478390,4,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3687,4938564761,23,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3889,5584725258,26,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4451,2363064042,11,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
5274,2362870184,11,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
5582,4722497033,22,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
5942,1288982174,6,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
8102,4510063283,21,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
9702,6443305192,30,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [36]:
merged1 = pd.merge(transformed_buys, historical_buy_data, left_index=True, right_index=True)

merged1.head(10)

Unnamed: 0_level_0,Item ID,Quantity,_Session ID,Category_0,Category_1024,Category_1036,Category_10367,Category_1037,Category_104,Category_1041,...,buy history:Category_931,buy history:Category_932,buy history:Category_936,buy history:Category_937,buy history:Category_941,buy history:Category_9424,buy history:Category_97,buy history:Category_99,buy history:Category_994,buy history:Category_9947
Session ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
932.0,214826960.0,2,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932.0,214826624.0,2,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932.0,214826912.0,1,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932.0,214709744.0,2,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932.0,214819744.0,1,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932.0,214826912.0,1,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932.0,214826960.0,2,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932.0,214826624.0,2,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932.0,214709744.0,2,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932.0,214819744.0,1,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [37]:
merged2 = pd.merge(merged1, historical_click_data, left_index=True, right_index=True)

merged2.head(10)

Unnamed: 0_level_0,Item ID,Quantity,_Session ID,Category_0,Category_1024,Category_1036,Category_10367,Category_1037,Category_104,Category_1041,...,click history:Category_2089677914,click history:Category_2089796643,click history:Category_3,click history:Category_4,click history:Category_5,click history:Category_6,click history:Category_7,click history:Category_8,click history:Category_9,click history:Category_S
Session ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
932,214826960.0,2,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214826624.0,2,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214826912.0,1,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214709744.0,2,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214819744.0,1,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214826912.0,1,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214826960.0,2,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214826624.0,2,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214709744.0,2,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
932,214819744.0,1,932.0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [38]:
pd.Series(merged2.columns)

0                                Item ID
1                               Quantity
2                            _Session ID
3                             Category_0
4                          Category_1024
5                          Category_1036
6                         Category_10367
7                          Category_1037
8                           Category_104
9                          Category_1041
10                         Category_1042
11                         Category_1046
12                         Category_1047
13                        Category_10471
14                       Category_106709
15                         Category_1089
16                        Category_10996
17                          Category_114
18                         Category_1140
19                         Category_1141
20                        Category_11414
21                         Category_1146
22                         Category_1151
23                        Category_11624
24              

In [39]:
model = TFFMRegressor(
    order=2,
    rank=7,
    optimizer=tf.train.AdamOptimizer(learning_rate=0.1),
    n_epochs=100,
    batch_size=-1,
    init_std=0.001,
    input_type='dense'
)

In [40]:
merged2.drop(['Item ID', '_Session ID', 'click history:Item ID', 'buy history:Item ID'], 1, inplace=True)

X = np.array(merged2)
X = np.nan_to_num(X)
y = np.array(merged2['Quantity'].as_matrix())

  """


In [41]:
X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.2)

In [42]:
X_te, X_te_cs, y_te, y_te_cs = train_test_split(X_te, y_te, test_size=0.5)

cold_start = pd.DataFrame(X_te_cs, columns=merged2.columns)

In [43]:
for column in cold_start.columns:
    if ('buy' in column or 'click' in column) and ('Category' not in column):
        print (column)

In [44]:
model.fit(X_tr, y_tr, show_progress=True)
predictions = model.predict(X_te)

#cold_start_predictions = model.predict(X_te_cold)
print('MSE: {}'.format(mean_squared_error(y_te, predictions)))
print('Cold-start MSE: {}'.format(mean_squared_error(y_te_cold, predictions)))
model.destroy()


Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use tf.cast instead.


100%|██████████| 100/100 [01:23<00:00,  1.28epoch/s]


MSE: 0.6581038522696456


NameError: name 'y_te_cold' is not defined

In [None]:
predictions