In [17]:
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from transformers import logging as transformers_logging
transformers_logging.set_verbosity_error()


class NERParser:
    def __init__(self, model_name: str = "dslim/bert-base-NER", lowercase: bool = False):
        """
        Initialize the NER parser with a model and optionally configure the lowercase preprocessing.
        """
        self.model_name = model_name
        self.lowercase = lowercase
        self.device = self.get_device()
        
        # Load the tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModelForTokenClassification.from_pretrained(self.model_name)
        
        # Set up the NER pipeline
        self.nlp_pipeline = pipeline("ner", 
                                     model=self.model, 
                                     tokenizer=self.tokenizer, 
                                     device=self.device, 
                                     aggregation_strategy="simple")

    def get_device(self):
        """
        Determines whether to use MPS, CUDA, or CPU depending on the available hardware.
        """
        if torch.backends.mps.is_available():
            print("MPS device found, using MPS backend.\n")
            return torch.device("mps")
        elif torch.cuda.is_available():
            print(f"CUDA device found, using CUDA backend. Device: {torch.cuda.get_device_name(0)}\n")
            return torch.device("cuda")
        else:
            print("Neither MPS nor CUDA found, using CPU.\n")
            return torch.device("cpu")

    
    def parse_ner_results(self, ner_results: list):
        """
        Parse the NER results and extract entities related to 'PER' (persons) and 'MISC' (potential movie titles).
        """
        per_entities, misc_entities = [], []
        
        for entity in ner_results:
            # Extraction of all Persons
            if entity['entity_group'] == 'PER':
                per_entities.append(entity['word'])
            # Extraction of all Misc that could indicate movies
            elif entity['entity_group'] == 'MISC':
                misc_entities.append(entity['word'])
        
        return per_entities, misc_entities

    
    def process_query(self, query: str):
        """
        Processes a text query, runs NER, and returns the extracted actors and movie names.
        """
        # Optionally lowercase the input if configured
        if self.lowercase:
            query = query.lower()
        
        # Run the NER pipeline
        ner_results = self.nlp_pipeline(query)

        # Parse the results to extract actors and movies
        per_entities, misc_entities = self.parse_ner_results(ner_results)
        
        return per_entities, misc_entities


##################
### Example usage
##################
ner_parser = NERParser(lowercase=False)
query_one = "Did Kate Winslet and Leonardo Di Caprio play in Titanic?"
actors, movies = ner_parser.process_query(query)
print("Actors:", actors)
print("Movies:", movies)
print("\n")

query_two = "I like Steven Spielberg, can you recommend me similar directors?"
actors, movies = ner_parser.process_query(query_two)
print("Actors:", actors)
print("Movies:", movies)
print("\n")

query_three = "Who played in the movie Inception?"
actors, movies = ner_parser.process_query(query_three)
print("Actors:", actors)
print("Movies:", movies)
print("\n")

query_four = "Produced by Steven Spielberg, Kate Winslet and Angelina Jolie played in Titanic is this correct?"
actors, movies = ner_parser.process_query(query_four)
print("Actors:", actors)
print("Movies:", movies)
print("\n")

MPS device found, using MPS backend.

Actors: ['Kate Winslet', 'Leonardo Di Caprio']
Movies: ['Titanic']


Actors: ['Steven Spielberg']
Movies: []


Actors: []
Movies: ['Inception']


Actors: ['Steven Spielberg', 'Kate Winslet', 'Angelina Jolie']
Movies: ['Titanic']


