# Wide & Deep

In [66]:
import pandas as pd
import torch
import torch.nn as nn

pd.set_option('display.max_columns', None)


In [36]:
data = pd.read_csv('../dataset/ml-100k/data.csv')
data.head()

Unnamed: 0,user_id,item_id,rating,timestamp
0,196,242,3,881250949
1,186,302,3,891717742
2,22,377,1,878887116
3,244,51,2,880606923
4,166,346,1,886397596


In [37]:
user = pd.read_csv('../dataset/ml-100k/user.csv')
user.head()

Unnamed: 0,user_id,age,gender,occupation,zipcode
0,1,24,M,technician,85711
1,2,53,F,other,94043
2,3,23,M,writer,32067
3,4,24,M,technician,43537
4,5,33,F,other,15213


In [38]:
item = pd.read_csv('../dataset/ml-100k/item.csv')
item.head()

Unnamed: 0,movie_id,movie_title,release_date,video_release_date,IMDb_URL,unknown,Action,Adventure,Animation,Childrens,...,Fantasy,Film_Noir,Horror,Musical,Mystery,Romance,Sci_Fi,Thriller,War,Western
0,1,Toy Story (1995),01-Jan-1995,,http://us.imdb.com/M/title-exact?Toy%20Story%2...,0,0,0,1,1,...,0,0,0,0,0,0,0,0,0,0
1,2,GoldenEye (1995),01-Jan-1995,,http://us.imdb.com/M/title-exact?GoldenEye%20(...,0,1,1,0,0,...,0,0,0,0,0,0,0,1,0,0
2,3,Four Rooms (1995),01-Jan-1995,,http://us.imdb.com/M/title-exact?Four%20Rooms%...,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
3,4,Get Shorty (1995),01-Jan-1995,,http://us.imdb.com/M/title-exact?Get%20Shorty%...,0,1,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,5,Copycat (1995),01-Jan-1995,,http://us.imdb.com/M/title-exact?Copycat%20(1995),0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0


In [39]:
len(item)

1682

In [40]:
data.item_id.nunique()

1682

In [41]:
data
data['like'] = data["rating"].apply(lambda x: x>3.5).astype(int)
data

Unnamed: 0,user_id,item_id,rating,timestamp,like
0,196,242,3,881250949,0
1,186,302,3,891717742,0
2,22,377,1,878887116,0
3,244,51,2,880606923,0
4,166,346,1,886397596,0
...,...,...,...,...,...
99995,880,476,3,880175444,0
99996,716,204,5,879795543,1
99997,276,1090,1,874795795,0
99998,13,225,2,882399156,0


In [42]:
data

Unnamed: 0,user_id,item_id,rating,timestamp,like
0,196,242,3,881250949,0
1,186,302,3,891717742,0
2,22,377,1,878887116,0
3,244,51,2,880606923,0
4,166,346,1,886397596,0
...,...,...,...,...,...
99995,880,476,3,880175444,0
99996,716,204,5,879795543,1
99997,276,1090,1,874795795,0
99998,13,225,2,882399156,0


## add features

In [43]:
data = data.merge(user, on='user_id', how='left')
data = data.merge(item, left_on='item_id', right_on='movie_id', how='left')

In [44]:
data

Unnamed: 0,user_id,item_id,rating,timestamp,like,age,gender,occupation,zipcode,movie_id,...,Fantasy,Film_Noir,Horror,Musical,Mystery,Romance,Sci_Fi,Thriller,War,Western
0,196,242,3,881250949,0,49,M,writer,55105,242,...,0,0,0,0,0,0,0,0,0,0
1,186,302,3,891717742,0,39,F,executive,00000,302,...,0,1,0,0,1,0,0,1,0,0
2,22,377,1,878887116,0,25,M,writer,40206,377,...,0,0,0,0,0,0,0,0,0,0
3,244,51,2,880606923,0,28,M,technician,80525,51,...,0,0,0,0,0,1,0,0,1,1
4,166,346,1,886397596,0,47,M,educator,55113,346,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99995,880,476,3,880175444,0,13,M,student,83702,476,...,0,0,0,0,0,0,0,0,0,0
99996,716,204,5,879795543,1,36,F,administrator,44265,204,...,0,0,0,0,0,0,1,0,0,0
99997,276,1090,1,874795795,0,21,M,student,95064,1090,...,0,0,0,0,0,0,0,1,0,0
99998,13,225,2,882399156,0,47,M,educator,29206,225,...,0,0,0,0,0,0,0,0,0,0


