## A simple nanogpt chatbot example
designed from minor changes over the sample.py script

In [None]:
# if run on colab, first install requirements
!git clone https://github.com/karpathy/nanoGPT.git
!pip install --quiet tiktoken transformers
!pip install --quiet --pre torch --force-reinstall --index-url https://download.pytorch.org/whl/nightly/cu117

In [None]:
import os
import pickle
from contextlib import nullcontext
import torch
import tiktoken

os.chdir('./nanoGPT')
from model import GPTConfig, GPT

## Settings

the default model name is 'gpt2-medium' (350M params). You can change it to 'gpt2' (124M params), 'gpt2-large' (774M params) or 'gpt2-xl' (1558M params). The 'gpt2-xl' size exceeds RAM limit for free google colab accounts.

max_new_tokens to 16. That makes the chat faster, but shorter replies.

temperature decreased to 0.3. If higher values, sounds like an insane chatbot, you can try.

In [None]:
init_from = 'gpt2-medium'#'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'out' # ignored if init_from is not 'resume'
max_new_tokens = 16 # number of tokens generated in each sample
temperature = 0.3 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 50 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16'
compile = True # use PyTorch 2.0 to compile the model to be faster

## Load model
This part is not changed compared to the original file.

In [None]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# model
if init_from == 'resume':
    # init from a model saved in a specific directory
    ckpt_path = os.path.join(out_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
elif init_from.startswith('gpt2'):
    # init from a given GPT-2 model
    model = GPT.from_pretrained(init_from, dict(dropout=0.0))

model.eval()
model.to(device)
if compile:
    model = torch.compile(model) # requires PyTorch 2.0 (optional)

# look for the meta pickle in case it is available in the dataset folder
load_meta = False
if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these...
    meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
    load_meta = os.path.exists(meta_path)
if load_meta:
    print(f"Loading meta from {meta_path}...")
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    # TODO want to make this more general to arbitrary encoder/decoder schemes
    stoi, itos = meta['stoi'], meta['itos']
    encode = lambda s: [stoi[c] for c in s]
    decode = lambda l: ''.join([itos[i] for i in l])
else:
    # ok let's assume gpt-2 encodings by default
    print("No meta.pkl found, assuming GPT-2 encodings...")
    enc = tiktoken.get_encoding("gpt2")
    encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
    decode = lambda l: enc.decode(l)


# the chatbot code

enter 'exit' to leave the chat

Note that the chatbot does not have any memory about the ongoing discussion. kinda redfish chatbot.

In [None]:
with torch.no_grad():
    with ctx:
        while True:
          prompt = input('Human > ')
          if prompt == "exit":
            break
          x = (torch.tensor(encode(prompt), dtype=torch.long, device=device)[None, ...])
          y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
          output = decode(y[0].tolist())
          # post process the answer
          # remove the prompt from the output
          output = output.split(prompt)[1]
          # split sentences, there might be repetitions, forget the end of the
          # output, often looks like a pending sentence
          # todo: add a repetition_penalty option in the generate method?
          if '.' in output:
            sentences = [s.strip() for s in output.split('.')[:-1]]
            # remove duplicates
            no_duplicates_sentences = []
            for item in sentences:
                if item not in no_duplicates_sentences:
                    no_duplicates_sentences.append(item)
            # then rebuild the paragraph as a set of unique sentences
            output = '.'.join(no_duplicates_sentences)
          # sometimes there's a character speaking, remove what is before ':'
          if ':' in output:
            output = output.split(':')[1]
          # one more strip call, just to be sure about trailing whitespaces
          output = output.strip()
          print("Chatbot >", output)
  