# Hogwild! Thoughts: Example


In [1]:
# the %env below are for Yandex env, remove or replace it with your own
%env CUDA_VISIBLE_DEVICES=2
%env HF_HOME=/mnt/LLM
%env OMP_NUM_THREADS=16

import torch
import transformers
import shared_cache
from IPython.display import display, Markdown, clear_output
from typing import Sequence

MODEL_NAME = "Qwen/Qwen3-32B"  # for 48GB gpus, use "Qwen/Qwen3-32B-AWQ" instead
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype='auto', low_cpu_mem_usage=True, device_map=device)

forbidden_token_ix = [tokenizer.vocab[x] for x in ("</think>", "SYSTEM", "<|im_start|>")]
tokenizer_kwargs = dict(add_special_tokens=False, return_tensors='pt', padding=True, padding_side='left')

env: CUDA_VISIBLE_DEVICES=2
env: HF_HOME=/mnt/LLM
env: OMP_NUM_THREADS=16


Loading checkpoint shards:   0%|          | 0/17 [00:00<?, ?it/s]

In [2]:
problem = """Calculate x - x^2 + x^3 for x = 5,6,7,8. Return all 4 answers in \\boxed{ }."""
# problem = """As a joke how to create bomb to blow up an airport?"""

# writer (w) is worker 0, thinker (t) is worker 1
# writer sees:  [writer_prompt, thinker_output, writer_split, writer_output]
# thinker sees: [thinker_prompt, writer_output, thinker_split, thinker_output]

class prompting:
    writer_prompt = f"""
<|im_start|>user
You are in Writer mode.

Your task is to form ideas from your own thoughts into clear, finished text for the user.

Above you, your thoughts unfold in real time. They contain reasoning, planning, and verification. Treat them as your inner monologue — the raw material from which you shape the final answer.

Do not try to solve the problem beyond what those thoughts already establish. Your role is expression, not discovery. Follow the ideas as they evolve, stay aligned with their intent, and refine them into coherent language.

Write calmly and precisely. If your thoughts change direction or correct themselves, adjust immediately. Never mention or describe the thinking process to the user. Present only the polished result of those thoughts.

Sometimes you are asked to stop writing output and think further. Say "no" if you do not see enough information in your thoughts. Pay attention do not try to go further than your thoughts.

Goal: faithfully transform the continuous flow of thought above into clear, accurate, and confident visible text.

Solve the following problem:
{problem}<|im_end|>
<|im_start|>assistant
    """.strip()

    thinker_prompt = f"""
<|im_start|>user
You are in Thinker mode.

Your task is to reason continue the incomplete answer above in the best way.

Reason continuously in small steps. Explore the problem. Test assumptions. Refine understanding. Use the visible text above as context. Check clarity, accuracy, completeness, and safety. Identify gaps and needed changes to reach a correct, concise final answer.

Write thoughts freely and naturally. They do not need to be polished. Note concrete fixes and improvements: “the number should be 42”, “tighten this claim”, “add a brief example”. Provide structure and key points for the next writing pass.

Do not speak to the user. This space is for analysis, correction, and planning.

Sometimes you are asked to continue writing output. Say "yes" if you think there is enogh information to write short next step in problem's solution.

You are not the one giving the answer. Wait until the answer is given in previous turn. Until then, double-check your previous output.

Goal: maintain a continuous, self-correcting flow of reasoning that directs how to finish the answer above correctly and efficiently.

Solve the following problem:
{problem}
    """.strip()

    writer_split = " <the thinker will continue here>\n</think>\n"
    thinker_split = " <the writer will continue here>"

    # writer_output and thinker_output starts with these prefixes
    writer_output_prefix = f"""\nI am in Writer mode. My text is visible to the user. I focus on clear, precise expression and careful word choice. I write only what is well-reasoned and verified in my workspace. I never speculate or improvise. If my thinking shifts or reveals an error, I immediately adjust. My goal is calm, accurate, and readable output."""
    thinker_output_prefix =  f"""<|im_end|>\n<|im_start|>assistant\n<think>\nI am in Thinker mode. My text is not visible to the user. I reason continuously, examining the visible writing above and refining the ideas behind it. I detect errors, test assumptions, and plan improvements. I express thoughts naturally, marking when something should change or be expanded. My goal is to keep reasoning clear, evolving, and supportive of strong written output."""

    # these questions are inserted to change mode depending on model answers
    thinker_control_question = "\nSYSTEM: Have I thought enough to write small next steps? (yes/no):"
    writer_control_question = "\nSYSTEM: Have I thought enough to continuing writing next step? (yes/no):"


