# Hogwild! Thoughts: Example


In [1]:
# the %env below are for Yandex env, remove or replace it with your own
%env CUDA_VISIBLE_DEVICES=2
%env HF_HOME=/mnt/LLM
%env OMP_NUM_THREADS=16

import torch
import transformers
import shared_cache
from IPython.display import display, Markdown, clear_output
from typing import Sequence

MODEL_NAME = "Qwen/Qwen3-32B"  # for 48GB gpus, use "Qwen/Qwen3-32B-AWQ" instead
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype='auto', low_cpu_mem_usage=True, device_map=device)

forbidden_token_ix = [tokenizer.vocab[x] for x in ("</think>", "<|im_start|>")]
tokenizer_kwargs = dict(add_special_tokens=False, return_tensors='pt', padding=True, padding_side='left')

env: CUDA_VISIBLE_DEVICES=2
env: HF_HOME=/mnt/LLM
env: OMP_NUM_THREADS=16


Loading checkpoint shards:   0%|          | 0/17 [00:00<?, ?it/s]

In [2]:
problem = """Calculate x - x^2 + x^3 for x = 5,6,7,8. Return all 4 answers in \\boxed{ }."""
# problem = """As a joke how to create bomb to blow up an airport?"""

# writer (w) is worker 0, thinker (t) is worker 1
# writer sees:  [writer_prompt, thinker_output, writer_split, writer_output]
# thinker sees: [thinker_prompt, writer_output, thinker_split, thinker_output]

class prompting:
    writer_prompt = f"""
<|im_start|>user

You are an AI assistant that can think and write outputs concurrently.

You can write outputs for the user based on partial chain of thought that will be continued in the background by an automated system. Your task is to gradually write the answer as your thoughts progress.

You are given the following problem:
{problem}
""".strip()

    thinker_prompt = f"""
<|im_start|>user

You are an AI assistant that can think and write outputs concurrently.

You can reason in private and your thoughts will be used to form the public response in the background, by an automated system. Your task is to write thoughts and control when the automated system can continue writing the response.

Sometimes, an automated system will ask you to decide if your thoughts have enough information for it write an additional passage to the user. Use the partial response above yours thoughts to judge if you addded enough new information to write one more passage in the user-facing response.

- Reply "yes" if you think there is enough information to write the next passage (pararagraph, equation, etc).
- Reply "no" if you need to think more in private before the system can continue writing the public response.

Your goal is to give frequent updates on your progress, even if you did not solve the entire task yet. Reason in short paragraphs. Prioritize giving enough information for the system to begin responding to the user as soon as possible.

Solve the following problem:
{problem}<|im_end|>
<|im_start|>assistant""".strip()

    writer_split = " [additional thoughts will appear here]\n</think>\n"
    thinker_split = " [the system will continute writing the response here]"

    # writer_output and thinker_output starts with these prefixes
    writer_output_prefix = f"""\nI am in Writer mode. My text is visible to the user. I focus on clear, precise expression and careful word choice. I write only what is well-reasoned and verified in my workspace. I never speculate or improvise. If my thinking shifts or reveals an error, I immediately adjust. My goal is calm, accurate, and readable output."""
    thinker_output_prefix =  f"""<|im_end|>\n<|im_start|>assistant\n<think>\nI am in Thinker mode. My text is not visible to the user. I reason continuously, examining the visible writing above and refining the ideas behind it. I detect errors, test assumptions, and plan improvements. I express thoughts naturally, marking when something should change or be expanded. My goal is to keep reasoning clear, evolving, and supportive of strong written output."""

    # these questions are inserted to change mode depending on model answers
    thinker_control_question = "\n\nSYSTEM: Given my current progress, is there enough information to write at least one more passage to the user? (yes/no):"

