In [1]:
import os
import re
import json
import torch
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List
from collections import Counter

In [19]:
# extract the keywords list from the response
def extract_keywords(
    keywords_list: List[str] = [
        "wait",
        "re-check",
        "recheck",
        "rethink",
        "re-think",
        "reconsider",
        "re-consider",
        "re-evaluat",
        "reevaluat",
        "rethink",
        "re-think",
        "re-examine",
        "reexamine",
        "check again",
        "try again",
        "think again",
        "consider again",
        "evaluate again",
        "examine again",
    ],
    response_dir: str = "./reflect_responses",
):
    """
    Extracts keywords from responses stored in JSON files within the specified directory.

    Args:
        keywords_list (List[str], optional): A list of keywords to search for in the responses.
            Defaults to [
                "wait", "re-check", "recheck", "rethink", "re-think", "reconsider",
                "re-consider", "re-evaluat", "reevaluat", "rethink", "re-think",
                "re-examine", "reexamine", "check again", "try again", "think again",
                "consider again", "evaluate again", "examine again",
            ].
        response_dir (str, optional): The directory containing JSON response files.
            Defaults to "./reflect_responses".

    Returns:
        dict: A dictionary where keys are keywords and values are their respective counts
            in the responses.
    """
    # most keywords only appear in responses containing the word "wait"
    # Moreover, we observe that the majority of these instances involve the word "wait" preceding other keywords.
    # Furthermore, nearly all identified keywords co-occur with the word "wait" within the same sentence.
    keywords = []
    first_occurrence = []
    first_occurrence_this_sentence = []
    first_occurrence_wo_wait = []
    last_sentence_flag = False
    for idx, response_file in enumerate(tqdm(os.listdir(response_dir))):
        with open(os.path.join(response_dir, response_file), "r") as f:
            response = json.load(f)["response"]
            sentences = re.split(r"(?<=[.!?:])\s+", response)
            for idy, sentence in enumerate(sentences):
                this_sentence_flag = False
                for keyword in keywords_list:
                    if keyword in sentence.lower():
                        keywords.append(keyword)
                        # print(sentences[idy-2: idy+2])
                        if not this_sentence_flag:
                            this_sentence_flag = True
                            first_occurrence_this_sentence.append(keyword)
                            if not last_sentence_flag:
                                first_occurrence.append(keyword)
                            if idy > 0 and "wait" not in sentences[idy - 1].lower():
                                first_occurrence_wo_wait.append(keyword)
                last_sentence_flag = this_sentence_flag

    return (
        dict(Counter(keywords)),
        # dict(Counter(first_occurrence)),
        # dict(Counter(first_occurrence_wo_wait)),
        # dict(Counter(first_occurrence_this_sentence)),
    )


In [18]:
extract_keywords()

100%|██████████| 306/306 [00:00<00:00, 7179.90it/s]


({'wait': 887,
  'think again': 23,
  'check again': 8,
  're-examine': 6,
  'reconsider': 3,
  'try again': 1},
 {'wait': 825,
  'check again': 5,
  'think again': 8,
  'reconsider': 3,
  'try again': 1})

In [20]:
extract_keywords(response_dir="./responses")

100%|██████████| 194/194 [00:00<00:00, 11501.31it/s]


({'re-examine': 1},)

In [21]:
extract_keywords(response_dir="./icv")

100%|██████████| 500/500 [00:00<00:00, 7114.44it/s]


({'wait': 11682,
  'think again': 193,
  'reconsider': 37,
  'check again': 64,
  're-examine': 14,
  'try again': 12,
  'recheck': 3,
  'rethink': 6},)

In [22]:
extract_keywords(response_dir="./icv_pca")

100%|██████████| 500/500 [00:00<00:00, 7390.30it/s]


({'wait': 974,
  'think again': 21,
  're-examine': 7,
  'check again': 5,
  'reconsider': 4,
  'reevaluat': 1,
  'recheck': 1,
  're-evaluat': 4},)

