In [1]:
import pandas as pd
import numpy as np
import pytorch_widedeep as wd
from pytorch_widedeep.preprocessing import WidePreprocessor
from pytorch_widedeep.models import Wide, DeepDense, WideDeep

In [2]:
train_path = '../data/interim/train_0.75_0.25.csv'
test_path = '../data/interim/test_0.75_0.25.csv'
genre_path = '../data/raw/movies.csv'
tag_path = '../data/raw/tags.csv'

In [3]:
train, test, genre, tags = pd.read_csv(train_path), pd.read_csv(test_path), pd.read_csv(genre_path), pd.read_csv(tag_path)
train = train.merge(genre).merge(tags.drop('timestamp', axis=1))
test = test.merge(genre).merge(tags.drop('timestamp', axis=1))

In [4]:
from pytorch_widedeep.preprocessing import WidePreprocessor, TextPreprocessor, DensePreprocessor

In [29]:
wide_cols = ['genres', 'tag']
crossed_cols = [('genres', 'tag')]
embs = [('userId', 16), ('movieId', 16)]
target_cols = 'rating'
train_target = train[target_cols].values

In [30]:
preprocess_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
train_wide = preprocess_wide.fit_transform(train)
test_wide = preprocess_wide.transform(train)

In [31]:
train_wide

array([[   1,  214, 3166],
       [   1,  215, 3167],
       [   1,  216, 3168],
       ...,
       [ 213, 3163, 8083],
       [  58, 3164, 8084],
       [  58, 3165, 8085]])

In [54]:
# DEEP
preprocess_deep = DensePreprocessor(embed_cols=embs)
train_id = preprocess_deep.fit_transform(train)
test_id = preprocess_deep.transform(test)

In [41]:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import TruncatedSVD

In [50]:
cv = CountVectorizer()
tfidfv = TfidfVectorizer()
svd = TruncatedSVD(n_components=5)

In [52]:
train_genre = svd.fit_transform(cv.fit_transform(train['genres']))
test_genre = svd.fit_transform(cv.transform(test['genres']))

In [53]:
train_tag = svd.fit_transform(tfidfv.fit_transform(train['tag']))
test_tag = svd.fit_transform(tfidfv.transform(test['tag']))

In [56]:
train_deep = np.hstack((train_id, train_genre, train_tag))
test_deep = np.hstack((test_id, test_genre, test_tag))

In [57]:
wide = Wide(wide_dim=np.unique(train_wide).shape[0], pred_dim=1)

deepdense = DeepDense(hidden_layers=[64,32], 
                      deep_column_idx=preprocess_deep.deep_column_idx,
                      embed_input=preprocess_deep.embeddings_input)

model = WideDeep(wide=wide, deepdense=deepdense)

In [78]:
model.compile(method='regression')

In [80]:
model.fit(X_wide=train_wide, X_deep=train_deep, target=target, n_epochs=5, batch_size=256, val_split=0.2, 
          warm_up=True, warm_epochs=5, warm_max_lr=0.01)

  0%|          | 0/28 [00:00<?, ?it/s]

Warming up wide for 5 epochs


epoch 1: 100%|██████████| 28/28 [00:16<00:00,  1.75it/s, loss=18.8]
epoch 2: 100%|██████████| 28/28 [00:17<00:00,  1.64it/s, loss=14.3]
epoch 3: 100%|██████████| 28/28 [00:19<00:00,  1.43it/s, loss=11.3]
epoch 4: 100%|██████████| 28/28 [00:18<00:00,  1.49it/s, loss=9.56]
epoch 5: 100%|██████████| 28/28 [00:17<00:00,  1.64it/s, loss=8.6]
  0%|          | 0/28 [00:00<?, ?it/s]

Warming up deepdense for 5 epochs


epoch 1: 100%|██████████| 28/28 [00:17<00:00,  1.60it/s, loss=6.11]
epoch 2: 100%|██████████| 28/28 [00:17<00:00,  1.57it/s, loss=0.991]
epoch 3: 100%|██████████| 28/28 [00:17<00:00,  1.56it/s, loss=0.624]
epoch 4: 100%|██████████| 28/28 [00:18<00:00,  1.54it/s, loss=0.513]
epoch 5: 100%|██████████| 28/28 [00:17<00:00,  1.56it/s, loss=0.461]
  0%|          | 0/28 [00:00<?, ?it/s]

Training


epoch 1: 100%|██████████| 28/28 [00:18<00:00,  1.48it/s, loss=1.84]
valid: 100%|██████████| 7/7 [00:02<00:00,  2.56it/s, loss=1.74]
epoch 2: 100%|██████████| 28/28 [00:17<00:00,  1.57it/s, loss=1.35]
valid: 100%|██████████| 7/7 [00:02<00:00,  2.77it/s, loss=1.58]
epoch 3: 100%|██████████| 28/28 [00:18<00:00,  1.49it/s, loss=1.24]
valid: 100%|██████████| 7/7 [00:02<00:00,  2.94it/s, loss=1.52]
epoch 4: 100%|██████████| 28/28 [00:19<00:00,  1.42it/s, loss=1.17]
valid: 100%|██████████| 7/7 [00:02<00:00,  2.68it/s, loss=1.46]
epoch 5: 100%|██████████| 28/28 [00:16<00:00,  1.71it/s, loss=1.11]
valid: 100%|██████████| 7/7 [00:02<00:00,  2.89it/s, loss=1.41]
