In [10]:
from modeling_nmf import NMF, NMFConfig
from data_utils import get_reddit_data
from collections import defaultdict
import torch
from transformers import AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer


In [11]:
from tqdm import tqdm

def build_datastore_embedding(sentences, model, tokenizer, device):
    print('building datastore embeddings')
    model.eval()
    
    # Initialize an empty list to hold the sentence embeddings
    all_sentence_embeddings = []
    
    # Process sentences in batches
    batch_size = 64
    
    for i in tqdm(range(0, len(sentences), batch_size), total=len(sentences)//batch_size):
        batch_sentences = sentences[i:i+batch_size]
        
        # Tokenize sentences
        encoded_input = tokenizer(batch_sentences, padding=True, truncation=True, return_tensors='pt').to(device)
        
        # Compute token embeddings
        with torch.no_grad():
            model_output = model(**encoded_input, output_hidden_states=True)
    
        # # Perform pooling
        # sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']).cpu()
    
        # insteading of doing pooling, can just use last hidden state
        sentence_embeddings = model_output.hidden_states[-1][:,0,:]
        
        all_sentence_embeddings.append(sentence_embeddings.cpu())
    
    # Concatenate all batched embeddings
    all_sentence_embeddings = torch.cat(all_sentence_embeddings, dim=0)
    
    all_sentence_embeddings = F.normalize(all_sentence_embeddings, p=2, dim=1).cpu()
    post_embeddings = all_sentence_embeddings

    return post_embeddings

In [12]:
class KNNLMForTuningHyperParams:

    def __init__(
        self, 
        data_path,
        tokenizer_name_or_path = 'sentence-transformers/all-MiniLM-L6-v2',
        model_name_or_path = 'saved_models'
    ):


        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        sentence_embedding_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
        nmf = NMF.from_pretrained(model_name_or_path).to(device)
        nmf.eval()
        sentence_embedding_model = nmf.model
        
        training_dataset, validation_dataset, movie_vocab = get_reddit_data_with_heldout(data_path, heldout_portion=0.2)
        movie_vocab = [m.split(' (')[0] for m in movie_vocab] # remove the year at end
        training_posts = sorted(list(set(training_dataset['context'])))
        trainingpost2idx = dict(zip(
            training_posts, 
            list(range(len(training_posts)))
        ))

        trainingpostidx2movies = defaultdict(list)
        movie_from_posts = []
        for i in range(len(training_dataset['context'])):
            post_idx = trainingpost2idx[training_dataset['context'][i]]
            # the split is for removing year at end, e.g. " (2019)"
            # movie = movie_vocab[training_dataset['label'][i]].split(' (')[0] 
            trainingpostidx2movies[post_idx].append(training_dataset['label'][i])

        post_embeddings = build_datastore_embedding(
            sentences = training_posts, 
            model = sentence_embedding_model, 
            tokenizer = sentence_embedding_tokenizer, 
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        ).to(device)

        self.device = device
        self.sentence_embedding_tokenizer = sentence_embedding_tokenizer
        self.sentence_embedding_model = sentence_embedding_model
        self.movie_vocab = movie_vocab
        self.training_posts = training_posts
        self.post_embeddings = post_embeddings
        self.trainingpostidx2movies = trainingpostidx2movies
        self.nmf = nmf
        self.validation_dataset = validation_dataset

    def encode_sentences(self, batch_sentences):
        with torch.no_grad():
            encoded_input = self.sentence_embedding_tokenizer(
                batch_sentences, 
                padding=True, 
                truncation=True, 
                return_tensors='pt'
            ).to(self.device)
            with torch.no_grad():
                model_output = self.sentence_embedding_model(**encoded_input, output_hidden_states=True)
            sentence_embeddings = model_output.hidden_states[-1][:,0,:]
    
        return sentence_embeddings

    def top_post_ids_retrieval(self, query):
    
        query_embedding = self.encode_sentences(query)[0].reshape(1, -1)
        cosine_similarities = torch.cosine_similarity(query_embedding, self.post_embeddings).cpu().numpy()
        sorted_indices = np.argsort(cosine_similarities)[::-1]
        return sorted_indices, cosine_similarities[sorted_indices]

    def count_based_probability(
        self,
        query, 
        num_posts_to_consider=30, 
        return_logits=False, 
        distance_weighting=False,
        temperature = 1.0
    ):
        relevant_post_ids, similarities = self.top_post_ids_retrieval(query)
        movie_pool = defaultdict(int)
        probas = np.zeros(len(self.movie_vocab))
        for i, id in enumerate(relevant_post_ids[:num_posts_to_consider]):
            for movieid in self.trainingpostidx2movies[id]:
                if distance_weighting:
                    probas[movieid] += similarities[i]/temperature
                else:
                    probas[movieid] += 1
        if return_logits:
            return probas
        probas = torch.nn.functional.softmax(torch.from_numpy(probas), dim=-1)
        return probas

    def predictor_probability(
        self, 
        query,
        return_logits = False
    ):
        
        with torch.no_grad():
            model_input = self.sentence_embedding_tokenizer(
                query, 
                return_tensors='pt', 
                max_length=368, 
                truncation=True
            )

            logits = reddit_knnlm_recommender.nmf(
                            model_input['input_ids'].to(self.device), 
                            model_input['token_type_ids'].to(self.device), 
                            model_input['attention_mask'].to(self.device),
                            labels=None
                        ).logits
        if return_logits:
            return logits.cpu().numpy()
        probas = F.softmax(logits, dim=-1)[0]
        return  probas.cpu().numpy()

In [13]:
class KNNLMRecommender:

    def __init__(
        self, 
        tokenizer_name_or_path = 'sentence-transformers/all-MiniLM-L6-v2',
        model_name_or_path = 'saved_models',
        data_path = 'reddit/reddit_large_train.csv'
    ):


        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        sentence_embedding_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
        nmf = NMF.from_pretrained(model_name_or_path).to(device)
        nmf.eval()
        sentence_embedding_model = nmf.model
        
        training_dataset, movie_vocab = get_reddit_data(data_path)
        movie_vocab = [m.split(' (')[0] for m in movie_vocab] # remove the year at end
        training_posts = sorted(list(set(training_dataset['context'])))
        trainingpost2idx = dict(zip(
            training_posts, 
            list(range(len(training_posts)))
        ))

        trainingpostidx2movies = defaultdict(list)
        movie_from_posts = []
        for i in range(len(training_dataset['context'])):
            post_idx = trainingpost2idx[training_dataset['context'][i]]
            # the split is for removing year at end, e.g. " (2019)"
            # movie = movie_vocab[training_dataset['label'][i]].split(' (')[0] 
            trainingpostidx2movies[post_idx].append(training_dataset['label'][i])

        post_embeddings = build_datastore_embedding(
            sentences = training_posts, 
            model = sentence_embedding_model, 
            tokenizer = sentence_embedding_tokenizer, 
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )

        self.device = device
        self.sentence_embedding_tokenizer = sentence_embedding_tokenizer
        self.sentence_embedding_model = sentence_embedding_model
        self.movie_vocab = movie_vocab
        self.training_posts = training_posts
        self.post_embeddings = post_embeddings
        self.trainingpostidx2movies = trainingpostidx2movies
        self.nmf = nmf

    def encode_sentences(self, batch_sentences):
        with torch.no_grad():
            encoded_input = self.sentence_embedding_tokenizer(
                batch_sentences, 
                padding=True, 
                truncation=True, 
                return_tensors='pt'
            ).to(self.device)
            with torch.no_grad():
                model_output = self.sentence_embedding_model(**encoded_input, output_hidden_states=True)
            sentence_embeddings = model_output.hidden_states[-1][:,0,:]
    
        return sentence_embeddings

    def top_post_ids_retrieval(self, query):
    
        query_embedding = self.encode_sentences(query)[0].reshape(1, -1).cpu()
        cosine_similarities = cosine_similarity(query_embedding, self.post_embeddings)
        sorted_indices = np.argsort(cosine_similarities[0])[::-1]
        return sorted_indices, cosine_similarities[0][sorted_indices]

    def count_based_probability(
        self,
        query, 
        num_posts_to_consider=30, 
        return_logits=False, 
        distance_weighting=False
    ):
        relevant_post_ids, similarities = self.top_post_ids_retrieval(query)
        movie_pool = defaultdict(int)
        probas = np.zeros(len(self.movie_vocab))
        for i, id in enumerate(relevant_post_ids[:num_posts_to_consider]):
            for movieid in self.trainingpostidx2movies[id]:
                if distance_weighting:
                    probas[movieid] += similarities[i]
                else:
                    probas[movieid] += 1
        if return_logits:
            return probas
        probas = torch.nn.functional.softmax(torch.from_numpy(probas), dim=-1)
        return probas

    def predictor_probability(
        self, 
        query,
        return_logits = False
    ):
        
        with torch.no_grad():
            model_input = self.sentence_embedding_tokenizer(
                query, 
                return_tensors='pt', 
                max_length=368, 
                truncation=True
            )

            logits = reddit_knnlm_recommender.nmf(
                            model_input['input_ids'].to(self.device), 
                            model_input['token_type_ids'].to(self.device), 
                            model_input['attention_mask'].to(self.device),
                            labels=None
                        ).logits
        if return_logits:
            return logits.cpu().numpy()
        probas = F.softmax(logits, dim=-1)[0]
        return  probas.cpu().numpy()

In [14]:
import json

def get_eligible_entities(resource_path=''):
    entity2id = eval(open(resource_path+'entity2id.json', 'r').readlines()[0])
    id2entity = {v:k for k,v in entity2id.items()}
    eligible_entities = [id2entity[idx].split('/')[-1].split('_(')[0].rstrip('>').replace('_',' ') for idx in \
    eval(open(resource_path+'item_ids.json', 'r').readlines()[0])]
    return eligible_entities

In [15]:
inspired_eligible_entities = set(get_eligible_entities('./entity_assets/inspired/'))

In [16]:
redial_eligible_entities = set(get_eligible_entities('./entity_assets/redial/'))

## Inspired

In [17]:
from data_utils import get_reddit_data_with_heldout
import numpy as np
import pandas as pd
from scipy.stats import sem

reddit_knnlm_recommender = KNNLMForTuningHyperParams(
    data_path= 'datasets/inspired/inspired_train.csv',
    model_name_or_path = 'models/inspired'
)

k=20
n_neighbors = [15, 30, 60, 90, 120, 150, 180]
recall = []
for num_posts_to_consider in n_neighbors:
    hits_at_k = []
    cache = dict()
    for idx in tqdm(range(len(reddit_knnlm_recommender.validation_dataset['context'])), total=len(reddit_knnlm_recommender.validation_dataset['context'])):
        query = reddit_knnlm_recommender.validation_dataset['context'][idx]
        target = reddit_knnlm_recommender.movie_vocab[reddit_knnlm_recommender.validation_dataset['label'][idx]]
        if query not in cache:
            movie_counts = reddit_knnlm_recommender.count_based_probability(
                query, 
                num_posts_to_consider=num_posts_to_consider, 
                return_logits=True
            )
            recommended_movies = np.array(reddit_knnlm_recommender.movie_vocab)[np.argsort(-movie_counts)]
            recommended_movies = [e for e in recommended_movies if e in inspired_eligible_entities]
        else:
            recommended_movies = cache[query]
        
        hits_at_k.append(int(target in recommended_movies[:k]))
    recall.append(np.mean(hits_at_k))


best_n_neighbors = n_neighbors[np.argmax(recall)]
print('the recommended number of neighbors to use is ', best_n_neighbors)

testset = pd.read_csv('datasets/inspired/inspired_test.csv')
test_inputs = testset['test_inputs']
test_groundtruths = testset['test_outputs']

reddit_knnlm_recommender = KNNLMRecommender(
    data_path= 'datasets/inspired/inspired_train.csv',
    model_name_or_path = 'models/inspired'
)

K = [1,5, 10, 20,50,100,300]

print('-------------------- Retrieval -------------------------------')

cache = dict()
for k in K:
    hits_at_k = []
    for idx in tqdm(range(len(test_inputs)), total = len(test_inputs) ):
        query = test_inputs[idx]
        target = test_groundtruths[idx]
        
        if query not in cache:
            movie_counts = reddit_knnlm_recommender.count_based_probability(
                query, 
                num_posts_to_consider=best_n_neighbors, 
                return_logits=True
            )
            recommended_movies = np.array(reddit_knnlm_recommender.movie_vocab)[np.argsort(-movie_counts)]
            recommended_movies = [e for e in recommended_movies if e in inspired_eligible_entities]
            cache[query] = recommended_movies
        else:
            recommended_movies = cache[query]
        hits_at_k.append(int(target in recommended_movies[:k]))
    print('r@'+str(k), np.mean(hits_at_k),'; se: ', sem(hits_at_k))

print('-------------------- Recommend -------------------------------')

cache = dict()
for k in K:
    hits_at_k = []
    for idx in tqdm(range(len(test_inputs)), total = len(test_inputs) ):
        query = test_inputs[idx]
        target = test_groundtruths[idx]
        
        if query not in cache:
            movie_probas = reddit_knnlm_recommender.predictor_probability(
                query, 
                return_logits=False
            )
            recommended_movies = np.array(reddit_knnlm_recommender.movie_vocab)[np.argsort(-movie_probas)]
            recommended_movies = [e for e in recommended_movies if e in inspired_eligible_entities]
            cache[query] = recommended_movies
        else:
            recommended_movies = cache[query]
        hits_at_k.append(int(target in recommended_movies[:k]))
    print('r@'+str(k), np.mean(hits_at_k),'; se: ', sem(hits_at_k))


print('-------------------- R+R (rerank) -------------------------------')

cache = dict()
for k in K:
    hits_at_k = []
    for idx in tqdm(range(len(test_inputs)), total = len(test_inputs) ):
        query = test_inputs[idx]
        target = test_groundtruths[idx]
        
        if query not in cache:
            movie_counts = reddit_knnlm_recommender.count_based_probability(
                query, 
                num_posts_to_consider=best_n_neighbors, 
                return_logits=True
            )
            movie_probas = reddit_knnlm_recommender.predictor_probability(
                query, 
                return_logits=False
            )
            movie_scores = movie_counts + movie_probas
            recommended_movies = np.array(reddit_knnlm_recommender.movie_vocab)[np.argsort(-movie_scores)]
            recommended_movies = [e for e in recommended_movies if e in inspired_eligible_entities]
            cache[query] = recommended_movies
        else:
            recommended_movies = cache[query]
        hits_at_k.append(int(target in recommended_movies[:k]))
    print('r@'+str(k), np.mean(hits_at_k),'; se: ', sem(hits_at_k))

print('-------------------- R+R (rerank) with small gamma (cleaner solution for paper) -------------------------------')

cache = dict()
for k in K:
    hits_at_k = []
    for idx in tqdm(range(len(test_inputs)), total = len(test_inputs) ):
        query = test_inputs[idx]
        target = test_groundtruths[idx]
        
        if query not in cache:
            movie_counts = reddit_knnlm_recommender.count_based_probability(
                query, 
                num_posts_to_consider=best_n_neighbors, 
                return_logits=False
            )
            movie_probas = reddit_knnlm_recommender.predictor_probability(
                query, 
                return_logits=False
            )
            movie_scores = movie_counts*(1-(1e-10)) + movie_probas*1e-10
            recommended_movies = np.array(reddit_knnlm_recommender.movie_vocab)[np.argsort(-movie_scores)]
            recommended_movies = [e for e in recommended_movies if e in inspired_eligible_entities]
            cache[query] = recommended_movies
        else:
            recommended_movies = cache[query]
        hits_at_k.append(int(target in recommended_movies[:k]))
    print('r@'+str(k), np.mean(hits_at_k),'; se: ', sem(hits_at_k))

interaction preservance rate at 20000 items
1.0
num items  1436
flattening posts into training data


100%|████████████████████████████████████████████████████████████████████| 584/584 [00:00<00:00, 930297.58it/s]
100%|███████████████████████████████████████████████████████████████████| 147/147 [00:00<00:00, 1076025.63it/s]


building datastore embeddings


10it [00:01,  8.09it/s]                                                                                        
100%|███████████████████████████████████████████████████████████████████████| 475/475 [00:02<00:00, 186.19it/s]
100%|███████████████████████████████████████████████████████████████████████| 475/475 [00:02<00:00, 187.14it/s]
100%|███████████████████████████████████████████████████████████████████████| 475/475 [00:02<00:00, 186.30it/s]
100%|███████████████████████████████████████████████████████████████████████| 475/475 [00:02<00:00, 185.17it/s]
100%|███████████████████████████████████████████████████████████████████████| 475/475 [00:02<00:00, 183.85it/s]
100%|███████████████████████████████████████████████████████████████████████| 475/475 [00:02<00:00, 182.72it/s]
100%|███████████████████████████████████████████████████████████████████████| 475/475 [00:02<00:00, 181.21it/s]


the recommended number of neighbors to use is  90
interaction preservance rate at 20000 items
1.0
num items  1436
flattening posts into training data


100%|███████████████████████████████████████████████████████████████████| 731/731 [00:00<00:00, 1052897.05it/s]


building datastore embeddings


12it [00:01,  7.64it/s]                                                                                        


-------------------- Retrieval -------------------------------


100%|███████████████████████████████████████████████████████████████████████| 211/211 [00:01<00:00, 183.48it/s]


r@1 0.02843601895734597 ; se:  0.01146992169674842


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 179294.60it/s]


