In [1]:
import os
import re
import unicodedata
import pandas as pd
import numpy as np
from rapidfuzz import process, fuzz
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    pipeline,
    T5Tokenizer,
    T5ForConditionalGeneration
)

from transformers import logging as transformers_logging
from sentence_transformers import SentenceTransformer
transformers_logging.set_verbosity_error()
import json
import time
from functools import wraps
import sentencepiece

# Silencing TqdmWarning
import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import nltk
from nltk.corpus import stopwords

nltk.download('stopwords')
stop_words_to_keep = ["what", "when", "where", "which", "while", "who", "whom", "why", "with", "how", "before", "after","same"]
stop_words = set([s for s in stopwords.words('english') if s not in stop_words_to_keep])

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\sandr\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [3]:
def measure_time(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()  # Record the start time
        result = func(*args, **kwargs)  # Call the function
        end_time = time.time()  # Record the end time
        elapsed_time = end_time - start_time  # Calculate elapsed time
        print(f"Execution time for {func.__name__}: {elapsed_time:.4f} seconds")
        return result
    return wrapper

In [4]:
class NERParser:
    def __init__(self, model_name="dslim/bert-base-NER", lowercase=False):
        self.lowercase = lowercase
        self.device = self.get_device()

        self.nlp_pipeline = pipeline(
            "ner", 
            model=AutoModelForTokenClassification.from_pretrained(model_name),
            tokenizer=AutoTokenizer.from_pretrained(model_name, do_lower_case=lowercase),
            device=self.device, 
            aggregation_strategy="simple"
        )

    def get_device(self):
        if torch.backends.mps.is_available():
            return torch.device("mps")
        elif torch.cuda.is_available():
            return torch.device("cuda")
        return torch.device("cpu")

    def parse_ner_results(self, ner_results):
        per_entities = [e['word'] for e in ner_results if e['entity_group'] == 'PER']
        misc_entities = [e['word'] for e in ner_results if e['entity_group'] == 'MISC']
        return per_entities, misc_entities

    def process_query(self, query):
        if self.lowercase:
            query = query.lower()
        return self.parse_ner_results(self.nlp_pipeline(query))



In [5]:
class DataBase:
    """Handles context data extraction for people and movies from a database with fuzzy matching support."""
    
    def __init__(self):
        self.db = pd.read_pickle(os.path.join(os.getcwd(), "exports/extended_graph_triples.pkl"))
        
        self.db['subject_id'] = self.db['subject_id'].astype(str).str.strip()
        self.db['predicate_label'] = self.db['predicate_label'].astype(str).str.strip()
        self.db['object_label'] = self.db['object_label'].astype(str).str.strip()
        
        self.db_pivot = self.db.pivot_table(
                index='subject_id',
                columns='predicate_label',
                values='object_label',
                aggfunc=lambda x: ' | '.join(x.astype(str))
            )
        
        with open('exports/movie_db.json') as f:
            self.movie_data = json.load(f)
            self.movie_ids = set(self.movie_data.keys())
        self.movie_db = pd.DataFrame(list(self.movie_data.items()), columns=["entity_id", "entity_label"])
        
        with open('exports/people_db.json') as f:
            self.people_data = json.load(f)
            self.people_ids = set(self.people_data.keys())
        self.people_db = pd.DataFrame(list(self.people_data.items()), columns=["entity_id", "entity_label"])
        
        self.entities = {**self.movie_data, **self.people_data}
        
        self.movie_names = self.movie_db["entity_label"].tolist()
        self.people_names = self.people_db["entity_label"].tolist()
        
        self.entity_list = self.movie_names + self.people_names
       
        self.people_movie_mapping = {}
        self.movie_people_mapping = {}
        
        self.map_people_movies()
        
        self.movie_recommender_db = self.filter_relevant_movies()
        
    def map_people_movies(self):
        id_triples = pd.read_pickle(os.path.join("exports/df_new_triples_only_ids.pkl"))
        
        id_triples['subject_id'] = id_triples['subject_id'].astype(str).str.strip()
        id_triples['object_id'] = id_triples['object_id'].astype(str).str.strip()
        
        for _, row in id_triples.iterrows():
            subject_id = row['subject_id']
            object_id = row['object_id']
            
            if subject_id in self.people_ids and object_id in self.movie_ids:
                if subject_id not in self.people_movie_mapping:
                    self.people_movie_mapping[subject_id] = []
                self.people_movie_mapping[subject_id].append(object_id)
                
                if object_id not in self.movie_people_mapping:
                    self.movie_people_mapping[object_id] = []
                self.movie_people_mapping[object_id].append(subject_id)
        
        for person, movies in self.people_movie_mapping.items():
            self.people_movie_mapping[person] = list(set(movies))
        
        for movie, people in self.movie_people_mapping.items():
            self.movie_people_mapping[movie] = list(set(people))

        
    @staticmethod
    def normalize_string(s):
        """Normalizes strings by removing non-ASCII characters, punctuation, and redundant spaces."""
        return ' '.join(re.sub(r'[^\w\s]', '', unicodedata.normalize('NFKD', s.lower())
                               .encode('ascii', 'ignore').decode('utf-8')).split())

    def fetch(self, entity_list, search_column):
        """Fetches relevant rows from the database where `search_column` matches values in `entity_list`."""
        relevant = self.db[self.db[search_column].isin(entity_list)].dropna(axis=1)
        
        if relevant.empty:
            return pd.DataFrame()

        return relevant.pivot_table(
            index='subject_id',
            columns='predicate_label',
            values='object_label',
            aggfunc=lambda x: ' | '.join(x.astype(str))
        ).reset_index()
    
    def filter_relevant_movies(self):

        clean_db = self.db[self.db["subject_id"].isin(self.movie_data.keys())]
        relevant_cols = [
            # "author", # only 99 movies hav an author
            # "cast member",
            "director", 
            "performer",
            "genre",
            # "narrative motif", # only 43 movies have a narrative motif
            "screenwriter",
            "subject_id",
            "node label" # Required for processing
        ]
        
        pv_db = clean_db.pivot_table(
                index='subject_id',
                columns='predicate_label',
                values='object_label',
                aggfunc=lambda x: ' | '.join(x.astype(str))
            )
        
        pv_db = pv_db[[col for col in relevant_cols if col in pv_db.columns]]

        return pv_db.reset_index()

In [6]:
class QueryEmbedderContextualized:
    def __init__(self, model_name='sentence-transformers/all-mpnet-base-v2'):
        """Initializes the QueryEmbedder with a SentenceTransformer model and device setup."""
        self.device = self.get_device()
        self.model = SentenceTransformer(model_name, device=self.device)
        self.cache = {}
    
    @staticmethod
    def get_device():
        """Determines the available hardware device (MPS, CUDA, or CPU)."""
        if torch.backends.mps.is_available():
            return torch.device("mps")
        if torch.cuda.is_available():
            return torch.device("cuda")
        return torch.device("cpu")

    def embed_phrase(self, phrases):
        """
        Generates embeddings for given phrases using SentenceTransformer, with caching.

        Args:
            phrases (str or List[str]): Input phrase(s) to embed.

        Returns:
            np.ndarray: Embedding vector(s) for the input phrase(s).
        """
        if isinstance(phrases, str):
            phrases = [phrases]
        elif not isinstance(phrases, list):
            raise TypeError("Input must be a string or a list of strings.")
        
        phrases_to_compute = [p for p in phrases if p not in self.cache]
        cached_embeddings = [self.cache[p] for p in phrases if p in self.cache]

        if phrases_to_compute:
            new_embeddings = self.model.encode(
                phrases_to_compute, 
                show_progress_bar=False, 
                convert_to_numpy=True, 
                normalize_embeddings=True
            )
            
            for phrase, emb in zip(phrases_to_compute, new_embeddings):
                self.cache[phrase] = emb
            cached_embeddings.extend(new_embeddings)
        
        return cached_embeddings[0] if len(cached_embeddings) == 1 else np.array(cached_embeddings)


In [7]:
class QuestionAnsweringAgent:
    
    def __init__(self):
        self.qa_model = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad", top_k=1)
    
    def query(self, query, context_df):
        
        top_columns = context_df.columns
        
        context = ""
        for index, row in context_df.iterrows():
            node_label = row.get("node label", "")
            
            row_context = f"This text is about \"{node_label}\":\n"
            
            for col in context_df[top_columns].columns:
                if col == "node label":
                    continue
                
                values = row[col]
                values_lst = str(values).split(",")
                
                if len(values_lst) > 5:
                    row_context += f"{col}: {', '.join(values_lst[:5])}"
                else:
                    row_context += f"{col}: {', '.join(values_lst)}"

            context += row_context + "\n\n"
        
        output = self.qa_model(question=query, context=context)
        
        answer_str = str()
        if isinstance(output, list) and output:
            answer_str = ", ".join([result['answer'] for result in output])
            
        elif isinstance(output, dict):
            answer_str = output['answer']
        
        if not answer_str:
            answer_str = "No answer found."
        
        return answer_str

In [8]:
class ConversationAgent:
    def __init__(self, model_name="google/flan-t5-large", max_length=150):
        self.device = self.get_device()
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(self.device)
        self.max_length = max_length

    @staticmethod
    def get_device():
        if torch.cuda.is_available():
            return torch.device("cuda")
        elif torch.backends.mps.is_available():
            return torch.device("mps")
        else:
            return torch.device("cpu")

    def generate_response(self, prompt):
        """
        Generates a response based on the given prompt.
        """
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        outputs = self.model.generate(
            **inputs,
            max_length=self.max_length,
            num_beams=5,
            early_stopping=True,
        )
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response.strip()


In [9]:
import csv
import rdflib
import numpy as np
import json
import pandas as pd
from sklearn.metrics import pairwise_distances

class GraphEmbeddings:
    
    def __init__(self, graph):
        
        self.RDFS = rdflib.namespace.RDFS
        self.WD = rdflib.Namespace('http://www.wikidata.org/entity/')
     
        self.entity_emb = np.load('exports/entity_embeds.npy')
        self.relation_emb = np.load('exports/relation_embeds.npy')
        
        with open('exports/entity_ids.del', 'r') as ifile:
            self.ent2id = {rdflib.term.URIRef(ent): int(idx) for idx, ent in csv.reader(ifile, delimiter='\t')}
            self.id2ent = {v: k for k, v in self.ent2id.items()}
            
        with open('exports/relation_ids.del', 'r') as ifile:
            self.rel2id = {rdflib.term.URIRef(rel): int(idx) for idx, rel in csv.reader(ifile, delimiter='\t')}
        
        with open("exports/predicate_db.json", encoding="utf-8") as f:
            self.predicates_db = json.load(f)
        
        self.ent2lbl = {rdflib.term.URIRef(row.subject_id): row.subject_label for index, row in graph.iterrows()}

    def answer_query_embedding(self, context, top_columns):
        STD_ERROR = "The embeddings did not provide a valid answer."
        try:
            context_id = context.subject_id.values[0]
            if not context_id:
                print("No context ID found.")
                return STD_ERROR

            e_id = self.ent2id.get(rdflib.term.URIRef(context_id))
            if e_id is None:
                return "2"

            head = self.entity_emb[e_id]
            wiki_predicate_id = ""

            column_mapping = {
                "movie cast": "cast member",
                "acted in": "notable work",
                "played in": "notable work",
                "appeared in": "notable work",
                "actors": "cast member",
                "players": "cast member"
            }

            for col in top_columns:
                
                col = column_mapping.get(col, col)
                
                if col == "node label":
                    continue
                    
                if not context[col].values[0]:
                    continue

                if col in self.predicates_db:
                    wiki_id = rdflib.term.URIRef(self.predicates_db[col])
                    if wiki_id in self.rel2id:
                        wiki_predicate_id = self.predicates_db[col]
                        break
                    else:
                        print(f"Predicate {col} not found in rel2id.")
                else:
                    print(f"Predicate {col} not found in predicate_db.")

            if not wiki_predicate_id:
                print("No valid predicate found.")
                return STD_ERROR

            r_id = self.rel2id[rdflib.term.URIRef(wiki_predicate_id)]
            pred = self.relation_emb[r_id]

            lhs = head + pred
            dist = pairwise_distances(lhs.reshape(1, -1), self.entity_emb).reshape(-1)
            most_likely = dist.argsort()

            num_results = min(3, len(most_likely))
            results_lst = [
                (self.id2ent[idx][len(self.WD):], self.ent2lbl.get(self.id2ent[idx], ""), dist[idx], rank + 1)
                for rank, idx in enumerate(most_likely[:num_results])
            ]
            results_df = pd.DataFrame(results_lst, columns=('Entity', 'Label', 'Score', 'Rank'))
            
            if results_df.empty:
                print("No results in embeddings found.")
                return STD_ERROR

            return f"Top answers from embeddings: {', '.join(results_df.Label.values)}"

        except Exception as e:
            print(f"Error during query answering: {e}")
            return STD_ERROR
        

In [10]:
def cosine_sim(vec1, vec2):
    norm_vec1 = np.linalg.norm(vec1)
    norm_vec2 = np.linalg.norm(vec2)
    
    if norm_vec1 == 0 or norm_vec2 == 0:
        return 0.0 
    
    return np.dot(vec1, vec2) / (norm_vec1 * norm_vec2)


def rescale_probabilities(similarities):
    """
    Rescales the similarity scores so that they sum to 1, turning them into a probability distribution.
    
    Args:
        similarities (List[float]): List of similarity scores.
        
    Returns:
        List[float]: Rescaled probabilities.
    """
    similarity_sum = sum(similarities)
    if similarity_sum == 0:
        return [0] * len(similarities)  # Avoid division by zero
    
    return [sim / similarity_sum for sim in similarities]

def find_closest_columns(query_embeddings, column_embeddings, high_threshold=0.4, top_n=10, rescaled_threshold=0.11):
    """
    Returns columns based on cosine similarity with a two-tiered strategy and rescaled probabilities.
    - If a column has similarity above 'high_threshold', return that column immediately.
    - Otherwise, return all columns with a similarity greater than 'low_threshold'.
    - Rescale the top N column similarities into probabilities and return columns with a rescaled probability greater than rescaled_threshold.
    
    Args:
        query_embeddings (List[np.ndarray]): Embeddings for query words.
        column_embeddings (Dict[str, np.ndarray]): Precomputed embeddings for columns.
        low_threshold (float): Minimum similarity threshold (default: 0.27).
        high_threshold (float): Confidence threshold to return immediately (default: 0.35).
        top_n (int): Number of top columns to consider for rescaling (default: 10).
        rescaled_threshold (float): Minimum rescaled probability threshold (default: 0.1).
    
    Returns:
        List[str]: The selected column names.
    """
    column_similarities = {}

    for col, col_vec in column_embeddings.items():
        similarities = [cosine_sim(col_vec, q_vec) for q_vec in query_embeddings if np.linalg.norm(q_vec) > 0]
        column_similarities[col] = np.mean(similarities) if similarities else -1

    sorted_columns = sorted(column_similarities.items(), key=lambda item: item[1], reverse=True)
    top_columns = sorted_columns[:top_n]
    
    column_names, similarities = zip(*top_columns)
    
    rescaled_probs = rescale_probabilities(similarities)
    
    selected_columns = []
    
    for col, sim in zip(column_names, similarities):
        if sim >= high_threshold:
            # print(f"High confidence match found: {col} with similarity {sim: .4f}")
            return [col]
    
    for col, rescaled_prob in zip(column_names, rescaled_probs):
        if rescaled_prob >= rescaled_threshold:
            # print(f"Column {col} has similarity {rescaled_prob: .4f}")
            selected_columns.append(col)
    
    return selected_columns

In [11]:
def filter_query(query, node_label):
                
    if not len(query):
        return []
    
    relevant = []
    for word in query.replace(". ", " ").lower().split(" "):
        cleaned_word = re.sub(r'[^A-Za-z]', '', word)
        if cleaned_word in stop_words or cleaned_word in node_label.lower().replace(" ", "") or cleaned_word == "":
            continue
        
        relevant.append(cleaned_word)
        
    return " ".join(relevant)

In [176]:
import torch
import numpy as np
import pandas as pd
import re
import unicodedata
import nltk
from nltk.corpus import stopwords
from rapidfuzz import process, fuzz
import time
from functools import wraps
import logging

# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

# Download stopwords
nltk.download('stopwords', quiet=True)

stop_words_to_keep = [
    "what", "when", "where", "which", "while", "who", "whom", "why",
    "with", "how", "before", "after", "same"
]
stop_words = set([s for s in stopwords.words('english') if s not in stop_words_to_keep])

# Define colors for logging
class BColors:
    HEADER = "\033[95m"
    OKBLUE = "\033[94m"
    OKCYAN = "\033[96m"
    OKGREEN = "\033[92m"
    WARNING = "\033[93m"
    FAIL = "\033[91m"
    ENDC = "\033[0m"
    BOLD = "\033[1m"
    UNDERLINE = "\033[4m"

# Custom formatter to color log levels
class ColoredFormatter(logging.Formatter):
    def __init__(self, fmt=None, datefmt=None):
        super().__init__(fmt, datefmt)
        self.COLORS = {
            'DEBUG': BColors.OKBLUE,
            'INFO': '',  # No color for INFO messages
            'WARNING': BColors.WARNING,
            'ERROR': BColors.FAIL,
            'CRITICAL': BColors.BOLD + BColors.FAIL,
        }

    def format(self, record):
        levelname = record.levelname.strip(BColors.ENDC)
        level_color = self.COLORS.get(levelname, '')
        record.levelname = level_color + record.levelname + BColors.ENDC
        return super().format(record)

# Set up logger
logger = logging.getLogger('factual_questions')
logger.setLevel(logging.INFO)

# Create console handler
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)

