<a href="https://colab.research.google.com/github/sean-halpin/dialoGPT_Virtual_Character/blob/main/dialog_with_persona.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip -q install transformers
!pip install flair==0.10
!pip install word2number

# Chatbot Persona 

In [None]:
persona = {
    "age inquiry": "39"
}

# Load pre-trained dialog req/res model

In [None]:
from transformers import AutoModelWithLMHead, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("shalpin87/dialoGPT-homer-simpson")
model = AutoModelWithLMHead.from_pretrained("shalpin87/dialoGPT-homer-simpson")

# Load a pre-trained zero-shot classifier  

In [None]:
from transformers import pipeline
classifier = pipeline("zero-shot-classification",
                      model="facebook/bart-large-mnli")

In [None]:
candidate_labels = ['age inquiry', 'job inquiry', 'name inquiry', 'statement', 'personal detail']

In [None]:
def classify(sequence_to_classify):
  class_result = classifier(sequence_to_classify, candidate_labels)
  return {
      'label':class_result['labels'][0], 
      'score':class_result['scores'][0]
      }

# Load pre-trained NER (Named Entity Recognition) Model

In [None]:
from flair.data import Sentence
from flair.models import SequenceTagger
tagger = SequenceTagger.load("flair/ner-english-ontonotes-fast")

In [None]:
from word2number import w2n

In [None]:
def tokens_to_list(tokens):
  tkns = []
  for t in tokens:
    tkns.append(t.text)
  return tkns

In [None]:

def replace_personal_info(response_str, personal_info_type, persona):
  # make example sentence
  sentence = Sentence(response_str)
  list_tkns = tokens_to_list(sentence)
  # predict NER tags
  tagger.predict(sentence)
  # iterate over entities and print
  for entity in sentence.get_spans('ner'):
    # print(dir(entity))
    if personal_info_type == "age inquiry":
      # print(entity.labels)
      # print(entity.id_text)
      # print(entity.to_dict())
      # print(entity.tokens)
      for t in entity.tokens:
        for l in t.labels:
          if "DATE" in l.value or "CARDINAL" in l.value:
            try:
              if t.text.isnumeric() or w2n.word_to_num(t.text):
                # print(t.text)
                # print(t.idx)
                list_tkns[t.idx - 1] = persona['age inquiry']
                return list_tkns
            except Exception as e:
              print("Error: {} - {}".format(e, t.text.isnumeric()))
  return list_tkns

# Test Replacement of Personal Token

In [None]:
# print(" ".join(replace_personal_info("I am 88 years old", "age inquiry", persona)))
# print(" ".join(replace_personal_info("I'm 19, what should I know?", "age inquiry", persona)))
# print(" ".join(replace_personal_info("20 , you ??", "age inquiry", persona)))
# print(" ".join(replace_personal_info("I'm thirty-one.", "age inquiry", persona)))

# Chat

In [None]:
questions = [
    "What is your name?",
    "Who are you?",
    "Where do you work?",
    "Who really killed Mr Burns?",
    "Have you ever stolen from the Kwik-E-Mart?",
    "Who was the worst member of the Be Sharps?",
    "Hey where did Barney Gumble go?",
    "What is your favorite bar to have a beer?",
    "What is the best beer in Springfield?",
    "Is Bart working for the Mob?",
    "I think there was an incident in sector 7 G",
    "Is Ned Flanders house okay?",
    "Oh my god it's Sideshow Bob",
    "What is a Flaming Moe?",
    "What is happening to Apu?",
    "Who quit the band?",
    "What age are you?",
    "How old are you?"
]

botname = "HomerBot"
# Let's chat
for step in range(len(questions)):
    print("***************************************")
    # model_input = input(">> User:")
    model_input = questions[step]
    print("Q. {}".format(model_input))
    classification = classify(model_input)
    new_user_input_ids = tokenizer.encode(model_input + tokenizer.eos_token, return_tensors='pt')

    bot_input_ids = new_user_input_ids

    num_return_seqs=1

    chat_history_ids = model.generate(
        bot_input_ids, 
        max_length=200,
        pad_token_id=tokenizer.eos_token_id,  
        no_repeat_ngram_size=3,       
        do_sample=True, 
        top_k=200, 
        top_p=0.55,
        temperature = 0.85,
        num_return_sequences=num_return_seqs
    )
    
    botname = "HomerBot"
    for i in range(0,num_return_seqs):
      bot_output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][i])
      if classification['label'] in persona:
        try:
          # print("\t{}\n\t{}\n\t{}".format(bot_output.replace("<|endoftext|>",""), classification['label'], persona))
          bot_output = " ".join(replace_personal_info(bot_output.replace("<|endoftext|>","").replace("<| endoftext |>",""), classification['label'], persona))
        except Exception as e:
          print(e)
      print("{}: {}".format(botname, bot_output.replace("<|endoftext|>","").replace("<| endoftext |>",""), skip_special_tokens=True))

    chat_history_ids = []

# Chat with User

In [None]:
for step in range(5):
    new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')
    bot_input_ids = new_user_input_ids
    
    chat_history_ids = model.generate(
        bot_input_ids, 
        max_length=200,
        pad_token_id=tokenizer.eos_token_id,  
        no_repeat_ngram_size=3,       
        do_sample=True, 
        top_k=200, 
        top_p=0.55,
        temperature = 0.85,
    )

    print("HomerBot: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))
    chat_history_ids = []