r@5 0.08530805687203792 ; se:  0.019276261283512997


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 195579.70it/s]


r@10 0.13744075829383887 ; se:  0.023759789586756046


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 187340.84it/s]


r@20 0.20853080568720378 ; se:  0.02803447781764188


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 165978.65it/s]


r@50 0.26540284360189575 ; se:  0.03046967065084662


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 155399.15it/s]


r@100 0.33649289099526064 ; se:  0.03260626767859446


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 108097.98it/s]


r@300 0.4881516587677725 ; se:  0.03449358924634331
-------------------- Recommend -------------------------------


100%|███████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 213.16it/s]


r@1 0.0 ; se:  0.0


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 183457.33it/s]


r@5 0.0 ; se:  0.0


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 197456.08it/s]


r@10 0.014218009478672985 ; se:  0.008169588695462889


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 189223.46it/s]


r@20 0.023696682464454975 ; se:  0.01049606495529992


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 166478.21it/s]


r@50 0.07582938388625593 ; se:  0.018267755671268843


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 152533.29it/s]


r@100 0.17061611374407584 ; se:  0.025958415047638204


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 104043.99it/s]


r@300 0.3981042654028436 ; se:  0.033779203187119355
-------------------- R+R (rerank) -------------------------------


100%|███████████████████████████████████████████████████████████████████████| 211/211 [00:01<00:00, 108.38it/s]


