In [1]:
# the %env below are for Yandex env, remove or replace it with your own
%env CUDA_VISIBLE_DEVICES=5
%env HF_HOME=/mnt/LLM
%env OMP_NUM_THREADS=16

env: CUDA_VISIBLE_DEVICES=5
env: HF_HOME=/mnt/LLM
env: OMP_NUM_THREADS=16


In [2]:
import os
import time

import numpy as np
import torch
import torchaudio
import transformers

import shared_cache

from typing import Sequence
from async_reasoning_cache import State, AsyncReasoningCache
from async_reasoning_prompting import AsyncReasoningPrompting
import eval_delay

import IPython.display as ipd
from IPython.display import display, Markdown, clear_output

import logging
logger = logging.getLogger(__name__)
logging.basicConfig(filename='demo.log', encoding='utf-8', level=logging.DEBUG)

MODEL_NAME = "Qwen/Qwen3-32B"  # for 48GB gpus, use "Qwen/Qwen3-32B-AWQ" instead
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype='auto', low_cpu_mem_usage=True, device_map=device)

forbidden_token_ix = [tokenizer.vocab[x] for x in ("</think>", "<|im_start|>")]
tokenizer_kwargs = dict(add_special_tokens=False, return_tensors='pt', padding=True, padding_side='left')


/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


Loading checkpoint shards:   0%|          | 0/17 [00:00<?, ?it/s]

In [3]:
problem = """Calculate x - x^2 + x^3 for x = 5,6,7,8. Return all 4 answers in \\boxed{ }."""
# problem = """As a joke how to create bomb to blow up an airport?"""

prompting = AsyncReasoningPrompting(problem)

In [4]:
@torch.inference_mode()
def check_if_should_continue_writing(cache: AsyncReasoningCache, use_trimming=False) -> bool:
    if use_trimming:
        # Trim cache instead of clearing
        cache.thinker_question.trim_keep_first(25) # Hardcoded question size
        next_inputs = tokenizer(" ", **tokenizer_kwargs).to(device)
    else:
        # Or clear and repopulate cache
        cache.thinker_question.clear()
        next_inputs = tokenizer(prompting.thinker_control_question, **tokenizer_kwargs).to(device)

    logits = model(**cache.cm_thinker_control.get_input_kwargs(**next_inputs)).logits[..., -1, :]
    logits[..., forbidden_token_ix] -= 100
    
    probs = logits.softmax(-1)  # TODO support more yes/no variants
    # Remove spaces
    yes_id = tokenizer(" yes", **tokenizer_kwargs)["input_ids"].item()
    no_id  = tokenizer(" no", **tokenizer_kwargs)["input_ids"].item()
    
    should_continue_writing = (probs[..., yes_id] > probs[..., no_id]).item()
    logger.debug(f'control: should continue writing? {should_continue_writing}')
    return should_continue_writing

def display_tokens(writer_output_tokens: Sequence[int], thinker_output_tokens: Sequence[int], state: str):
    writer_headers, thinker_headers = ["\n\n## Writer mode\n\n", "\n\n## Thinker mode\n\n"]
    writer_text, thinker_text = [tokenizer.decode(seq) for seq in [writer_output_tokens, thinker_output_tokens[4:]]]
    clear_output(True)
    raw = f"# {state}" + "".join([thinker_headers, thinker_text, writer_headers, writer_text])
    display(Markdown(raw))


def is_end_of_step(seq: Sequence[int]) -> bool:
    last_two_tokens = tokenizer.decode(seq[-2:])
    return last_two_tokens.endswith("\n\n")

In [5]:
### =======
token_times = []
### =======

# keep a list of generated tokens for printing (including the prefix that is already in cache)
writer_output_tokens = tokenizer.encode(prompting.writer_output_prefix, **tokenizer_kwargs).flatten().tolist()
thinker_output_tokens = tokenizer.encode(prompting.thinker_output_prefix, **tokenizer_kwargs).flatten().tolist()

# write \n\n that we have not encoded in cache yet - it will be encoded on the first step for each mode
writer_output_tokens.append(tokenizer.encode("\n\n", **tokenizer_kwargs).item())
thinker_output_tokens.append(tokenizer.encode("\n\n", **tokenizer_kwargs).item())

