# Thinking Augmented Generation
Author: [Zain Hasan](https://x.com/ZainHasan6)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/togethercomputer/together-cookbook/blob/main/Agents/Thinking_Augmented_Generation.ipynb)

## Introduction

In this notebook we will explore how you can improve the quality of smaller specialized models by using reasoning models.

Specifically we will get `DeepSeek-R1` to reason about a prompt and then provide the `thinking` tokens to a smaller model like `Mistral Small 3` to generate a better response.


In [None]:
!pip install -qU together

In [3]:
from together import Together

client = Together()

Small model alone:

In [4]:
question = "How many r's are in the word strawberry and burberry combined?"

answer = client.chat.completions.create(
  model="mistralai/Mistral-Small-24B-Instruct-2501",
  messages=[{"role": "user", 
             "content": question}],
)


print(answer.choices[0].message.content)

Let's count the number of 'r's in each word:

- In the word "strawberry", there are 3 'r's.
- In the word "burberry", there are 2 'r's.

Combined, the total number of 'r's in "strawberry" and "burberry" is 3 + 2 = 5.


Let's get R1 to think about the question and then provide the thinking tokens to a smaller model like `Mistral Small 3` to generate a better response.


In [5]:
thought = client.chat.completions.create(
  model="deepseek-ai/DeepSeek-R1",
  messages=[{"role": "user", "content": question}],
  stop = ['</think>'] # Stop generation when </think> is encountered
)

print(thought.choices[0].message.content)

<think>
First, the question is: "How many r's are in the word strawberry and burberry combined?" I need to count the number of 'r's in both words together.

Let me start with "strawberry". I'll spell it out: S-T-R-A-W-B-E-R-R-Y. Now, I need to find all the 'r's.

- Position 3: R

- Position 8: R

- Position 9: R

In "strawberry", there are three 'r's. Let me confirm: S-T-R (that's one), then after W-B-E, there's R-R (that's two more), so yes, three 'r's.

Now, "burberry": B-U-R-B-E-R-R-Y.

Spelling it out: B-U-R (first R), then B-E-R (second R), then R (third R), and Y. So, positions:

- Position 3: R

- Position 6: R

- Position 7: R

"burberry" has three 'r's as well. B-U-R (R at 3), B-E-R (R at 6), and then another R at 7? Let's list the letters:

1. B

2. U

3. R (first R)

4. B

5. E

6. R (second R)

7. R (third R)? No, after R at position 6, it's R-Y? I think I'm confusing it.

"Burberry" is typically spelled B-U-R-B-E-R-R-Y. So:

- Letter 1: B

- Letter 2: U

- Letter 3: R (fir

Prompt template to pass in the thinking tokens to the smaller model:


In [6]:
PROMPT_TEMPLATE = """
Question: {question}
Thought process: {thinking_tokens} </think>
Answer:
"""

In [7]:
answer = client.chat.completions.create(
  model="mistralai/Mistral-Small-24B-Instruct-2501",
  messages=[{"role": "user", 
             "content": PROMPT_TEMPLATE.format(question = question,
                                               thinking_tokens=thought.choices[0].message.content
                                               )}],
)

print(answer.choices[0].message.content)

The answer to the question "How many r's are in the word strawberry and burberry combined?" is 6.

Let's summarize the thought process one more time for clarity:
1. "Strawberry" contains 3 r's (positions 3, 8, and 9).
2. "Burberry" contains 3 r's (positions 3, 6, and 7).
3. When combined (whether summed or concatenated), the total number of r's remains 6.

Therefore, the final answer is 6.