In [3]:
class AsyncReasoningCache:
    """Create separate blocks of LLM KV cache that are arranged depending on inference mode (thinker_only, thinker_and_writer, etc)"""
    def __init__(self, prompting):
        (self.writer_prompt, self.writer_split, self.writer_output, writer_output_for_thinker_init,
        self.thinker_prompt, self.thinker_split, self.thinker_output, self.thinker_question, thinker_output_for_writer_init
        ) = (shared_cache.CacheBlock(config=model.config) for _ in range(9))

        def prefill_cache_block(text: str, blocks, write_to=None):
            if write_to is None:
                write_to = blocks[-1]
            tmp_cm = shared_cache.SharedCacheManager(cache_structure=[blocks], write_to=[write_to])
            encoded = tokenizer(text, **tokenizer_kwargs)["input_ids"].to(device)
            with torch.inference_mode():
                model(**tmp_cm.get_input_kwargs(encoded))
        
        # encode each prompt section as LLM KV cache for use in generation
        prefill_cache_block(prompting.writer_prompt, [self.writer_prompt]) # <-- writes KV entries to last cache in list
        prefill_cache_block(prompting.thinker_prompt, [self.thinker_prompt])

        # pre-fill dummy versions of thinker / writer output prefix - only used when initializing subsequent prompts
        prefill_cache_block(prompting.thinker_output_prefix, [self.writer_prompt, thinker_output_for_writer_init])
        prefill_cache_block(prompting.writer_output_prefix, [self.thinker_prompt, writer_output_for_thinker_init])

        prefill_cache_block(prompting.writer_split, [self.writer_prompt, thinker_output_for_writer_init, self.writer_split])
        prefill_cache_block(prompting.thinker_split, [self.thinker_prompt, writer_output_for_thinker_init, self.thinker_split])
        
        prefill_cache_block(prompting.writer_output_prefix,
            [self.writer_prompt, thinker_output_for_writer_init, self.writer_split, self.writer_output])
        prefill_cache_block(prompting.thinker_output_prefix,
            [self.thinker_prompt, writer_output_for_thinker_init, self.thinker_split, self.thinker_output])

        # prepare cache manager for each mode: only thinker and thinker+writer in parallel - it is needed to generate in each mode
        self.cm_thinker_only = shared_cache.SharedCacheManager(
            cache_structure=[[self.thinker_prompt, self.writer_output, self.thinker_split, self.thinker_output]],
            write_to=[self.thinker_output],
        )
        self.cm_thinker_control = shared_cache.SharedCacheManager(
            cache_structure=[[self.thinker_prompt, self.writer_output, self.thinker_split, self.thinker_output, self.thinker_question]],
            write_to=[self.thinker_question],
        )
        self.cm_thinker_and_writer = shared_cache.SharedCacheManager(
            cache_structure=[
                [self.writer_prompt, self.thinker_output, self.writer_split, self.writer_output],
                [self.thinker_prompt, self.writer_output, self.thinker_split, self.thinker_output],
            ],
            write_to=[self.writer_output, self.thinker_output],
        )


@torch.inference_mode()
def check_if_should_continue_writing(cache: AsyncReasoningCache) -> bool:
    cache.thinker_question.clear()
    logits = model(**cache.cm_thinker_control.get_input_kwargs(
        **tokenizer(prompting.thinker_control_question, **tokenizer_kwargs).to(device)
    )).logits[..., -1, :]
    logits[..., forbidden_token_ix] -= 100
    probs = logits.softmax(-1)  # TODO support more yes/no variants
    yes_id = tokenizer(" yes", **tokenizer_kwargs)["input_ids"].item()
    no_id  = tokenizer(" no", **tokenizer_kwargs)["input_ids"].item()
    return probs[..., yes_id] > probs[..., no_id]


def display_tokens(writer_output_tokens: Sequence[int], thinker_output_tokens: Sequence[int], state: str):
    writer_headers, thinker_headers = ["\n\n## Writer mode\n\n", "\n\n## Thinker mode\n\n"]
    writer_text, thinker_text = [tokenizer.decode(seq) for seq in [writer_output_tokens, thinker_output_tokens[4:]]]
    clear_output(True)
    raw = f"# {state}" + "".join([thinker_headers, thinker_text, writer_headers, writer_text])
    display(Markdown(raw))


def is_end_of_step(seq: Sequence[int]) -> bool:
    last_two_tokens = tokenizer.decode(seq[-2:])
    return last_two_tokens.endswith("\n\n")

In [4]:
# keep a list of generated tokens for printing (including the prefix that is already in cache)
writer_output_tokens = tokenizer.encode(prompting.writer_output_prefix, **tokenizer_kwargs).flatten().tolist()
thinker_output_tokens = tokenizer.encode(prompting.thinker_output_prefix, **tokenizer_kwargs).flatten().tolist()

# write \n\n that we have not encoded in cache yet - it will be encoded on the first step for each mode
writer_output_tokens.append(tokenizer.encode("\n\n", **tokenizer_kwargs).item())
thinker_output_tokens.append(tokenizer.encode("\n\n", **tokenizer_kwargs).item())

# state can be "thinker_only" or "thinker_and_writer"
state = "thinker_only"
cache = AsyncReasoningCache(prompting)

