In [10]:
import pickle
import torch
import json
import numpy as np
from sklearn.neighbors import NearestNeighbors
from nltk.corpus import stopwords
from nltk.tokenize import TweetTokenizer

with open('w2i_movie_names.json') as f:
    w2i_movies = json.load(f)

with open('i2w_movie_names.json') as f:
    i2w_movies = json.load(f)

In [11]:
with open('movie_data_separate.pkl', 'rb') as f:
    movie_data = pickle.load(f)
    
with open('neighbours.pkl', 'rb') as f:
    neighbours = pickle.load(f)

In [12]:
with open('w2i_review_comments_fact.json') as f:
        w2i_rpc = json.load(f)
with open('i2w_review_comments_fact.json') as f:
        i2w_rpc = json.load(f)

In [13]:
len(i2w_rpc.keys())

20155

In [14]:
w2i_rpc['unknown']

2

In [15]:
stop_words = stopwords.words('english')
tknzr = TweetTokenizer()

In [16]:
def k_nearest_neighbors(k, embeddings):
    k += 1
    nbrs = NearestNeighbors(n_neighbors=k, algorithm='kd_tree')
    nbrs.fit(embeddings)
    distances, indices = nbrs.kneighbors(embeddings)
    return distances, indices

In [17]:
def get_average_utterance_embedding(utterance, embed_dim, stopwords, w2i, trained_word_embeddings):
    #obtain the average embedding for the whole utterance using the word embeddings
    #learned in the movie embedding training
    utterance_embedding = torch.zeros(embed_dim)
    
    count = 0
    for w in utterance:
        #skip stop words
        if w in stop_words or w in ['<SOS>', '<EOS>']:
            pass
        elif w in w2i: #word in dictionary
            #print('word',w)
            word_em = trained_word_embeddings[w2i[w]]
            utterance_embedding += word_em
            count += 1
        else:
            word_em = trained_word_embeddings[w2i_rpc['unknown']] #unk
            utterance_embedding += word_em
            count += 1
            
    #print(utterance_embedding, utterance_embedding/count, count)
    avg_utterances_embedding = utterance_embedding/count
    
    return avg_utterances_embedding

In [18]:
def get_similar_movie_responses(movie_id, n, model, w2i, i2w, utterance, neighbours, stop_words):
    '''
    w2i and i2w are vocabularies for plot-review-comments
    utterance is a tokenized sentence of strings
    returning a list of tokenized sentences of strings
    '''
    similar_movie_data = neighbours[movie_id] 
    similar_movie_id = similar_movie_data.imdb_id
    print(similar_movie_id)
    similar_movie_chat = similar_movie_data.chat
    
    trained_word_embeddings = model['model_state_dict']['word_embedding.weight']
    #print(trained_word_embeddings.shape)
    embed_dim = trained_word_embeddings.shape[1]
    
    avg_utterance_embedding = get_average_utterance_embedding(utterance, embed_dim, stop_words, w2i, trained_word_embeddings)
    
    similar_responses = []  
    
    all_chat_reps = []    
    all_chat_reps.append(avg_utterance_embedding.numpy())
    all_chat_indices = [(-1,-1)]
    for c in range(len(similar_movie_chat)):
        chat = similar_movie_chat[c]
        enc = chat.encoder_chat
        dec = chat.decoder_chat
        
        for s in range(len(enc)):
            sent = enc[s]            
            sent = tknzr.tokenize(sent)
            #print(sent)
            sent_avg_embedding = get_average_utterance_embedding(sent, embed_dim, stop_words, w2i, trained_word_embeddings)
            
            all_chat_reps.append(sent_avg_embedding.numpy())
            #chat index and then sentence index for speaker 1
            #so that we can get the related speaker 2 utterance
            all_chat_indices.append((c,s))  
        
    #print(all_chat_reps[0])
    distances, indices = k_nearest_neighbors(n, all_chat_reps)
    print(indices)
    neighbours = indices[0]
    
    for n in neighbours:
        (c,s) = all_chat_indices[n]
        print(c,s)
        if c != -1:
            print(similar_movie_chat[c].encoder_chat[s])
            print(similar_movie_chat[c].decoder_chat[s])
            similar_responses.append(tknzr.tokenize(similar_movie_chat[c].decoder_chat[s]))
            
    return similar_responses

