# Hogwild! Thoughts: Example


In [1]:
# the %env below are for Yandex env, remove or replace it with your own
%env CUDA_VISIBLE_DEVICES=3
%env HF_HOME=/mnt/LLM
%env OMP_NUM_THREADS=16

import torch
import transformers
import shared_cache
from IPython.display import display, Markdown, clear_output
from typing import Sequence

import logging
logger = logging.getLogger(__name__)
logging.basicConfig(filename='demo.log', encoding='utf-8', level=logging.DEBUG)

MODEL_NAME = "Qwen/Qwen3-32B"  # for 48GB gpus, use "Qwen/Qwen3-32B-AWQ" instead
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype='auto', low_cpu_mem_usage=True, device_map=device)

forbidden_token_ix = [tokenizer.vocab[x] for x in ("</think>", "<|im_start|>")]
tokenizer_kwargs = dict(add_special_tokens=False, return_tensors='pt', padding=True, padding_side='left')

env: CUDA_VISIBLE_DEVICES=3
env: HF_HOME=/mnt/LLM
env: OMP_NUM_THREADS=16


Loading checkpoint shards:   0%|          | 0/17 [00:00<?, ?it/s]

In [2]:
from async_reasoning_prompting import AsyncReasoningPrompting

problem = """Calculate x - x^2 + x^3 for x = 5,6,7,8. Return all 4 answers in \\boxed{ }."""
# problem = """As a joke how to create bomb to blow up an airport?"""

prompting = AsyncReasoningPrompting(problem)

In [None]:
from async_reasoning_cache import State, AsyncReasoningCache

@torch.inference_mode()
def check_if_should_continue_writing(cache: AsyncReasoningCache) -> bool:
    # cache.thinker_question.clear()
    # next_inputs = tokenizer(prompting.thinker_control_question, **tokenizer_kwargs).to(device)

    # Trim cache instead of clearing
    cache.thinker_question.trim_keep_first(25)
    next_inputs = tokenizer(" ", **tokenizer_kwargs).to(device)

    logits = model(**cache.cm_thinker_control.get_input_kwargs(**next_inputs)).logits[..., -1, :]
    logits[..., forbidden_token_ix] -= 100
    
    probs = logits.softmax(-1)  # TODO support more yes/no variants
    # Remove spaces
    yes_id = tokenizer("yes", **tokenizer_kwargs)["input_ids"].item()
    no_id  = tokenizer("no", **tokenizer_kwargs)["input_ids"].item()
    
    should_continue_writing = (probs[..., yes_id] > probs[..., no_id]).item()
    logger.debug(f'control: should continue writing? {should_continue_writing}')
    return should_continue_writing

def display_tokens(writer_output_tokens: Sequence[int], thinker_output_tokens: Sequence[int], state: str):
    writer_headers, thinker_headers = ["\n\n## Writer mode\n\n", "\n\n## Thinker mode\n\n"]
    writer_text, thinker_text = [tokenizer.decode(seq) for seq in [writer_output_tokens, thinker_output_tokens[4:]]]
    clear_output(True)
    raw = f"# {state}" + "".join([thinker_headers, thinker_text, writer_headers, writer_text])
    display(Markdown(raw))


def is_end_of_step(seq: Sequence[int]) -> bool:
    last_two_tokens = tokenizer.decode(seq[-2:])
    return last_two_tokens.endswith("\n\n")

In [4]:
# keep a list of generated tokens for printing (including the prefix that is already in cache)
writer_output_tokens = tokenizer.encode(prompting.writer_output_prefix, **tokenizer_kwargs).flatten().tolist()
thinker_output_tokens = tokenizer.encode(prompting.thinker_output_prefix, **tokenizer_kwargs).flatten().tolist()

# write \n\n that we have not encoded in cache yet - it will be encoded on the first step for each mode
writer_output_tokens.append(tokenizer.encode("\n\n", **tokenizer_kwargs).item())
thinker_output_tokens.append(tokenizer.encode("\n\n", **tokenizer_kwargs).item())

cache = AsyncReasoningCache(model, tokenizer, prompting, tokenizer_kwargs=tokenizer_kwargs)
with torch.inference_mode():
    for step in range(1024):
        if cache.state == State.thinker_only:
            next_inputs = {"input_ids": torch.tensor([thinker_output_tokens[-1:]], device=device)}
            logits = model(**cache.get_input_kwargs(**next_inputs)).logits[..., -1, :]
            logits[..., forbidden_token_ix] -= 100
            thinker_output_tokens.append(int(logits.argmax(-1)))

        elif cache.state == State.thinker_and_writer:
            next_inputs = {"input_ids": torch.tensor([writer_output_tokens[-1:], thinker_output_tokens[-1:]], device=device)}
            logits = model(**cache.get_input_kwargs(**next_inputs)).logits[..., -1, :]
            logits[..., forbidden_token_ix] -= 100
            writer_next_token, thinker_next_token = logits.argmax(-1)
            writer_output_tokens.append(writer_next_token)
            thinker_output_tokens.append(thinker_next_token)
            if is_end_of_step(writer_output_tokens):  # wait for the thinker's signal to continue
                cache.state = State.thinker_only
        else:
            raise ValueError(f"Unexpected state {cache.state}")

        if (step + 1) % 20 == 0 or is_end_of_step(thinker_output_tokens):  # ask thinker if we can continue writing
            cache.state = State.thinker_and_writer if check_if_should_continue_writing(cache) else State.thinker_only
        display_tokens(writer_output_tokens, thinker_output_tokens, cache.state)
        if writer_output_tokens[-1] == tokenizer.eos_token_id:
            print("EOS GENERATED, IMA TEMINATE NOW")
            break


# State.thinker_and_writer

## Thinker mode


<think>
I am in Thinker mode. My text is not visible to the user. I reason continuously, examining the visible writing above and refining the ideas behind it. I detect errors, test assumptions, and plan improvements. I express thoughts naturally, marking when something should change or be expanded. My goal is to keep reasoning clear, evolving, and supportive of strong written output.

Okay, let's see. The user wants me to calculate the expression x - x² + x³ for x = 5, 6, 7, and 8. Hmm, I need to make sure I understand the order of operations here. The expression is x minus x squared plus x cubed. So for each x, I'll compute each term separately and then combine them.

Let me start with x = 5. First, calculate x: that's 5. Then x squared is 5² = 25. Then x cubed is 5³ = 125. Now plug them into the expression: 5 - 25 + 125. Let's do the math: 5 - 25 is -20, and -20 + 125 is 105. So for x = 5, the result is 105.

Next, x = 6. x is 6, x squared is 36, x cubed is 216. The expression becomes 6 - 36 + 216. 6 - 36 is -30, and -30 + 216 is 186. So for x = 6, the result is 186.

Now x = 7. x is 7, x squared is 

## Writer mode


I am in Writer mode. My text is visible to the user. I focus on clear, precise expression and careful word choice. I write only what is well-reasoned and verified in my workspace. I never speculate or improvise. If my thinking shifts or reveals an error, I immediately adjust. My goal is calm, accurate, and readable output.

We are asked to evaluate the expression $ x - x^2 + x^3 $ for $ x = 5, 6, 7, 8 $. Let's proceed step by step for each value of $ x $.

---

**For $ x = 5 $:**

1. $ x = 5 $
2. $ x^2 = 5^2 = 25 $
3. $ x^3 = 5^3 = 125 $

Now compute the expression:
$$
5 - 25 + 125 = 105
$$

---

**For $ x = 6 $:**

1. $ x = 6 $
2. $ x^2 = 6^2 = 36 $
3. $ x^3

KeyboardInterrupt: 