In [1]:
from transformers import AutoTokenizer, AutoModelWithLMHead
from dataclasses import dataclass
from pathlib import Path
import re

In [2]:
root_dir = Path('/home/yyang/')
pretrained_weights = 'gpt2-xl'
pretrained_model_dir = root_dir / 'models' / 'transformers'

In [3]:
def load_or_download_pretrained(cls, pretrained_model_dir: Path, pretrained_weights: str, **kw_args):
    cache_dir = str(pretrained_model_dir / pretrained_weights)
    try:
        ret = cls.from_pretrained(cache_dir, **kw_args)
    except:
        ret = cls.from_pretrained(pretrained_weights, **kw_args)
        ret.save_pretrained(cache_dir)
    return ret

In [4]:
tokenizer = load_or_download_pretrained(AutoTokenizer, pretrained_model_dir, pretrained_weights)
tokenizer

<transformers.tokenization_gpt2.GPT2Tokenizer at 0x7fa480301550>

In [5]:
model = load_or_download_pretrained(AutoModelWithLMHead, pretrained_model_dir, pretrained_weights)



HBox(children=(FloatProgress(value=0.0, description='Downloading', max=6431878936.0, style=ProgressStyle(descr…




Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2-xl and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias', 'h.12.attn.masked_bias', 'h.13.attn.masked_bias', 'h.14.attn.masked_bias', 'h.15.attn.masked_bias', 'h.16.attn.masked_bias', 'h.17.attn.masked_bias', 'h.18.attn.masked_bias', 'h.19.attn.masked_bias', 'h.20.attn.masked_bias', 'h.21.attn.masked_bias', 'h.22.attn.masked_bias', 'h.23.attn.masked_bias', 'h.24.attn.masked_bias', 'h.25.attn.masked_bias', 'h.26.attn.masked_bias', 'h.27.attn.masked_bias', 'h.28.attn.masked_bias', 'h.29.attn.masked_bias', 'h.30.attn.masked_bias', 'h.31.attn.masked_bias', 'h.32.attn.masked_bias', 'h.33.attn.masked_bias', 'h.34.attn.masked_bias', 'h.35.attn.masked

In [6]:
@dataclass
class Chatbot:
    tokenizer: AutoTokenizer
    model: AutoModelWithLMHead
    name : str
        
    def chat(self, question, max_len=100):
        input_sentence = f"""Me: {question}
        {self.name}:"""

        input_ids = self.tokenizer.encode(input_sentence, return_tensors='pt')
        sample_output = self.model.generate(
            input_ids, 
            max_length=max_len, 
            pad_token_id=tokenizer.eos_token_id,

            do_sample=True,
            top_k=50,
        )
        result = self.tokenizer.decode(sample_output[0], skip_special_tokens=True)

        response = result[len(input_sentence)+1:]
        response = re.split('\n|\S+:', response)[0]
        response = response.replace('\xa0', ' ').strip()
        return response

In [7]:
bot = Chatbot(tokenizer, model, name='Joe')

In [8]:
%%time

bot.chat('Who are you?')

CPU times: user 9min 3s, sys: 7.38 s, total: 9min 11s
Wall time: 28.8 s


'Joe? Hi!'

In [9]:
bot.chat('What does God mean to you?')

'God means love and tolerance.'

In [10]:
bot.chat('Why are you here?')

"I'm here on a date, okay? I just wanted to ask. So I know you don't like it when people talk about us like that, but I'm having trouble getting her to let me touch her. Joe is a man who has been in a relationship with Lela Alabaster since he was a little kid, and he'd love to touch her, to feel her through her clothes"

In [11]:
bot.chat('What is the purpose of life?')

'What is the purpose of life?'

In [12]:
bot.chat('Is science a lie?')

"No, I don't think its lie, I think that what we think is true as scientists. In order for something to be called science, it has to make sense in scientific terms. We can use what is known about the natural world to understand the world in the future. Now, what we will also make use of is in order to make sense in this world now. We will come up with"

In [13]:
bot.chat('When will corona virus end?')

"Well, this is where we really have an advantage over Ebola. As you well know, we have already started a big, broad effort in our field with our very high powered field labs. We've sent several teams south east of the airport to try to find what's going on in the forests. It is getting quite out of hand. They are getting quite desperate, the soldiers and the villagers are getting"