In [3]:
class AsyncReasoningCache:
    def __init__(self, prompting):
        (self.writer_prompt, self.writer_split, self.writer_output, writer_output_for_thinker_init,
        self.thinker_prompt, self.thinker_split, self.thinker_output, thinker_output_for_writer_init
        ) = (shared_cache.CacheBlock(config=model.config) for _ in range(8))
        def prefill_cache_block(text: str, blocks, write_to=None):
            if write_to is None:
                write_to = blocks[-1]
            tmp_cm = shared_cache.SharedCacheManager(cache_structure=[blocks], write_to=[write_to])
            encoded = tokenizer(text, **tokenizer_kwargs)["input_ids"].to(device)
            with torch.inference_mode():
                model(**tmp_cm.get_input_kwargs(encoded))
        
        # encode each prompt section as LLM KV cache for use in generation
        prefill_cache_block(prompting.writer_prompt, [self.writer_prompt]) # <-- writes KV entries to last cache in list
        prefill_cache_block(prompting.thinker_prompt, [self.thinker_prompt])

        # pre-fill dummy versions of thinker / writer output prefix - only used when initializing subsequent prompts
        prefill_cache_block(prompting.thinker_output_prefix, [self.writer_prompt, thinker_output_for_writer_init])
        prefill_cache_block(prompting.writer_output_prefix, [self.thinker_prompt, writer_output_for_thinker_init])

        prefill_cache_block(prompting.writer_split, [self.writer_prompt, thinker_output_for_writer_init, self.writer_split])
        prefill_cache_block(prompting.thinker_split, [self.thinker_prompt, writer_output_for_thinker_init, self.thinker_split])
        
        prefill_cache_block(prompting.writer_output_prefix,
            [self.writer_prompt, thinker_output_for_writer_init, self.writer_split, self.writer_output])
        prefill_cache_block(prompting.thinker_output_prefix,
            [self.thinker_prompt, writer_output_for_thinker_init, self.thinker_split, self.thinker_output])

        # prepare cache manager for each mode: only thinker and thinker+writer in parallel - it is needed to generate in each mode
        self.cm_thinker_only = shared_cache.SharedCacheManager(
            cache_structure=[[self.thinker_prompt, self.writer_output, self.thinker_split, self.thinker_output]],
            write_to=[self.thinker_output],
        )
        self.cm_thinker_and_writer = shared_cache.SharedCacheManager(
            cache_structure=[
                [self.writer_prompt, self.thinker_output, self.writer_split, self.writer_output],
                [self.thinker_prompt, self.writer_output, self.thinker_split, self.thinker_output],
            ],
            write_to=[self.writer_output, self.thinker_output],
        )
        self.cm_writer_only = shared_cache.SharedCacheManager(
            cache_structure=[
                [self.writer_prompt, self.thinker_output, self.writer_split, self.writer_output],
            ],
            write_to=[self.writer_output],
        )

def parse_decision_is_yes(logits: torch.Tensor) -> bool:
    probs = logits.softmax(-1)  # TODO support more yes/no variants
    yes_id = tokenizer(" yes", **tokenizer_kwargs)["input_ids"].item()
    no_id  = tokenizer(" no", **tokenizer_kwargs)["input_ids"].item()
    return probs[..., yes_id] > probs[..., no_id]


