<a href="https://colab.research.google.com/github/sid779/Deep_learning_with_pytorch/blob/master/r1_overthinker.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **🤔 DeepSeek R1 Overthinker**
Using this app you can force DeepSeek R1 models to think more deeply, for as long as you wish. It works like this:
- Detects when the model tries to conclude thoughts too early (`</think>` token)
- Replaces those with prompts that encourage additional reasoning
- Continues until a minimum threshold of thinking is reached

You decide how long the model should think. The result is more thorough and well-reasoned responses (hopefully).

<br>

Colab by [anzorq](https://twitter.com/hahahahohohe). If you like it, please consider supporting me:

[<a href="https://www.buymeacoffee.com/anzorq" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" height="32px" width="108px" alt="Buy Me A Coffee"></a>](https://www.buymeacoffee.com/anzorq)
<br>
[![GitHub Repo stars](https://img.shields.io/github/stars/qunash/r1-overthinker?style=social)](https://github.com/qunash/r1-overthinker)

# Install dependencies

In [1]:
%%capture
!pip install uv
!uv pip install --system unsloth gradio gradio_log
!uv pip install --system --force-reinstall --no-deps git+https://github.com/unslothai/unsloth.git

#@markdown ### ⬅️ Run this cell

# Run the app

In [None]:
#@markdown ### ⬅️ Run this cell
#@markdown #### To open the app in a separate tab, click on the generated public URL that looks like this `https://*.gradio.live`
import torch
from huggingface_hub import HfApi
import gc
import sys
import os

class ModelManager:
    def __init__(self):
        self.current_model_name = None
        self.model = None
        self.tokenizer = None
        self.max_seq_length = 21848
        self.dtype = None  # Auto-detects bfloat16/float16
        self.load_in_4bit = True
        self.log_file = "loading.log"

        # Clear log file on initialization
        open(self.log_file, 'w').close()

    class OutputTee:
        def __init__(self, *outputs):
            self.outputs = outputs
        def write(self, data):
            for o in self.outputs:
                o.write(data)
                o.flush()
        def flush(self):
            for o in self.outputs:
                o.flush()

    def log_print(self, text):
        with open(self.log_file, 'a', encoding='utf-8') as f:
            f.write(text + "\n")
        print(text)

    def get_available_models(self):
        api = HfApi()
        models = api.list_models(
            search="DeepSeek-R1",
            author="unsloth",
            sort="downloads",
            direction=-1
        )
        model_options = sorted([model.id for model in models])
        if not model_options:
            self.log_print("No DeepSeek-R1 models found from unsloth")
            return []
        return model_options

    def unload_current_model(self):
        if self.model is not None:
            try:
                self.model.cpu()
                del self.model
                del self.tokenizer
                self.model = None
                self.tokenizer = None
                self.current_model_name = None
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    torch.cuda.synchronize()
            except Exception as e:
                self.log_print(f"Error during model unloading: {str(e)}")
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

    def load_model(self, model_name):
        if model_name == self.current_model_name:
            yield "Model already loaded", False
            return

        # Clear the log file before starting new load
        open(self.log_file, 'w').close()

        old_stdout = sys.stdout
        old_stderr = sys.stderr
        log_fd = open(self.log_file, 'a', encoding='utf-8')
        sys.stdout = ModelManager.OutputTee(old_stdout, log_fd)
        sys.stderr = ModelManager.OutputTee(old_stderr, log_fd)

        try:
            self.log_print(f"Loading model: {model_name}")
            self.unload_current_model()

            from unsloth import FastLanguageModel
            self.model, self.tokenizer = FastLanguageModel.from_pretrained(
                model_name=model_name,
                max_seq_length=self.max_seq_length,
                dtype=self.dtype,
                load_in_4bit=self.load_in_4bit
            )
            FastLanguageModel.for_inference(self.model)

            self.current_model_name = model_name
            self.log_print(f"Successfully loaded {model_name}")
            yield f"Successfully loaded {model_name}", True

        except Exception as e:
            self.log_print(f"Error loading model: {str(e)}")
            yield f"Error loading model. See logs for details.", False
        finally:
            sys.stdout = old_stdout
            sys.stderr = old_stderr
            log_fd.close()

    def get_current_model_info(self):
        return "No model currently loaded" if self.current_model_name is None else f"Currently loaded model: {self.current_model_name}"

    def force_gpu_cleanup(self):
        self.unload_current_model()
        if torch.cuda.is_available():
            initial = torch.cuda.memory_allocated()
            gc.collect()
            torch.cuda.empty_cache()
            final = torch.cuda.memory_allocated()
            self.log_print(f"GPU Memory freed: {(initial - final) / 1024**2:.2f} MB")

    def set_context_length(self, length):
        """Set the max sequence length"""
        try:
            length = int(length)
            if length < 1024:
                length = 1024
            self.max_seq_length = length
            return True
        except:
            return False


import gradio as gr
import torch
import random
from gradio import ChatMessage
from gradio_log import Log
import logging

# Initialize model manager
model_manager = ModelManager()

# Store parameters in global variables since they're just configuration
params = {
    "min_thinking_tokens": 1024,
    "max_output_tokens": 2048,
    "temperature": 0.7,
    "top_p": 0.95,
    "repetition_penalty": 1.2,
    "max_swaps": -1,
    "replacement_tokens": """Hold on - what if we challenge our initial assumptions?
Let me try a completely different approach to this problem
What would be a strong counter-argument to my current reasoning?
Let's validate each step of my previous logic
Could there be edge cases we haven't considered?
What if we work backwards from the expected result?
Are there any hidden constraints we're missing?
Let's try to disprove our current answer
Let me break this down into smaller sub-problems
Is there a more elegant solution we're overlooking?
What patterns emerge if we look at extreme cases?
Could we solve this using an entirely different domain?"""
}

def load_selected_model(model_name, context_length):
    """Wrapper function to handle model loading with progress updates"""
    # First yield should disable the button and inputs
    yield [
        "Loading model...",                   # model_info
        0,                                    # current_tab
        gr.update(interactive=False),         # load_button
        gr.update(interactive=False),         # model_dropdown
        gr.update(interactive=False)          # context_length
    ]

    # Set context length before loading
    model_manager.set_context_length(context_length)

    for message, success in model_manager.load_model(model_name):
        # Return list of values in same order as output components
        yield [
            message,                          # model_info
            1 if success else 0,              # current_tab
            gr.update(interactive=True),      # load_button
            gr.update(interactive=True),      # model_dropdown
            gr.update(interactive=True)       # context_length
        ]

def get_model_info():
    return model_manager.get_current_model_info()

def generate_with_replacements(message, history):
    if model_manager.model is None:
        return [ChatMessage(role="assistant", content="Please load a model first in the 'Choose Model' tab")]

    # Convert history to messages format
    messages = [{"role": msg["role"], "content": msg["content"]} for msg in history]
    # Add current message
    messages.append({"role": "user", "content": message})

    # Initialize tokens with full conversation history
    current_tokens = model_manager.tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to("cuda")

    # Find think tags using current tokenizer
    _, start_think_token, end_think_token = model_manager.tokenizer.encode("<think></think>")

    thinking_msg = None
    final_msg = None
    thinking_content = ""
    final_content = ""
    n_thinking_tokens = 0
    swap_count = 0
    is_thinking = False
    total_new_tokens = 0
    max_batch_tokens = 20

    while total_new_tokens < params["max_output_tokens"]:
        outputs = model_manager.model.generate(
            input_ids=current_tokens,
            max_new_tokens=min(max_batch_tokens, params["max_output_tokens"] - total_new_tokens),
            temperature=params["temperature"],
            top_p=params["top_p"],
            repetition_penalty=params["repetition_penalty"],
            use_cache=True,
            return_dict_in_generate=True,
            output_scores=True
        )

        new_tokens = outputs.sequences[:, current_tokens.shape[1]:]
        total_new_tokens += new_tokens.shape[1]

        for i, token in enumerate(new_tokens[0]):
            token_id = token.item()

            # Check for token swapping before processing the token
            if (token_id in (end_think_token, model_manager.model.config.eos_token_id)
                and n_thinking_tokens < params["min_thinking_tokens"]
                and (params["max_swaps"] == -1 or swap_count < params["max_swaps"])):

                replacement = "\n" + random.choice([t.strip() for t in params["replacement_tokens"].split('\n') if t.strip()])
                replacement_tokens = model_manager.tokenizer.encode(replacement)

                n_thinking_tokens += len(replacement_tokens)
                swap_count += 1
                thinking_content += replacement

                current_tokens = torch.cat([
                    outputs.sequences[:, :current_tokens.shape[1] + i],
                    torch.tensor([replacement_tokens]).to("cuda")
                ], dim=1)

                thinking_msg.content = thinking_content
                thinking_msg.metadata = {
                    "title": f"🤔 Thinking process (Extensions: {swap_count}, Tokens: {n_thinking_tokens})"
                }
                yield [thinking_msg]
                break

            token_str = model_manager.tokenizer.decode([token_id])

            # Update thinking state
            if token_id == start_think_token:
                is_thinking = True
                thinking_msg = ChatMessage(
                    role="assistant",
                    content="",
                    metadata={"title": "🤔 Thinking process"}
                )
                continue
            elif token_id == end_think_token:
                is_thinking = False
                final_msg = ChatMessage(role="assistant", content="")
                continue

            # Add content and update messages
            if is_thinking:
                thinking_content += token_str
                n_thinking_tokens += 1
                if thinking_msg:
                    thinking_msg.content = thinking_content
                    thinking_msg.metadata = {
                        "title": f"🤔 Thinking process (Extensions: {swap_count}, Tokens: {n_thinking_tokens})"
                    }
                    yield [thinking_msg]
            else:
                final_content += token_str
                if final_msg:
                    messages = []
                    if thinking_msg:
                        messages.append(thinking_msg)
                    final_msg.content = final_content.strip()
                    messages.append(final_msg)
                    yield messages

            if token_id == model_manager.model.config.eos_token_id:
                messages = []
                if thinking_msg:
                    thinking_msg.content = thinking_content
                    messages.append(thinking_msg)
                if final_msg:
                    final_msg.content = final_content.strip()
                    messages.append(final_msg)
                yield messages
                return

        else:
            current_tokens = outputs.sequences

def update_global_params(min_tokens, max_tokens, max_swaps, replacements, temp, top_p, rep_pen):
    params.update({
        "min_thinking_tokens": min_tokens,
        "max_output_tokens": max_tokens,
        "max_swaps": max_swaps,
        "replacement_tokens": replacements,
        "temperature": temp,
        "top_p": top_p,
        "repetition_penalty": rep_pen
    })

# Create the interface
with gr.Blocks() as demo:
    gr.Markdown("""
        # 🤔 DeepSeek R1 Overthinker

        Using this app you can force DeepSeek R1 models to think more deeply, for as long as you wish. It works like this:
        - Detects when the model tries to conclude thoughts too early (`</think>` token)
        - Replaces those with prompts that encourage additional reasoning
        - Continues until a minimum threshold of thinking is reached

        You decide how long the model should think. The result is more thorough and well-reasoned responses (hopefully).
    """)

    current_tab = gr.State(value=0)

    with gr.Tabs() as tabs:
        with gr.Tab("1. Choose Model"):
            model_dropdown = gr.Dropdown(
                choices=model_manager.get_available_models(),
                label="Select DeepSeek R1 Model",
                interactive=True
            )

            context_length = gr.Number(
                value=21848,
                label="Maximum context length (in tokens)",
                precision=0,
                minimum=1024,
                info="""Higher values require more VRAM. Reference values from Llama 3.1 testing:
3,000 tokens → ~8 GB VRAM
22,000 tokens → ~12 GB VRAM
41,000 tokens → ~16 GB VRAM
78,000 tokens → ~24 GB VRAM
154,000 tokens → ~40 GB VRAM

Actual VRAM usage depends on the model. Start with a lower value if you experience out-of-memory errors."""
            )

            load_button = gr.Button("Load Selected Model")
            model_info = gr.Markdown(get_model_info())

            # Create log component with auto-refresh
            loading_log = Log(
                model_manager.log_file,
                dark=True,
                xterm_font_size=14,
                label="Loading Progress",
                every=0.5
            )

            # Update model info when loading model
            load_button.click(
                fn=load_selected_model,
                inputs=[model_dropdown, context_length],
                outputs=[
                    model_info,           # Markdown
                    current_tab,          # State
                    load_button,          # Button
                    model_dropdown,       # Dropdown
                    context_length        # Number
                ],
                queue=True
            ).success(                    # Change .then to .success
                fn=lambda tab: gr.Tabs(selected=tab),
                inputs=[current_tab],
                outputs=[tabs]
            )

        with gr.Tab("2. Chat", id=1):
            with gr.Row():
                with gr.Column(scale=3):
                    chat_interface = gr.ChatInterface(
                        generate_with_replacements,
                        chatbot=gr.Chatbot(
                            bubble_full_width=True,
                            show_copy_button=True,
                            latex_delimiters=[
                                {"left": "$$", "right": "$$", "display": True},
                                {"left": "$", "right": "$", "display": False},
                                {"left": "\\(", "right": "\\)", "display": False},
                                {"left": "\\[", "right": "\\]", "display": True},
                                {"left": "\\begin{equation}", "right": "\\end{equation}", "display": True},
                                {"left": "\\begin{align}", "right": "\\end{align}", "display": True},
                                {"left": "\\begin{alignat}", "right": "\\end{alignat}", "display": True},
                                {"left": "\\begin{gather}", "right": "\\end{gather}", "display": True},
                                {"left": "\\begin{CD}", "right": "\\end{CD}", "display": True},
                                {"left": "\\boxed{", "right": "}", "display": False},
                                {"left": "\\frac{", "right": "}", "display": False},
                                {"left": "\\sqrt{", "right": "}", "display": False}
                            ],
                            type="messages",
                            editable='all',
                            min_height=550,
                        ),
                        fill_height=True,
                        type="messages"
                    )

                with gr.Column(scale=1):
                    min_tokens = gr.Slider(minimum=0, maximum=model_manager.max_seq_length-512, value=1024, step=128, label="Minimum Thinking Tokens")
                    max_tokens = gr.Slider(minimum=256, maximum=model_manager.max_seq_length, value=2048, step=256, label="Maximum Output Tokens")
                    max_swaps = gr.Slider(minimum=-1, maximum=20, value=-1, step=1, label="Maximum Reasoning Extensions", info="Limit how many times to extend the model's reasoning (-1 for unlimited).")
                    replacements = gr.Textbox(value=params["replacement_tokens"], label="Replacement Tokens (one per line)", lines=5, max_lines=5)
                    temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
                    top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
                    rep_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty")

                    # Connect all parameter components to the update function
                    for param in [min_tokens, max_tokens, max_swaps, replacements, temperature, top_p, rep_penalty]:
                        param.change(
                            fn=update_global_params,
                            inputs=[min_tokens, max_tokens, max_swaps, replacements, temperature, top_p, rep_penalty],
                            outputs=[]
                        )
    gr.HTML("""
    <div style="border-top: 1px solid #303030;">
      <br>
      <p>App by: <a href="https://twitter.com/hahahahohohe"><img src="https://img.shields.io/twitter/follow/hahahahohohe?label=%40anzorq&style=social" alt="Twitter Follow"></a></p><br>
      <p>Enjoying this app? Please consider <a href="https://www.buymeacoffee.com/anzorq">supporting me</a></p>
      <a href="https://www.buymeacoffee.com/anzorq" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 45px !important;width: 162px !important;" ></a><br><br>
      <a href="https://github.com/qunash/r1-overthinker" target="_blank"><img alt="GitHub Repo stars" src="https://img.shields.io/github/stars/qunash/r1-overthinker?style=social"></a>
    </div>
    """)

demo.launch(debug=True, share=True)
