Code for OkapiBM25

In [36]:
# Get relevant imports
import math
import numpy as np
import pandas as pd

In [37]:
# BM25 class, with tf and idf calculations
class BM25:
    def __init__(self, data):
        self.data_size = 0
        self.avgdl = 0
        self.doc_freqs = []
        self.idf = {}
        self.doc_len = []

        tf = self.cal_tf(data)
        self._calc_idf(tf)

    def cal_tf(self, data):
        tf = {}  # term -> number of documents with term
        num_doc = 0
        for document in data:
            self.doc_len.append(len(document))
            num_doc += len(document)

            frequencies = {}
            for term in document:
                if term not in frequencies:
                    frequencies[term] = 0
                frequencies[term] += 1
            self.doc_freqs.append(frequencies)

            for term, freq in frequencies.items():
                try:
                    tf[term]+=1
                except KeyError:
                    tf[term] = 1

            self.data_size += 1

        self.avgdl = num_doc / self.data_size
        return tf

    def _calc_idf(self, tf):
        raise NotImplementedError()

    def get_scores(self, query):
        raise NotImplementedError()

    def get_batch_scores(self, query, doc_ids):
        raise NotImplementedError()

    def get_top_n(self, query, documents, n=5):

        assert self.data_size == len(documents), "The documents given don't match the index data!"

        scores = self.get_scores(query)
        top_n = np.argsort(scores)[::-1][:n]
        return [documents[i] for i in top_n]

# BM25Okapi class, inheriting from BM25
class BM25Okapi(BM25):
    def __init__(self, data, k1 = 1.5, b = 0.75, epsilon=0):
        self.k1 = k1
        self.b = b
        self.epsilon = epsilon
        super().__init__(data)

    def _calc_idf(self, tf):
        # collect idf sum to calculate an average idf for epsilon value, default floor of idf is set to 0
        # collect words with negative idf to set them a special epsilon value.
        # idf can be negative if term is contained in more than half of documents
        idf_sum = 0
        negative_idfs = []
        for term, freq in tf.items():
            # freq is document frequency
            idf = math.log(self.data_size) - math.log(freq)
            self.idf[term] = idf
            idf_sum += idf
            if idf < 0:
                negative_idfs.append(term)
        self.average_idf = idf_sum / len(self.idf)

        eps = self.epsilon * self.average_idf
        for term in negative_idfs:
            self.idf[term] = eps

    # returns the RSV scores for a query
    def get_scores(self, query):
        score = np.zeros(self.data_size)
        doc_len = np.array(self.doc_len)
        for q in query:
            q_freq = np.array([(doc.get(q) or 0) for doc in self.doc_freqs])
            score += (self.idf.get(q) or 0) * (q_freq * (self.k1 + 1) /
                                               (q_freq + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl)))
        return score

In [38]:
corpus = pd.read_csv("movie_lens_dataset\movies_metadata_processed_no_stopwords.csv")

  corpus = pd.read_csv("movie_lens_dataset\movies_metadata_processed_no_stopwords.csv")


In [39]:
# return a dataframe of the top 10 movies with the highest RSV scores given a query
def top_related_movies(query_title, corpus, n=10):
    overview = corpus['overview']
    tokenized_corpus = [doc.split(" ") for doc in overview]
    bm25 = BM25Okapi(tokenized_corpus)
    
    query = corpus.loc[corpus['original_title'] == query_title, 'overview'].item()
    tokenized_query = query.split(" ")

    scores = bm25.get_scores(tokenized_query)
    df = pd.DataFrame({"original_title": corpus['original_title'], "release_date": corpus['release_date'], "poster_path": corpus['poster_path'], "rsv_score": scores})
    top_indices = df.sort_values(by="rsv_score", ascending=False)[1:11].index
    output = df[['original_title', 'release_date', 'poster_path', 'rsv_score']].iloc[top_indices]
    output.to_csv('./output/bm25/{}.csv'.format(query_title))
       
    return output

In [40]:
print(top_related_movies("Toy Story", corpus))

               original_title release_date                       poster_path  \
2945              Toy Story 2   1999-10-30  /3CmK3XurcLeUyMifCR28ibzupbB.jpg   
14703             Toy Story 3   2010-06-16  /mMltbSxwEdNE4Cv8QYLpzkHWTDo.jpg   
23023               Small Fry   2011-11-23   /4deuIpnrXuzMAggQk19Jk81CST.jpg   
483                    Malice   1993-09-29  /dlTWhgLHdEgSmCa8nEFgN7Hfu05.jpg   
1056    Rebel Without a Cause   1955-10-27  /bEAoNvtqvO0c2lItNkKlKUqhPuw.jpg   
36254                   Ozzie   2006-09-20  /81WCWEhTobf3fBZ6cJjE5mDXjpE.jpg   
16397               Group Sex   2010-08-24  /4mNUWTVathZZtOIFrakwYiUzaVy.jpg   
11014  For Your Consideration   2006-11-22  /lLSeBuqgWYJ4MZt18prqtWRTlof.jpg   
6319   What's Up, Tiger Lily?   1966-11-02  /AcU8gPPmceM4uUTbAHLwvREDy0o.jpg   
1910                Condorman   1981-07-02  /abgxJPCU2AAxP4kSdvdWCU6qr6r.jpg   

       rsv_score  
2945   99.265152  
14703  82.377874  
23023  43.532924  
483    34.842891  
1056   31.868091  
36254