In [46]:
from pathlib import Path
import os
import sys
import gc
import re
import shutil
import json
import math
import jinja2
from collections import defaultdict
import numpy as np
import pandas as pd
import bitsandbytes
import accelerate
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from tqdm import tqdm
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, NamedTuple, Callable, Iterable, Set, Optional, Any
import scml
print(f"accelerate={accelerate.__version__}, bitsandbytes={bitsandbytes.__version__}")

accelerate=0.30.1, bitsandbytes=0.43.1


In [23]:
model_dir = Path("huggingface/mistralai/Mixtral-8x7B-Instruct-v0.1")
model_max_length = 4096
seed = 422
environment = jinja2.Environment()

In [3]:
tim = scml.Timer()
tim.start()
percentiles=[.01, .05, .1, .2, .3, .4, .5, .6, .7, .8, .9, .95, .99]
os.environ["TOKENIZERS_PARALLELISM"] = "false"
pd.set_option("max_info_columns", 9999)
pd.set_option("display.max_columns", 9999)
pd.set_option("display.max_rows", 9999)
pd.set_option('max_colwidth', 9999)
tqdm.pandas()
scml.seed_everything(seed)

In [4]:
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
    for i in range(torch.cuda.device_count()):
        print(f"device={i}, {torch.cuda.get_device_name(i)}")
        print('Mem Allocated:', round(torch.cuda.memory_allocated(i)/1024**3,1), 'GB')
        print('Mem Cached:   ', round(torch.cuda.memory_reserved(i)/1024**3,1), 'GB')
else:
    print("cpu")

device=0, NVIDIA GeForce RTX 4070 Ti SUPER
Mem Allocated: 0.0 GB
Mem Cached:    0.0 GB
device=1, NVIDIA GeForce RTX 4070 Ti SUPER
Mem Allocated: 0.0 GB
Mem Cached:    0.0 GB


In [9]:
df = pd.read_csv("input/persuade20/persuade_2.0_human_scores_demo_id_github.csv", low_memory=False)
df["full_text_len"] = df["full_text"].str.len()
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 25996 entries, 0 to 25995
Data columns (total 15 columns):
 #   Column                      Non-Null Count  Dtype  
---  ------                      --------------  -----  
 0   essay_id_comp               25996 non-null  object 
 1   full_text                   25996 non-null  object 
 2   holistic_essay_score        25996 non-null  int64  
 3   word_count                  25996 non-null  int64  
 4   prompt_name                 25996 non-null  object 
 5   task                        25996 non-null  object 
 6   assignment                  25996 non-null  object 
 7   source_text                 12875 non-null  object 
 8   gender                      25996 non-null  object 
 9   grade_level                 24828 non-null  float64
 10  ell_status                  24787 non-null  object 
 11  race_ethnicity              25996 non-null  object 
 12  economically_disadvantaged  20759 non-null  object 
 13  student_disability_status   208

In [33]:
title_to_prompt = {}
title_to_exemplars = {}
for t in df.itertuples():
    prompt_name = str(getattr(t, "prompt_name"))
    assignment = str(getattr(t, "assignment"))
    title_to_prompt[prompt_name] = assignment
for title in title_to_prompt.keys():
    mask = (df["prompt_name"]==title) & (df["holistic_essay_score"]==6)
    tmp = df[mask].sort_values(["full_text_len"], ascending=[True])
    title_to_exemplars[title] = tmp["full_text"].tolist()
print(json.dumps(title_to_prompt, indent=2))

{
  "Phones and driving": "Today the majority of humans own and operate cell phones on a daily basis. In essay form, explain if drivers should or should not be able to use cell phones in any capacity while operating a vehicle.",
  "Car-free cities": "Write an explanatory essay to inform fellow citizens about the advantages of limiting car usage. Your essay must be based on ideas and information that can be found in the passage set. Manage your time carefully so that you can read the passages; plan your response; write your response; and revise and edit your response. Be sure to use evidence from multiple sources; and avoid overly relying on one source. Your response should be in the form of a multiparagraph essay. Write your essay in the space provided.",
  "Summer projects": "Some schools require students to complete summer projects to assure they continue learning during their break. Should these summer projects be teacher-designed or student-designed? Take a position on this questio

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_dir, model_max_length=model_max_length)
tokenizer.pad_token = tokenizer.eos_token
print(f"{repr(tokenizer)}\nmodel_input_names={tokenizer.model_input_names}")
print(f"pad_token_id={tokenizer.pad_token_id}")

