In [11]:
import pickle
import torch
import json
import numpy as np
from sklearn.neighbors import NearestNeighbors
from nltk.corpus import stopwords
from nltk.tokenize import TweetTokenizer

with open('w2i_movie_names.json') as f:
    w2i_movies = json.load(f)

with open('i2w_movie_names.json') as f:
    i2w_movies = json.load(f)

In [12]:
with open('movie_data_separate.pkl', 'rb') as f:
    movie_data = pickle.load(f)
    
with open('neighbours.pkl', 'rb') as f:
    neighbours = pickle.load(f)

In [13]:
with open('w2i_review_comments_plot.json') as f:
        w2i_rpc = json.load(f)
with open('i2w_review_comments_plot.json') as f:
        i2w_rpc = json.load(f)

In [15]:
len(i2w_rpc.keys())

35195

In [16]:
stop_words = stopwords.words('english')
tknzr = TweetTokenizer()

In [17]:
def k_nearest_neighbors(k, embeddings):
    k += 1
    nbrs = NearestNeighbors(n_neighbors=k, algorithm='kd_tree')
    nbrs.fit(embeddings)
    distances, indices = nbrs.kneighbors(embeddings)
    return distances, indices

In [18]:
def get_average_utterance_embedding(utterance, embed_dim, stopwords, w2i, trained_word_embeddings):
    #obtain the average embedding for the whole utterance using the word embeddings
    #learned in the movie embedding training
    utterance_embedding = torch.zeros(embed_dim)
    
    count = 0
    for w in utterance:
        #skip stop words
        if w in stop_words or w in ['<SOS>', '<EOS>']:
            pass
        elif w in w2i: #word in dictionary
            #print('word',w)
            #TO DO REMOVE THIS IF
            if w2i[w] < trained_word_embeddings.shape[0]:
                word_em = trained_word_embeddings[w2i[w]]
                utterance_embedding += word_em
                count += 1
        else:
            #print('Unk',w)
            word_em = trained_word_embeddings[0] #unk
            utterance_embedding += word_em
            count += 1
            
    #print(utterance_embedding, utterance_embedding/count, count)
    avg_utterances_embedding = utterance_embedding/count
    
    return avg_utterances_embedding

In [41]:
def get_similar_movie_responses(movie_id, n, model, w2i, i2w, utterance, neighbours, stop_words):
    '''
    w2i and i2w are vocabularies for plot-review-comments
    utterance is a tokenized sentence of strings
    returning a list of tokenized sentences of strings
    '''
    similar_movie_data = neighbours[movie_id] 
    similar_movie_id = similar_movie_data.imdb_id
    print(similar_movie_id)
    similar_movie_chat = similar_movie_data.chat
    
    trained_word_embeddings = model['model_state_dict']['word_embedding.weight']
    #print(trained_word_embeddings.shape)
    embed_dim = trained_word_embeddings.shape[1]
    
    avg_utterance_embedding = get_average_utterance_embedding(utterance, embed_dim, stop_words, w2i, trained_word_embeddings)
    
    similar_responses = []  
    
    all_chat_reps = []    
    all_chat_reps.append(avg_utterance_embedding.numpy())
    all_chat_indices = [(-1,-1)]
    for c in range(len(similar_movie_chat)):
        chat = similar_movie_chat[c]
        enc = chat.encoder_chat
        dec = chat.decoder_chat
        
        for s in range(len(enc)):
            sent = enc[s]            
            sent = tknzr.tokenize(sent)
            #print(sent)
            sent_avg_embedding = get_average_utterance_embedding(sent, embed_dim, stop_words, w2i, trained_word_embeddings)
            
            all_chat_reps.append(sent_avg_embedding.numpy())
            #chat index and then sentence index for speaker 1
            #so that we can get the related speaker 2 utterance
            all_chat_indices.append((c,s))  
        
    #print(all_chat_reps[0])
    distances, indices = k_nearest_neighbors(n, all_chat_reps)
    print(indices)
    neighbours = indices[0]
    
    for n in neighbours:
        (c,s) = all_chat_indices[n]
        print(c,s)
        if c != -1:
            print(similar_movie_chat[c].encoder_chat[s])
            print(similar_movie_chat[c].decoder_chat[s])
            similar_responses.append(tknzr.tokenize(similar_movie_chat[c].decoder_chat[s]))
            
    return similar_responses

movie_id = 'tt0058150'
n = 5
utterance = ['which', 'is', 'your', 'favourite', 'character'] #speaker 1 utterance
model = torch.load('model_movie.pkl', map_location='cpu' )
similar_responses = get_similar_movie_responses(movie_id, n, model, w2i_rpc, i2w_rpc, utterance, neighbours,stopwords)

tt0061452
[[ 0 42  6 26 14 12]
 [37  1 32 16  3 22]
 [ 2 43 30 46 45  3]
 [ 3 46 45 34 40 30]
 [ 4 45 34 44 12  3]
 [ 5  3 46 40 22 34]
 [ 0 42  6 26 14 12]
 [ 7 30 34 45  3 39]
 [ 8 34 35 39 44 40]
 [ 9 40 39 44 30 43]
 [10 21 22 28 11  3]
 [11 34  3 46 40 30]
 [12 31  4 15 45  3]
 [13 40 34 46  3 39]
 [14 30 39 40 46 25]
 [15 31 12 45 34 13]
 [16  3 25 37  1 32]
 [17 15  5 39 31 45]
 [18 34 46 13  3 35]
 [19 23 34 35 13  3]
 [20 39 34 45 30 44]
 [21 10 28 22  3  1]
 [22 28  5  3 44 45]
 [23 34 44 19 45 40]
 [24 29 40 13 45 39]
 [25 46 40 16 44  3]
 [ 0 42  6 26 14 12]
 [27 45 12  5 43 31]
 [28 22  3 46 45 44]
 [29 43 46 30 45 44]
 [30 45 46 39  3 40]
 [31 15 12 45 40 34]
 [32 37  1 16  3  5]
 [33 43 45 40  3 35]
 [34 45 40 39 44 46]
 [35 43 34 45  8 44]
 [36 45 40 44 46 34]
 [37  1 32 16  3 22]
 [38 30 16 39  3  2]
 [39 34 45 44 40 30]
 [40 34 44 45 13 39]
 [41 35 13 46 39  8]
 [ 0 42  6 26 14 12]
 [43 35 45 44 46  4]
 [44 45 34 40  4 43]
 [45 44 34  4 46 40]
 [46  3 45 34 30 25]]
-1