## data split

In [45]:
def data_split(df, save_to_csv=False):
    # add binary target "like"
    df = make_target(df)
    
    data_sorted = df.sort_values(by=['user_id', 'timestamp'], ascending=[True, False])
    test_set = data_sorted.groupby('user_id').head(5)
    test_indices = test_set.index
    train_set = data_sorted.drop(index=test_indices)
    train_set.reset_index(drop=True, inplace=True)
    test_set.reset_index(drop=True, inplace=True)

    if save_to_csv:
        train_set.to_csv('../dataset/ml-100k/train_test/train.csv', index=False)
        test_set.to_csv('../dataset/ml-100k/train_test/test.csv', index=False)
    return train_set, test_set


def make_target(df, threshold=3.5):
    df['like'] = df["rating"].apply(lambda x: x>threshold).astype(int)
    df.drop('rating', axis=1, inplace=True)
    
    return df


In [46]:
train_set, test_set = data_split(data)

In [47]:
train_set

Unnamed: 0,user_id,item_id,timestamp,like,age,gender,occupation,zipcode,movie_id,movie_title,...,Fantasy,Film_Noir,Horror,Musical,Mystery,Romance,Sci_Fi,Thriller,War,Western
0,1,111,889751711,1,24,M,technician,85711,111,"Truth About Cats & Dogs, The (1996)",...,0,0,0,0,0,1,0,0,0,0
1,1,242,889751633,1,24,M,technician,85711,242,Kolya (1996),...,0,0,0,0,0,0,0,0,0,0
2,1,189,888732928,0,24,M,technician,85711,189,"Grand Day Out, A (1992)",...,0,0,0,0,0,0,0,0,0,0
3,1,32,888732909,1,24,M,technician,85711,32,Crumb (1994),...,0,0,0,0,0,0,0,0,0,0
4,1,209,888732908,1,24,M,technician,85711,209,This Is Spinal Tap (1984),...,0,0,0,1,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95280,943,100,875501725,1,22,M,student,77841,100,Fargo (1996),...,0,0,0,0,0,0,0,1,0,0
95281,943,174,875410099,1,22,M,student,77841,174,Raiders of the Lost Ark (1981),...,0,0,0,0,0,0,0,0,0,0
95282,943,181,875409978,1,22,M,student,77841,181,Return of the Jedi (1983),...,0,0,0,0,0,1,1,0,1,0
95283,943,28,875409978,1,22,M,student,77841,28,Apollo 13 (1995),...,0,0,0,0,0,0,0,1,0,0


In [48]:
test_set.shape

(4715, 32)

In [49]:
set(test_set.movie_id.unique()) - set(train_set.movie_id.unique())

{1130, 1236, 1525, 1613, 1618, 1624, 1625, 1645, 1650, 1671, 1674}

In [50]:
# there are unseen items in the test set

In [51]:
sorted(train_set.movie_id.unique())[:5]

[1, 2, 3, 4, 5]

THUS, WE ASSIGN 0 TO THE UNSEEN ITEM

In [52]:
train_set.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 95285 entries, 0 to 95284
Data columns (total 32 columns):
 #   Column              Non-Null Count  Dtype  
---  ------              --------------  -----  
 0   user_id             95285 non-null  int64  
 1   item_id             95285 non-null  int64  
 2   timestamp           95285 non-null  int64  
 3   like                95285 non-null  int64  
 4   age                 95285 non-null  int64  
 5   gender              95285 non-null  object 
 6   occupation          95285 non-null  object 
 7   zipcode             95285 non-null  object 
 8   movie_id            95285 non-null  int64  
 9   movie_title         95285 non-null  object 
 10  release_date        95276 non-null  object 
 11  video_release_date  0 non-null      float64
 12  IMDb_URL            95272 non-null  object 
 13  unknown             95285 non-null  int64  
 14  Action              95285 non-null  int64  
 15  Adventure           95285 non-null  int64  
 16  Anim

## cross product transformation

