<a href="https://colab.research.google.com/github/zelladoor/Machine-Learning-Projects/blob/master/gpt-neo_dungeon_untuned.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Instructions

Go through each cell in this notebook one by one, take a look at the options and descriptions and then press the play button to the left of it. You can skip the optional one. Don't skip any of the others. After running the "Play" cell, a small form will appear underneath, which you can use to actually play.

To reset the state of your game, run the "Setup" cell again. Closing the notebook will lose your progress, so if you want to keep your story, use the "history" action, copy out your story to a text editor. You can also copy out your author's note and memory from the output of the "info" action.


In [None]:
#@title Setup
#@markdown Run this for setting up dependencies or resetting actions
!pip install git+https://github.com/finetuneanon/transformers@gpt-neo-dungeon-localattention1
!wget -c http://ftp.us.debian.org/debian/pool/main/m/megatools/megatools_1.11.0~git20200404-1_amd64.deb -O megatools.deb
!dpkg -i megatools.deb
!nvidia-smi

import os

from transformers import GPTNeoForCausalLM, AutoTokenizer
import tarfile
import codecs
import torch
import subprocess

from IPython.display import HTML, display
import ipywidgets as widgets

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))

try:
  initialized += 1
except:
  get_ipython().events.register('pre_run_cell', set_css)
  initialized = 0

actions = []
memory = ("", torch.zeros((1, 0)).long())
lmi = ["", torch.zeros((1, 0)).long()]
an = ("", torch.zeros((1, 0)).long())
an_depth = 3
history = None

In [None]:
#@title Model setup

model_name = "EleutherAI/gpt-neo-2.7B" #@param ["EleutherAI/gpt-neo-2.7B", "EleutherAI/gpt-neo-1.3B"]

model = None
tokenizer = None
pipeline = None
checkpoint = None

if True:
  from transformers.file_utils import cached_path, WEIGHTS_NAME, hf_bucket_url
  archive_file = hf_bucket_url(model_name, filename=WEIGHTS_NAME)
  resolved_archive_file = cached_path(archive_file)
  checkpoint = torch.load(resolved_archive_file, map_location="cuda:0")
  for k in checkpoint.keys():
    checkpoint[k] = checkpoint[k].half()
  model = GPTNeoForCausalLM.from_pretrained(model_name, state_dict=checkpoint).half().to("cuda").eval()
  for k in list(checkpoint.keys()):
    del checkpoint[k]
  del checkpoint
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [None]:
#@title Sampling settings
#@markdown You can modify sampling settings here. Don't forget to run the cell again after changing. The number of generated tokens is subtracted from the context window size, don't set it high.
top_k = 60 #@param {type:"number"}
top_p = 0.9 #@param {type:"number"}
temperature =  0.6#@param {type:"number"}
number_generated_tokens =  40#@param {type:"integer"}
repetition_penalty = 1.25 #@param {type:"number"}
repetition_penalty_range = 300 #@param {type:"number"}
repetition_penalty_slope = 3.33 #@param {type:"number"}
number_show_last_actions = 15 #@param {type:"integer"}

#@markdown Temperatures seem to give results different from those in AID, so play around with it. Even 0.5 can give good results.

In [None]:
#@title Basic sampling

#@markdown Use this cell if you just want to sample from the model in a free form way.

basic_prompt = "The rays of the evening sun falling in through the window bathed the room in a soft, warm light" #@param {type:"string"}

ids = tokenizer(basic_prompt, return_tensors="pt").input_ids.to("cpu")
n_ids = ids.shape[1]
if n_ids < 1:
  n_ids = 1
  ids = torch.tensor([[tokenizer.eos_token_id]])
max_length = n_ids + number_generated_tokens
torch.cuda.empty_cache()
basic_output = model.generate(
    ids.long().cuda(),
    do_sample=True,
    min_length=max_length,
    max_length=max_length,
    temperature=temperature,
    top_k = top_k,
    top_p = top_p,
    repetition_penalty = repetition_penalty,
    repetition_penalty_range = repetition_penalty_range,
    repetition_penalty_slope = repetition_penalty_slope,
    use_cache=True,
    pad_token_id=tokenizer.eos_token_id
).long().to("cpu")
torch.cuda.empty_cache()

print(tokenizer.decode(basic_output[0]))

# Using gpt-neo dungeon's play function

If your prompt starts with a letter, try putting a space or newline in front.

* **generate** adds your prompt as an action and generates more output
* **continue** generates more output
* **replace** replaces the last output with the prompt and generates more, use this to edit
* **info** outputs LMI and memory
* **history** outputs all actions so far
* **memory** sets memory to the text in the prompt field
* **authorsnote** sets author's note to the text in the prompt field
* **andepth** sets the depth of the author's note to the number in the prompt
* **tokenize** tokenizes the text in the prompt field and outputs the number of tokens

In [None]:
#@title Play

action_type = "generate"
prompt = ""
need_refresh = True

action_types = ["generate", "continue", "replace", "undo", "retry", "memory", "authorsnote", "andepth", "info", "history", "tokenize"]