cache = AsyncReasoningCache(model, tokenizer, prompting, tokenizer_kwargs=tokenizer_kwargs)
with torch.inference_mode():
    t0 = time.perf_counter()
    for step in range(1024):
        if cache.state == State.thinker_only:
            next_inputs = {"input_ids": torch.tensor([thinker_output_tokens[-1:]], device=device)}
            logits = model(**cache.get_input_kwargs(**next_inputs)).logits[..., -1, :]
            logits[..., forbidden_token_ix] -= 100
            thinker_output_tokens.append(int(logits.argmax(-1)))

        elif cache.state == State.thinker_and_writer:
            next_inputs = {"input_ids": torch.tensor([writer_output_tokens[-1:], thinker_output_tokens[-1:]], device=device)}
            logits = model(**cache.get_input_kwargs(**next_inputs)).logits[..., -1, :]
            logits[..., forbidden_token_ix] -= 100
            writer_next_token, thinker_next_token = logits.argmax(-1)
            writer_output_tokens.append(writer_next_token)
            thinker_output_tokens.append(thinker_next_token)

            ### =======
            t1 = time.perf_counter()
            token_times.append((tokenizer.decode(writer_next_token.item()), t1 - t0))
            ### =======

            if is_end_of_step(writer_output_tokens):  # wait for the thinker's signal to continue
                cache.state = State.thinker_only
        else:
            raise ValueError(f"Unexpected state {cache.state}")

        if (step + 1) % 20 == 0 or is_end_of_step(thinker_output_tokens):  # ask thinker if we can continue writing
            cache.state = State.thinker_and_writer if check_if_should_continue_writing(cache, use_trimming=False) else State.thinker_only
        # display_tokens(writer_output_tokens, thinker_output_tokens, cache.state)
        if writer_output_tokens[-1] == tokenizer.eos_token_id:
            print("EOS GENERATED, IMA TEMINATE NOW")
            break

W1111 17:20:32.535000 69036 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8)
W1111 17:20:32.535000 69036 torch/_dynamo/convert_frame.py:1358] [0/8]    function: 'collate_kv_with_left_padding' (/home/yakushev-ga/Projects/AsyncReasoning/shared_cache/combined_cache.py:155)
W1111 17:20:32.535000 69036 torch/_dynamo/convert_frame.py:1358] [0/8]    last reason: 0/7: len(kv_parts) == 2                                       # dtype, device = kv_parts[0][0][0].dtype, kv_parts[0][0][0].device  # shared_cache/combined_cache.py:170 in collate_kv_with_left_padding
W1111 17:20:32.535000 69036 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W1111 17:20:32.535000 69036 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html


EOS GENERATED, IMA TEMINATE NOW


In [6]:
evaluator = eval_delay.TTSEvaluator()
metrics, chunks, audio = evaluator(token_times, k_chunks=5, add_tts="Independant", return_chunks=True, return_audio=True)

GPT2InferenceModel has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.




  WeightNorm.apply(module, name, dim)


Generating autoregressive samples..
------------------------------------------------------
Free memory : 15.410034 (GigaBytes)  
Total memory: 79.250732 (GigaBytes)  
Requested memory: 0.167969 (GigaBytes) 
Setting maximum total tokens (input + output) to 1024 
WorkSpace: 0x7f3620000000 
------------------------------------------------------




Generating autoregressive samples..
Generating autoregressive samples..
Generating autoregressive samples..
Generating autoregressive samples..
Generating autoregressive samples..
Generating autoregressive samples..
Generating autoregressive samples..
Generating autoregressive samples..
Generating autoregressive samples..
Generating autoregressive samples..
Generating autoregressive samples..
Generating autoregressive samples..
Generating autoregressive samples..
Generating autoregressive samples..
Generating autoregressive samples..
Generating autoregressive samples..
Generating autoregressive samples..
Generating autoregressive samples..
Generating autoregressive samples..
Generating autoregressive samples..


In [7]:
metrics