r@1 0.023696682464454975 ; se:  0.010496064955299918


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 174315.18it/s]


r@5 0.08530805687203792 ; se:  0.019276261283512997


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 197500.14it/s]


r@10 0.12322274881516587 ; se:  0.022681952442793474


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 189344.92it/s]


r@20 0.1943127962085308 ; se:  0.02730387495943987


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 170355.75it/s]


r@50 0.25118483412322273 ; se:  0.029927771242945177


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 153405.81it/s]


r@100 0.3412322274881517 ; se:  0.032717608075019784


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 108309.65it/s]


r@300 0.46445497630331756 ; se:  0.03441598128848769
-------------------- R+R (rerank) with small gamma (cleaner solution for paper) -------------------------------


100%|███████████████████████████████████████████████████████████████████████| 211/211 [00:01<00:00, 106.46it/s]


r@1 0.023696682464454975 ; se:  0.010496064955299918


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 176964.24it/s]


r@5 0.08530805687203792 ; se:  0.019276261283512997


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 192600.25it/s]


r@10 0.12322274881516587 ; se:  0.022681952442793474


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 181129.38it/s]


r@20 0.1943127962085308 ; se:  0.02730387495943987


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 164070.85it/s]


r@50 0.25118483412322273 ; se:  0.029927771242945177


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 152717.54it/s]


