In [12]:
import json
import numpy as np
import re
import spacy

from collections import Counter
from sklearn.metrics.pairwise import cosine_similarity

spacy.prefer_gpu()


class DataLoader:

    def __init__(self, path):
        self.path = path


    def load_data(self):
        with open(self.path, 'r') as file:
            data = json.load(file)
        return data


    def save_data(self, data ,save_path):
        with open(save_path, 'w') as file:
            json.dump(data, file)


class TextProcessor:

    def __init__(self):
        self.nlp = spacy.load('en_core_web_sm')
       
        
    def _lowercase(self, text):
        return self.nlp(text.lower())


    def _rm_stop_punct(self, text):
        return [t for t in text if not t.is_punct and not t.is_stop]


    def _lemmatizer(self, text):
        return [t.lemma_ for t in text if t.dep_]


    def preprocess_text(self, text):
        doc = self.nlp(text)
        lower = self._lowercase(doc.text)
        no_stop = self._rm_stop_punct(lower)
        lemma = self._lemmatizer(no_stop)
        return lemma


class TFIDFVectFromScratch:

    """
    This class is based on the solution provided for the liveproject milestone
    """

    def __init__(self, text_data):
        self.data = text_data
    

    def _build_flatten_vocab(self, text_var):
        tokens = [token[text_var] for token in self.data]
        return list(set([token for sub in tokens for token in sub]))


    def _token_counter_within(self, text_var):
        return [Counter(doc[text_var]) for doc in self.data]

    
    def _token_counter_across(self, token_counts, vocab):
        return {token: sum([1 for doc in token_counts if token in doc ]) for token in vocab}

    
    def _generate_tfidf(self, tokens_within, vocabulary, term_text, tokens_across):
         ### Function based on solution provided via LiveProject ###

        # Iterate over tokens counted within docs
        for idx, doc in enumerate(tokens_within):
            tfidf_vector = []
            # Iterate ober tokens counted across docs
            for token in vocabulary:
                # TF -> count per term per doc / doc length
                tf = doc[token] / len(self.data[idx][term_text]) 
                # IDF -> num documents / num documents with term
                idf = np.log(len(self.data) / tokens_across[token])
                tfidf = tf * idf
                tfidf_vector.append(tfidf)
            self.data[idx]['tf_idf'] = tfidf_vector
        return self.data
  

    def _generate_tfidf_query(self, vocabulary, tokens_across, q_tokenized):
         ### Function based on solution provided via LiveProject ###

        q_vector = TextProcessor().preprocess_text(q_tokenized)
        q_counted = Counter(q_vector)
        
        q_vec = []
        for doc in vocabulary:
            tf = q_counted[doc] / len(q_tokenized)
            idf = np.log(len(self.data) / tokens_across[doc])
            tfidf = tf * idf
            q_vec.append(tfidf)
        return q_vec


    def tfidf_generator(self, term_text, q_tokens, is_query=False):
        vocab = self._build_flatten_vocab(term_text)
        within = self._token_counter_within(term_text)
        across = self._token_counter_across(within, vocab)
        if not is_query:
            return self._generate_tfidf(within, vocab, term_text, across)
        else:
            return self._generate_tfidf_query(vocab, across, q_tokens)


class SimilaritySearch:

    """
    This class is based on the solution provided for the liveproject milestone
    """

    def __init__(self, text_data):
        self.data = text_data


    def similarity_rankings(self, term_text, query_tfidf):
        q_vector = query_tfidf
        q_array = np.array(q_vector)

        doc_rankings = []
        for doc in self.data:
            ranking = {}
            doc_array = np.array(doc[term_text])
            similarity = cosine_similarity(q_array.reshape(1, -1), doc_array.reshape(1, -1))[0][0]
            if similarity > 0:
                ranking['title'] = doc['title']
                ranking['ranking'] = similarity
                doc_rankings.append(ranking)
        return sorted(doc_rankings, key=lambda x: x['ranking'], reverse=True)


class InvertedIndexSearch:

    def __init__(self, text_data, vocab):
        self.data = text_data
        self.vocabulary = vocab

    
    def _inverted_index(self):
        inverted_idx = {}
        for idx, word in enumerate(self.vocabulary):
            inverted_idx[word] = []
            for doc in self.data:
                if word in doc['tokenized_text']:
                    inverted_idx[word].append((doc['title'], doc['tf_idf'][idx]))
        return inverted_idx


    def search(self, query):
        query_tokens = TextProcessor().preprocess_text(query)
        inverted_idx = self._inverted_index()
        # Generate index for query and flatten nested list
        result_list = [x for i in [inverted_idx[token] for token in query_tokens] for x in i]
        # Create tuple based on document title
        titles = {x[0] for x in result_list}
        # Add values if title occurs multiple times and sort
        sums = sorted([(i, sum(x[1] for x in result_list if x[0] == i)) for i in titles], key=lambda x: x[1], reverse=True)
        # Generate final output
        return {'query': query, 'relevant_articles':[[x, y] for x, y in sums]}
        


if __name__== '__main__':
    data_loader = DataLoader('../data_hub/processed_data_tfidf.json')
    data = data_loader.load_data()

    vocab_loader = DataLoader('../data_hub/vocab.json')
    vocab = vocab_loader.load_data()

    idx = InvertedIndexSearch(data, vocab)

    query = 'black death'
    results = idx.search(query) 
    print(results)

{'query': 'black death', 'relevant_articles': [['Pandemic', 0.04679883867323402], ['Cholera', 0.014518164813892178], ['Antonine Plague', 0.013132843743864298], ['Epidemiology of HIV/AIDS', 0.011824072374200845], ['Bills of mortality', 0.008868054280650635], ['Spanish flu', 0.008559216569384194], ['1929–1930 psittacosis pandemic', 0.008115106275689732], ['Pandemic Severity Assessment Framework', 0.007964826529843625], ['HIV/AIDS', 0.006800010001763728], ['COVID-19 pandemic', 0.0050011701466459975], ['Swine influenza', 0.004551329445624929]]}