def assemble():
  remaining = (2048 - number_generated_tokens + 1) - memory[1].shape[1] - an[1].shape[1]
  n_actions = len(actions)
  n_ctx = 0
  back_i = n_actions
  for i in range(n_actions):
      i_action = n_actions - i - 1
      n_tok = actions[i_action][1].shape[1]
      if remaining > n_ctx + n_tok:
        n_ctx += n_tok
        back_i = i_action
      else:
        break
  lmi[0], lmi[1] = memory[0], memory[1]
  start = False
  if n_actions - back_i - 1 < an_depth:
    start = True
  while back_i < n_actions:
    if start or n_actions - back_i - 1 == an_depth:
      lmi[0] += an[0]
      lmi[1] = torch.cat([lmi[1].cpu(), an[1].cpu()], 1).long()
      start = False
    lmi[0] += actions[back_i][0]
    lmi[1] = torch.cat([lmi[1].cpu(), actions[back_i][1].cpu()], 1).long()
    back_i += 1

def clear_output():
  with out:
    IPython.display.clear_output()

def set_action(change):
  global action_type
  action_type = change.new

def set_prompt(change):
  global prompt
  prompt = change.new

@torch.no_grad()
def play(do_action=None):
  global memory, need_refresh, an, an_depth, action_type, history
  an_updated = False
  memory_updated = False
  if do_action is not None:
    action = do_action
    action_type = do_action
  else:
    action = action_type
  with out:
    if prompt in action_types:
      action == prompt
    else:
      if action == "replace":
        if len(actions) > 0:
          actions.pop()
        need_refresh = True
        action = "generate"
      if action == "generate":
        text = prompt
        if len(text) > 0:
          for line in text.splitlines(True):
            tokens = tokenizer(line, return_tensors="pt").input_ids.to("cpu")
            actions.append((line, tokens))
        action = "continue"
      if action == "info":
        clear_output()
        print("LMI: " + lmi[0])
        print("LMI tokens: " + str(lmi[1].shape[1]))
        print("Memory: " + memory[0])
        print("Author's note: " + an[0])
        print("Author's note depth: " + str(an_depth))
        need_refresh = True
      if action == "history":
        clear_output()
        print("".join([action[0] for action in actions]), end="")
        need_refresh = False
      if action == "retry":
        if len(actions) > 0:
          actions.pop()
        need_refresh = True
        action = "continue"
      if action == "undo":
        if len(actions) > 0:
          actions.pop()
        assemble()
        clear_output()
        print("".join([action[0] for action in actions[-number_show_last_actions:]]), end="")
        need_refresh = False
      if action == "memory":
        if prompt == "":
          memory = ("", torch.zeros((1, 0)).long())
          text = ""
        else:
          text = codecs.decode(prompt + "\n", "unicode-escape")
          tokens = tokenizer(text, return_tensors="pt").input_ids.to("cpu")
          memory = (text, tokens)
        clear_output()
        print("Memory: " + text)
        memory_updated = True
      if action == "authorsnote":
        if prompt == "":
          an = ("", torch.zeros((1, 0)).long())
          text = ""
        else:
          text = "\n[Author's note: " + codecs.decode(prompt, "unicode-escape") + "]\n"
          tokens = tokenizer(text, return_tensors="pt").input_ids.to("cpu")
          an = (text, tokens)
        clear_output()
        print("Author's note: " + text)
        an_updated = True
      if action == "andepth":
        clear_output()
        try:
          an_depth = int(codecs.decode(prompt + "\n", "unicode-escape"))
        except:
          pass
        print("Author's note depth: " + str(an_depth))
        an_updated = True
      if action == "tokenize":
        text = codecs.decode(prompt, "unicode-escape")
        tokens = tokenizer(text, return_tensors="pt").input_ids.to("cpu")
        clear_output()
        print("Tokens: " + str(tokens.shape[1]))
        print(tokens[0])
        need_refresh = True
      if action == "continue":
        assemble()
        ids = lmi[1].cuda()
        n_ids = ids.shape[1]
        if n_ids < 1:
          n_ids = 1
          ids = torch.tensor([[tokenizer.eos_token_id]])
        max_length = number_generated_tokens + n_ids
        #ids[:, :] = 13
        torch.cuda.empty_cache()
        clear_output()
        gen_tokens = model.generate(
            ids.long().cuda(),
            do_sample=True,
            min_length=max_length,
            max_length=max_length,
            temperature=temperature,
            top_k = top_k,
            top_p = top_p,
            repetition_penalty = repetition_penalty,
            repetition_penalty_range = repetition_penalty_range,
            repetition_penalty_slope = repetition_penalty_slope,
            use_cache=True,
            pad_token_id=tokenizer.eos_token_id
        ).long()
        stop_tokens = [0, 13, 30, 526, 764, 1701, 2474, 5145, 5633]
        for i in reversed(range(len(gen_tokens[0]))):
          if i < n_ids:
            gen_tokens = gen_tokens[0]
            break
          if gen_tokens[0][i] in stop_tokens:
            gen_tokens = gen_tokens[0][:i+1]
            break
        gen_text = tokenizer.decode(gen_tokens[n_ids:])
        if len(gen_text) > 0:
          actions.append((gen_text, gen_tokens[n_ids:].unsqueeze(0).cpu()))
        print("".join([action[0] for action in actions[-number_show_last_actions:]]), end="")
        torch.cuda.empty_cache()
        need_refresh = False
    if history is not None:
      if history:
        with out_history:
          IPython.display.clear_output()
          print("".join([action[0] for action in actions]), end="")
        with out_history2:
          IPython.display.clear_output()
      else:
        with out_history2:
          IPython.display.clear_output()
          print("".join([action[0] for action in actions]), end="")
        with out_history:
          IPython.display.clear_output()
      if an_updated:
        with out_an:
          IPython.display.clear_output()
          if len(an[0]) > 0:
            print("AN depth: " + str(an_depth) + "\n" + an[0], end="")
      if memory_updated:
        with out_memory:
          IPython.display.clear_output()
          print(memory[0], end="")
      history = not history

