In [1]:
from transformers import AutoTokenizer, TFAutoModelWithLMHead
from dataclasses import dataclass
from pathlib import Path
import re

In [2]:
root_dir = Path('/home/yyang/')
pretrained_weights = 'gpt2-xl'
pretrained_model_dir = root_dir / 'models' / 'transformers'

In [3]:
def load_or_download_pretrained(cls, pretrained_model_dir: Path, pretrained_weights: str, **kw_args):
    cache_dir = str(pretrained_model_dir / pretrained_weights)
    try:
        ret = cls.from_pretrained(cache_dir, **kw_args)
    except:
        ret = cls.from_pretrained(pretrained_weights, **kw_args)
        ret.save_pretrained(cache_dir)
    return ret

In [4]:
tokenizer = load_or_download_pretrained(AutoTokenizer, pretrained_model_dir, pretrained_weights)
tokenizer

<transformers.tokenization_gpt2.GPT2Tokenizer at 0x7fe1ef05a9d0>

In [5]:
model = load_or_download_pretrained(TFAutoModelWithLMHead, pretrained_model_dir, pretrained_weights)

All model checkpoint weights were used when initializing TFGPT2LMHeadModel.

All the weights of TFGPT2LMHeadModel were initialized from the model checkpoint at /home/yyang/models/transformers/gpt2-xl.
If your task is similar to the task the model of the ckeckpoint was trained on, you can already use TFGPT2LMHeadModel for predictions without further training.


In [6]:
@dataclass
class TFChatbot:
    tokenizer: AutoTokenizer
    model: TFAutoModelWithLMHead
    name : str
        
    def chat(self, question, max_len=100):
        input_sentence = f"""Me: {question}
        {self.name}:"""

        input_ids = self.tokenizer.encode(input_sentence, return_tensors='tf')
        sample_output = self.model.generate(
            input_ids, 
            max_length=max_len, 
            pad_token_id=tokenizer.eos_token_id,

            do_sample=True,
            top_k=50,
        )
        result = self.tokenizer.decode(sample_output[0], skip_special_tokens=True)

        response = result[len(input_sentence)+1:]
        response = re.split('\n|\S+:', response)[0]
        response = response.replace('\xa0', ' ').strip()
        return response

In [7]:
bot = TFChatbot(tokenizer, model, name='Mohammed')

In [13]:
%%time

bot.chat('Who are you?')

CPU times: user 13.2 s, sys: 51.1 ms, total: 13.3 s
Wall time: 13.3 s


"I'm a person from America that wants to visit India."

In [14]:
bot.chat('What does God mean to you?')

'The highest value is God. For instance the best man is his closest and closest friend. It is the highest goal.'

In [15]:
bot.chat('Why are you here?')

'I want to talk to you about the Quran and how to handle Muslims.'

In [None]:
bot.chat('What is the purpose of life?')

In [None]:
bot.chat('Is science a lie?')

In [None]:
bot.chat('When will corona virus end?')