In [1]:
import os
import gc
import time
import warnings

import pandas as pd
import polars as pl
import uuid
import json
from typing import List
from datetime import datetime
import psycopg2
import re
import random
from collections import Counter

In [2]:

# Login using e.g. `huggingface-cli login` to access this dataset
s1_df = pd.read_parquet("hf://datasets/simplescaling/s1K-1.1/data/train-00000-of-00001.parquet")

In [3]:
test_trace = s1_df['deepseek_thinking_trajectory'].iloc[0]

# Regex'ing

In [4]:
start_tag = '<keep>'
end_tag = '</keep>'
s = "Hi my name is Sam. I am a student at Stanford University. I am a student at Stanford University. I am a student at Stanford University."

def return_para_w_last_sentence_wrapped(para: str, last_n_sentences: int = 2):
    # Match a period, space, and then a capital letter
    eos_regex = r'[.!?]\s[A-Z]'
    offset = 2
    
    positions = [m.start() for m in re.finditer(eos_regex, para)]
    if positions == []:
        positions = [0]
        offset = 0
    if len(positions) < last_n_sentences:
        last_n_sentences = len(positions)
    last_sentence_start = positions[-1*last_n_sentences]+offset
    last_sentence = para[last_sentence_start:]
    wrapped_sentence = f'{start_tag}{last_sentence}{end_tag}'
    wrapped_paragraph = f'{para[:last_sentence_start]}{wrapped_sentence}\n\n'
    return wrapped_paragraph


In [5]:
def insert_keep_text_at_eop(text, config: dict = {}):
    default_config = {
        'k': 3,
        'min_p_length': 300
    }
    if len(config) == 0:
        config = default_config
    else:
        config = {**default_config, **config}
    # Wrap the last sentence of every paragraph with length larger than min_p_length with <keep> tags.
    paragraphs = text.split('\n\n')
    text_wrapped = ''
    ct_para = 0
    ct_sen = 0
    for paragraph in paragraphs:
        if len(paragraph) > config['min_p_length']:
            if ct_para % config['k'] == 0:
                # Wrap the last sentence of the paragraph with <keep> tags.
                # Assume new sentences start with a capital letter
                wrapped_paragraph = return_para_w_last_sentence_wrapped(paragraph, last_n_sentences=2)
                text_wrapped += wrapped_paragraph
            else:
                text_wrapped += paragraph + '\n\n'
            ct_para += 1
        else:
            if ct_sen % config['k'] == 0:
                wrapped_paragraph = return_para_w_last_sentence_wrapped(paragraph, last_n_sentences=1)
                text_wrapped += wrapped_paragraph
            else:
                text_wrapped += paragraph + '\n\n'
            ct_sen += 1

    return text_wrapped

def insert_keep_text(text, strategy: str ='eop', config: dict = None):
    if config is None:
        config = {}
    if strategy == 'eop':
        return insert_keep_text_at_eop(text, config)
    else:
        raise ValueError(f"Invalid strategy: {strategy}")

print(insert_keep_text(test_trace, 'eop', {}))

Alright, so I need to figure out how many rational numbers between 0 and 1 have a product of the numerator and denominator (when written in lowest terms) equal to 20 factorial. <keep>Let me start by understanding the problem.</keep>

First, a rational number between 0 and 1 can be written as a fraction \(\frac{a}{b}\) where \(0 < a < b\) and \(a\) and \(b\) are coprime positive integers. The product \(a \times b\) is said to be equal to 20! which is a huge number. <keep>So I need to find all pairs \((a, b)\) such that \(a \times b = 20!\), \(0 < a < b\), and \(\gcd(a, b) = 1\). Then count how many such pairs exist.</keep>

Let me break down the problem.

Given that \(a \times b = 20!\) and \(\gcd(a, b) = 1\), it means that \(a\) and \(b\) are coprime factors of 20!. So they form a coprime pair whose product is 20!. Such pairs are related to the divisors of 20! where each pair consists of two coprime numbers.

