Following along with the example at https://ai-guide.future.mozilla.org/content/comparing-open-llms/

In [None]:
import torch
from transformers import pipeline, set_seed, AutoModelForQuestionAnswering, AutoTokenizer, AutoModelForSeq2SeqLM

In [None]:
class ChatBot:

    def __init__(self, model: str, task: str='conversational', device: str='mps') -> None:
        self.model = model
        self.pipeline = pipeline(task, model=model, device=device)
        self.tokenizer = AutoTokenizer.from_pretrained(model)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model)
        self.conversation = []
    
    def __call__(self, text: str) -> str:
        self.conversation.append(text)
        inputs = self.tokenizer(["</s> <s>".join(self.conversation)], return_tensors="pt")
        reply_ids = self.model.generate(**inputs)
        response = self.tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0]
        self.conversation.append(response)
        return response

In [None]:
chat = ChatBot("facebook/blenderbot-400M-distill")

In [None]:
chat("""I wrote this class to have our conversation, do you have any ideas for improvements?
    ```python
    class ChatBot:

    def __init__(self, model: str, task: str='conversational', device: str='mps') -> None:
        self.model = model
        self.pipeline = pipeline(task, model=model, device=device)
        self.tokenizer = AutoTokenizer.from_pretrained(model)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model)
        self.conversation = []
    
    def __call__(self, text: str) -> str:
        self.conversation.append(text)
        inputs = self.tokenizer(["</s> <s>".join(self.conversation)], return_tensors="pt")
        reply_ids = self.model.generate(**inputs)
        response = self.tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0]
        self.conversation.append(response)
        return response
     ```
     """)
display(chat.conversation)