In [103]:
import os
import re
import unicodedata
import pandas as pd
import numpy as np
from rapidfuzz import process, fuzz
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    pipeline,
    T5Tokenizer,
    T5ForConditionalGeneration
)

from transformers import logging as transformers_logging
from sentence_transformers import SentenceTransformer
transformers_logging.set_verbosity_error()
import json
import time
from functools import wraps
import sentencepiece

# Silencing TqdmWarning
import warnings
warnings.filterwarnings('ignore')

In [104]:
import nltk
from nltk.corpus import stopwords

nltk.download('stopwords')
stop_words_to_keep = ["what", "when", "where", "which", "while", "who", "whom", "why", "with", "how", "before", "after","same"]
stop_words = set([s for s in stopwords.words('english') if s not in stop_words_to_keep])

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\sandr\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [105]:
def measure_time(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()  # Record the start time
        result = func(*args, **kwargs)  # Call the function
        end_time = time.time()  # Record the end time
        elapsed_time = end_time - start_time  # Calculate elapsed time
        print(f"Execution time for {func.__name__}: {elapsed_time:.4f} seconds")
        return result
    return wrapper

In [106]:
class NERParser:
    def __init__(self, model_name="dslim/bert-base-NER", lowercase=False):
        self.lowercase = lowercase
        self.device = self.get_device()

        self.nlp_pipeline = pipeline(
            "ner", 
            model=AutoModelForTokenClassification.from_pretrained(model_name),
            tokenizer=AutoTokenizer.from_pretrained(model_name, do_lower_case=lowercase),
            device=self.device, 
            aggregation_strategy="simple"
        )

    def get_device(self):
        if torch.backends.mps.is_available():
            return torch.device("mps")
        elif torch.cuda.is_available():
            return torch.device("cuda")
        return torch.device("cpu")

    def parse_ner_results(self, ner_results):
        per_entities = [e['word'] for e in ner_results if e['entity_group'] == 'PER']
        misc_entities = [e['word'] for e in ner_results if e['entity_group'] == 'MISC']
        return per_entities, misc_entities

    def process_query(self, query):
        if self.lowercase:
            query = query.lower()
        return self.parse_ner_results(self.nlp_pipeline(query))



In [107]:
class DataBase:
    """Handles context data extraction for people and movies from a database with fuzzy matching support."""
    
    def __init__(self):
        self.db = pd.read_pickle(os.path.join(os.getcwd(), "exports/extended_graph_triples.pkl"))
        
        self.db['subject_id'] = self.db['subject_id'].astype(str).str.strip()
        self.db['predicate_label'] = self.db['predicate_label'].astype(str).str.strip()
        self.db['object_label'] = self.db['object_label'].astype(str).str.strip()
        
        self.db_pivot = self.db.pivot_table(
                index='subject_id',
                columns='predicate_label',
                values='object_label',
                aggfunc=lambda x: ' | '.join(x.astype(str))
            )
        
        with open('exports/movie_db.json') as f:
            self.movie_data = json.load(f)
            self.movie_ids = set(self.movie_data.keys())
        self.movie_db = pd.DataFrame(list(self.movie_data.items()), columns=["entity_id", "entity_label"])
        
        with open('exports/people_db.json') as f:
            self.people_data = json.load(f)
            self.people_ids = set(self.people_data.keys())
        self.people_db = pd.DataFrame(list(self.people_data.items()), columns=["entity_id", "entity_label"])
        
        self.entities = {**self.movie_data, **self.people_data}
        
        self.movie_names = self.movie_db["entity_label"].tolist()
        self.people_names = self.people_db["entity_label"].tolist()
        
        self.entity_list = self.movie_names + self.people_names
       
        self.people_movie_mapping = {}
        self.movie_people_mapping = {}
        
        self.map_people_movies()
        
        self.movie_recommender_db = self.filter_relevant_movies()
        
    def map_people_movies(self):
        id_triples = pd.read_pickle(os.path.join("exports/df_new_triples_only_ids.pkl"))
        
        id_triples['subject_id'] = id_triples['subject_id'].astype(str).str.strip()
        id_triples['object_id'] = id_triples['object_id'].astype(str).str.strip()
        
        for _, row in id_triples.iterrows():
            subject_id = row['subject_id']
            object_id = row['object_id']
            
            if subject_id in self.people_ids and object_id in self.movie_ids:
                if subject_id not in self.people_movie_mapping:
                    self.people_movie_mapping[subject_id] = []
                self.people_movie_mapping[subject_id].append(object_id)
                
                if object_id not in self.movie_people_mapping:
                    self.movie_people_mapping[object_id] = []
                self.movie_people_mapping[object_id].append(subject_id)
        
        for person, movies in self.people_movie_mapping.items():
            self.people_movie_mapping[person] = list(set(movies))
        
        for movie, people in self.movie_people_mapping.items():
            self.movie_people_mapping[movie] = list(set(people))

        
    @staticmethod
    def normalize_string(s):
        """Normalizes strings by removing non-ASCII characters, punctuation, and redundant spaces."""
        return ' '.join(re.sub(r'[^\w\s]', '', unicodedata.normalize('NFKD', s.lower())
                               .encode('ascii', 'ignore').decode('utf-8')).split())

    def fetch(self, entity_list, search_column):
        """Fetches relevant rows from the database where `search_column` matches values in `entity_list`."""
        relevant = self.db[self.db[search_column].isin(entity_list)].dropna(axis=1)
        
        if relevant.empty:
            return pd.DataFrame()

        return relevant.pivot_table(
            index='subject_id',
            columns='predicate_label',
            values='object_label',
            aggfunc=lambda x: ' | '.join(x.astype(str))
        ).reset_index()
    
    def filter_relevant_movies(self):

        clean_db = self.db[self.db["subject_id"].isin(self.movie_data.keys())]
        relevant_cols = [
            # "author", # only 99 movies hav an author
            # "cast member",
            "director", 
            "performer",
            "genre",
            # "narrative motif", # only 43 movies have a narrative motif
            "screenwriter",
            "subject_id",
            "node label" # Required for processing
        ]
        
        pv_db = clean_db.pivot_table(
                index='subject_id',
                columns='predicate_label',
                values='object_label',
                aggfunc=lambda x: ' | '.join(x.astype(str))
            )
        
        pv_db = pv_db[[col for col in relevant_cols if col in pv_db.columns]]

        return pv_db.reset_index()

In [108]:
class QueryEmbedderContextualized:
    def __init__(self, model_name='sentence-transformers/all-mpnet-base-v2'):
        """Initializes the QueryEmbedder with a SentenceTransformer model and device setup."""
        self.device = self.get_device()
        self.model = SentenceTransformer(model_name, device=self.device)
        self.cache = {}
    
    @staticmethod
    def get_device():
        """Determines the available hardware device (MPS, CUDA, or CPU)."""
        if torch.backends.mps.is_available():
            return torch.device("mps")
        if torch.cuda.is_available():
            return torch.device("cuda")
        return torch.device("cpu")

    def embed_phrase(self, phrases):
        """
        Generates embeddings for given phrases using SentenceTransformer, with caching.

        Args:
            phrases (str or List[str]): Input phrase(s) to embed.

        Returns:
            np.ndarray: Embedding vector(s) for the input phrase(s).
        """
        if isinstance(phrases, str):
            phrases = [phrases]
        elif not isinstance(phrases, list):
            raise TypeError("Input must be a string or a list of strings.")
        
        phrases_to_compute = [p for p in phrases if p not in self.cache]
        cached_embeddings = [self.cache[p] for p in phrases if p in self.cache]

        if phrases_to_compute:
            new_embeddings = self.model.encode(
                phrases_to_compute, 
                show_progress_bar=False, 
                convert_to_numpy=True, 
                normalize_embeddings=True
            )
            
            for phrase, emb in zip(phrases_to_compute, new_embeddings):
                self.cache[phrase] = emb
            cached_embeddings.extend(new_embeddings)
        
        return cached_embeddings[0] if len(cached_embeddings) == 1 else np.array(cached_embeddings)


In [109]:
class QuestionAnsweringAgent:
    
    def __init__(self):
        self.qa_model = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad", top_k=1)
    
    def query(self, query, context_df):
        
        top_columns = context_df.columns
        
        context = ""
        for index, row in context_df.iterrows():
            node_label = row.get("node label", "")
            
            row_context = f"This text is about \"{node_label}\":\n"
            
            for col in context_df[top_columns].columns:
                if col == "node label":
                    continue
                
                values = row[col]
                values_lst = str(values).split(",")
                
                if len(values_lst) > 5:
                    row_context += f"{col}: {', '.join(values_lst[:5])}"
                else:
                    row_context += f"{col}: {', '.join(values_lst)}"

            context += row_context + "\n\n"
        
        output = self.qa_model(question=query, context=context)
        
        answer_str = str()
        if isinstance(output, list) and output:
            answer_str = ", ".join([result['answer'] for result in output])
            
        elif isinstance(output, dict):
            answer_str = output['answer']
        
        if not answer_str:
            answer_str = "No answer found."
        
        return answer_str

In [110]:
class ConversationAgent:
    def __init__(self, model_name="google/flan-t5-large", max_length=150):
        self.device = self.get_device()
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(self.device)
        self.max_length = max_length

    @staticmethod
    def get_device():
        if torch.cuda.is_available():
            return torch.device("cuda")
        elif torch.backends.mps.is_available():
            return torch.device("mps")
        else:
            return torch.device("cpu")

    def generate_response(self, prompt):
        """
        Generates a response based on the given prompt.
        """
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        outputs = self.model.generate(
            **inputs,
            max_length=self.max_length,
            num_beams=5,
            early_stopping=True,
        )
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response.strip()


In [111]:
import csv
import rdflib
import numpy as np
import json
import pandas as pd
from sklearn.metrics import pairwise_distances

class GraphEmbeddings:
    
    def __init__(self, graph):
        
        self.RDFS = rdflib.namespace.RDFS
        self.WD = rdflib.Namespace('http://www.wikidata.org/entity/')
     
        self.entity_emb = np.load('exports/entity_embeds.npy')
        self.relation_emb = np.load('exports/relation_embeds.npy')
        
        with open('exports/entity_ids.del', 'r') as ifile:
            self.ent2id = {rdflib.term.URIRef(ent): int(idx) for idx, ent in csv.reader(ifile, delimiter='\t')}
            self.id2ent = {v: k for k, v in self.ent2id.items()}
            
        with open('exports/relation_ids.del', 'r') as ifile:
            self.rel2id = {rdflib.term.URIRef(rel): int(idx) for idx, rel in csv.reader(ifile, delimiter='\t')}
        
        with open("exports/predicate_db.json", encoding="utf-8") as f:
            self.predicates_db = json.load(f)
        
        self.ent2lbl = {rdflib.term.URIRef(row.subject_id): row.subject_label for index, row in graph.iterrows()}

    def answer_query_embedding(self, context, top_columns):
        STD_ERROR = "The embeddings did not provide a valid answer."
        try:
            context_id = context.subject_id.values[0]
            if not context_id:
                print("No context ID found.")
                return STD_ERROR

            e_id = self.ent2id.get(rdflib.term.URIRef(context_id))
            if e_id is None:
                return "2"

            head = self.entity_emb[e_id]
            wiki_predicate_id = ""

            column_mapping = {
                "movie cast": "cast member",
                "acted in": "notable work",
                "played in": "notable work",
                "appeared in": "notable work",
                "actors": "cast member",
                "players": "cast member"
            }

            for col in top_columns:
                
                col = column_mapping.get(col, col)
                
                if col == "node label":
                    continue
                    
                if not context[col].values[0]:
                    continue

                if col in self.predicates_db:
                    wiki_id = rdflib.term.URIRef(self.predicates_db[col])
                    if wiki_id in self.rel2id:
                        wiki_predicate_id = self.predicates_db[col]
                        break
                    else:
                        print(f"Predicate {col} not found in rel2id.")
                else:
                    print(f"Predicate {col} not found in predicate_db.")

            if not wiki_predicate_id:
                print("No valid predicate found.")
                return STD_ERROR

            r_id = self.rel2id[rdflib.term.URIRef(wiki_predicate_id)]
            pred = self.relation_emb[r_id]

            lhs = head + pred
            dist = pairwise_distances(lhs.reshape(1, -1), self.entity_emb).reshape(-1)
            most_likely = dist.argsort()

            num_results = min(3, len(most_likely))
            results_lst = [
                (self.id2ent[idx][len(self.WD):], self.ent2lbl.get(self.id2ent[idx], ""), dist[idx], rank + 1)
                for rank, idx in enumerate(most_likely[:num_results])
            ]
            results_df = pd.DataFrame(results_lst, columns=('Entity', 'Label', 'Score', 'Rank'))
            
            if results_df.empty:
                print("No results in embeddings found.")
                return STD_ERROR

            return f"Top answers from embeddings: {', '.join(results_df.Label.values)}"

        except Exception as e:
            print(f"Error during query answering: {e}")
            return STD_ERROR
        

In [112]:
def cosine_sim(vec1, vec2):
    norm_vec1 = np.linalg.norm(vec1)
    norm_vec2 = np.linalg.norm(vec2)
    
    if norm_vec1 == 0 or norm_vec2 == 0:
        return 0.0 
    
    return np.dot(vec1, vec2) / (norm_vec1 * norm_vec2)


def rescale_probabilities(similarities):
    """
    Rescales the similarity scores so that they sum to 1, turning them into a probability distribution.
    
    Args:
        similarities (List[float]): List of similarity scores.
        
    Returns:
        List[float]: Rescaled probabilities.
    """
    similarity_sum = sum(similarities)
    if similarity_sum == 0:
        return [0] * len(similarities)  # Avoid division by zero
    
    return [sim / similarity_sum for sim in similarities]

def find_closest_columns(query_embeddings, column_embeddings, high_threshold=0.4, top_n=10, rescaled_threshold=0.11):
    """
    Returns columns based on cosine similarity with a two-tiered strategy and rescaled probabilities.
    - If a column has similarity above 'high_threshold', return that column immediately.
    - Otherwise, return all columns with a similarity greater than 'low_threshold'.
    - Rescale the top N column similarities into probabilities and return columns with a rescaled probability greater than rescaled_threshold.
    
    Args:
        query_embeddings (List[np.ndarray]): Embeddings for query words.
        column_embeddings (Dict[str, np.ndarray]): Precomputed embeddings for columns.
        low_threshold (float): Minimum similarity threshold (default: 0.27).
        high_threshold (float): Confidence threshold to return immediately (default: 0.35).
        top_n (int): Number of top columns to consider for rescaling (default: 10).
        rescaled_threshold (float): Minimum rescaled probability threshold (default: 0.1).
    
    Returns:
        List[str]: The selected column names.
    """
    column_similarities = {}

    for col, col_vec in column_embeddings.items():
        similarities = [cosine_sim(col_vec, q_vec) for q_vec in query_embeddings if np.linalg.norm(q_vec) > 0]
        column_similarities[col] = np.mean(similarities) if similarities else -1

    sorted_columns = sorted(column_similarities.items(), key=lambda item: item[1], reverse=True)
    top_columns = sorted_columns[:top_n]
    
    column_names, similarities = zip(*top_columns)
    
    rescaled_probs = rescale_probabilities(similarities)
    
    selected_columns = []
    
    for col, sim in zip(column_names, similarities):
        if sim >= high_threshold:
            # print(f"High confidence match found: {col} with similarity {sim: .4f}")
            return [col]
    
    for col, rescaled_prob in zip(column_names, rescaled_probs):
        if rescaled_prob >= rescaled_threshold:
            # print(f"Column {col} has similarity {rescaled_prob: .4f}")
            selected_columns.append(col)
    
    return selected_columns

In [113]:
def filter_query(query, node_label):
                
    if not len(query):
        return []
    
    relevant = []
    for word in query.replace(". ", " ").lower().split(" "):
        cleaned_word = re.sub(r'[^A-Za-z]', '', word)
        if cleaned_word in stop_words or cleaned_word in node_label.lower().replace(" ", "") or cleaned_word == "":
            continue
        
        relevant.append(cleaned_word)
        
    return " ".join(relevant)

In [158]:
def fuzzy_match(query_str, comparison_list, threshold=30, prioritize_exact=True):
    if not comparison_list or not query_str:
        return []
    
    matches = process.extract(query_str, comparison_list, scorer=fuzz.partial_ratio, limit=50)
        
    id_name_score = []
    
    if prioritize_exact and query_str in comparison_list:
        matched_id = next(key for key, value in db.entities.items() if value == query_str)
        id_name_score.append((matched_id, query_str, 100))
    
    for match in matches:
        name = match[0]
        score = match[1]
        matched_id = next(key for key, value in db.entities.items() if value == name)
        
        length_diff = abs(len(name) - len(query_str)) / len(query_str)
        adjusted_score = score * (1 - length_diff)

        if db.entities[matched_id] in query_str:
            adjusted_score = 100
        print(db.entities[matched_id], adjusted_score)
        
        id_name_score.append((matched_id, name, adjusted_score))
    
    return [id for id, _, score in id_name_score if score >= threshold]

In [115]:
def get_top_matches(df, normalized_query, top_n=2):
    concatenated_rows = df.apply(lambda row: ' '.join(row.astype(str)), axis=1).tolist()
    
    exact_matches = [i for i, row in enumerate(concatenated_rows) if normalized_query == row]
    
    if len(exact_matches) < top_n:
        remaining_slots = top_n - len(exact_matches)
        fuzzy_matches = process.extract(normalized_query, concatenated_rows, scorer=fuzz.partial_ratio, limit=remaining_slots)
        fuzzy_indices = [match[2] for match in fuzzy_matches]
    else:
        fuzzy_indices = []
    
    top_indices = exact_matches + fuzzy_indices
    
    return df.iloc[top_indices]

In [116]:
db = DataBase()

ner_parser = NERParser(lowercase=False)

qe = QueryEmbedderContextualized()

qa = QuestionAnsweringAgent()

ca = ConversationAgent(model_name="google/flan-t5-xl") #### Could take some time, approx. 10 GB storage :)

ge = GraphEmbeddings(db.db)


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.87s/it]