But since the fraction is between 0 and 1, \(a < b\). So for each such coprim

In [6]:
# verify that the wrapped text without keep tags is the same as the original text
def remove_keep_text(text):
    return text.replace('</keep>', '').replace('<keep>', '')

assert remove_keep_text(insert_keep_text(test_trace, 'eop', {})).strip() == test_trace.strip()

# Prompting

In [7]:
prompt = """
You are Deepseek, an expert reasoning Large Language Model. Your task is to assist in optimizing reasoning outputs by indicating which tokens are essential for future reasoning steps.

Currently, reasoning models often wrap their thought process within <think> and </think> tags. However, this approach can lead to unnecessarily long contexts, as many tokens included might not be crucial for subsequent reasoning.

To improve this, in addition to the existing <think> ... </think> tags (which denote the overall reasoning process), you will identify and mark specific tokens to preserve using <keep> and </keep> tags. These <keep> tags explicitly highlight tokens essential for the next steps of reasoning.

Your process:

1. Read the provided question and pre-existing reasoning trace.

2. Consider carefully: "If I continue this reasoning, which tokens must be retained?"

3. Clearly indicate the exact locations for inserting <keep> and </keep> tags around tokens necessary for further generation.

Important: Be selective in choosing which tokens to wrap with <keep> tags. Strive for a balanced approach—avoid marking too few tokens, which may omit critical information, and avoid marking too many, which could reduce efficiency.

Output Format:

- Provide your selections clearly by indicating the paragraph number and the sentence number for each token or token group to retain. For example:

    - Paragraph 2, Sentence 3: <keep>essential tokens</keep>

- Ensure the text between <keep> and </keep> tags exactly matches the original text. Do not alter or paraphrase any tokens.

- Ensure the text between <keep> and </keep> tags is at least one complete sentence or phrase. You may include multiple sentences or phrases if they are closely related.

This method simulates efficient reasoning by focusing on essential tokens, optimizing context usage.
"""

def get_prompt_messsages(question, deepseek_thinking_trajectory):
    n_para = deepseek_thinking_trajectory.count('\n\n')
    user_prompt = f"""
    # Question:
    {question}

    # Deepseek Thinking Trajectory:
    {deepseek_thinking_trajectory}

    You must output at least {n_para // 5} blocks of kept text.
    """

    return [
        {"role": "system", "content": prompt},
        {"role": "user", "content": user_prompt},
    ]

In [8]:
import os
from openai import OpenAI

client_nebius = OpenAI(
      base_url="https://api.studio.nebius.com/v1/",
      api_key=os.getenv('NEBIUS_API_KEY')
  )

def generate_response_nebius(client, messages, stream=False, verbose=False):
  response = client.chat.completions.create(
      model="deepseek-ai/DeepSeek-R1",
      max_tokens=16384,
      temperature=0.9,
      top_p=0.95,
      messages=messages,
      stream=stream
  )

  if stream:
      completion = ""
      for token in response:
        if hasattr(token, 'choices'):
            completion += token.choices[0].delta.content
            if verbose:
              print(token.choices[0].delta.content, end='', flush=True)
      return completion
  else:
    return response.choices[0].message.content

# Keep Text Parsing

In [9]:
import editdistance
from matplotlib import pyplot as plt
import numpy as np
with open('ex_keep_output.txt', 'r') as f:
    ex_keep_output = f.read()
with open('test_t.txt','r') as f:
    test_trace = f.read()

def extract_keep_text(ex_keep_output) -> list[str]:
    return re.findall(r': <keep>(.*?)</keep>', ex_keep_output)

keep_texts = extract_keep_text(ex_keep_output)