LlamaTokenizerFast(name_or_path='huggingface/mistralai/Mixtral-8x7B-Instruct-v0.1', vocab_size=32000, model_max_length=4096, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '</s>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
model_input_names=['input_ids', 'attention_mask']
pad_token_id=2


In [6]:
%%time
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
    str(model_dir),
    device_map = "auto",
    quantization_config=quantization_config,
)

Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

CPU times: user 4min 4s, sys: 3min 33s, total: 7min 38s
Wall time: 14min 12s


In [7]:
print(model)

MixtralForCausalLM(
  (model): MixtralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MixtralDecoderLayer(
        (self_attn): MixtralSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): MixtralSparseMoeBlock(
          (gate): Linear4bit(in_features=4096, out_features=8, bias=False)
          (experts): ModuleList(
            (0-7): 8 x MixtralBlockSparseTop2MLP(
              (w1): Linear4bit(in_features=4096, out_features=14336, bias=False)
              (w2): Linear4bit(in_features=14336, out_features=4096, bias=False)
              (w3): Linear4bit(in_features=4096, out_

In [40]:
def generate(prompt: str) -> str:
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = model.generate(
        **inputs, 
        max_new_tokens=128,
        do_sample=True,
        temperature=1.0,
        top_p=0.95,
        top_k=100,
        repetition_penalty=1.1,
        pad_token_id=tokenizer.eos_token_id,
    )
    generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    generated_text = generated_text.split("[/INST]")[1]
    generated_text = generated_text.strip()
    return generated_text

In [44]:
template = environment.from_string(
"""[INST] You are a teacher setting an essay question for students in Grade 6-12.
Given the student's essay, output the essay question between opening tag [QUESTION] and closing tag [/QUESTION].
Output a short title of the essay question between opening tag [TITLE] and closing tag [/TITLE].
[ESSAY]{{ essay1 }}[/ESSAY]
[QUESTION]{{ question1 }}[/QUESTION]
[TITLE]{{ title1 }}[/TITLE]
[ESSAY]{{ essay2 }}[/ESSAY]
[QUESTION]{{ question2 }}[/QUESTION]
[TITLE]{{ title2 }}[/TITLE]
[ESSAY]{{ essay0 }}[/ESSAY]
[QUESTION][/QUESTION]
[TITLE][/TITLE][/INST]"""
)

In [45]:
%%time
prompt = template.render(
    essay0=title_to_exemplars["Cell phones at school"][0],
    essay1=title_to_exemplars["Exploring Venus"][0],
    question1=title_to_prompt["Exploring Venus"],
    title1="Exploring Venus",
    essay2=title_to_exemplars["Does the electoral college work?"][0],
    question2=title_to_prompt["Does the electoral college work?"],
    title2="Does the electoral college work?",
)
output = generate(prompt)
question = ""
title = ""
print(f"{prompt}\n\n=====  END OF PROMPT  =====\n\n{output}")

[INST] You are a teacher setting an essay question for students in Grade 6-12.
Given the student's essay, output the essay question between opening tag [QUESTION] and closing tag [/QUESTION].
Output a short title of the essay question between opening tag [TITLE] and closing tag [/TITLE].
[ESSAY]The author supports his arguement quite effectively, though not without any flaws. The author effectively lists off the benefits of exploring Venus, and proper proposals for doing so, however the complexity and unrealistic assumptions of the solutions as well as the negative descripton of Venus take away from its point rather than benefiting it.

The effective conveying of benefits helps to support its idea of risking danger for the reward of landing. The author describes how Venus is relatable with earth which creates a connection between the reader and Venus, as well as creating a mood of hospitality and familiarity. It also states that because Venus is the "nearest option for a planetary visi

In [57]:
pattern = re.compile(r"\[QUESTION\](.+)\[/QUESTION\]", re.IGNORECASE)
for m in pattern.finditer(output):
    if m[0] == "":
        continue
    question = m[1]
    break
pattern = re.compile(r"\[TITLE\](.+)\[/TITLE\]", re.IGNORECASE)
for m in pattern.finditer(output):
    if m[0] == "":
        continue
    title = m[1]
    break
print(f"title={title}\nquestion={question}")

title=The Use of Phones in School
question=Should students be allowed to use their phones during school hours? In an argumentative essay, use the excerpts from student essays to support your stance. Consider counterarguments and provide evidence to refute them.


In [32]:
tim.stop()
print(f"Total time taken {str(tim.elapsed)}")

RuntimeError: Not started