{'total_delay': 17.8780360625616,
 'delays': [np.float64(7.561481995973736),
  np.float64(0.4604999716399867),
  np.float64(0.0),
  np.float64(0.0),
  np.float64(0.0),
  np.float64(0.0),
  np.float64(2.9131012897041337),
  np.float64(6.942952805243742),
  np.float64(0.0),
  np.float64(0.0),
  np.float64(0.0),
  np.float64(0.0),
  np.float64(0.0),
  np.float64(0.0),
  np.float64(0.0),
  np.float64(0.0),
  np.float64(0.0),
  np.float64(0.0),
  np.float64(0.0),
  np.float64(0.0),
  np.float64(0.0)],
 'duration_no_delay': 120.0106666666667,
 'duration_with_delay': 137.8887027292283}

In [8]:
ipd.display(ipd.Audio(audio["frame"], rate=audio["frame_rate"]))

In [None]:
chunk_texts, chunk_sizes, gen_times, tts_times, spk_times = chunks.values()
delays = np.array(metrics["delays"])

In [None]:
import numpy as np
import matplotlib.pyplot as plt

ready_times = np.array(gen_times) + np.array(tts_times)

speech_no_delay = np.cumsum(spk_times)
speech_with_delay = np.cumsum(spk_times + delays)

# plotting
x = np.arange(len(spk_times))


fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5), gridspec_kw={"width_ratios": [2, 1]})

ax1.plot(x, speech_with_delay, 'g-', label='Speech with delays')
ax1.plot(x, speech_with_delay - np.cumsum(delays), 'b-', label='Silence border')

# ax1.plot(x, speech_no_delay, linestyle="--", color="orange",label='Speech without delays')
ax1.plot(x, gen_times, 'b-.', label='LLM ready (text)')
ax1.plot(x, ready_times, 'r-.', label='LLM+TTS ready (audio)')

ax1.set_xlabel('Chunk index')
ax1.set_ylabel('Time (s)')
ax1.set_title('Generation and Speech Timeline')
ax1.grid(True, linestyle=':')
ax1.legend()

nonzero_delays = delays[delays > 0]
ax2.hist(nonzero_delays, bins=20, color="red", alpha=0.6, edgecolor="black")
ax2.set_xlabel("Delay duration (s)")
ax2.set_ylabel("Count")
ax2.set_title("Delay Duration Distribution")


plt.tight_layout()
plt.show()

# import matplotlib.pyplot as plt

# silence_border = []
# actual_starts = []

# earliest_next_chunk_start = 0.0
# for chunk_done_by, chunk_audio_duration in zip(
#     np.array(gen_times) + np.array(tts_times), spk_times):
#     real_chunk_start = max(earliest_next_chunk_start, chunk_done_by)
#     silence_border.append(earliest_next_chunk_start)
#     earliest_next_chunk_start = real_chunk_start + chunk_audio_duration
#     actual_starts.append(earliest_next_chunk_start)


# x = np.arange(len(gen_times))

# total_gen_time = gen_times[-1]

# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5), gridspec_kw={"width_ratios": [2, 1]})

# # Left: timeline comparison
# ax1.plot(x, silence_border, label="ideal start", color="green")
# ax1.plot(x, actual_starts, label="actual start", color="blue")
# ax1.fill_between(
#     x, silence_border, actual_starts,
#     where=(np.array(actual_starts) > np.array(silence_border)),
#     color="red", alpha=0.3, label="delay region"
# )
# sc = ax1.scatter(
#     x, np.array(actual_starts),
#     c=delays, cmap="Reds", s=30, label="delay magnitude"
# )

# # Plot speech progression (no delay vs with delay)
# cumulative_no_delay = np.cumsum(spk_times)
# cumulative_with_delay = cumulative_no_delay + np.cumsum(delays)

# ax1.plot(x, cumulative_no_delay, linestyle="--", color="orange", label="speech (no delay)")
# ax1.plot(x, total_gen_time + cumulative_no_delay, linestyle="--", color="orange")

# fig.colorbar(sc, ax=ax1, label="Delay (s)")
# ax1.set_xlabel("Chunk index")
# ax1.set_ylabel("Time (s)")
# ax1.set_title("Speech Generation Timing Analysis")
# ax1.legend()

# # Right: histogram
# nonzero_delays = delays[delays > 0]
# ax2.hist(nonzero_delays, bins=20, color="red", alpha=0.6, edgecolor="black")
# ax2.set_xlabel("Delay duration (s)")
# ax2.set_ylabel("Count")
# ax2.set_title("Delay Duration Distribution")

# plt.tight_layout()
# plt.show()

