In [3]:
import transformers 
from transformers import pipeline
import numpy as np
import pandas as pd
import tensorflow as tf
import pickle
import torch

In [5]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cuda:0'

In [9]:
named_entity_recognition_pipeline = pipeline("ner", aggregation_strategy="simple", model="Jean-Baptiste/roberta-large-ner-english", device=device)

In [23]:
zero_shot_classifier_pipeline = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=device)

In [113]:
subject = "none"
def nlp_get_subject(text_input):
    global subject
    # first get all named entities from the text input
    entities = named_entity_recognition_pipeline(text_input)
    # get all location entities
    location_entities = [entity for entity in entities if entity['entity_group']=='LOC']
    subject = location_entities[0]['word'] if len(location_entities)!=0 else subject # replace subject is there is a new location entity
    return subject, entities

def nlp_get_major_category(text_input):
    # major categories, text info, videos, images, tweets, 
    major_categories = ["videos", "images", "tweets", "info"]
    classes = zero_shot_classifier_pipeline(text_input, candidate_labels=major_categories)
    major_category = classes['labels'][0]
    major_category_score = classes['scores'][0]
    return major_category, major_category_score
    
def nlp_get_info_category_strat_multi_point_category(text_input):
    history_categories = ["history"]
    tourist_categories = ["tourist", "eat"]
    general_info_categories = ["general"]
    info_categories = history_categories + tourist_categories + general_info_categories
    classes = zero_shot_classifier_pipeline(text_input, candidate_labels=info_categories)
    major_category = classes['labels'][0]
    major_category_score = classes['scores'][0]
    average_distance = sum([major_category_score - i for i in classes['scores'][1:]]) / (len(classes['scores']))
    if major_category in history_categories:
        major_category = "history"
    if major_category in tourist_categories:
        major_category = "tourist"
    if major_category in general_info_categories:
        major_category = "general"
    return major_category, major_category_score
    
def nlp_get_info_category_strat_simple(text_input):
    classes = ['history', 'tourist', 'general']    
    classes = zero_shot_classifier_pipeline(text_input, candidate_labels=classes)
    major_category = classes['labels'][0]
    major_category_score = classes['scores'][0]
    return major_category, major_category_score
    
def nlp_pipeline(text_input):
    global subject
    # first get all named entities from the text input
    subject, entities = nlp_get_subject(text_input)


    # next, we need to figure out what the user is asking for
    major_category, major_category_score = nlp_get_major_category(text_input)

    # next, if we have a request for info, we need to break down what info they are asking for,   
    if major_category == "info":
        major_category, major_category_score = nlp_get_info_category_strat_multi_point_category(text_input)
        

    print((subject, major_category, major_category_score, text_input))

In [114]:

nlp_pipeline("show me pics of downtown Dallas.")
nlp_pipeline("show me Dallas videos.")
print()

# info tests
# general info questions
print("general")
nlp_pipeline("tell me about Dallas.")
nlp_pipeline("general info about Dallas.")
nlp_pipeline("what is Dallas.")
# history questions
print("\nhistory")
nlp_pipeline("what is the history of Dallas.")
nlp_pipeline("when was Dallas Founded.")
nlp_pipeline("tell me about Dallas's past.")
nlp_pipeline("When was the last slave Freed in Texas?")
# tourist questions
print("\ntourist")
nlp_pipeline("Things to do in Dallas")
nlp_pipeline("site seeing Dallas")
nlp_pipeline("Places to go Dallas")
nlp_pipeline("Places to eat Dallas")
nlp_pipeline("Dallas nightlife")
nlp_pipeline("5 start hotels Dallas")
nlp_pipeline("Dallas housing")

(' Dallas', 'images', 0.8798865675926208, 'show me pics of downtown Dallas.')
(' Dallas', 'videos', 0.8130041360855103, 'show me Dallas videos.')

general
(' Dallas', 'general', 0.402066171169281, 'tell me about Dallas.')
(' Dallas', 'general', 0.976294994354248, 'general info about Dallas.')
(' Dallas', 'general', 0.6452187895774841, 'what is Dallas.')

history
(' Dallas', 'history', 0.871361494064331, 'what is the history of Dallas.')
(' Dallas', 'history', 0.6147159934043884, 'when was Dallas Founded.')
(' Dallas', 'history', 0.8361102342605591, "tell me about Dallas's past.")
(' Texas', 'history', 0.5355569124221802, 'When was the last slave Freed in Texas?')

tourist
(' Dallas', 'tourist', 0.43379834294319153, 'Things to do in Dallas')
(' Dallas', 'tourist', 0.9219280481338501, 'site seeing Dallas')
(' Dallas', 'tourist', 0.6388269662857056, 'Places to go Dallas')
(' Dallas', 'tourist', 0.8585844039916992, 'Places to eat Dallas')
(' Dallas', 'tourist', 0.39705610275268555, 'Dallas