r@100 0.3412322274881517 ; se:  0.032717608075019784


100%|████████████████████████████████████████████████████████████████████| 211/211 [00:00<00:00, 108283.14it/s]

r@300 0.46445497630331756 ; se:  0.03441598128848769





## Reddit

In [17]:
from data_utils import get_reddit_data_with_heldout
import numpy as np
import pandas as pd
from scipy.stats import sem

# reddit_knnlm_recommender = KNNLMForTuningHyperParams(
#     data_path= 'datasets/reddit/reddit_large_train.csv',
#     model_name_or_path = 'models/reddit'
# )

# k=20
# n_neighbors = [15, 30, 60, 90, 120, 150, 180]
# recall = []
# for num_posts_to_consider in n_neighbors:
#     hits_at_k = []
#     cache = dict()
#     for idx in tqdm(range(len(reddit_knnlm_recommender.validation_dataset['context'])), total=len(reddit_knnlm_recommender.validation_dataset['context'])):
#         query = reddit_knnlm_recommender.validation_dataset['context'][idx]
#         target = reddit_knnlm_recommender.movie_vocab[reddit_knnlm_recommender.validation_dataset['label'][idx]]
#         if query not in cache:
#             movie_counts = reddit_knnlm_recommender.count_based_probability(
#                 query, 
#                 num_posts_to_consider=num_posts_to_consider, 
#                 return_logits=True
#             )
#             recommended_movies = np.array(reddit_knnlm_recommender.movie_vocab)[np.argsort(-movie_counts)]
#         else:
#             recommended_movies = cache[query]
        