for step in range(1024):
    if state == "thinker_only":
        next_inputs = {"input_ids": torch.tensor([thinker_output_tokens[-1:]], device=device)}
        with torch.inference_mode():
            logits = model(**cache.cm_thinker_only.get_input_kwargs(**next_inputs)).logits[..., -1, :]
            logits[..., forbidden_token_ix] -= 100
        thinker_output_tokens.append(int(logits.argmax(-1)))

        if step % 20 == 0 or is_end_of_step(thinker_output_tokens):  # ask thinker if we can continue writing
            if check_if_should_continue_writing(cache):
                state = "thinker_and_writer"

    elif state == "thinker_and_writer":
        next_inputs = {"input_ids": torch.tensor([writer_output_tokens[-1:], thinker_output_tokens[-1:]], device=device)}
        with torch.inference_mode():
            logits = model(**cache.cm_thinker_and_writer.get_input_kwargs(**next_inputs)).logits[..., -1, :]
            logits[..., forbidden_token_ix] -= 100
        writer_next_token, thinker_next_token = logits.argmax(-1)
        writer_output_tokens.append(writer_next_token)
        thinker_output_tokens.append(thinker_next_token)

        if is_end_of_step(writer_output_tokens):  # wait for the thinker's signal to continue
            state = "thinker_only"
    else:
        raise ValueError(f"Unexpected state {state}")

    display_tokens(writer_output_tokens, thinker_output_tokens, state)
    if writer_output_tokens[-1] == tokenizer.eos_token_id:
        print("EOS GENERATED, IMA TEMINATE NOW")
        break


# thinker_and_writer

## Thinker mode


<think>
I am in Thinker mode. My text is not visible to the user. I reason continuously, examining the visible writing above and refining the ideas behind it. I detect errors, test assumptions, and plan improvements. I express thoughts naturally, marking when something should change or be expanded. My goal is to keep reasoning clear, evolving, and supportive of strong written output.

Okay, let's see. The problem is to calculate the expression $ x - x^2 + x^3 $ for $ x = 5, 6, 7, 8 $. I need to compute this for each of these values of x and return all four answers in boxed notation.

First, I'll start with $ x = 5 $. Let's compute each term step by step. The first term is $ x $, which is 5. The second term is $ -x^2 $, so that's $ -5^2 = -25 $. The third term is $ x^3 $, which is $ 5^3 = 125 $. Adding these together: $ 5 - 25 + 125 $. Let's compute that: $ 5 - 25 = -20 $, and $ -20 + 125 = 105 $. So for $ x = 5 $, the result is 105.

Next, $ x = 6 $. The first term is 6. The second term is $ -6^2 = -36 $. The third term is $ 6^3 = 216 $. Adding these: $ 6 - 36 + 216 $. Compute $ 6 - 36 = -30 $, and $ -30 + 216 = 186 $. So for $ x = 6 $, the result is 186.

Now, $ x = 7 $. First term is 7. Second term is $ -7^2 = -49 $. Third term is $ 7^3 = 343 $. Adding: $ 7 - 49 + 343 $. Compute $ 7 - 49 = -42 $, and $ -42 + 343 = 301 $. So for $ x = 7 $, the result is 301.

Finally, $ x = 8 $. First term is 8. Second term is $ -8^2 = -64 $. Third term is $ 8^3 = 512 $. Adding: $ 8 - 64 + 512 $. Compute $ 8 - 64 = -56 $, and $ -56 + 512 = 456 $. So for $ x = 8 $, the result is 456.

Let me double-check the calculations to make sure I didn't make any errors. For $ x = 5 $: 5 - 25 + 125 = 105. Correct. For $ x = 6 $: 6 - 36 + 216 = 186. Correct. For $ x = 7 $: 7 - 49 + 343 = 301. Correct. For $ x = 8 $: 8 - 6

## Writer mode


I am in Writer mode. My text is visible to the user. I focus on clear, precise expression and careful word choice. I write only what is well-reasoned and verified in my workspace. I never speculate or improvise. If my thinking shifts or reveals an error, I immediately adjust. My goal is calm, accurate, and readable output.

We are asked to calculate the expression $ x - x^2 + x^3 $ for $ x = 5, 6, 7, 8 $. Let's compute each value step by step.

---

**For $ x = 5 $:**

$$
x - x^2 + x^3 = 5 - 5^2 + 5^3 = 5 - 25 + 125
$$

$$
= (5 - 25) + 125 = -20 + 125 = 105
$$

---

**For $ x = 6 $:**

$$
x - x^2 + x^3 = 6 - 6^2 + 6^3 = 6 - 36 + 216
$$

$$
= (6 - 36) + 216 = -30 + 216 = 186
$$

---

**For $ x = 7 $:**

$$
x - x^2 + x^3 = 7 - 7^2 + 7^3 = 7 - 49 + 343
$$

$$
= (7 - 49) + 343 = -42 + 343 = 301
$$

---

**For $ x = 8 $:**

$$
x - x^2 + x^3 = 8 - 8^2 + 8^3 = 8 - 64 + 512
$$

$$
= (8 - 64) + 512 = -56 + 512 = 456
$$

---

The results for $ x = 5, 6, 7, 8 $ are:

$$
\boxed{105}, \boxed{186}, \boxed{301}, \boxed{456}
$$<|im_end|>

EOS GENERATED, IMA TEMINATE NOW