movie_id = 'tt0058150'
n = 5
utterance = ['which', 'is', 'your', 'favourite', 'character'] #speaker 1 utterance
model = torch.load('model_movie.pkl', map_location='cpu' )
similar_responses = get_similar_movie_responses(movie_id, n, model, w2i_rpc, i2w_rpc, utterance, neighbours,stopwords)

tt0061452
[[ 0  6 42 26 30 35]
 [37  1 32 16  2 21]
 [ 2 34 39 28 43 40]
 [ 3 46 45 40  4 39]
 [ 4 45 43 36 34 44]
 [ 5 43 45 35 40 44]
 [ 0  6 42 26 30 35]
 [ 7 20 36 45 12 34]
 [ 8 35 40 44 45 43]
 [ 9 34 44 45 40 11]
 [10 21 29 28 22 45]
 [11 34 45 46 40  4]
 [12 45  4 36 34 46]
 [13 34 11 40 45 46]
 [14 30 43 28 45 40]
 [15 31 40 44 34 43]
 [16 32  1 37  3 39]
 [17 40 31 44 15 43]
 [18 45 39 43  8 44]
 [19 43 28 40 44 39]
 [20 45 36 35  4 34]
 [21 10 28 22 39 29]
 [22 40 35 28 44 45]
 [23 45 40 46 39 44]
 [24 45 39 43 34 40]
 [25 40 30 44 43 45]
 [ 0  6 42 26 30 35]
 [27 43 40 45 28  4]
 [28 40 22 43 45 29]
 [29 45 40 30 43 34]
 [30 40 45 35 44 43]
 [31 15 40 44 43 39]
 [32 37  1 16 28  3]
 [33 22 40 45 43 28]
 [34 11 40 45 43 35]
 [35 45 43 34 40  8]
 [36  4 45 34 43 20]
 [37  1 32 16  2 21]
 [38 30 40 22  8 34]
 [39 44 45 40 34 35]
 [40 43 45 44 34 28]
 [41 44 34 45 39  4]
 [ 0  6 42 26 30 35]
 [43 45 40  4 35 34]
 [44 45 39 40 43 34]
 [45 43  4 44 40 35]
 [46  3 45 11 40 34]]
-1

In [19]:
similar_responses

[['my',
  'favorite',
  'character',
  'was',
  'peter',
  'sellers',
  'version',
  'of',
  'james',
  'bond',
  'he',
  'along',
  'with',
  'woody',
  'allen',
  'gave',
  'the',
  'funniest',
  'performances',
  'in',
  'the',
  'film'],
 ['my',
  'favorite',
  'character',
  'was',
  'bond',
  'because',
  'he',
  'is',
  'always',
  'dynamic'],
 ['my',
  'favorite',
  'character',
  'was',
  'peter',
  'sellers',
  'version',
  'of',
  'james',
  'bond',
  'he',
  'along',
  'with',
  'woody',
  'allen',
  'gave',
  'the',
  'funniest',
  'performances',
  'in',
  'the',
  'film'],
 ['its',
  'my',
  'favorite',
  'aside',
  'from',
  'goldeneye',
  'that',
  'first',
  'chase',
  'is',
  'the',
  'best',
  'action',
  'sequence',
  'in',
  'history',
  'im',
  'still',
  'blow',
  'away',
  'by',
  'it'],
 ['casino',
  'royale',
  'also',
  'has',
  'probably',
  'the',
  'best',
  'of',
  'bonds',
  'wit',
  ':',
  'everyone',
  'is',
  'going',
  'to',
  'know',
  'that',
  'y