In [159]:
import heapq

@measure_time
def answer_query(query, correct_answer=""):
    normalized_query = db.normalize_string(query)
    
    # NER --------------------------
    ner_person, ner_movies = ner_parser.process_query(normalized_query)
    ner_matches = ner_person + ner_movies
    
    is_domain_specific = bool(ner_matches)
    # if not is_domain_specific:
    #     small_talk = ca.generate_response(query)
    #     print(f"Smalltalk: {small_talk}")
    #     return
    
    print(f"NER_Person: {ner_person}")
    print(f"NER_Movies: {ner_movies}")
    
    ner_people_ids = []
    for ner_p in ner_person:
        ner_p = db.normalize_string(ner_p)
        if ner_p in db.people_names:
            ner_people_ids = [key for key, value in db.people_data.items() if value == ner_p]
        if not ner_people_ids:
            ner_people_ids = [key for key, value in db.people_data.items() if ner_p in value]
        if not ner_people_ids:
            ner_people_ids = fuzzy_match(" ".join(ner_person), db.people_names, threshold=75)
    
    ner_movie_ids = []     
    for ner_m in ner_movies:
        ner_m = db.normalize_string(ner_m)
        if ner_m in db.movie_names:
            ner_people_ids = [key for key, value in db.movie_data.items() if value == ner_m]
        if not ner_people_ids:
            ner_people_ids = [key for key, value in db.movie_data.items() if ner_m in value]
        if not ner_people_ids:
            ner_people_ids = fuzzy_match(" ".join(ner_movies), db.movie_names, threshold=75)

    ner_ids = ner_movie_ids + ner_people_ids
    context = db.fetch(ner_ids, "subject_id")
    # Context NER --------------------------
    # Context Fuzzy --------------------------

    if context.empty:
        print("NER failed, proceeding with fuzzy matching.")
        fuzzy_person_matches = fuzzy_match(normalized_query, db.people_names, threshold=30)
        fuzzy_movie_matches = fuzzy_match(normalized_query, db.movie_names, threshold=30)
    
        fuzzy_movie_context = db.fetch(fuzzy_movie_matches, "subject_id") 
        fuzzy_person_context = db.fetch(fuzzy_person_matches, "subject_id")
        
        if fuzzy_person_context.empty and fuzzy_movie_context.empty:
            small_talk = ca.generate_response(query)
            print(f"Smalltalk: {small_talk}")
            return
        
        context = pd.concat([fuzzy_movie_context, fuzzy_person_context])
    
    
    context = get_top_matches(context, normalized_query, top_n=1)    

    
    node_label = ""
    if not context.empty and "node label" in context.columns and not context["node label"].isna().values[0]:
        node_label = context["node label"].values[0]  
        
    entity_id = ""
    if not context.empty and "subject_id" in context.columns and not context["subject_id"].isna().values[0]:
        entity_id = context["subject_id"].values[0]


    if "CURRENT MODE" == "RECOMMENDER":
        node_label = db.normalize_string(node_label)
        if node_label in db.people_names:
            movie_ids = db.people_movie_mapping[entity_id]
            context = db.fetch(movie_ids, "subject_id")
            subject_labels = context["node label"].tolist()
            print(", ".join(subject_labels))
            return
            
        context.dropna(axis=1, inplace=True)
        columns = [col for col in context.columns if col in db.movie_recommender_db.columns]
        red_db = db.movie_recommender_db[columns]
        context = context[columns]
        
        # drop the identified row already in the context
        subject_id_to_remove = context["subject_id"].values[0]
        red_db = red_db[red_db["subject_id"] != subject_id_to_remove]
        
        red_db.dropna(thresh=len(red_db.columns) - 0, inplace=True)
        red_db.reset_index(drop=True, inplace=True)
        
        def calculate_similarity(i, row):
            similarities = []
            for col in context.columns:
                if pd.isna(row[col]):
                    continue
                    
                if col in ["node label", "subject_id"]:
                    continue
                
                similarity = fuzz.ratio(row[col], context[col].iloc[0]) / 100.0
                similarities.append(similarity)
    
            return (i, np.mean(similarities) if similarities else 0)
        
        top_scores = []
        for i, row in red_db.iterrows():
            index, score = calculate_similarity(i, row)
            
            if len(top_scores) < 3:
                heapq.heappush(top_scores, (score, index))
            else:
                # Maintain top 3 scores
                heapq.heappushpop(top_scores, (score, index))
    
        # Extract indices from top 3 scores
        top_indices = [index for score, index in top_scores]
        top_rows = red_db.iloc[top_indices]
        subject_labels = top_rows['node label'].tolist() 
        print(", ".join(subject_labels))

    elements_to_remove = ["image", "color", "sport"]
    context = context.drop(columns=elements_to_remove, errors='ignore')
    
    # Initial context for embeddings
    initial_context = context.copy()
    
    # EXPERIMENTAL - rename columns
    columns_to_rename = {
        "cast member":"movie cast",
        "notable work": "acted in"
    }
    
    columns_to_rename = {k: v for k, v in columns_to_rename.items() if k in context.columns}
    context = context.rename(columns=columns_to_rename)
    
    columns_to_duplicate = [("acted in", "played in"),
                            ("acted in", "appeared in"),
                            ("movie cast", "actors"),
                            ("movie cast", "players")]
    
    for col_to_duplicate, col in columns_to_duplicate: 
        try:
            context[col] = context[col_to_duplicate].copy()
        except KeyError:
            pass
        
    context.dropna(axis=1, inplace=True)
    
    query_filtered = filter_query(query, node_label)
        
    column_embeddings = {col: qe.embed_phrase(col) for col in context.columns}
    query_embeddings = [qe.embed_phrase(word) for word in query_filtered.split()]  
    top_columns_embeddings = find_closest_columns(query_embeddings, column_embeddings)
        
    # EXPERIMENTAL - always keep columns
    col_always_keep = ["node label"]
    
    combined_columns = set(top_columns_embeddings + col_always_keep)
    top_columns = [col for col in combined_columns if col in context.columns]
    filtered_context_df = context[top_columns]

    answer = qa.query(query, filtered_context_df)
    formatted_answer = ca.generate_response(f"Format the answer to the question into a sentence. If you think the answer is completely off, overrule with your own knowledge.Question: {query}\nAnswer: {answer}")
        
    print(ge.answer_query_embedding(initial_context, top_columns))
    print(formatted_answer)