#         hits_at_k.append(int(target in recommended_movies[:k]))
#     recall.append(np.mean(hits_at_k))


# best_n_neighbors = n_neighbors[np.argmax(recall)]
# print('the recommended number of neighbors to use is ', best_n_neighbors)

best_n_neighbors = 30 # uncomment above to tune n_neighbors

testset = pd.read_csv('datasets/reddit/reddit_test.csv')
test_inputs = testset['test_inputs']
test_groundtruths = testset['test_outputs']

reddit_knnlm_recommender = KNNLMRecommender(
    data_path= 'datasets/reddit/reddit_large_train.csv',
    model_name_or_path = 'models/reddit'
)

K = [1,5, 10, 20,50,100,300]

# print('-------------------- Retrieval -------------------------------')

# cache = dict()
# for k in K:
#     hits_at_k = []
#     for idx in tqdm(range(len(test_inputs)), total = len(test_inputs) ):
#         query = test_inputs[idx]
#         target = test_groundtruths[idx]
        
#         if query not in cache:
#             movie_counts = reddit_knnlm_recommender.count_based_probability(
#                 query, 
#                 num_posts_to_consider=best_n_neighbors, 
#                 return_logits=True
#             )
#             recommended_movies = np.array(reddit_knnlm_recommender.movie_vocab)[np.argsort(-movie_counts)]
#             cache[query] = recommended_movies
#         else:
#             recommended_movies = cache[query]
#         hits_at_k.append(int(target in recommended_movies[:k]))
#     print('r@'+str(k), np.mean(hits_at_k),'; se: ', sem(hits_at_k))