In [53]:
user.groupby('gender').count().user_id

gender
F    273
M    670
Name: user_id, dtype: int64

In [54]:
user.groupby('occupation').count().user_id.sort_values()

occupation
doctor             7
homemaker          7
none               9
salesman          12
lawyer            12
retired           14
healthcare        16
entertainment     18
marketing         26
technician        27
artist            28
scientist         31
executive         32
writer            45
librarian         51
programmer        66
engineer          67
administrator     79
educator          95
other            105
student          196
Name: user_id, dtype: int64

let's just try to cross-product genre and gender for now, based on intuition

Todo: change to better combination

In [55]:
import numpy as np
def cp_gender_genre(df):
    genres = ['unknown', 'Action', 'Adventure',
       'Animation', 'Childrens', 'Comedy', 'Crime', 'Documentary', 'Drama',
       'Fantasy', 'Film_Noir', 'Horror', 'Musical', 'Mystery', 'Romance',
       'Sci_Fi', 'Thriller', 'War', 'Western']
    crossed_cols = [['is_male', genre] for genre in genres]
    for a, b in crossed_cols:
        df[a+'-'+b] = np.where((df['gender'] == 'M') & (df[b]==1), 1, 0)
    return df

In [56]:
train_cp = cp_gender_genre(train_set)

In [57]:
sum(np.where(train_cp['gender'] == 'M', 1, 0)) / len(train_cp['gender'])

0.7441884871700687

In [58]:
train_cp['gender'][1000:1200]

1000    M
1001    M
1002    M
1003    M
1004    M
       ..
1195    M
1196    M
1197    M
1198    M
1199    M
Name: gender, Length: 200, dtype: object