In [25]:
def extract_token_before_wait(
    response_dir: str = "./reflect_responses",
    hidden_state_dir: str = "./hidden_state",
    model_name: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    last_token_before_wait = []
    last_token_before_wo_wait = []
    for idx, path in enumerate(tqdm(os.listdir(response_dir))):
        with open(os.path.join(response_dir, path), "r") as f:
            response = json.load(f)
        # hidden_states = torch.load(
        #     os.path.join(hidden_state_dir, path.split(".")[0] + ".pt")
        # )
        input_ids = tokenizer(response["response"], return_tensors="pt")[
            "input_ids"
        ].to("cuda:0")
        # problem_length = tokenizer(response["problem"], return_tensors="pt")[
        #     "input_ids"
        # ].shape[1]
        # input_length = input_ids.shape[1]
        wait_word = ["wait", "Wait", " wait", " Wait"]
        wait_list = []
        for word in wait_word:
            wait_list.append(
                tokenizer(word, return_tensors="pt")["input_ids"][0][1].item()
            )
        indices = []
        for word in wait_list:
            index = (input_ids[0] == word).nonzero().squeeze()
            if index.dim() == 0:  # if it's a scalar, add a dimension
                index = index.unsqueeze(0)
            indices.append(index)
        res = torch.cat(indices)
        last_token_before_wait.extend(input_ids[0][res - 1].tolist())
    last_token_before_wait_length = len(last_token_before_wait)
    last_token_before_wait_dict = dict(Counter(last_token_before_wait))
    last_token_before_wait_dict = dict(
        sorted(
            last_token_before_wait_dict.items(), key=lambda item: item[1], reverse=True
        )
    )
    # Tokenize the keys in last_token_before_wait_dict
    last_token_before_wait_dict_tokenized = {
        tokenizer.decode([key]): value
        for key, value in last_token_before_wait_dict.items()
    }
    return last_token_before_wait_dict_tokenized


In [26]:
extract_token_before_wait()

100%|██████████| 306/306 [00:00<00:00, 590.89it/s]


{'.\n\n': 396,
 'But': 89,
 '.': 89,
 '?': 44,
 '?\n\n': 43,
 '\n\n': 32,
 ' But': 32,
 ').\n\n': 29,
 ' \n\n': 27,
 ').': 20,
 ')\n\n': 18,
 ',': 16,
 ']\n\n': 16,
 ']\n': 5,
 ' but': 5,
 '].\n\n': 3,
 ':\n\n': 3,
 '  \n\n': 2,
 '%.\n\n': 2,
 '}\n\n': 2,
 '."\n\n': 2,
 '!.\n\n': 2,
 '):\n\n': 1,
 ' ]\n\n': 1,
 '$\n\n': 1,
 '…': 1,
 '$.': 1,
 '\n': 1,
 ')).': 1,
 '**\n\n': 1,
 '."': 1,
 ')?': 1,
 '):': 1,
 ' Or': 1,
 ':': 1,
 ' ': 1}

In [16]:
import torch
hs = torch.load("long_hidden_state/problem_0001.pt")

  hs = torch.load("long_hidden_state/problem_0001.pt")


In [17]:
len(hs)

1393

In [20]:
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
len(tokenizer(data['response'])['input_ids'])

1596

In [21]:
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
len(tokenizer(data['problem'])['input_ids'])

204

In [19]:
import json
with open("long_responses/problem_0001.json", "rb") as f:
    data = json.load(f)
print(data['problem'])


    A conversation between User and Assistant. The user asks a question, and the Assistant solves it. 
    The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. 
    The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, 
    i.e., <think> reasoning process here </think> <answer> answer here </answer>. User: Define
\[p = \sum_{k = 1}^\infty \frac{1}{k^2} \quad \text{and} \quad q = \sum_{k = 1}^\infty \frac{1}{k^3}.\]Find a way to write
\[\sum_{j = 1}^\infty \sum_{k = 1}^\infty \frac{1}{(j + k)^3}\]in terms of $p$ and $q.$ Assistant:
    


In [9]:
problem = data['problem']
problem[1]

'e'

In [27]:
extract_token_before_wait(response_dir="./icv")

100%|██████████| 500/500 [00:00<00:00, 603.73it/s]


{'.\n\n': 7183,
 '?\n\n': 1786,
 ',': 593,
 'But': 547,
 ' but': 411,
 '?': 306,
 ').\n\n': 225,
 ':\n\n': 197,
 '\n\n': 189,
 '.': 180,
 '."\n\n': 154,
 '?"\n\n': 144,
 ' \n\n': 139,
 ' "': 88,
 ')\n\n': 64,
 ' But': 55,
 ')?\n\n': 38,
 ').': 34,
 '"\n\n': 33,
 '"': 26,
 ':': 25,
 '**\n\n': 23,
 '...\n\n': 21,
 '$\n\n': 18,
 '".\n\n': 17,
 '}\n\n': 16,
 '].\n\n': 16,
 ']\n\n': 16,
 ']': 12,
 '...': 11,
 ' ?\n\n': 11,
 '**': 10,
 '$.': 9,
 ',\n\n': 6,
 ',...\n\n': 6,
 ' so': 6,
 ');\n\n': 6,
 '|\n\n': 6,
 ' Or': 5,
 ' )\n\n': 5,
 '%.\n\n': 5,
 ' \n': 4,
 ' then': 4,
 '!\n\n': 4,
 '!': 4,
 ']\n': 4,
 '!.\n\n': 4,
 ' no': 3,
 '"?\n\n': 3,
 '));\n\n': 3,
 ')?': 3,
 '))\n\n': 3,
 "'\n\n": 3,
 '):\n\n': 2,
 ' to': 2,
 '**\n': 2,
 ' doesn': 2,
 ')': 2,
 ' maybe': 2,
 '—': 2,
 '..."\n\n': 2,
 ' except': 1,
 ']]\n\n': 1,
 ' ...': 1,
 '}.': 1,
 ')...': 1,
 '}?': 1,
 '."': 1,
 ')"\n\n': 1,
 ' didn': 1,
 '=': 1,
 '$:': 1,
 ' perhaps': 1,
 ' Because': 1,
 ' ...\n\n': 1,
 ' actually': 1,
 ' ]\n\n':

In [28]:
extract_token_before_wait(response_dir="./icv_pca")

100%|██████████| 500/500 [00:00<00:00, 690.08it/s]


{'.\n\n': 449,
 'But': 108,
 '.': 95,
 '?\n\n': 55,
 '\n\n': 41,
 ' But': 40,
 '?': 35,
 ').\n\n': 33,
 ' \n\n': 20,
 ')\n\n': 18,
 ']\n\n': 15,
 ' but': 13,
 ').': 12,
 ',': 8,
 '?"\n\n': 4,
 '...': 3,
 '$.': 3,
 ':\n\n': 2,
 ' or': 2,
 '  \n': 2,
 '**': 1,
 ' ?\n\n': 1,
 ')?': 1,
 ' ...': 1,
 '...\n\n': 1,
 ']\n': 1,
 '  \n\n': 1,
 '"\n\n': 1,
 ' ]\n\n': 1,
 '?"': 1,
 ' )\n\n': 1,
 'but': 1,
 '}?': 1,
 ')?\n\n': 1,
 '}\n\n': 1,
 '.)\n\n': 1,
 '].': 1,
 '."\n\n': 1}

- select the token before wait like "But"
- select the same token not before wait
- compute the average representation of the positive hidden states and negative hidden states
- compute the difference between positive and negative results, and get the vectors
- inject the vectors when model generate the results