In [1]:
! pip install transformers accelerate bitsandbytes optimum


Collecting accelerate
  Downloading accelerate-0.29.3-py3-none-any.whl (297 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m297.6/297.6 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting bitsandbytes
  Downloading bitsandbytes-0.43.1-py3-none-manylinux_2_24_x86_64.whl (119.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.8/119.8 MB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting optimum
  Downloading optimum-1.19.1-py3-none-any.whl (417 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m417.0/417.0 kB[0m [31m44.5 MB/s[0m eta [36m0:00:00[0m
Collecting coloredlogs (from optimum)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
Collecting datasets (from optimum)
  Downloading datasets-2.19.0-py3-none-any.whl (542 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import gc

import pandas as pd
import torch
import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


from transformers.pipelines import pipeline
from tqdm.notebook import tqdm
from tqdm.contrib import tenumerate
from transformers import AutoTokenizer, BitsAndBytesConfig


Helper function for freeing allocation

In [3]:
def flush():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

flush()


def find_xth_index(tokens, x):
  x = int(x)
  letter_count = 0
  for i, token in enumerate(tokens):
    if letter_count + len(token) >= x:
      return i
    letter_count += len(token)
  return len(tokens) - 1


Load novels from the project dialogism novel corpus.
To ensure persistance in this notebook, mount Google drive and upload it there. It's also more convenient.

In [4]:
novel_names = ["PrideAndPrejudice", "Emma"]
data_folder = "."
p_contexts = []
e_contexts = []
context_lengths = [1, 2, 4, 8, 16]


p_quotes = pd.read_csv(f"{data_folder}/PrideAndPrejudice.csv")
e_quotes = pd.read_csv(f"{data_folder}/Emma.csv")

for length in context_lengths:
  p_contexts.append(pd.read_csv(f"{data_folder}/PrideAndPrejudice_context{length}.csv"))
  e_contexts.append(pd.read_csv(f"{data_folder}/Emma_context{length}.csv"))



Define models to use from Hugging Face

In [5]:
models = [
    "meta-llama/Llama-2-7b-chat-hf",
    "meta-llama/Llama-2-13b-chat-hf",
    "mistralai/Mistral-7B-Instruct-v0.2",
]

model_names = [
    "Llama-2-7b-chat-hf",
    "Llama-2-13b-chat-hf",
    "Mistral-7B-Instruct-v0.2",
]


Define prompts to use

In [6]:
def llama_prompt(quote):
  return f"OUTPUT THE NAME OF THE CHARACTER WHO SAID:\n'{quote}'\n\nOnly give me the speaker’s name and nothing else. Please do NOT include the quote in the response."

def mistral_prompt(quote):
  return f"[INST]\n{llama_prompt(quote)}\n[/INST]"

def llama_context_prompt(quote, left, right):
  return f"CONTEXT:\n'{left} {quote} {right}'\n\nGIVEN CONTEXT, {llama_prompt(quote)}"

def mistral_context_prompt(quote, left, right):
  return f"[INST]\n{llama_context_prompt(quote, left, right)}\n[/INST]"


Get the pretrained models from Hugging Face and run the predictions on each novel

In [7]:
def infer(novel, model_name, model, model_prompt, model_context_prompt, quotes, contexts):
  # del pipe
  flush()
  tokenizer = AutoTokenizer.from_pretrained(model)
  tokenizer.pad_token = tokenizer.eos_token
  pipe = pipeline(
      task = "text-generation",
      model = model,
      tokenizer = tokenizer,
      device_map="auto",
      torch_dtype=torch.bfloat16,
      trust_remote_code=True,
      model_kwargs = {
          "low_cpu_mem_usage": True,
          },
      )

  # Infer on no context
  results = []
  for quote_idx, quote in tqdm(quotes.iterrows(), total = quotes.shape[0], desc = "No context quotes"):
    # break
    prompt_text = model_prompt(quote["quoteText"])
    sequences = pipe(
        prompt_text,
        return_full_text=False,
        max_new_tokens=100,
        pad_token_id=tokenizer.eos_token_id,
        )

    speaker_gen = ""
    for out in sequences:
      speaker_gen += out["generated_text"]
    # speaker_gen = sequences[0]["generated_text"]
    results.append([prompt_text, speaker_gen])

  results_df = pd.DataFrame(results, columns=["prompt_text", "inferred_speaker"])
  save_dir = f"results/context0/{model_name}"
  if not os.path.exists(save_dir):
    os.makedirs(save_dir)
  results_df.to_csv(f"{save_dir}/{novel}.csv", index=False)

  # Infer on contexts
  for c_idx, context in enumerate(tqdm(context_lengths, desc="Context lengths")):
    results = []
    for idx, quote in tqdm(quotes.iterrows(), desc="Quotes", total=quotes.shape[0]):
      context_row = contexts[c_idx].iloc[idx]
      prompt_text = model_context_prompt(quote["quoteText"], context_row["left_context"], context_row["right_context"])
      sequences = pipe(
          prompt_text,
          return_full_text=False,
          max_new_tokens=100,
          pad_token_id=tokenizer.eos_token_id,
          )

      speaker_gen = ""
      for out in sequences:
        speaker_gen += out["generated_text"]
      # speaker_gen = sequences[0]["generated_text"]
      results.append([prompt_text, speaker_gen])

      # if idx == 10:
      #   break

    results_df = pd.DataFrame(results, columns=["prompt_text", "inferred_speaker"])
    save_dir = f"results/context{context}/{model_name}"
    if not os.path.exists(save_dir):
      os.makedirs(save_dir)
    results_df.to_csv(f"{save_dir}/{novel}.csv", index=False)
  del pipe
  flush()




In [10]:
# Mistral with INST
infer("PrideAndPrejudice", "Mistral 7b INST", "mistralai/Mistral-7B-Instruct-v0.2", mistral_prompt, mistral_context_prompt, p_quotes, p_contexts)
infer("Emma", "Mistral 7b INST", "mistralai/Mistral-7B-Instruct-v0.2", mistral_prompt, mistral_context_prompt, e_quotes, e_contexts)

# # Mistral without INST
infer("PrideAndPrejudice", "Mistral 7b NO INST", "mistralai/Mistral-7B-Instruct-v0.2", llama_prompt, llama_context_prompt, p_quotes, p_contexts)
infer("Emma", "Mistral 7b NO INST", "mistralai/Mistral-7B-Instruct-v0.2", llama_prompt, llama_context_prompt, e_quotes, e_contexts)

# Llama 13
infer("PrideAndPrejudice", "Llama 13b", "meta-llama/Llama-2-13b-chat-hf", llama_prompt, llama_context_prompt, p_quotes, p_contexts)
infer("Emma", "Llama 13b", "meta-llama/Llama-2-13b-chat-hf", llama_prompt, llama_context_prompt, e_quotes, e_contexts)

# Llama 7
infer("PrideAndPrejudice", "Llama 7b", "meta-llama/Llama-2-7b-chat-hf", llama_prompt, llama_context_prompt, p_quotes, p_contexts)
infer("Emma", "Llama 7b", "meta-llama/Llama-2-7b-chat-hf", llama_prompt, llama_context_prompt, e_quotes, e_contexts)



Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/9.95G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/9.90G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/6.18G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

No context quotes:   0%|          | 0/1270 [00:00<?, ?it/s]

Context lengths:   0%|          | 0/5 [00:00<?, ?it/s]

Quotes:   0%|          | 0/1270 [00:00<?, ?it/s]

Quotes:   0%|          | 0/1270 [00:00<?, ?it/s]

Quotes:   0%|          | 0/1270 [00:00<?, ?it/s]

Quotes:   0%|          | 0/1270 [00:00<?, ?it/s]

Quotes:   0%|          | 0/1270 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

No context quotes:   0%|          | 0/1593 [00:00<?, ?it/s]

Context lengths:   0%|          | 0/5 [00:00<?, ?it/s]

Quotes:   0%|          | 0/1593 [00:00<?, ?it/s]

Quotes:   0%|          | 0/1593 [00:00<?, ?it/s]

Quotes:   0%|          | 0/1593 [00:00<?, ?it/s]

Quotes:   0%|          | 0/1593 [00:00<?, ?it/s]

Quotes:   0%|          | 0/1593 [00:00<?, ?it/s]

Download files to computer

In [11]:
!zip -r /content/results.zip /content/results/

# from google.colab import files
# files.download("/content/results.zip")

updating: content/results/ (stored 0%)
updating: content/results/context8/ (stored 0%)
updating: content/results/context8/Llama 7b/ (stored 0%)
updating: content/results/context8/Llama 7b/Emma.csv (deflated 90%)
updating: content/results/context8/Llama 7b/PrideAndPrejudice.csv (deflated 90%)
updating: content/results/context4/ (stored 0%)
updating: content/results/context4/Llama 7b/ (stored 0%)
updating: content/results/context4/Llama 7b/Emma.csv (deflated 87%)
updating: content/results/context4/Llama 7b/PrideAndPrejudice.csv (deflated 87%)
updating: content/results/context1/ (stored 0%)
updating: content/results/context1/Llama 7b/ (stored 0%)
updating: content/results/context1/Llama 7b/Emma.csv (deflated 84%)
updating: content/results/context1/Llama 7b/PrideAndPrejudice.csv (deflated 84%)
updating: content/results/context2/ (stored 0%)
updating: content/results/context2/Llama 7b/ (stored 0%)
updating: content/results/context2/Llama 7b/Emma.csv (deflated 85%)
updating: content/results/