import ipywidgets as widgets
import IPython.display
out = widgets.Output(layout={'border': '1px solid black', "width": "1280px"})
dropdown = widgets.Dropdown(options=action_types, value=action_type, description='Action:', disabled=False)
dropdown.observe(set_action, 'value')
button = widgets.Button(description='[selected action]', disabled=False)
button.on_click(lambda _: play(dropdown.value))
retry_button = widgets.Button(description='Retry', disabled=False)
retry_button.on_click(lambda _: play("retry"))
continue_button = widgets.Button(description='Continue', disabled=False)
continue_button.on_click(lambda _: play("continue"))
undo_button = widgets.Button(description='Undo', disabled=False)
undo_button.on_click(lambda _: play("undo"))
hbox = widgets.HBox([button, retry_button, continue_button, undo_button])
input = widgets.Textarea(value='', placeholder='', description='Input:', disabled=False, rows=4, layout={"width": "1280px"})
input.observe(set_prompt, 'value')

display(out, dropdown, hbox, input)

In [None]:
#@title History
#@markdown Run this cell to have an auto-updating full listing of the current story.

history = True
out_history = widgets.Output(layout={'border': '1px solid black', "width": "1280px"})
out_history2 = widgets.Output(layout={'border': '1px solid black', "width": "1280px"})
out_memory = widgets.Output(layout={'border': '1px solid black', "width": "1280px"})
out_an = widgets.Output(layout={'border': '1px solid black', "width": "1280px"})
display(out_history, out_history2, out_memory, out_an)

In [None]:
#@title Attention display
#@markdown If you don't know what this is, just ignore. attn_head_combination selects the operation used to combine layer and attention head results. Only global attention is included.

visualizer_prompt = "Before his eyes was an orange cat with stripes. Running his fingers through its soft fur, he admired" #@param {type:"string"}
max_attentions =  8#@param {type:"integer"}
attn_head_combination = "mean" #@param ["mean", "max"]

ids = tokenizer(visualizer_prompt, return_tensors="pt").input_ids.to("cpu")
n_ids = ids.shape[1]
if n_ids < 1:
  n_ids = 1
  ids = torch.tensor([[tokenizer.eos_token_id]])
max_length = n_ids + number_generated_tokens
torch.cuda.empty_cache()
basic_output = model.generate(
    ids.long().cuda(),
    do_sample=True,
    min_length=max_length,
    max_length=max_length,
    temperature=temperature,
    top_k = top_k,
    top_p = top_p,
    repetition_penalty = repetition_penalty,
    repetition_penalty_range = repetition_penalty_range,
    repetition_penalty_slope = repetition_penalty_slope,
    use_cache=True,
    pad_token_id=tokenizer.eos_token_id,
    return_dict_in_generate=True,
    output_attentions=True
)
torch.cuda.empty_cache()
attentions = basic_output["attentions"]
basic_output = basic_output["sequences"].cpu()

print("Prompt: " + visualizer_prompt)
print()

def combine(attentions):
  if attn_head_combination == "mean":
    attentions = attentions.mean(0)
  else:
    attentions = attentions.max(0)[0]
  return attentions

import torch.nn.functional as F
import numpy as np
#gen_tokens x layers x tensor (1 x heads x n_ids square)
torch.set_printoptions(sci_mode=False)
for i in range(number_generated_tokens):
  layer_attn = []
  for j in range(0,len(attentions[i]),2):
    layer_attn.append(combine(attentions[i][j][0,:,-1].float().cpu()))
  layer_attn = torch.stack(layer_attn)
  token_attn = combine(layer_attn)
  prob, topk = token_attn.topk(max_attentions)
  top_tokens = []
  for top in topk:
    decoded = tokenizer.decode(torch.tensor([basic_output[0][top]]))
    top_tokens.append(f"{decoded!r}@{top} ({token_attn[top]:.4f})")
  print("Token: " + repr(tokenizer.decode(torch.tensor([basic_output[0][n_ids + i]]))))
  print(", ".join(top_tokens))
  print()

print("Output: " + tokenizer.decode(basic_output[0]))