# print('-------------------- Recommend -------------------------------')

# cache = dict()
# for k in K:
#     hits_at_k = []
#     for idx in tqdm(range(len(test_inputs)), total = len(test_inputs) ):
#         query = test_inputs[idx]
#         target = test_groundtruths[idx]
        
#         if query not in cache:
#             movie_probas = reddit_knnlm_recommender.predictor_probability(
#                 query, 
#                 return_logits=False
#             )
#             recommended_movies = np.array(reddit_knnlm_recommender.movie_vocab)[np.argsort(-movie_probas)]
#             cache[query] = recommended_movies
#         else:
#             recommended_movies = cache[query]
#         hits_at_k.append(int(target in recommended_movies[:k]))
#     print('r@'+str(k), np.mean(hits_at_k),'; se: ', sem(hits_at_k))


# print('-------------------- R+R (rerank) -------------------------------')

# cache = dict()
# for k in K:
#     hits_at_k = []
#     for idx in tqdm(range(len(test_inputs)), total = len(test_inputs) ):
#         query = test_inputs[idx]
#         target = test_groundtruths[idx]
        
#         if query not in cache:
#             movie_counts = reddit_knnlm_recommender.count_based_probability(
#                 query, 
#                 num_posts_to_consider=best_n_neighbors, 
#                 return_logits=True
#             )
#             movie_probas = reddit_knnlm_recommender.predictor_probability(
#                 query, 
#                 return_logits=False
#             )
#             movie_scores = movie_counts + movie_probas
#             recommended_movies = np.array(reddit_knnlm_recommender.movie_vocab)[np.argsort(-movie_scores)]
#             cache[query] = recommended_movies
#         else:
#             recommended_movies = cache[query]
#         hits_at_k.append(int(target in recommended_movies[:k]))
#     print('r@'+str(k), np.mean(hits_at_k),'; se: ', sem(hits_at_k))

print('-------------------- R+R (rerank) with small gamma (cleaner solution for paper) -------------------------------')

cache = dict()
for k in K:
    hits_at_k = []
    for idx in tqdm(range(len(test_inputs)), total = len(test_inputs) ):
        query = test_inputs[idx]
        target = test_groundtruths[idx]
        
        if query not in cache:
            movie_counts = reddit_knnlm_recommender.count_based_probability(
                query, 
                num_posts_to_consider=best_n_neighbors, 
                return_logits=False
            )
            movie_probas = reddit_knnlm_recommender.predictor_probability(
                query, 
                return_logits=False
            )
            movie_scores = movie_counts*(1-(1e-10)) + movie_probas*1e-10
            recommended_movies = np.array(reddit_knnlm_recommender.movie_vocab)[np.argsort(-movie_scores)]
            cache[query] = recommended_movies
        else:
            recommended_movies = cache[query]
        hits_at_k.append(int(target in recommended_movies[:k]))
    print('r@'+str(k), np.mean(hits_at_k),'; se: ', sem(hits_at_k))

interaction preservance rate at 20000 items
0.9965472793625889
num items  20000
flattening posts into training data


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 39928/39928 [00:00<00:00, 390030.60it/s]


building datastore embeddings


621it [01:02,  9.96it/s]                                                                                                                    


-------------------- R+R (rerank) with small gamma (cleaner solution for paper) -------------------------------


100%|████████████████████████████████████████████████████████████████████████████████████████████████| 19438/19438 [02:09<00:00, 150.59it/s]


r@1 0.012604177384504579 ; se:  0.0008001810805301169


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 19438/19438 [00:00<00:00, 122067.68it/s]


r@5 0.05957403025002572 ; se:  0.0016977595429411435


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 19438/19438 [00:00<00:00, 122939.80it/s]


r@10 0.1020681140034983 ; se:  0.0021714614543734746


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 19438/19438 [00:00<00:00, 122901.07it/s]


r@20 0.1558802345920362 ; se:  0.002601854118789058


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 19438/19438 [00:00<00:00, 120070.40it/s]


r@50 0.2576396748636691 ; se:  0.003136888849981628


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 19438/19438 [00:00<00:00, 118067.11it/s]


r@100 0.34556024282333575 ; se:  0.003411003031134365


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 19438/19438 [00:00<00:00, 108624.50it/s]

r@300 0.47715814384195904 ; se:  0.0035826280908021206





## Redial

