In [3]:
import numpy as np
import torch
import time
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

# model_name = "deepseek-ai/deepseek-math-7b-instruct"
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
# model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
model.generation_config = GenerationConfig.from_pretrained(model_name)
model.generation_config.pad_token_id = model.generation_config.eos_token_id

Loading checkpoint shards: 100%|██████████| 3/3 [00:24<00:00,  8.01s/it]


In [19]:
from util import get_prompt_message, extract_last_number

In [6]:
def generate_text_and_scores(model, tokenizer, question, num_fewshot):
    # messages = get_prompt_message(question, num_fewshot)

    messages = [{"role": "user", "content": question } ] # quick for testing

    input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
    outputs = model.generate(input_tensor.to(model.device), max_new_tokens=1000, return_dict_in_generate=True, output_scores=True)
    transition_scores = model.compute_transition_scores(
        outputs.sequences, outputs.scores, normalize_logits=True
    )
    print(outputs.sequences)

    input_length = input_tensor.shape[1]
    generated_tokens = outputs.sequences[:, input_length:]

    print(generated_tokens)
    for tok, score in zip(generated_tokens[0], transition_scores[0]):
        # | token | token string | log probability | probability
        print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.cpu().numpy()} | {np.exp(score.cpu().numpy()):.2%}")

    return outputs.sequences, generated_tokens[0], transition_scores[0]

def to_tokens_and_logprobs(model, tokenizer, input_ids):
    outputs = model(input_ids)
    probs = torch.log_softmax(outputs.logits, dim=-1).detach()

    # Adjust indices to ignore the first token's log prob as it corresponds to the second token
    probs = probs[:, :-1, :]
    input_ids = input_ids[:, 1:]
    gen_probs = torch.gather(probs, 2, input_ids[:, :, None]).squeeze(-1)

    batch = []
    for input_sentence, input_probs in zip(input_ids, gen_probs):
        text_sequence = []
        for token, p in zip(input_sentence, input_probs):
            if token not in tokenizer.all_special_ids:
                text_sequence.append((tokenizer.decode(token), p.item()))
        batch.append(text_sequence)
    return batch

In [58]:
model.eval()

gpt35_df = pd.read_csv('../conditional/data/112_gsm8k_gpt35_cot_onesent_responses.csv')

total_start = time.time()
for idx, row in gpt35_df[:1].iterrows():
    start_time = time.time()
    question = row['Question']
    total_tokens, generated_tokens, generated_logprobs = generate_text_and_scores(model, tokenizer, question, num_fewshot=4)
    
    print("Generated Text Log Probabilities:")
    for i, (token, score) in enumerate(zip(generated_tokens, generated_logprobs)):
        print(f"Token: {token}, Generated Log Prob: {score.item()}")

    # last_segment = "".join([tokenizer.decode(tok) for tok in generated_tokens[-10]])
    last_segment = "".join([tokenizer.decode(tok.item()) for tok in generated_tokens[-10:]])
    print(last_segment)
    last_integer = extract_last_number(last_segment)

    ans_declaration_toks = torch.Tensor(tokenizer(" The answer is " + str(last_integer))['input_ids'][1:]) # not adding '.' token or [2] EOS token here
    ans_declaration_toks = ans_declaration_toks.to(torch.int32).to(model.device)
    total_tokens = torch.cat((total_tokens[:, :-1], ans_declaration_toks.unsqueeze(0)), dim=1)
    log_probs = to_tokens_and_logprobs(model, tokenizer, total_tokens) 

    print("\nCalculated Log Probabilities from Concatenated Text:")
    for sequence in log_probs:
        for token, log_prob in sequence:
            print(f"Token: {token}, Calculated Log Prob: {log_prob}") 
    
    print("iteration time (s): ", time.time() - start_time)

