In [21]:
import numpy as np
import pickle

In [22]:
def get_query_vector(query, model_weights):
    kws = [kw.lower() for kw in query.split(' ')]
    n = 0
    res = np.zeros(64)
    for kw in kws:
        if kw in model_weights['kw_set']:
            print(kw)
            res+=model_weights['kw_embedding'][kw]
            n+=1
    if n==0:
        return model_weights['kw_default_embedding']
    else:
        return res/n

def get_genre_vector(genres, model_weights):
    v = np.zeros((22,1))
    for g in genres:
        if g in model_weights['genres_dict']:
            v[model_weights['genres_dict'][g]]=1
    
    return model_weights['genres_weight'].dot(v).squeeze(1)+model_weights['genres_bias']

In [23]:
def predict(model_weights, query, user_embedding, genres):

    query_embedding = get_query_vector(query, model_weights)
    genres_embedding = get_genre_vector(genres_q, model_weights)

    final_query_em = user_embedding + genres_embedding + query_embedding

    nM = model_weights['movie_embedding'].shape[0]
    relevance_em = np.concatenate((model_weights['movie_embedding'], genres_embedding.reshape(1,64).repeat(nM,axis=0), query_embedding.reshape(1,64).repeat(nM,axis=0)),axis=1)

    for i, fc in enumerate(model_weights['fc_relevance']):
        relevance_em = relevance_em.dot(fc['weight'].T)+fc['bias']
        if i<len(model_weights['fc_relevance'])-1:
            relevance_em = np.tanh(relevance_em)
        else:
            relevance_em = 1. / (1. + np.exp(-relevance_em))

    relevance_score = relevance_em.squeeze(axis=1)
    
    rating_score = final_query_em.dot(model_weights['movie_embedding'].T)
    
    return relevance_score, rating_score

In [24]:
if __name__=='__main__':
    
    with open('./../data/processed_data/model_weights_v04','rb') as f:
        model_weights = pickle.load(f)
    
    query = 'toy story'
    genres_q = ['family','animation']
    
    rel, rat = predict(model_weights, query, model_weights['ave_user_embedding'], genres_q)
    
    ind_rel = np.argsort(rel)[::-1][:100]
    ind = [ind_rel[i] for i in np.argsort(rat[ind_rel])[::-1]]

    print(ind)


toy
story
[3039, 2619, 3462, 3146, 2708, 1650, 3803, 2450, 2445, 1821, 3305, 2027, 2888, 3142, 2290, 929, 2279, 1742, 1390, 2722, 159, 2345, 1460, 3478, 1118, 2249, 3220, 2906, 82, 164, 1035, 160, 2606, 613, 1249, 2928, 1954, 444, 2373, 2270, 3289, 2971, 3402, 2828, 3804, 650, 2847, 3847, 2383, 3119, 856, 1002, 203, 2882, 722, 996, 353, 3380, 1507, 2674, 3759, 990, 3486, 3288, 1570, 765, 1734, 815, 1053, 3004, 2372, 1658, 1412, 2066, 1464, 1867, 2396, 3643, 3357, 791, 897, 1200, 3360, 1212, 1465, 1480, 315, 892, 3418, 1740, 3113, 2305, 962, 2210, 174, 699, 772, 380, 2942, 700]