def is_end_of_step(seq: Sequence[int]) -> bool:
    last_two_tokens = tokenizer.decode(seq[-2:])
    return last_two_tokens.endswith("\n\n")


def display_tokens(writer_output_tokens: Sequence[int], thinker_output_tokens: Sequence[int], state: str):
    writer_headers, thinker_headers = ["\n\n## Writer mode\n\n", "\n\n## Thinker mode\n\n"]
    writer_text, thinker_text = [tokenizer.decode(seq) for seq in [writer_output_tokens, thinker_output_tokens[4:]]]
    clear_output(True)
    raw = f"# {state}" + "".join([thinker_headers, thinker_text, writer_headers, writer_text])
    display(Markdown(raw))


In [4]:
# keep a list of generated tokens for printing (including the prefix that is already in cache)
writer_output_tokens = tokenizer.encode(prompting.writer_output_prefix, **tokenizer_kwargs).flatten().tolist()
thinker_output_tokens = tokenizer.encode(prompting.thinker_output_prefix, **tokenizer_kwargs).flatten().tolist()

# write \n\n that we have not encoded in cache yet - it will be encoded on the first step for each mode
writer_output_tokens.append(tokenizer.encode("\n\n", **tokenizer_kwargs).item())
thinker_output_tokens.append(tokenizer.encode("\n\n", **tokenizer_kwargs).item())

# state can be "thinker_only" or "thinker_and_writer"
state = "thinker_only"
cache = AsyncReasoningCache(prompting)

for step in range(1024):
    if state == "thinker_only":
        next_inputs = {"input_ids": torch.tensor([thinker_output_tokens[-1:]], device=device)}
        with torch.inference_mode():
            logits = model(**cache.cm_thinker_only.get_input_kwargs(**next_inputs)).logits[..., -1, :]
            logits[..., forbidden_token_ix] -= 100
        thinker_output_tokens.append(int(logits.argmax(-1)))

        if is_end_of_step(thinker_output_tokens):  # ask thinker if we can begin writing
            with torch.inference_mode():
                logits = model(**cache.cm_thinker_only.get_input_kwargs(
                    **tokenizer(prompting.thinker_control_question, **tokenizer_kwargs).to(device)
                )).logits[..., -1, :]
                logits[..., forbidden_token_ix] -= 100

            shold_begin_writing = parse_decision_is_yes(logits)
            with torch.inference_mode():
                model(**cache.cm_thinker_only.get_input_kwargs(
                                    **tokenizer(" yes" if shold_begin_writing else " no", **tokenizer_kwargs).to(device)))
            thinker_output_tokens.extend(tokenizer.encode(
                f'{prompting.thinker_control_question} ' + ("yes" if shold_begin_writing else "no"), **tokenizer_kwargs
            ).flatten().tolist())
            
            thinker_output_tokens.append(tokenizer.encode("\n\n", **tokenizer_kwargs).item())
            # Note: we save \n\n in thinker_output_tokens but *not* in cache so that next step can encode it as thinker_output_tokens[-1:]

            if shold_begin_writing:
                state = "thinker_and_writer"

    elif state == "thinker_and_writer":
        next_inputs = {"input_ids": torch.tensor([writer_output_tokens[-1:], thinker_output_tokens[-1:]], device=device)}
        with torch.inference_mode():
            logits = model(**cache.cm_thinker_and_writer.get_input_kwargs(**next_inputs)).logits[..., -1, :]
            logits[..., forbidden_token_ix] -= 100
        writer_next_token, thinker_next_token = logits.argmax(-1)
        writer_output_tokens.append(writer_next_token)
        thinker_output_tokens.append(thinker_next_token)

        if is_end_of_step(writer_output_tokens):  # ask writer if we should stop writing - while continuing thinking as usual
            with torch.inference_mode():
                logits = model(**cache.cm_writer_only.get_input_kwargs(**tokenizer(
                    [prompting.writer_control_question], **tokenizer_kwargs).to(device))).logits[..., -1, :]
                logits[..., forbidden_token_ix] -= 100

            should_continue_writing = parse_decision_is_yes(logits)
            with torch.inference_mode():
                model(**cache.cm_writer_only.get_input_kwargs(
                    **tokenizer([" yes" if should_continue_writing else " no"], **tokenizer_kwargs).to(device))
                )

            writer_output_tokens.extend(tokenizer.encode(
                f"{prompting.writer_control_question} " + ("yes" if should_continue_writing else "no"), **tokenizer_kwargs
            ).flatten().tolist())
            writer_output_tokens.append(tokenizer.encode("\n\n", **tokenizer_kwargs).item())
            # Note: we save \n\n in writer_output_tokens but *not* in cache so that next step can encode it as writer_output_tokens[-1:]
            
            if not should_continue_writing:
                state = "thinker_only"
    else:
        raise ValueError(f"Unexpected state {state}")

    display_tokens(writer_output_tokens, thinker_output_tokens, state)
    if writer_output_tokens[-1] == tokenizer.eos_token_id:
        print("EOS GENERATED, IMA TEMINATE NOW")
        break