In [160]:
answer_query("In Which movies did Tom Holland have a role?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
tom holland 100
tom holland 100
tom hollander 25.58139534883721
todd holland 23.255813953488378
allan havey 20.930232558139533
sam rolfe 16.74418604651163
vera briole 20.46511627906977
ian white 16.74418604651163
tom nolan 16.279069767441865
tom villa 16.279069767441865
david tom 16.279069767441865
alan hale 16.279069767441865
alan hale 16.279069767441865
al cole 12.522361359570658
ja rule 12.522361359570658
lana 6.976744186046516
tom hillmann 20.93023255813954
tom wolf 13.953488372093023
mary holland 20.93023255813954
tom choi 13.953488372093023
tom volf 13.953488372093023
matt holland 20.93023255813954
matt holland 20.93023255813954
john holland 20.93023255813954
dave 6.976744186046516
tom hale 13.953488372093023
tom held 13.953488372093023
arno 6.976744186046516
ovid 6.976744186046516
zola 6.976744186046516
eva rose 13.953488372093023
tim holt 13.953488372093023
tim hill 13.953488372093023
tom bell 13.95348837

In [119]:
answer_query("Tell me a good joke")

Smalltalk: if you want to be a doctor, you have to be a dentist.
Execution time for answer_query: 6.7287 seconds


In [120]:
answer_query("Hello, how is life?")

Smalltalk: I'm fine, thanks.
Execution time for answer_query: 4.3098 seconds


In [121]:
answer_query("Hello, how are you?")

Smalltalk: I'm fine, thanks.
Execution time for answer_query: 2.9935 seconds


In [122]:
answer_query("Hello, what is the capital of switzerland?")

Smalltalk: Bern
Execution time for answer_query: 1.2812 seconds


In [123]:
answer_query("Who is the director of Good Will Hunting?", "Gus Van Sant")

Smalltalk: steven soderbergh
Execution time for answer_query: 3.5764 seconds


In [124]:
answer_query("Who directed The Bridge on the River Kwai?", "David Lean")

Smalltalk: david lean
Execution time for answer_query: 3.4225 seconds


In [125]:
answer_query("Who directed The Dark Knight?", "Christopher Nolan")

Smalltalk: christopher nolan
Execution time for answer_query: 3.2563 seconds


In [126]:
answer_query("Where was Angelina Jolie born?", "Los Angeles")

NER_Person: ['angel']
NER_Movies: []
Top answers from embeddings: new york city, los angeles, united states of america
Angelina Jolie was born in Los Angeles.
Execution time for answer_query: 8.0347 seconds


In [127]:
answer_query("Who is the main actor in harry potter and the philosopher's stone?")

Smalltalk: Rupert Grint
Execution time for answer_query: 2.6386 seconds


In [128]:
answer_query("Who was Brad Pitt married to?")

Smalltalk: jennifer aniston
Execution time for answer_query: 2.7212 seconds


In [129]:
answer_query("When was Inception released?")

Smalltalk: 2010
Execution time for answer_query: 0.7807 seconds


In [130]:
answer_query("Who is the director of Star Wars?")

Smalltalk: george lucas
Execution time for answer_query: 2.5941 seconds


In [131]:
answer_query("When was the Godfather III published?")

Smalltalk: 1972
Execution time for answer_query: 1.1505 seconds


In [132]:
answer_query("Who is the director of Star Wars?")

Smalltalk: george lucas
Execution time for answer_query: 2.9784 seconds


In [133]:
answer_query("When was Inception released?")

Smalltalk: 2010
Execution time for answer_query: 1.0705 seconds


In [134]:
answer_query("Who was Angelina Jolie married to?")

Smalltalk: bruce willis
Execution time for answer_query: 2.6550 seconds


In [135]:
answer_query("Who was Brad Pitt married to?")

Smalltalk: jennifer aniston
Execution time for answer_query: 2.3369 seconds


In [136]:
answer_query("What is the religion of Tom Cruise?")

Smalltalk: agnostic
Execution time for answer_query: 2.0710 seconds


In [137]:
answer_query("Who is the main actor in harry potter and the philosopher's stone?")

Smalltalk: Rupert Grint
Execution time for answer_query: 3.4976 seconds


In [138]:
answer_query("Who are the cast in Jurassic Park?")

Smalltalk: johnny depp johnny depp johnny depp
Execution time for answer_query: 5.6983 seconds


In [139]:
answer_query("Who acted in Jurassic Park?")

Smalltalk: michael j fox
Execution time for answer_query: 2.6528 seconds


In [140]:
answer_query("Who played in Jurassic Park?")

Smalltalk: edward g robinson
Execution time for answer_query: 3.0933 seconds


In [141]:
answer_query("In which movie did Tom Cruise play?")

Smalltalk: risky business
Execution time for answer_query: 1.2705 seconds


In [142]:
answer_query("In which movie did Rebel Wilson act?")

Smalltalk: saturday night fever
Execution time for answer_query: 2.1450 seconds


In [143]:
answer_query("In which movie did Liam Neeson play?")

NER_Person: ['l']
NER_Movies: []
Top answers from embeddings: larry sanders, larry black, tito ortiz
Liam Neeson played in 43.
Execution time for answer_query: 19.5011 seconds


In [144]:
answer_query("Who is an actor in Taken 2?")

Smalltalk: ryan reynolds
Execution time for answer_query: 2.4334 seconds


In [145]:
answer_query("What is the role of Vin Diesel in Fast and Furious?")

Smalltalk: dolph lundgren
Execution time for answer_query: 2.0516 seconds


In [146]:
answer_query("For which movie did Leonardo Di Caprio win an Oscar?")

Smalltalk: lord of rings
Execution time for answer_query: 2.1920 seconds