# Create formatter
formatter = ColoredFormatter('%(asctime)s | %(levelname)s | %(funcName)s | %(message)s')

# Add formatter to handler
ch.setFormatter(formatter)

# Add handler to logger if not already added
if not logger.hasHandlers():
    logger.addHandler(ch)

def get_device():
    """Determines the available hardware device (MPS, CUDA, or CPU)."""
    if torch.backends.mps.is_available():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

def measure_time(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()  # Record the start time
        result = func(*args, **kwargs)  # Call the function
        end_time = time.time()  # Record the end time
        elapsed_time = end_time - start_time  # Calculate elapsed time
        # Log the execution time under INFO level, with time in green color
        elapsed_time_str = f"{BColors.OKGREEN}{elapsed_time:.4f} seconds{BColors.ENDC}"
        logger.info(f"Execution time for {func.__name__}: {elapsed_time_str}")
        return result
    return wrapper

def cosine_sim(vec1, vec2):
    norm_vec1 = np.linalg.norm(vec1)
    norm_vec2 = np.linalg.norm(vec2)

    if norm_vec1 == 0 or norm_vec2 == 0:
        return 0.0

    return np.dot(vec1, vec2) / (norm_vec1 * norm_vec2)

def rescale_probabilities(similarities):
    """
    Rescales the similarity scores so that they sum to 1, turning them into a probability distribution.

    Args:
        similarities (List[float]): List of similarity scores.

    Returns:
        List[float]: Rescaled probabilities.
    """
    similarity_sum = sum(similarities)
    if similarity_sum == 0:
        return [0] * len(similarities)  # Avoid division by zero

    return [sim / similarity_sum for sim in similarities]

def find_closest_columns(query_embeddings, column_embeddings, high_threshold=0.4, top_n=10, rescaled_threshold=0.11):
    """
    Returns columns based on cosine similarity with a two-tiered strategy and rescaled probabilities.
    - If a column has similarity above 'high_threshold', return that column immediately.
    - Otherwise, return all columns with a rescaled probability greater than 'rescaled_threshold'.

    Args:
        query_embeddings (List[np.ndarray]): Embeddings for query words.
        column_embeddings (Dict[str, np.ndarray]): Precomputed embeddings for columns.
        high_threshold (float): Confidence threshold to return immediately (default: 0.4).
        top_n (int): Number of top columns to consider for rescaling (default: 10).
        rescaled_threshold (float): Minimum rescaled probability threshold (default: 0.11).

    Returns:
        List[str]: The selected column names.
    """
    column_similarities = {}

    for col, col_vec in column_embeddings.items():
        similarities = [cosine_sim(col_vec, q_vec) for q_vec in query_embeddings if np.linalg.norm(q_vec) > 0]
        column_similarities[col] = np.mean(similarities) if similarities else -1

    sorted_columns = sorted(column_similarities.items(), key=lambda item: item[1], reverse=True)
    top_columns = sorted_columns[:top_n]

    column_names, similarities = zip(*top_columns)

    rescaled_probs = rescale_probabilities(similarities)

    selected_columns = []

    for col, sim in zip(column_names, similarities):
        if sim >= high_threshold:
            logger.info(f"High confidence match found: {col} with similarity {sim:.4f}")
            return [col]

    for col, rescaled_prob in zip(column_names, rescaled_probs):
        if rescaled_prob >= rescaled_threshold:
            logger.info(f"Column {col} has rescaled similarity {rescaled_prob:.4f}")
            selected_columns.append(col)

    return selected_columns

def filter_query(query, node_label):
    if not query:
        return ''

    relevant = []
    node_label_cleaned = node_label.lower().replace(" ", "")
    for word in query.replace(". ", " ").lower().split():
        cleaned_word = re.sub(r'[^A-Za-z]', '', word)
        if cleaned_word in stop_words or cleaned_word in node_label_cleaned or not cleaned_word:
            continue
        relevant.append(cleaned_word)

    return " ".join(relevant)


def fuzzy_match(query_str, comparison_list, threshold=30, prioritize_exact=True):
    if not comparison_list or not query_str:
        return []

    # Preprocess db.entities to reverse-map names to IDs for faster lookup
    name_to_id = {v: k for k, v in db.entities.items()}

    # Perform fuzzy matching with optimized lookup
    matches = process.extract(query_str, comparison_list, scorer=fuzz.partial_ratio, limit=30)
    longest_matching_id = ""

    for match in matches:
        name = match[0]
        matched_id = name_to_id.get(name)
        if matched_id and name in query_str:
            if not longest_matching_id or len(name) > len(db.entities[longest_matching_id]):
                longest_matching_id = matched_id

    if longest_matching_id:
        print(f"Found exact match with Fuzzy: {db.entities[longest_matching_id]}")
        return [longest_matching_id]

    return []

In [13]:
def get_top_matches(df, normalized_query, top_n=2):
    concatenated_rows = df.apply(lambda row: ' '.join(row.astype(str)), axis=1).tolist()
    
    exact_matches = [i for i, row in enumerate(concatenated_rows) if normalized_query == row]
    
    if len(exact_matches) < top_n:
        remaining_slots = top_n - len(exact_matches)
        fuzzy_matches = process.extract(normalized_query, concatenated_rows, scorer=fuzz.partial_ratio, limit=remaining_slots)
        fuzzy_indices = [match[2] for match in fuzzy_matches]
    else:
        fuzzy_indices = []
    
    top_indices = exact_matches + fuzzy_indices
    
    return df.iloc[top_indices]

In [135]:
def recommend(node_label, entity_id, context, db):
    if "CURRENT MODE" == "RECOMMENDER" or True:
        node_label = db.normalize_string(node_label)
        if node_label in db.people_names:
            movie_ids = db.people_movie_mapping[entity_id]
            context = db.fetch(movie_ids, "subject_id")
            subject_labels = context["node label"].tolist()
            return sorted(subject_labels, key=len)[:3]

        context.dropna(axis=1, inplace=True)
        columns = [col for col in context.columns if col in db.movie_recommender_db.columns]
        red_db = db.movie_recommender_db[columns]
        context = context[columns]
        
        # drop the identified row already in the context
        subject_id_to_remove = context["subject_id"].values[0]
        red_db = red_db[red_db["subject_id"] != subject_id_to_remove]
        
        red_db.dropna(thresh=len(red_db.columns) - 0, inplace=True)
        red_db.reset_index(drop=True, inplace=True)
        
        COLUMN_WEIGHTS = {
            'director': 0.3,
            'performer': 0.1,
            'genre': 0.3,
            'screenwriter': 0.2,
            'cast member': 0.1
        }
        
        def calculate_similarity(i, row):
            similarities = []
            for col in context.columns:
                if pd.isna(row[col]):
                    continue
                    
                if col in ["node label", "subject_id"]:
                    continue
                
                set_context = set(context[col].iloc[0].split(","))
                set_row = set(row[col].split(","))
                similarity = len(set_context.intersection(set_row)) / len(set_context.union(set_row))
                similarities.append(similarity * COLUMN_WEIGHTS[col])
    
            return i, np.mean(similarities) if similarities else 0
        
        top_scores = []
        for i, row in red_db.iterrows():
            index, score = calculate_similarity(i, row)
            
            if len(top_scores) < 3:
                heapq.heappush(top_scores, (score, index))
            else:
                # Maintain top 3 scores
                heapq.heappushpop(top_scores, (score, index))
    
        # Extract indices from top 3 scores
        top_indices = [index for score, index in top_scores]
        top_rows = red_db.iloc[top_indices]
        
        subject_labels = top_rows['node label'].tolist()
        return subject_labels


In [14]:
db = DataBase()

ner_parser = NERParser(lowercase=False)

qe = QueryEmbedderContextualized()

qa = QuestionAnsweringAgent()

ca = ConversationAgent(model_name="google/flan-t5-xl") #### Could take some time, approx. 10 GB storage :)

ge = GraphEmbeddings(db.db)


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  3.62it/s]