In [15]:
from data_utils import get_reddit_data_with_heldout
import numpy as np
import pandas as pd
from scipy.stats import sem

# reddit_knnlm_recommender = KNNLMForTuningHyperParams(
#     data_path= 'datasets/redial/redial_train.csv',
#     model_name_or_path = 'models/redial'
# )

# k=20
# n_neighbors = [15, 30, 60, 90, 120, 150, 180]
# recall = []
# for num_posts_to_consider in n_neighbors:
#     hits_at_k = []
#     cache = dict()
#     for idx in tqdm(range(len(reddit_knnlm_recommender.validation_dataset['context'])), total=len(reddit_knnlm_recommender.validation_dataset['context'])):
#         query = reddit_knnlm_recommender.validation_dataset['context'][idx]
#         target = reddit_knnlm_recommender.movie_vocab[reddit_knnlm_recommender.validation_dataset['label'][idx]]
#         if query not in cache:
#             movie_counts = reddit_knnlm_recommender.count_based_probability(
#                 query, 
#                 num_posts_to_consider=num_posts_to_consider, 
#                 return_logits=True
#             )
#             recommended_movies = np.array(reddit_knnlm_recommender.movie_vocab)[np.argsort(-movie_counts)]
#             recommended_movies = [e for e in recommended_movies if e in redial_eligible_entities]
#         else:
#             recommended_movies = cache[query]
        
#         hits_at_k.append(int(target in recommended_movies[:k]))
#     recall.append(np.mean(hits_at_k))


# best_n_neighbors = n_neighbors[np.argmax(recall)]
# print('the recommended number of neighbors to use is ', best_n_neighbors)

best_n_neighbors = 60

testset = pd.read_csv('datasets/redial/redial_test.csv')
test_inputs = testset['test_inputs']
test_groundtruths = testset['test_outputs']

reddit_knnlm_recommender = KNNLMRecommender(
    data_path= 'datasets/redial/redial_train.csv',
    model_name_or_path = 'models/redial'
)

K = [1,5, 10, 20,50,100,300]

print('-------------------- Retrieval -------------------------------')

cache = dict()
for k in K:
    hits_at_k = []
    for idx in tqdm(range(len(test_inputs)), total = len(test_inputs) ):
        query = test_inputs[idx]
        target = test_groundtruths[idx]
        
        if query not in cache:
            movie_counts = reddit_knnlm_recommender.count_based_probability(
                query, 
                num_posts_to_consider=best_n_neighbors, 
                return_logits=True
            )
            recommended_movies = np.array(reddit_knnlm_recommender.movie_vocab)[np.argsort(-movie_counts)]
            recommended_movies = [e for e in recommended_movies if e in redial_eligible_entities]
            cache[query] = recommended_movies
        else:
            recommended_movies = cache[query]
        hits_at_k.append(int(target in recommended_movies[:k]))
    print('r@'+str(k), np.mean(hits_at_k),'; se: ', sem(hits_at_k))

print('-------------------- Recommend -------------------------------')

cache = dict()
for k in K:
    hits_at_k = []
    for idx in tqdm(range(len(test_inputs)), total = len(test_inputs) ):
        query = test_inputs[idx]
        target = test_groundtruths[idx]
        
        if query not in cache:
            movie_probas = reddit_knnlm_recommender.predictor_probability(
                query, 
                return_logits=False
            )
            recommended_movies = np.array(reddit_knnlm_recommender.movie_vocab)[np.argsort(-movie_probas)]
            recommended_movies = [e for e in recommended_movies if e in redial_eligible_entities]
            cache[query] = recommended_movies
        else:
            recommended_movies = cache[query]
        hits_at_k.append(int(target in recommended_movies[:k]))
    print('r@'+str(k), np.mean(hits_at_k),'; se: ', sem(hits_at_k))


print('-------------------- R+R (rerank) -------------------------------')

cache = dict()
for k in K:
    hits_at_k = []
    for idx in tqdm(range(len(test_inputs)), total = len(test_inputs) ):
        query = test_inputs[idx]
        target = test_groundtruths[idx]
        
        if query not in cache:
            movie_counts = reddit_knnlm_recommender.count_based_probability(
                query, 
                num_posts_to_consider=best_n_neighbors, 
                return_logits=True
            )
            movie_probas = reddit_knnlm_recommender.predictor_probability(
                query, 
                return_logits=False
            )
            movie_scores = movie_counts + movie_probas
            recommended_movies = np.array(reddit_knnlm_recommender.movie_vocab)[np.argsort(-movie_scores)]
            recommended_movies = [e for e in recommended_movies if e in redial_eligible_entities]
            cache[query] = recommended_movies
        else:
            recommended_movies = cache[query]
        hits_at_k.append(int(target in recommended_movies[:k]))
    print('r@'+str(k), np.mean(hits_at_k),'; se: ', sem(hits_at_k))


