<a href="https://colab.research.google.com/github/sean-halpin/dialoGPT_Virtual_Character/blob/main/dialog_with_persona.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# from google.colab import drive
# drive.mount('/content/drive/')
# import os
# os.chdir("/content/drive/My Drive/Colab Notebooks")


In [None]:
!pip -q install transformers
!pip install flair
!pip install word2number

# Chatbot Persona 

In [4]:
persona = {
    "age inquiry": "39"
}

# Load pre-trained dialog req/res model

In [None]:
from transformers import AutoModelWithLMHead, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
# Standard DialoGPT
model = AutoModelWithLMHead.from_pretrained("microsoft/DialoGPT-medium")
# Transfer Learned DialoGPT
# model = AutoModelWithLMHead.from_pretrained("output-small")

# Load a pre-trained zero-shot classifier  

In [None]:
from transformers import pipeline
classifier = pipeline("zero-shot-classification",
                      model="facebook/bart-large-mnli")

In [7]:
candidate_labels = ['age inquiry', 'job inquiry', 'name inquiry', 'statement', 'personal detail']

In [8]:
def classify(sequence_to_classify):
  class_result = classifier(sequence_to_classify, candidate_labels)
  return {
      'label':class_result['labels'][0], 
      'score':class_result['scores'][0]
      }

# Load pre-trained NER (Named Entity Recognition) Model

In [None]:
from flair.data import Sentence
from flair.models import SequenceTagger
tagger = SequenceTagger.load("flair/ner-english-ontonotes-fast")

In [10]:
from word2number import w2n

In [11]:
def tokens_to_list(tokens):
  tkns = []
  for t in tokens:
    tkns.append(t.text)
  return tkns

In [17]:

def replace_personal_info(response_str, personal_info_type, persona):
  # make example sentence
  sentence = Sentence(response_str)
  list_tkns = tokens_to_list(sentence)
  # predict NER tags
  tagger.predict(sentence)
  # iterate over entities and print
  for entity in sentence.get_spans('ner'):
    if personal_info_type == "age inquiry":
      # print(entity.labels)
      # print(entity.id_text)
      # print(entity.to_dict())
      # print(entity.tokens)
      for t in entity.tokens:
        for l in t.get_labels():
          print(l)
          if "DATE" in l.value or "CARDINAL" in l.value:
            try:
              if t.text.isnumeric() or w2n.word_to_num(t.text).isnumeric():
                # print(t.text)
                # print(t.idx)
                list_tkns[t.idx - 1] = persona['age inquiry']
                return list_tkns
            except Exception as e:
              print("Error: {} - {}".format(e, t))
  return list_tkns

# Test Replacement of Personal Token

In [18]:
print(" ".join(replace_personal_info("I am 88 years old", "age inquiry", persona)))
print(" ".join(replace_personal_info("I'm 19, what should I know?", "age inquiry", persona)))
print(" ".join(replace_personal_info("20 , you ??", "age inquiry", persona)))

B-DATE (0.7147)
I am 39 years old
S-DATE (0.991)
I 'm 39 , what should I know ?
S-CARDINAL (0.5947)
39 , you ? ?


# Chat

In [None]:
# Let's chat for 5 lines
for step in range(5):
    # encode the new user input, add the eos_token and return a tensor in Pytorch
    usr_input = input(">> User:")
    classification = classify(usr_input)
    new_user_input_ids = tokenizer.encode(usr_input + tokenizer.eos_token, return_tensors='pt')

    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

    num_return_seqs=1

    # generated a response while limiting the total chat history to 1000 tokens, 
    chat_history_ids = model.generate(
        bot_input_ids, 
        max_length=200,
        pad_token_id=tokenizer.eos_token_id,  
        no_repeat_ngram_size=3,       
        do_sample=True, 
        top_k=100, 
        top_p=0.7,
        temperature = 0.7,
        num_return_sequences=num_return_seqs
    )
    
    # pretty print last ouput tokens from bot
    botname = "Bot"
    for i in range(0,num_return_seqs):
      bot_output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][i])
      if classification['label'] in persona:
        try:
          print("\t{}\n\t{}\n\t{}".format(bot_output, classification['label'], persona))
          bot_output = " ".join(replace_personal_info(bot_output, classification['label'], persona))
        except Exception as e:
          print(e)
      print("{}: {}".format(botname, bot_output, skip_special_tokens=True))

    # Reset Chat History
    # chat_history_ids = chat_history_ids[0].unsqueeze(0)

>> User:What age are you?
	I'm 21<|endoftext|>
	age inquiry
	{'age inquiry': '39'}
S-DATE (0.9433)
Bot: I 'm 39 <| endoftext |>