In [136]:
import heapq

@measure_time
def answer_query(query, correct_answer=""):
    normalized_query = db.normalize_string(query)
    
    # NER --------------------------
    ner_person, ner_movies = ner_parser.process_query(normalized_query)
    ner_matches = ner_person + ner_movies
    
    is_domain_specific = bool(ner_matches)
    # if not is_domain_specific:
    #     small_talk = ca.generate_response(query)
    #     print(f"Smalltalk: {small_talk}")
    #     return
    
    print(f"NER_Person: {ner_person}")
    print(f"NER_Movies: {ner_movies}")
    
    ner_people_ids = []
    for ner_p in ner_person:
        ner_p = db.normalize_string(ner_p)
        if ner_p in db.people_names:
            ner_people_ids = [key for key, value in db.people_data.items() if value == ner_p]
        if not ner_people_ids:
            ner_people_ids = [key for key, value in db.people_data.items() if ner_p in value]
        if not ner_people_ids:
            ner_people_ids = fuzzy_match(" ".join(ner_person), db.people_names, threshold=75)
    
    ner_movie_ids = []     
    for ner_m in ner_movies:
        ner_m = db.normalize_string(ner_m)
        if ner_m in db.movie_names:
            ner_people_ids = [key for key, value in db.movie_data.items() if value == ner_m]
        if not ner_people_ids:
            ner_people_ids = [key for key, value in db.movie_data.items() if ner_m in value]
        if not ner_people_ids:
            ner_people_ids = fuzzy_match(" ".join(ner_movies), db.movie_names, threshold=75)

    ner_ids = ner_movie_ids + ner_people_ids
    context = db.fetch(ner_ids, "subject_id")
    # Context NER --------------------------
    # Context Fuzzy --------------------------

    if context.empty:
        print("NER failed, proceeding with fuzzy matching.")
        fuzzy_person_matches = fuzzy_match(normalized_query, db.people_names, threshold=30)
        fuzzy_movie_matches = fuzzy_match(normalized_query, db.movie_names, threshold=30)
    
        fuzzy_movie_context = db.fetch(fuzzy_movie_matches, "subject_id") 
        fuzzy_person_context = db.fetch(fuzzy_person_matches, "subject_id")
        
        if fuzzy_person_context.empty and fuzzy_movie_context.empty:
            small_talk = ca.generate_response(query)
            print(f"Smalltalk: {small_talk}")
            return
        
        context = pd.concat([fuzzy_movie_context, fuzzy_person_context])
    
    
    context = get_top_matches(context, normalized_query, top_n=1)    

    
    node_label = ""
    if not context.empty and "node label" in context.columns and not context["node label"].isna().values[0]:
        node_label = context["node label"].values[0]  
        
    entity_id = ""
    if not context.empty and "subject_id" in context.columns and not context["subject_id"].isna().values[0]:
        entity_id = context["subject_id"].values[0]

    if "RECOMMENDER":
        movies = recommend(node_label, entity_id, context, db)
        formatted_answer = f"Based on your interest in {node_label.title()}, I recommend watching the following movies:\n- " + '\n- '.join(movies)

        print(formatted_answer)
        return
    
    elements_to_remove = ["image", "color", "sport"]
    context = context.drop(columns=elements_to_remove, errors='ignore')
    
    # Initial context for embeddings
    initial_context = context.copy()
    
    # EXPERIMENTAL - rename columns
    columns_to_rename = {
        "cast member":"movie cast",
        "notable work": "acted in"
    }
    
    columns_to_rename = {k: v for k, v in columns_to_rename.items() if k in context.columns}
    context = context.rename(columns=columns_to_rename)
    
    columns_to_duplicate = [("acted in", "played in"),
                            ("acted in", "appeared in"),
                            ("movie cast", "actors"),
                            ("movie cast", "players")]
    
    for col_to_duplicate, col in columns_to_duplicate: 
        try:
            context[col] = context[col_to_duplicate].copy()
        except KeyError:
            pass
        
    context.dropna(axis=1, inplace=True)
    
    query_filtered = filter_query(query, node_label)
        
    column_embeddings = {col: qe.embed_phrase(col) for col in context.columns}
    query_embeddings = [qe.embed_phrase(word) for word in query_filtered.split()]  
    top_columns_embeddings = find_closest_columns(query_embeddings, column_embeddings)
        
    # EXPERIMENTAL - always keep columns
    col_always_keep = ["node label"]
    
    combined_columns = set(top_columns_embeddings + col_always_keep)
    top_columns = [col for col in combined_columns if col in context.columns]
    filtered_context_df = context[top_columns]

    answer = qa.query(query, filtered_context_df)
    formatted_answer = ca.generate_response(f"Format the answer to the question into a sentence. If you think the answer is completely off, overrule with your own knowledge.Question: {query}\nAnswer: {answer}")
        
    print(ge.answer_query_embedding(initial_context, top_columns))
    print(formatted_answer)