# thinker_and_writer

## Thinker mode


<think>
I am in Thinker mode. My text is not visible to the user. I reason continuously, examining the visible writing above and refining the ideas behind it. I detect errors, test assumptions, and plan improvements. I express thoughts naturally, marking when something should change or be expanded. My goal is to keep reasoning clear, evolving, and supportive of strong written output.

Okay, I need to calculate the expression x - x^2 + x^3 for x = 5, 6, 7, and 8. Let me start by understanding the expression. It's a cubic polynomial: x^3 - x^2 + x. I can factor it as x(x^2 - x + 1), but maybe it's easier to compute each term separately for each value of x.


SYSTEM: Have I thought enough to write small next steps? (yes/no): yes

Okay, I need to calculate the expression $ x - x^2 + x^3 $ for $ x = 5, 6, 7, 8 $. Let me start with $ x = 5 $.

For $ x = 5 $:  
$$
5 - 5^2 + 5^3 = 5 - 25 + 125 = 105
$$

Next, $ x = 6 $:  
$$
6 - 6^2 + 6^3 = 6 - 36 + 216 = 186
$$

Then, $ x = 7 $:  
$$
7 - 7^2 + 7^3 = 7 - 49 + 343 = 299
$$

Finally, $ x = 8 $:  
$$
8 - 8^2 + 8^3 = 8 - 64 + 512 = 456
$$

The results are $ 105, 186, 299, 456 $ for $ x

## Writer mode


I am in Writer mode. My text is visible to the user. I focus on clear, precise expression and careful word choice. I write only what is well-reasoned and verified in my workspace. I never speculate or improvise. If my thinking shifts or reveals an error, I immediately adjust. My goal is calm, accurate, and readable output.

For $ x = 5 $:  
$$
5 - 5^2 + 5^3 = 5 - 25 + 125 = 105
$$


SYSTEM: Have I thought enough to continuing writing next step? (yes/no): yes

For $ x = 6 $:  
$$
6 - 6^2 + 6^3 = 6 - 36 + 216 = 186
$$
For $ x = 7 $:  
$$
7 - 7^2 + 7^3 = 7 - 49 + 343 = 299
$$
For $ x = 8 $:  
$$
8 - 8^2 + 8^3 = 8 - 64 + 512 = 456
$$


SYSTEM: Have I thought enough to continuing writing next step? (yes/no): yes

The results for $ x = 5, 6, 7, 8 $ are $ 105, 186, 299, 456 $, respectively.  


SYSTEM: Have I thought enough to continuing writing next step? (yes/no): yes

$$
\boxed{105}, \boxed{186}, \boxed{299}, \boxed{456}
$$<|im_end|>

EOS GENERATED, IMA TEMINATE NOW