def check_keep_text(keep_text: list[str], test_trace: str) -> bool:
    # At most 10% of the texts can be missing
    missing_texts = 0
    for text in keep_text:
        if text not in test_trace:
            missing_texts += 1
    print(f'{missing_texts} / {len(keep_text)} = {missing_texts / len(keep_text)} missing texts')
    return missing_texts <= len(keep_text) * 0.1

check_keep_text(keep_texts, test_trace)

# Use word edit distance to check for close substrings
def word_edit_distance(text1: str, text2: str) -> int:
    return editdistance.eval(text1, text2)


def filter_keep_texts(keep_texts: list[str], test_trace: str) -> list[str]:
    # Filter out keep texts that aren't similar enough to the test trace
    valid_keep_texts = []
    for text in keep_texts:
        if text not in test_trace:
            # Find the closest match using word edit distance
            closest_match = min([test_trace[a:a+len(text)] for a in range(len(test_trace))], key=lambda x: word_edit_distance(text, x))

            # Only pick closest match if edit distance is more than 2 standard deviations from the mean
            std = np.std([word_edit_distance(text, test_trace[i:i+len(text)]) for i in range(len(test_trace))])
            mean = np.mean([word_edit_distance(text, test_trace[i:i+len(text)]) for i in range(len(test_trace))])
            deviation = abs(word_edit_distance(text, closest_match) - mean)
            if deviation > 9 * std:
                valid_keep_texts.append(closest_match)
        else:
            valid_keep_texts.append(text)
    proportion_kept = len(valid_keep_texts) / len(keep_texts)
    return valid_keep_texts, proportion_kept

valid_keep_texts, proportion_kept = filter_keep_texts(keep_texts, test_trace)
print(f'Proportion of keep texts kept: {proportion_kept}')
assert check_keep_text(valid_keep_texts, test_trace)

14 / 26 = 0.5384615384615384 missing texts
Proportion of keep texts kept: 0.8076923076923077
0 / 21 = 0.0 missing texts


In [10]:
# Insert Keep tags

def insert_keep_tags(text: str, keep_texts: list[str]) -> str:
    for keep_text in keep_texts:
        # If the keep tag splits a word, we offset the keep tag to include the entire word
        # Find the first instance of the keep text in the text
        start_idx = text.find(keep_text)
        if start_idx != -1:
            # Find the last instance of the keep text in the text
            end_idx = start_idx + len(keep_text)
            # If the keep tag splits a word, we offset the keep tag to include the entire word
            for i in range(6):
                if text[start_idx-1].isalnum() and text[start_idx].isalnum():
                    start_idx -= 1
                else:
                    break
            for i in range(6):
                if text[end_idx].isalnum() and text[end_idx+1].isalnum():
                    end_idx += 1
                else:
                    break
        text = text[:start_idx] + f'<keep>{text[start_idx:end_idx]}</keep>' + text[end_idx:]
    return text

text_w_keep_tags = insert_keep_tags(test_trace, valid_keep_texts)
# ensure stripped text is the same as the original text
assert remove_keep_text(text_w_keep_tags).strip() == test_trace.strip()

In [11]:
# The whole output processing

def process_keep_additions(response: str, test_trace: str) -> str:
    keep_texts = extract_keep_text(response)
    valid_keep_texts, proportion_kept = filter_keep_texts(keep_texts, test_trace)
    print(f'Proportion of keep texts kept: {proportion_kept}')
    assert check_keep_text(valid_keep_texts, test_trace)
    text_w_keep_tags = insert_keep_tags(test_trace, valid_keep_texts)
    return text_w_keep_tags



In [12]:
s1_df[['question', 'deepseek_thinking_trajectory', 'solution', 'cot_type', 'source_type']].head()