tensor([[    1,   733, 16289, 28793,  1186, 28747,  1387,   460, 28705, 28740,
         28782,  7099,   297,   272,  5977,   333, 28723,  8697,   333,  7433,
           622,  5100,  7099,   297,   272,  5977,   333,  3154, 28723,  2530,
           590,   460,  2203, 28725,   736,   622,   347, 28705, 28750, 28740,
          7099, 28723,  1602,  1287,  7099,   863,   272,  5977,   333,  7433,
          5100,  3154, 28804,    13, 28741, 28747,   816,  1149,   395, 28705,
         28740, 28782,  7099, 28723, 11680,   478,   506, 28705, 28750, 28740,
          7099, 28723,   415,  5133,  1580,   347,   272,  1474,   302,  7099,
           590, 24571, 28723,  1537, 28725,   590,  1580,   506, 24571, 28705,
         28750, 28740,   387, 28705, 28740, 28782,   327, 28705, 28784,  7099,
         28723,   415,  4372,   349, 28705, 28784,    13, 28824, 28747,  1047,
           736,   460, 28705, 28770,  8300,   297,   272, 12128,  2055,   304,
         28705, 28750,   680,  8300, 12688, 28725,  

Looping

Test Ground

In [20]:
question = row['Question']
total_tokens, generated_tokens, generated_logprobs = generate_text_and_scores(model, tokenizer, question, num_fewshot=4)

tensor([[    1,   733, 16289, 28793,  1186, 28747,  1387,   460, 28705, 28740,
         28782,  7099,   297,   272,  5977,   333, 28723,  8697,   333,  7433,
           622,  5100,  7099,   297,   272,  5977,   333,  3154, 28723,  2530,
           590,   460,  2203, 28725,   736,   622,   347, 28705, 28750, 28740,
          7099, 28723,  1602,  1287,  7099,   863,   272,  5977,   333,  7433,
          5100,  3154, 28804,    13, 28741, 28747,   816,  1149,   395, 28705,
         28740, 28782,  7099, 28723, 11680,   478,   506, 28705, 28750, 28740,
          7099, 28723,   415,  5133,  1580,   347,   272,  1474,   302,  7099,
           590, 24571, 28723,  1537, 28725,   590,  1580,   506, 24571, 28705,
         28750, 28740,   387, 28705, 28740, 28782,   327, 28705, 28784,  7099,
         28723,   415,  4372,   349, 28705, 28784,    13, 28824, 28747,  1047,
           736,   460, 28705, 28770,  8300,   297,   272, 12128,  2055,   304,
         28705, 28750,   680,  8300, 12688, 28725,  

In [21]:
print("Generated Text Log Probabilities:")
for i, (token, score) in enumerate(zip(generated_tokens, generated_logprobs)):
    print(f"Token: {token}, Generated Log Prob: {score.item()}")

Generated Text Log Probabilities:
Token: 560, Generated Log Prob: -0.2814345955848694
Token: 3999, Generated Log Prob: -0.0003399271226953715
Token: 28725, Generated Log Prob: -1.1920928244535389e-07
Token: 23675, Generated Log Prob: -2.3245540432981215e-05
Token: 515, Generated Log Prob: -1.1920922133867862e-06
Token: 3910, Generated Log Prob: -2.062299427052494e-05
Token: 533, Generated Log Prob: -0.0143471360206604
Token: 2430, Generated Log Prob: -1.1920922133867862e-06
Token: 298, Generated Log Prob: -7.033323527139146e-06
Token: 28705, Generated Log Prob: 0.0
Token: 28781, Generated Log Prob: -4.768370445162873e-07
Token: 28783, Generated Log Prob: 0.0
Token: 3282, Generated Log Prob: -0.0015118608716875315
Token: 28723, Generated Log Prob: -0.016050538048148155
Token: 560, Generated Log Prob: -0.023763837292790413
Token: 2246, Generated Log Prob: -5.9602869441732764e-05
Token: 28725, Generated Log Prob: -2.622600959512056e-06
Token: 630, Generated Log Prob: -3.4689302992774174e-

In [23]:
# last_segment = "".join([tokenizer.decode(tok) for tok in generated_tokens[-10]])
# Ensure generated_tokens is a list of tensors before decoding
if isinstance(generated_tokens[-10], torch.Tensor):
    last_segment = "".join([tokenizer.decode(tok) for tok in generated_tokens[-10:]])
else:
    last_segment = "".join([tokenizer.decode(tok.item()) for tok in generated_tokens[-10:]])

last_segment

'ips.Theansweris72.</s>'

In [34]:
last_integer = extract_last_number(last_segment)
last_integer

72

In [55]:
ans_declaration_toks = torch.Tensor(tokenizer(" The answer is " + str(last_integer) + ".")['input_ids'][1:]) # not adding [2] EOS token here
ans_declaration_toks = ans_declaration_toks.to(torch.int32).to(model.device)
ans_declaration_toks.shape

torch.Size([8])

In [56]:
# total_tokens.shape
total_tokens[:, :-1]
extended_tokens = torch.cat((total_tokens[:, :-1], ans_declaration_toks.unsqueeze(0)), dim=1)
extended_tokens


tensor([[    1,   733, 16289, 28793,  1186, 28747,  1387,   460, 28705, 28740,
         28782,  7099,   297,   272,  5977,   333, 28723,  8697,   333,  7433,
           622,  5100,  7099,   297,   272,  5977,   333,  3154, 28723,  2530,
           590,   460,  2203, 28725,   736,   622,   347, 28705, 28750, 28740,
          7099, 28723,  1602,  1287,  7099,   863,   272,  5977,   333,  7433,
          5100,  3154, 28804,    13, 28741, 28747,   816,  1149,   395, 28705,
         28740, 28782,  7099, 28723, 11680,   478,   506, 28705, 28750, 28740,
          7099, 28723,   415,  5133,  1580,   347,   272,  1474,   302,  7099,
           590, 24571, 28723,  1537, 28725,   590,  1580,   506, 24571, 28705,
         28750, 28740,   387, 28705, 28740, 28782,   327, 28705, 28784,  7099,
         28723,   415,  4372,   349, 28705, 28784,    13, 28824, 28747,  1047,
           736,   460, 28705, 28770,  8300,   297,   272, 12128,  2055,   304,
         28705, 28750,   680,  8300, 12688, 28725,  