In [59]:
train_cp.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 95285 entries, 0 to 95284
Data columns (total 51 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   user_id              95285 non-null  int64  
 1   item_id              95285 non-null  int64  
 2   timestamp            95285 non-null  int64  
 3   like                 95285 non-null  int64  
 4   age                  95285 non-null  int64  
 5   gender               95285 non-null  object 
 6   occupation           95285 non-null  object 
 7   zipcode              95285 non-null  object 
 8   movie_id             95285 non-null  int64  
 9   movie_title          95285 non-null  object 
 10  release_date         95276 non-null  object 
 11  video_release_date   0 non-null      float64
 12  IMDb_URL             95272 non-null  object 
 13  unknown              95285 non-null  int64  
 14  Action               95285 non-null  int64  
 15  Adventure            95285 non-null 

# data loader

In [60]:
from sklearn.preprocessing import OneHotEncoder, StandardScaler, LabelEncoder
def preprocess(train_df, test_df):
    # map unseen item_id in test set to item_id=0
    unique_item_ids = set(train_df['item_id'].unique())
    def map_unseen_to_zero(item_id):
        if item_id in unique_item_ids:
            return item_id
        else:
            return 0

    test_df['item_id'] = test_df['item_id'].apply(map_unseen_to_zero)

    # one-hot encoder
    onehot_features = ['gender', 'occupation']
    one_hot_encoder = OneHotEncoder(sparse_output=False, drop='if_binary')
    occupation_encoded_train = one_hot_encoder.fit_transform(train_df[onehot_features])
    occupation_encoded_test = one_hot_encoder.transform(test_df[onehot_features])

    # label encoder
    occupation_encoder = LabelEncoder()
    train_df['occupation_encoded'] = occupation_encoder.fit_transform(train_df['occupation'])
    test_df['occupation_encoded'] = occupation_encoder.fit_transform(test_df['occupation'])

    # scaler for numerical features
    scaler = StandardScaler()
    train_df['age_scaled'] = scaler.fit_transform(train_df[['age']])
    test_df['age_scaled'] = scaler.transform(test_df[['age']])

    # Concatenate the one-hot encoded occupation back to the dataframe
    onehot_train_df = pd.DataFrame(occupation_encoded_train, columns=one_hot_encoder.get_feature_names_out(onehot_features))
    onehot_test_df = pd.DataFrame(occupation_encoded_test, columns=one_hot_encoder.get_feature_names_out(onehot_features))

    train_df.reset_index(drop=True, inplace=True)
    test_df.reset_index(drop=True, inplace=True)

    train_df = pd.concat([train_df, onehot_train_df], axis=1)
    test_df = pd.concat([test_df, onehot_test_df], axis=1)

    #cross product of features
    train_df = cp_gender_genre(train_df)
    test_df = cp_gender_genre(test_df)


    return train_df, test_df


In [61]:
train_df, test_df = preprocess(train_set, test_set)


In [62]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 95285 entries, 0 to 95284
Data columns (total 75 columns):
 #   Column                    Non-Null Count  Dtype  
---  ------                    --------------  -----  
 0   user_id                   95285 non-null  int64  
 1   item_id                   95285 non-null  int64  
 2   timestamp                 95285 non-null  int64  
 3   like                      95285 non-null  int64  
 4   age                       95285 non-null  int64  
 5   gender                    95285 non-null  object 
 6   occupation                95285 non-null  object 
 7   zipcode                   95285 non-null  object 
 8   movie_id                  95285 non-null  int64  
 9   movie_title               95285 non-null  object 
 10  release_date              95276 non-null  object 
 11  video_release_date        0 non-null      float64
 12  IMDb_URL                  95272 non-null  object 
 13  unknown                   95285 non-null  int64  
 14  Action

In [67]:
train_df

Unnamed: 0,user_id,item_id,timestamp,like,age,gender,occupation,zipcode,movie_id,movie_title,release_date,video_release_date,IMDb_URL,unknown,Action,Adventure,Animation,Childrens,Comedy,Crime,Documentary,Drama,Fantasy,Film_Noir,Horror,Musical,Mystery,Romance,Sci_Fi,Thriller,War,Western,is_male-unknown,is_male-Action,is_male-Adventure,is_male-Animation,is_male-Childrens,is_male-Comedy,is_male-Crime,is_male-Documentary,is_male-Drama,is_male-Fantasy,is_male-Film_Noir,is_male-Horror,is_male-Musical,is_male-Mystery,is_male-Romance,is_male-Sci_Fi,is_male-Thriller,is_male-War,is_male-Western,occupation_encoded,age_scaled,gender_M,occupation_administrator,occupation_artist,occupation_doctor,occupation_educator,occupation_engineer,occupation_entertainment,occupation_executive,occupation_healthcare,occupation_homemaker,occupation_lawyer,occupation_librarian,occupation_marketing,occupation_none,occupation_other,occupation_programmer,occupation_retired,occupation_salesman,occupation_scientist,occupation_student,occupation_technician,occupation_writer
0,1,111,889751711,1,24,M,technician,85711,111,"Truth About Cats & Dogs, The (1996)",26-Apr-1996,,http://us.imdb.com/M/title-exact?Truth%20About...,0,0,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0,19,-0.773435,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
1,1,242,889751633,1,24,M,technician,85711,242,Kolya (1996),24-Jan-1997,,http://us.imdb.com/M/title-exact?Kolya%20(1996),0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,19,-0.773435,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
2,1,189,888732928,0,24,M,technician,85711,189,"Grand Day Out, A (1992)",01-Jan-1992,,http://us.imdb.com/M/title-exact?Grand%20Day%2...,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,19,-0.773435,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
3,1,32,888732909,1,24,M,technician,85711,32,Crumb (1994),01-Jan-1994,,http://us.imdb.com/M/title-exact?Crumb%20(1994),0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,19,-0.773435,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
4,1,209,888732908,1,24,M,technician,85711,209,This Is Spinal Tap (1984),01-Jan-1984,,http://us.imdb.com/M/title-exact?This%20Is%20S...,0,0,0,0,0,1,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1,0,0,1,0,0,0,1,0,0,0,0,0,0,19,-0.773435,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95280,943,100,875501725,1,22,M,student,77841,100,Fargo (1996),14-Feb-1997,,http://us.imdb.com/M/title-exact?Fargo%20(1996),0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,1,0,0,18,-0.946923,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
95281,943,174,875410099,1,22,M,student,77841,174,Raiders of the Lost Ark (1981),01-Jan-1981,,http://us.imdb.com/M/title-exact?Raiders%20of%...,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,18,-0.946923,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
95282,943,181,875409978,1,22,M,student,77841,181,Return of the Jedi (1983),14-Mar-1997,,http://us.imdb.com/M/title-exact?Return%20of%2...,0,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,0,1,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,0,1,0,18,-0.946923,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
95283,943,28,875409978,1,22,M,student,77841,28,Apollo 13 (1995),01-Jan-1995,,http://us.imdb.com/M/title-exact?Apollo%2013%2...,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,18,-0.946923,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0


In [63]:
feature_config = {}

feature_config['continuous_cols'] = ["age_scaled"]

feature_config['binary_cols'] = ['gender_M']

# a list of (column_name, unique_values, embedding_dim)
feature_config['embedding_inputs'] = [('occupation_encoded', 21, 3), ('item_id', 1682+1, 6)]
feature_config['embedding_cols'] = ['occupation_encoded', 'item_id']

# gender * genre
feature_config['crossed_cols'] = [col for col in train_df.columns if '-' in col]

# gender + occupation
feature_config['one_hot_cols'] = feature_config['binary_cols'] + [col for col in train_df.columns if 'occupation_' in col]

# genres
feature_config['multi_hot_cols'] = ['unknown', 'Action', 'Adventure',
                    'Animation', 'Childrens', 'Comedy', 
                    'Crime', 'Documentary', 'Drama',
                    'Fantasy', 'Film_Noir', 'Horror', 
                    'Musical', 'Mystery', 'Romance',
                    'Sci_Fi', 'Thriller', 'War', 'Western']

feature_config['wide_cols'] = feature_config['continuous_cols'] + \
                                feature_config['one_hot_cols'] + \
                                feature_config['multi_hot_cols']
# deep_cols = continuous_cols + embedding_inputs + multi_embeddings + ['gender_M']

feature_config['target'] = 'like'

In [68]:
from torch.utils.data import TensorDataset

def make_dataset(df, feature_config):
    X_wide = df[feature_config['wide_cols']]
    X_deep_dense = df[feature_config['continuous_cols'] + feature_config['binary_cols']]
    X_deep_embedding = df[feature_config['embedding_cols']]
    X_multi_hot = df[feature_config['multi_hot_cols']] 
    y = df[feature_config['target']]

    X_wide_tensor = torch.tensor(X_wide.values, dtype=torch.long)
    X_deep_dense_tensor = torch.tensor(X_deep_dense.values, dtype=torch.float32)
    X_deep_embedding_tensor = torch.tensor(X_deep_embedding.values, dtype=torch.float32)
    X_multi_hot_tensor = torch.tensor(X_multi_hot.values, dtype=torch.long)
    y_tensor = torch.tensor(y.values, dtype=torch.long).view(-1, 1)


    data_set = TensorDataset(X_wide_tensor, 
                               X_deep_dense_tensor, 
                               X_deep_embedding_tensor, 
                               X_multi_hot_tensor, 
                               y_tensor)
    return data_set


SyntaxError: incomplete input (1633751085.py, line 1)

In [64]:
wide_dim = len(wide_cols)
deep_dim = len(deep_cols)
num_genres = 19
genre_embedding_dim = 3
occupation_vocab_size, 
occupation_embedding_dim

NameError: name 'occupation_vocab_size' is not defined

In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader

# Convert to PyTorch tensors
X_wide_train_tensor = torch.tensor(X_wide_train.values, dtype=torch.float)
X_deep_train_tensor = torch.tensor(X_deep_train.values, dtype=torch.float)
y_train_tensor = torch.tensor(y_train.values, dtype=torch.float32).view(-1, 1)

X_wide_test_tensor = torch.tensor(X_wide_test.values, dtype=torch.float)
X_deep_test_tensor = torch.tensor(X_deep_test.values, dtype=torch.float)
y_test_tensor = torch.tensor(y_test.values, dtype=torch.float32).view(-1, 1)

# Create DataLoader instances
train_dataset = TensorDataset(X_wide_train_tensor, X_deep_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_wide_test_tensor, X_deep_test_tensor, y_test_tensor)

batch_size = 32
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


# model

In [199]:
class GenreEmbedding(nn.Module):
    def __init__(self, num_genres=19, embedding_dim=3):
        super(GenreEmbedding, self).__init__()
        self.num_genres = num_genres
        self.embedding_dim = embedding_dim
        self.embedding = nn.Embedding(self.num_genres, self.embedding_dim)
        # Creating an index tensor for all genres
        self.genre_indices = torch.arange(0, self.num_genres, dtype=torch.long)
    
    def forward(self, x):
        # x is expected to be a batch of multi-hot vectors of shape (batch_size, num_genres)
        
        # Embed all genres
        all_genres_embedded = self.embedding(self.genre_indices)
        
        # Use the multi-hot vectors to select and sum embeddings
        genre_embeddings = torch.matmul(x, all_genres_embedded)
        
        return genre_embeddings

In [None]:
class WideAndDeep(nn.Module):
    def __init__(
            self, 
            wide_dim, 
            deep_dim, 
            num_genres, 
            genre_embedding_dim, 
            # occupation_vocab_size, 
            # occupation_embedding_dim
            embedding_inputs, 
            ):
        
        super(WideAndDeep, self).__init__()
        self.wide = nn.Linear(wide_dim, 1)
        self.genre_embedding = GenreEmbedding(num_genres, genre_embedding_dim)
        # self.occupation_embedding = nn.Embedding(occupation_vocab_size, 
        #                                          occupation_embedding_dim)
        
        for col,val,dim in self.embedding_inputs:
            setattr(self, 'emb_layer_'+col, nn.Embedding(val, dim))
        
        total_deep_dim = deep_dim + genre_embedding_dim + occupation_embedding_dim
        self.deep = nn.Sequential(
            nn.Linear(total_deep_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )
        self.output = nn.Linear(2, 1)

    def forward(self, X_wide, X_deep, X_genres, X_occupation):

        emb = [getattr(self, 'emb_layer_'+col)(X_d[:,self.deep_column_idx[col]].long())
               for col,_,_ in self.embeddings_input]
        genres_embedded = self.genre_embedding(X_genres)
        occupation_embedded = self.occupation_embedding(X_occupation)
        X_deep_combined = torch.cat([X_deep, genres_embedded, occupation_embedded], 1)
        wide_out = self.wide(X_wide)
        deep_out = self.deep(X_deep_combined)
        combined = torch.cat([wide_out, deep_out], 1)
        out = torch.sigmoid(self.output(combined))
        return out

In [None]:
class WideAndDeep(nn.Module):

    def __init__(self, wide_dim, embeddings_input, continuous_cols, deep_column_idx, hidden_layers, n_class):

        super(WideAndDeep, self).__init__()
        self.wide_dim = wide_dim
        self.deep_column_idx = deep_column_idx
        self.embeddings_input = embeddings_input
        self.continuous_cols = continuous_cols
        self.hidden_layers = hidden_layers
        self.n_class = n_class

        for col,val,dim in self.embeddings_input:
            setattr(self, 'emb_layer_'+col, nn.Embedding(val, dim))

        input_emb_dim = np.sum([emb[2] for emb in self.embeddings_input])
        self.linear_1 = nn.Linear(input_emb_dim+len(continuous_cols), self.hidden_layers[0])
        for i,h in enumerate(self.hidden_layers[1:],1):
            setattr(self, 'linear_'+str(i+1), nn.Linear(self.hidden_layers[i-1], self.hidden_layers[i] ))

        self.output = nn.Linear(self.hidden_layers[-1]+self.wide_dim, n_class)

    def forward(self, X_w, X_d):

        emb = [getattr(self, 'emb_layer_'+col)(X_d[:,self.deep_column_idx[col]].long())
               for col,_,_ in self.embeddings_input]

        cont_idx = [self.deep_column_idx[col] for col in self.continuous_cols]
        cont = [X_d[:, cont_idx].float()]

        deep_inp = torch.cat(emb+cont, 1)

        x_deep = F.relu(self.linear_1(deep_inp))
        for i in range(1,len(self.hidden_layers)):
            x_deep = F.relu( getattr(self, 'linear_'+str(i+1))(x_deep) )

        wide_deep_input = torch.cat([x_deep, X_w.float()], 1)

        out = F.sigmoid(self.output(wide_deep_input))

        return out

In [170]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 10
for epoch in range(epochs):
    model.train()
    for X_w, X_d, y in train_loader:
        optimizer.zero_grad()
        y_pred = model(X_w, X_d)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

<enumerate at 0x7f9103da8860>