print('-------------------- R+R (rerank) with small gamma (cleaner solution for paper) -------------------------------')

cache = dict()
for k in K:
    hits_at_k = []
    for idx in tqdm(range(len(test_inputs)), total = len(test_inputs) ):
        query = test_inputs[idx]
        target = test_groundtruths[idx]
        
        if query not in cache:
            movie_counts = reddit_knnlm_recommender.count_based_probability(
                query, 
                num_posts_to_consider=best_n_neighbors, 
                return_logits=False
            )
            movie_probas = reddit_knnlm_recommender.predictor_probability(
                query, 
                return_logits=False
            )
            movie_scores = movie_counts*(1-(1e-10)) + movie_probas*1e-10
            recommended_movies = np.array(reddit_knnlm_recommender.movie_vocab)[np.argsort(-movie_scores)]
            recommended_movies = [e for e in recommended_movies if e in redial_eligible_entities]
            cache[query] = recommended_movies
        else:
            recommended_movies = cache[query]
        hits_at_k.append(int(target in recommended_movies[:k]))
    print('r@'+str(k), np.mean(hits_at_k),'; se: ', sem(hits_at_k))

interaction preservance rate at 20000 items
1.0
num items  5140
flattening posts into training data


100%|██████████████████████████████████████████████████████████████████████████████████████████████| 8929/8929 [00:00<00:00, 1403918.89it/s]


building datastore embeddings


136it [00:10, 13.37it/s]                                                                                                                    


-------------------- Retrieval -------------------------------


100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:58<00:00, 73.43it/s]


r@1 0.014225746268656716 ; se:  0.0018086291367414823


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 186717.35it/s]


r@5 0.07882462686567164 ; se:  0.004115526282453156


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 203316.51it/s]


r@10 0.13036380597014927 ; se:  0.005142455913816717


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 197717.51it/s]


r@20 0.19496268656716417 ; se:  0.006050719409148609


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 169243.57it/s]


r@50 0.2716884328358209 ; se:  0.00679387319005625


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 143212.32it/s]


r@100 0.32532649253731344 ; se:  0.007155332211070106


100%|████████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 98362.42it/s]


r@300 0.3558768656716418 ; se:  0.007312360356578281
-------------------- Recommend -------------------------------


100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:26<00:00, 160.34it/s]


r@1 0.01632462686567164 ; se:  0.001935400234206339


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 186423.17it/s]


r@5 0.06529850746268656 ; se:  0.003773213274540625


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 213031.40it/s]


r@10 0.11287313432835822 ; se:  0.004832940144924738


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 199076.58it/s]


r@20 0.17444029850746268 ; se:  0.005795900370660485


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 168794.06it/s]


r@50 0.27705223880597013 ; se:  0.006835299502280828


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 141657.94it/s]


r@100 0.3631063432835821 ; se:  0.007344692838232835


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 101766.97it/s]


r@300 0.5170242537313433 ; se:  0.0076320529990241736
-------------------- R+R (rerank) -------------------------------


100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [01:11<00:00, 60.34it/s]


r@1 0.013292910447761194 ; se:  0.0017491514807566862


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 188326.45it/s]


r@5 0.078125 ; se:  0.0040987770162482185


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 211073.79it/s]


r@10 0.12896455223880596 ; se:  0.005118896520469577


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 198140.09it/s]


r@20 0.19776119402985073 ; se:  0.006083389571891694


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 172100.36it/s]


r@50 0.2896455223880597 ; se:  0.006927782253933873


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 142139.34it/s]


r@100 0.37569962686567165 ; se:  0.007396741432586565


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 101724.96it/s]


r@300 0.5298507462686567 ; se:  0.007622859390812175
-------------------- R+R (rerank) with small gamma (cleaner solution for paper) -------------------------------


100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [01:11<00:00, 59.62it/s]


r@1 0.013292910447761194 ; se:  0.0017491514807566862


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 191014.65it/s]


r@5 0.07695895522388059 ; se:  0.004070646097800788


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 209246.74it/s]


r@10 0.12569962686567165 ; se:  0.005063147557043472


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 196821.72it/s]


r@20 0.1970615671641791 ; se:  0.006075266692605248


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 166043.57it/s]


r@50 0.2908115671641791 ; se:  0.006936013318055069


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 139619.11it/s]


r@100 0.3763992537313433 ; se:  0.0073994756987399045


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4288/4288 [00:00<00:00, 102445.19it/s]

r@300 0.5293843283582089 ; se:  0.0076232820990675245