Unnamed: 0,question,deepseek_thinking_trajectory,solution,cot_type,source_type
0,"Given a rational number, write it as a fractio...","Alright, so I need to figure out how many rati...",128,math,qq8933/AIME_1983_2024
1,Let $ \mathcal{H}$ be an infinite-dimensiona...,"Okay, so I need to show that there exists a po...",1. **Consider a countable subset \( S_0 \subse...,math,AI-MO/NuminaMath-CoT/aops_forum
2,Find the remainder when $9 \times 99 \times 99...,"Alright, so I have this problem here: I need t...",109,math,qq8933/AIME_1983_2024
3,Compute the mean molecular speed v in the heav...,"Okay, I need to calculate the mean molecular s...",167.0,math,TIGER-Lab/TheoremQA/float
4,Two capacitors with capacitance values $C_{1}=...,"Okay, so I need to find the percentage error i...",1.3,math,daman1209arora/jeebench/phy


# Synchronous call

In [None]:
import os
from tqdm import tqdm

responses_dir = 'responses'
questions = s1_df['question'].tolist()
deepseek_thinking_trajectories = s1_df['deepseek_thinking_trajectory'].tolist()

responses = []
for i, (question, deepseek_thinking_trajectory) in tqdm(enumerate(zip(questions, deepseek_thinking_trajectories)), total=len(questions)):
    if not os.path.exists(f'{responses_dir}/{i}/'):
        os.makedirs(f'{responses_dir}/{i}/')
    filename = f'{responses_dir}/{i}/first_run.txt'
    if os.path.exists(filename):
        with open(filename, 'r') as f:
            responses.append(f.read())
        continue
    messages = get_prompt_messsages(question, deepseek_thinking_trajectory)
    while True:
        try:
            response = generate_response_nebius(client_nebius, messages, stream=True)
            break
        except Exception as e:
            print(f'Error generating response: {e}')
            time.sleep(1)
    responses.append(response)
    with open(filename, 'w') as f:
        f.write(response)
    with open(f'{responses_dir}/{i}/original_ds_response.txt', 'w') as f:
        f.write(deepseek_thinking_trajectory)

tagged_responses = []
for response in responses:
    tagged_responses.append(process_keep_additions(response, deepseek_thinking_trajectory))

s1_df['tagged_ds_response'] = tagged_responses
sft_ds = s1_df[['question', 'deepseek_thinking_trajectory', 'solution', 'cot_type', 'source_type', 'tagged_ds_response']]
sft_ds.to_parquet('sft_ds.parquet')


# Async Call

In [None]:
from openai import AsyncOpenAI
import asyncio
from tqdm.asyncio import tqdm as atqdm


# Create async client
client_nebius_async = AsyncOpenAI(
    base_url="https://api.studio.nebius.com/v1/",
    api_key=os.getenv('NEBIUS_API_KEY')
)

# Async version of generate_response_nebius
async def generate_response_nebius_async(client, messages):
    try:
        response = await client.chat.completions.create(
            model="deepseek-ai/DeepSeek-R1",
            max_tokens=16384,
            temperature=0.9,
            top_p=0.95,
            messages=messages,
        )

        return response.choices[0].message.content
    except Exception as e:
        print(f"Error in generate_response_nebius_async: {e}")
        raise e

# Update process_question to use the async version
async def process_question(i, question, deepseek_thinking_trajectory, semaphore, responses_dir):
    async with semaphore:
        if not os.path.exists(f'{responses_dir}/{i}/'):
            os.makedirs(f'{responses_dir}/{i}/')
        
        filename = f'{responses_dir}/{i}/first_run.txt'
        if os.path.exists(filename):
            with open(filename, 'r') as f:
                response = f.read()
        else:
            messages = get_prompt_messsages(question, deepseek_thinking_trajectory)
            print("messages", messages)
            while True:
                try:
                    print(f'Generating response for question {i}')
                    response = await generate_response_nebius_async(client_nebius_async, messages)
                    break
                except Exception as e:
                    print(f'Error generating response for question {i}: {e} - retrying...')
                    await asyncio.sleep(1)
            print(f'Generated response for question {i}')
            with open(filename, 'w') as f:
                f.write(response)
            with open(f'{responses_dir}/{i}/original_ds_response.txt', 'w') as f:
                f.write(deepseek_thinking_trajectory)
        
        print(f'Returning response for question {i}')
        return response, deepseek_thinking_trajectory

# Main async function
async def main():
    responses_dir = 'responses'
    questions = s1_df['question'].tolist()
    deepseek_thinking_trajectories = s1_df['deepseek_thinking_trajectory'].tolist()
    
    semaphore = asyncio.Semaphore(20)
    tasks = []
    
    for i, (question, deepseek_thinking_trajectory) in enumerate(zip(questions, deepseek_thinking_trajectories)):
        task = asyncio.create_task(
            process_question(i, question, deepseek_thinking_trajectory, semaphore, responses_dir)
        )
        tasks.append(task)
    
    tagged_responses_and_trajectories = await atqdm.gather(*tasks)
    tagged_responses = [t[0] for t in tagged_responses_and_trajectories]
    
    s1_df['tagged_ds_response'] = tagged_responses
    s1_df['tagged_ds_response'] = s1_df.apply(lambda x: process_keep_additions(x['deepseek_thinking_trajectory'], x['tagged_ds_response']), axis=1)
    sft_ds = s1_df[['question', 'deepseek_thinking_trajectory', 'solution', 
                    'cot_type', 'source_type', 'tagged_ds_response']]
    sft_ds.to_parquet('sft_ds.parquet')

# Run the async code
await main()

  0%|          | 0/1000 [00:00<?, ?it/s]

Returning response for question 0
Returning response for question 1
Returning response for question 2
Returning response for question 3
Returning response for question 4
Returning response for question 5
Returning response for question 6
Returning response for question 7
Returning response for question 8
Returning response for question 9
Returning response for question 10
Returning response for question 11
Returning response for question 12
Returning response for question 13
Returning response for question 14
Returning response for question 15
Returning response for question 16
Returning response for question 17
Returning response for question 18
Returning response for question 19
Returning response for question 20
Returning response for question 21
Returning response for question 22
Returning response for question 23
Returning response for question 24
Returning response for question 25
Returning response for question 26
Returning response for question 27
Returning response for questio

 34%|███▎      | 336/1000 [01:39<03:17,  3.36it/s]

Generated response for question 339
Returning response for question 339
Generating response for question 355


 34%|███▎      | 337/1000 [02:13<04:51,  2.27it/s]

Generated response for question 337
Returning response for question 337
Generating response for question 356


 34%|███▍      | 338/1000 [02:35<06:15,  1.76it/s]

Generated response for question 352
Returning response for question 352
Generating response for question 357


 34%|███▍      | 339/1000 [02:45<07:05,  1.55it/s]

Generated response for question 342
Returning response for question 342
Generating response for question 358


 34%|███▍      | 340/1000 [02:50<07:43,  1.42it/s]

Generated response for question 348
Returning response for question 348
Generating response for question 359


 34%|███▍      | 341/1000 [02:51<07:49,  1.40it/s]

Generated response for question 341
Returning response for question 341
Generating response for question 360


 34%|███▍      | 342/1000 [03:06<11:17,  1.03s/it]

Generated response for question 347
Returning response for question 347
Generating response for question 361


 34%|███▍      | 343/1000 [03:47<25:33,  2.33s/it]

Generated response for question 346
Returning response for question 346
Generating response for question 362


 34%|███▍      | 344/1000 [04:17<39:16,  3.59s/it]

Generated response for question 344
Returning response for question 344
Generating response for question 363


 34%|███▍      | 345/1000 [04:36<49:20,  4.52s/it]

Generated response for question 335
Returning response for question 335
Generating response for question 364


 35%|███▍      | 346/1000 [04:59<1:04:30,  5.92s/it]

Generated response for question 357
Returning response for question 357
Generating response for question 365
