In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from typing import Tuple, Dict, List, Any, Optional
from threading import Thread

def load_hf_model(model_name: str):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    return model, tokenizer


def query(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, query: str, stream: bool = False) -> Dict[str, Any]:    
    messages = [
        {"role": "user", "content": query}
    ]
    
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    if not stream:
        with torch.no_grad():
            output = model.generate(
                **model_inputs,
                max_new_tokens=512,
                do_sample=True,
                temperature=0.7,
                top_p=0.9
            )
        
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        return {"query": query, "response": response, "messages": messages}
    else:
        # Create a streamer for token-by-token generation
        streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
        
        # Set up generation parameters
        generation_kwargs = dict(
            **model_inputs,
            max_new_tokens=512,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            streamer=streamer
        )
        
        # Start generation in a separate thread
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()
        
        # Return the streamer object for the caller to iterate through
        return {
            "query": query, 
            "streamer": streamer, 
            "thread": thread, 
            "messages": messages
        }

In [None]:
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
model, tokenizer = load_hf_model(model_name)

In [None]:
stream_result = query(model, tokenizer, "What is 27 * 32", stream=True)

for token in stream_result['streamer']:
    print(token, end="", flush=True)

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
import time
import queue
import re

device = "cuda"


def load_model(model_name: str):
    model = AutoModelForCausalLM.from_pretrained(
        model_name, device_map=device, torch_dtype=torch.float16  # Using float16 for wider compatibility
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer


def incremental_dual_stream(
    model, tokenizer, prompt: str, internal_buffer: int = 2, max_tokens: int = 100
):
    internal_generated = ""
    external_generated = ""

    internal_ctx = prompt + "\nReasoning:"
    
    # Internal generation setup
    internal_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    internal_token_queue = queue.Queue()
    
    # Internal generation function
    def internal_generate():
        # Properly handle tokenization with attention mask
        inputs = tokenizer(internal_ctx, return_tensors='pt', padding=True)
        input_ids = inputs.input_ids.to(model.device)
        attention_mask = inputs.attention_mask.to(model.device)
        
        print("Starting internal reasoning generation...")
        model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            streamer=internal_streamer,
            max_new_tokens=max_tokens,
            do_sample=True,  # Set to True to use temperature
            temperature=0.7,
            top_p=0.9,
            top_k=40
        )
    
    # Start internal generation thread
    internal_thread = Thread(target=internal_generate)
    internal_thread.daemon = True
    internal_thread.start()
    
    # Thread to collect tokens from the streamer and put them in a queue
    def collect_internal_tokens():
        try:
            for token in internal_streamer:
                internal_token_queue.put(token)
        except Exception as e:
            print(f"Error collecting internal tokens: {e}")
        finally:
            # Signal that generation is done
            internal_token_queue.put(None)
            
    token_collector = Thread(target=collect_internal_tokens)
    token_collector.daemon = True
    token_collector.start()
    
    # For cleaner output, we'll track significant updates
    last_internal_update = ""
    last_external_update = ""
    update_counter = 0
    
    # Initialize token to None
    token = None
    
    # Main processing loop
    try:
        print("\n" + "="*50)
        print("INCREMENTAL GENERATION PROGRESS:")
        print("="*50 + "\n")
        
        while True:
            update_happened = False
            
            # Fetch new internal tokens
            while not internal_token_queue.empty() and len(internal_generated.split()) - len(external_generated.split()) < internal_buffer:
                token = internal_token_queue.get()
                if token is None:  # End of generation
                    print("✓ Internal reasoning generation complete")
                    break
                
                internal_generated += token
                
                # Only show update if something meaningful changed (more than just whitespace)
                if internal_generated.strip() != last_internal_update.strip() and len(internal_generated.strip()) - len(last_internal_update.strip()) > 3:
                    last_internal_update = internal_generated
                    update_counter += 1
                    print(f"\n[Update {update_counter}] 🧠 INTERNAL REASONING:\n{internal_generated.strip()}")
                    update_happened = True
            
            # Generate external token if we have enough internal reasoning
            if len(internal_generated.split()) > len(external_generated.split()) or internal_token_queue.empty():
                # Update external context with available internal reasoning
                external_ctx = f"{prompt}\nReasoning: {internal_generated}\nFinal answer:"
                
                # Generate a few tokens for external output (instead of just one)
                # Properly handle tokenization with attention mask
                inputs = tokenizer(external_ctx + external_generated, return_tensors="pt", padding=True)
                input_ids = inputs.input_ids.to(model.device)
                attention_mask = inputs.attention_mask.to(model.device)
                
                outputs = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    max_new_tokens=3,  # Generate more tokens at once for efficiency
                    do_sample=True,    # Must be True if using temperature
                    temperature=0.3,
                    top_p=0.95,
                    pad_token_id=tokenizer.eos_token_id
                )
                
                # Get the new tokens (get everything after our input)
                new_text = tokenizer.decode(outputs[0, len(input_ids[0]):], skip_special_tokens=True)
                
                # If we got something new
                if new_text:
                    external_generated += new_text
                    
                    # Only show update if something meaningful changed
                    if external_generated.strip() != last_external_update.strip():
                        last_external_update = external_generated
                        update_counter += 1
                        print(f"\n[Update {update_counter}] 🗣️ EXTERNAL ANSWER:\n{external_generated.strip()}")
                        update_happened = True
                
                # Stop conditions
                if len(external_generated.split()) >= max_tokens:
                    print("\n✓ Reached max tokens")
                    break
            
            # Check if we're done with generation
            if token is None and internal_token_queue.empty():
                print("\n✓ Generation complete")
                break
            
            # If no update happened, wait a bit before checking again
            if not update_happened:
                time.sleep(0.2)
            
    except Exception as e:
        print(f"\nError in main processing loop: {e}")
        import traceback
        traceback.print_exc()
    
    # Wait for threads to finish (with timeout)
    token_collector.join(timeout=5)
    internal_thread.join(timeout=5)
    
    # Ensure the external answer looks complete - if it ends abruptly, add an ending
    if external_generated and not re.search(r'[.!?]$', external_generated.strip()):
        external_generated += "."
    
    return internal_generated, external_generated

In [None]:
# model_name = "Qwen/Qwen2.5-14B-Instruct"
# model, tokenizer = load_model(model_name)

In [None]:
# Sample prompt for testing
prompt = "Solve incrementally: What is 18 x 23?"
print(f"\nUsing prompt: '{prompt}'\n")

# Run the dual stream generation
print("Starting dual stream generation...\n")
internal_reasoning, external_answer = incremental_dual_stream(
    model, tokenizer, prompt, internal_buffer=3, max_tokens=150
)

# Display final results
print("\n" + "="*50)
print("FINAL RESULTS:")
print("\n🧠 INTERNAL REASONING:")
print(internal_reasoning.strip())

print("\n🗣️ EXTERNAL ANSWER:")
print(external_answer.strip())
print("="*50)