# Thinking Augmented Generation
Author: [Zain Hasan](https://x.com/ZainHasan6)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/togethercomputer/together-cookbook/blob/main/Agents/Thinking_Augmented_Generation.ipynb)

## Introduction

In this notebook we will explore how you can improve the quality of smaller specialized models by using reasoning models.

Specifically we will get `DeepSeek-R1` to reason about a prompt and then provide the `thinking` tokens to a smaller model like `Mistral Small 3` to generate a better response.


In [1]:
from together import Together

client = Together(api_key = "---")

Small model alone:

In [27]:
question = "How many r's are in the word strawberry and burberry combined?"

answer = client.chat.completions.create(
  model="mistralai/Mistral-Small-24B-Instruct-2501",
  messages=[{"role": "user", 
             "content": question}],
)


print(answer.choices[0].message.content)

To determine the number of "r's" in the words "strawberry" and "burberry" combined, we need to count the "r's" in each word and then add them together.

- The word "strawberry" has 3 "r's".
- The word "burberry" has 2 "r's".

Adding these together:

3 (from strawberry) + 2 (from burberry) = 5

So, there are 5 "r's" in the words "strawberry" and "burberry" combined.


Let's get R1 to think about the question and then provide the thinking tokens to a smaller model like `Mistral Small 3` to generate a better response.


In [None]:
thought = client.chat.completions.create(
  model="deepseek-ai/DeepSeek-R1",
  messages=[{"role": "user", "content": question}],
  stop = ['</think>'] # Stop generation when </think> is encountered
)

print(thought.choices[0].message.content)

<think>
Okay, let's see. The question is asking how many times the letter "r" appears in the words "strawberry" and "burberry" combined. Hmm, I need to make sure I count each "r" in both words correctly. Let me start by writing down each word separately and then check each letter one by one.

First, let's take "strawberry". Let me spell it out: S-T-R-A-W-B-E-R-R-Y. Wait, let me count the letters. S (1), T (2), R (3), A (4), W (5), B (6), E (7), R (8), R (9), Y (10). So "strawberry" has 10 letters. Now, looking for the letter "r". Let's go through each letter again:

1. S - no
2. T - no
3. R - yes, that's the first "r"
4. A - no
5. W - no
6. B - no
7. E - no
8. R - second "r"
9. R - third "r"
10. Y - no

Wait, hold on. So in "strawberry", there are three "r"s? Let me confirm. The spelling is S-T-R-A-W-B-E-R-R-Y. So after the B, it's E, then R, R, Y. So positions 8 and 9 are both R's. So that's two R's after the E. But the third letter is also R. So that's three R's total in "strawberry"

Prompt template to pass in the thinking tokens to the smaller model:


In [29]:
PROMPT_TEMPLATE = """
Question: {question}
Thought process: {thinking_tokens} </think>
Answer:
"""

In [30]:
answer = client.chat.completions.create(
  model="mistralai/Mistral-Small-24B-Instruct-2501",
  messages=[{"role": "user", 
             "content": PROMPT_TEMPLATE.format(question = question,
                                               thinking_tokens=thought.choices[0].message.content
                                               )}],
)

print(answer.choices[0].message.content)

The total number of 'r's in the words "strawberry" and "burberry" combined is 6.