In [177]:
answer_query("Jurassic Park?")

2024-11-06 13:06:53,406 | INFO[0m | fuzzy_match | Found exact match with Fuzzy: jurassic park


NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Based on your interest in Jurassic Park, I recommend watching the following movies:
- Indiana Jones and the Kingdom of the Crystal Skull
- The Lost World: Jurassic Park
- War of the Worlds
Execution time for answer_query: 1.9434 seconds


In [178]:
answer_query("The Grand Budapest Hotel", "Gus Van Sant")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.


2024-11-06 13:06:55,463 | INFO[0m | fuzzy_match | Found exact match with Fuzzy: the grand budapest hotel


Based on your interest in The Grand Budapest Hotel, I recommend watching the following movies:
- The Life Aquatic with Steve Zissou
- The Darjeeling Limited
- Hotel Chevalier
Execution time for answer_query: 1.5441 seconds


In [179]:
answer_query("Who directed The Bridge on the River Kwai?", "David Lean")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.


2024-11-06 13:06:57,159 | INFO[0m | fuzzy_match | Found exact match with Fuzzy: the bridge on the river kwai


Based on your interest in The Bridge On The River Kwai, I recommend watching the following movies:
- A Passage to India
- Doctor Zhivago
- Ryan's Daughter
Execution time for answer_query: 2.4445 seconds


In [180]:
answer_query("Who directed The Dark Knight?", "Christopher Nolan")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.


2024-11-06 13:06:59,514 | INFO[0m | fuzzy_match | Found exact match with Fuzzy: the dark knight


Based on your interest in The Dark Knight, I recommend watching the following movies:
- Interstellar
- The Dark Knight Rises
- Batman Begins
Execution time for answer_query: 0.3928 seconds


In [181]:
answer_query("Where was Angelina Jolie born?", "Los Angeles")

NER_Person: ['angel']
NER_Movies: []
Based on your interest in Angelina Jolie, I recommend watching the following movies:
- Salt
- Difret
- Wanted
Execution time for answer_query: 0.2147 seconds


In [182]:
answer_query("Who is the main actor in harry potter?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.


2024-11-06 13:07:00,396 | INFO[0m | fuzzy_match | Found partial match with Fuzzy: actor in law


Based on your interest in Actor In Law, I recommend watching the following movies:
- The Benchwarmers
- What About Bob?
- L'emmerdeur
Execution time for answer_query: 1.5101 seconds


In [142]:
answer_query("Who is the main actor in harry potter and the philosopher's stone?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Found exact match with Fuzzy: harry potter and the philosophers stone
Based on your interest in Harry Potter And The Philosopher'S Stone, I recommend watching the following movies:
- The Christmas Chronicles 2
- Percy Jackson & the Olympians: The Lightning Thief
- Harry Potter and the Chamber of Secrets
Execution time for answer_query: 2.1463 seconds


In [143]:
answer_query("Who was Brad Pitt married to?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Found exact match with Fuzzy: brad pitt
Found exact match with Fuzzy: it
Based on your interest in Brad Pitt, I recommend watching the following movies:
- Hunk
- Fury
- Troy
Execution time for answer_query: 0.3703 seconds


In [144]:
answer_query("When was Inception released?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Found exact match with Fuzzy: ti
Found exact match with Fuzzy: inception
Based on your interest in Inception, I recommend watching the following movies:
- Batman Begins
- Dunkirk
- Interstellar
Execution time for answer_query: 0.4099 seconds


In [146]:
answer_query("Who is the director of Star Wars?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Found exact match with Fuzzy: star
Based on your interest in Star!, I recommend watching the following movies:
- Rooftops
- Somebody Up There Likes Me
- The Sound of Music
Execution time for answer_query: 1.9087 seconds


In [147]:
answer_query("When was the Godfather III published?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Found exact match with Fuzzy: the godfather
Based on your interest in The Godfather, I recommend watching the following movies:
- Carlito's Way
- Bram Stoker's Dracula
- Rumble Fish
Execution time for answer_query: 0.4394 seconds


In [148]:
answer_query("Who is the director of Star Wars?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Found exact match with Fuzzy: star
Based on your interest in Star!, I recommend watching the following movies:
- Rooftops
- Somebody Up There Likes Me
- The Sound of Music
Execution time for answer_query: 1.8353 seconds


In [149]:
answer_query("When was Inception released?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Found exact match with Fuzzy: ti
Found exact match with Fuzzy: inception
Based on your interest in Inception, I recommend watching the following movies:
- Batman Begins
- Dunkirk
- Interstellar
Execution time for answer_query: 0.4089 seconds


In [150]:
answer_query("Who was Angelina Jolie married to?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Found exact match with Fuzzy: angelina jolie
Found exact match with Fuzzy: angel
Based on your interest in Angelina Jolie, I recommend watching the following movies:
- Salt
- Difret
- Wanted
Execution time for answer_query: 0.4122 seconds


In [151]:
answer_query("Who was Brad Pitt married to?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Found exact match with Fuzzy: brad pitt
Found exact match with Fuzzy: it
Based on your interest in Brad Pitt, I recommend watching the following movies:
- Hunk
- Fury
- Troy
Execution time for answer_query: 0.4185 seconds


In [152]:
answer_query("What is the religion of Tom Cruise?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Found exact match with Fuzzy: tom cruise
Based on your interest in Tom Cruise, I recommend watching the following movies:
- Taps
- Legend
- Top Gun
Execution time for answer_query: 0.5702 seconds


In [153]:
answer_query("Who is the main actor in harry potter and the philosopher's stone?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Found exact match with Fuzzy: harry potter and the philosophers stone
Based on your interest in Harry Potter And The Philosopher'S Stone, I recommend watching the following movies:
- The Christmas Chronicles 2
- Percy Jackson & the Olympians: The Lightning Thief
- Harry Potter and the Chamber of Secrets
Execution time for answer_query: 1.6883 seconds


In [154]:
answer_query("Who are the cast in Jurassic Park?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Found exact match with Fuzzy: jurassic park
Based on your interest in Jurassic Park, I recommend watching the following movies:
- Indiana Jones and the Kingdom of the Crystal Skull
- The Lost World: Jurassic Park
- War of the Worlds
Execution time for answer_query: 1.9254 seconds


In [155]:
answer_query("Who acted in Jurassic Park?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Found exact match with Fuzzy: jurassic park
Based on your interest in Jurassic Park, I recommend watching the following movies:
- Indiana Jones and the Kingdom of the Crystal Skull
- The Lost World: Jurassic Park
- War of the Worlds
Execution time for answer_query: 1.5711 seconds


In [156]:
answer_query("Who played in Jurassic Park?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Found exact match with Fuzzy: jurassic park
Based on your interest in Jurassic Park, I recommend watching the following movies:
- Indiana Jones and the Kingdom of the Crystal Skull
- The Lost World: Jurassic Park
- War of the Worlds
Execution time for answer_query: 1.7797 seconds


In [157]:
answer_query("In which movie did Tom Cruise play?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Found exact match with Fuzzy: tom cruise
Found exact match with Fuzzy: play
Based on your interest in Tom Cruise, I recommend watching the following movies:
- Taps
- Legend
- Top Gun
Execution time for answer_query: 0.4521 seconds


In [158]:
answer_query("In which movie did Rebel Wilson act?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Found exact match with Fuzzy: rebel wilson
Found exact match with Fuzzy: wilson
Based on your interest in Rebel Wilson, I recommend watching the following movies:
- Cats
- Grimsby
- Fat Pizza
Execution time for answer_query: 0.4057 seconds


In [159]:
answer_query("In which movie did Liam Neeson play?")

NER_Person: ['l']
NER_Movies: []
Based on your interest in Larry Sanders, I recommend watching the following movies:
- Movie 43
- Indiana Jones and the Last Crusade
Execution time for answer_query: 10.5456 seconds


In [160]:
answer_query("Who is an actor in Taken 2?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Found exact match with Fuzzy: taken 2
Based on your interest in Taken 2, I recommend watching the following movies:
- Raid
- Taken
- Taken 3
Execution time for answer_query: 0.3981 seconds


In [161]:
answer_query("What is the role of Vin Diesel in Fast and Furious?")

NER_Person: []
NER_Movies: []
NER failed, proceeding with fuzzy matching.
Found exact match with Fuzzy: vin diesel
Found exact match with Fuzzy: fast
Based on your interest in Vin Diesel, I recommend watching the following movies:
- Strays
- Riddick
- Furious 7
Execution time for answer_query: 0.5005 seconds


In [162]:
answer_query("For which movie did Leonardo DiCaprio win an Oscar?")

NER_Person: ['le']
NER_Movies: []
Based on your interest in Helena Christensen, I recommend watching the following movies:
- Allegro
- The Reunion 3
- Prêt-à-Porter
Execution time for answer_query: 2.9037 seconds
