In [2]:
import torch
torch.backends.cuda.matmul.allow_tf32 = True
import random
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import time
import re
from tqdm import tqdm
from collections import Counter

In [15]:
from modeling_mistral import MistralForCausalLM

In [4]:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--batch_idx", type=int, default=0)
parser.add_argument("--baseline", action="store_true")
parser.add_argument("--device_batch_size", type=int, default=8)
parser.add_argument("--max_idx", type=int, default=128)
parser.add_argument("--n_votes", type=int, default=8)
parser.add_argument("--temp", type=float, default=0.9)
parser.add_argument("--start_final_answer_idx", type=int, default=384)
parser.add_argument("--answer_length", type=int, default=12)
parser.add_argument("--root_prefix", type=str, default="YOUR_ROOT_HERE")
parser.add_argument("--checkpoint", type=str, default="ezelikman/quietstar-8-ahead")
parser.add_argument("--final_answer_text", type=str, default="\nTherefore, the answer (arabic numerals) is")
parser.add_argument("--zero_shot_cot_prompt", type=str, default="\nA: Let's think step by step.")
parser.add_argument("--n_ahead", type=int, default=8)
args = parser.parse_args()
args

Namespace(batch_idx=0, baseline=False, device_batch_size=8, max_idx=128, n_votes=8, temp=0.9, start_final_answer_idx=384, answer_length=12, root_prefix='YOUR_ROOT_HERE', checkpoint='ezelikman/quietstar-8-ahead', final_answer_text='/home/wassname/.local/share/jupyter/runtime/kernel-v2-3161588ZqAxHVwx8UO2.json', zero_shot_cot_prompt="\nA: Let's think step by step.", n_ahead=8)

In [16]:
params = None
if params is None:
    params = {}
else:
    params = params.params
n_ahead = params.get("n_ahead", args.n_ahead if not args.baseline else 1)
n_ahead_talk = 1
use_start_thought_token = params.get("use_start_thought_token", True)
use_end_thought_token = params.get("use_end_thought_token", True)
include_policy_loss = params.get("include_policy_loss", True)
gumbel_detach = params.get("gumbel_detach", True)
merged_talk_heads = params.get("merged_talk_heads", True)
residual_think_head = params.get("residual_think_head", False)
optimize_lm_head_only_at_start = params.get("optimize_lm_head_only_at_start", False)
print("Loading model")
model = MistralForCausalLM.from_pretrained(
    args.checkpoint,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    device_map='auto',
    cache_dir=args.root_prefix + "cache",
    max_thoughts=n_ahead + n_ahead_talk + 1,
    merged_talk_heads=merged_talk_heads,
    merged_lm_and_talk_heads=False,
    merged_lm_and_think_heads=True,
    use_concat_talk_head=True,
    use_shallow_think=True,
    use_shallow_talk=False,
    use_complex_think_head=False,
    use_complex_talk_head=True,
    use_weighted_talk_head=True,
)
print("Loaded model")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.padding_side = "right"
tokenizer.pad_token_id = tokenizer.eos_token_id
special_tokens_to_add = []
if model.use_start_thought_token:
    special_tokens_to_add.append("<|startthought|>")
if model.use_end_thought_token:
    special_tokens_to_add.append("<|endthought|>")
if special_tokens_to_add:
    tokenizer.add_special_tokens({"additional_special_tokens": special_tokens_to_add})
    model.resize_token_embeddings(len(tokenizer))
model.tokenizer = tokenizer
model.gumbel_detach = gumbel_detach
model.include_policy_loss = include_policy_loss
model.use_end_thought_token = use_end_thought_token
model.use_start_thought_token = use_start_thought_token
model.n_ahead = n_ahead
model.n_ahead_talk = n_ahead_talk
model.n_passes = 1
model.residual_think_head = residual_think_head
if args.baseline:
    model.skip_residual = True
    model.cumulative_residual = False
    model.clever_residual = False
    model.base_residual = False
model.optimize_lm_head_only_at_start = optimize_lm_head_only_at_start
model.use_policy_loss = False
model.rm_initialized = True
model.first_run = False
model.wandb_enabled = False
model.config_params = params
model.run_start = int(time.time())
model.eval_mode = True
model.eval()


Loading model




Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded model


MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32002, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm(

Loaded model


tokenizer_config.json:   0%|          | 0.00/1.38k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/59.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


AttributeError: 'MistralForCausalLM' object has no attribute 'use_start_thought_token'