# Notebook: Distillation Dataset for SFT

## Purpose

This notebook implements Phase A (Distillation) of the GRPO verifiable-reward coding project. Its sole objective is to generate a small, high-quality SFT warmup dataset by prompting a strong teacher model (DeepSeek-R1) to solve a subset of MBPP problems in a strict canonical format.

## Output Artifact
* `distilled_sft.jsonl` (~100 entries)

Each entry corresponds to one MBPP problem paired with a teacher-generated solution.

## Canonical Output Format

All teacher outputs must follow this format exactly:

```
<start_working_out>
...reasoning...
<end_working_out>
<SOLUTION>
...valid Python code only...
</SOLUTION>
```

Any deviation from this schema results in automatic rejection of the sample.

## Acceptance Policy

A generated sample is kept if and only if all of the following hold:
* All canonical tags are present
* The response is not truncated (finish_reason != "length")
* The extracted `<SOLUTION>` block parses successfully with `ast.parse`
* The output length is within predefined limits

Malformed, partial, or non-parsing outputs are discarded without exception.

## End State

At the end of this notebook, we should have a clean `distilled_sft.jsonl` file ready for SFT warmup.

# Step 1: Mounting Google Drive and Importing Libraries

In [1]:
from google.colab import drive
drive.mount("/content/drive")
%cd /content/drive/MyDrive/grpo-verified-reasoner
!ls

Mounted at /content/drive
/content/drive/MyDrive/grpo-verified-reasoner
data			      LICENSE  notebooks  src
huggingface_tokenizers_cache  models   README.md  unsloth_compiled_cache


In [2]:
import os
import re
import json
import random
import ast
from tqdm import tqdm
from openai import OpenAI
from datasets import load_dataset
from google.colab import userdata

# Step 2: Setting Up the DeepSeek API Key

In [5]:
api_key = userdata.get("DEEPSEEK_API_KEY")

In [6]:
os.environ["DEEPSEEK_API_KEY"] = api_key

# Step 3: Loading the MBPP Dataset

In [7]:
mbpp = load_dataset("mbpp", split="train")

README.md: 0.00B [00:00, ?B/s]

full/train-00000-of-00001.parquet:   0%|          | 0.00/87.2k [00:00<?, ?B/s]

full/test-00000-of-00001.parquet:   0%|          | 0.00/116k [00:00<?, ?B/s]

full/validation-00000-of-00001.parquet:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

full/prompt-00000-of-00001.parquet:   0%|          | 0.00/7.88k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/374 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/500 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/90 [00:00<?, ? examples/s]

Generating prompt split:   0%|          | 0/10 [00:00<?, ? examples/s]

In [8]:
print(f"MBPP train split size: {len(mbpp)}")
print("Columns:", mbpp.column_names)

MBPP train split size: 374
Columns: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list']


In [9]:
# Randomly sample 100 MBPP problems

SEED = 42
NUM_SAMPLES = 200
random.seed(SEED)

all_indices = list(range(len(mbpp)))
selected_indices = random.sample(all_indices, NUM_SAMPLES)

print(f"Selected {len(selected_indices)} MBPP problems.")

Selected 200 MBPP problems.


In [10]:
# Extract selected problems

selected_problems = []

for idx in selected_indices:
    sample = mbpp[idx]
    selected_problems.append({
        "mbpp_index": idx,
        "text": sample["text"],
        "test_list": sample["test_list"]
    })

print("Example selected problem:\n")
print(selected_problems[0]["text"])

Example selected problem:

Write a function to convert a date of yyyy-mm-dd format to dd-mm-yyyy format.


In [11]:
len(selected_problems)

200

In [12]:
# Save selected MBPP indices for reproducibility

DATA_DIR = "data"
INDICES_PATH = os.path.join(DATA_DIR, "mbpp_sft_indices.json")
os.makedirs(DATA_DIR, exist_ok=True)

with open(INDICES_PATH, "w") as f:
    json.dump(
        {
            "seed": SEED,
            "indices": selected_indices
        },
        f,
        indent=2
    )

print(f"Saved selected MBPP indices to {INDICES_PATH}")

Saved selected MBPP indices to data/mbpp_sft_indices.json


# Step 4: Defining the Teacher Prompt

In [13]:
TEACHER_SYSTEM_PROMPT = """You are a code-generation engine that emits structured output.

Your task is to RENDER an answer strictly according to the schema below.
This is a hard constraint, not a suggestion.

You MUST:
- Emit ALL tags exactly as written
- Emit them in the correct order
- Emit NO text outside the tags
- Emit valid Python code inside <SOLUTION>

You MUST NOT:
- Add explanations outside the tags
- Change tag names
- Omit any tag
- Add commentary, apologies, or preambles

If you cannot comply, you must STILL output the schema with your best attempt.

REQUIRED OUTPUT SCHEMA:

<START_WORKING_OUT>
Concise reasoning steps necessary to derive the solution.
No commentary. No repetition.
</END_WORKING_OUT>
<SOLUTION>
Valid Python code only.
</SOLUTION>
"""

# Step 5: Generating the CoT Dataset

This section generates the **distilled SFT dataset** by querying a strong teacher model (**DeepSeek-Reasoner**) on a fixed subset of MBPP problems and applying strict acceptance criteria.

### Overview

For each selected MBPP problem, we:
1. Prompt the teacher model using a **strict schema-enforcing system prompt**
2. Collect the generated reasoning + solution
3. Validate structural correctness using **robust, case-insensitive tag checks**
4. Verify that the extracted solution is **syntactically valid Python**
5. Accept only fully compliant samples and reject all others

The loop continues until exactly **100 valid samples** are collected.

---

### Teacher Model Setup

We initialize the DeepSeek client using the OpenAI-compatible API interface:

- `deepseek-reasoner` is used as the **teacher model**
- Temperature is set to **0.0** to ensure deterministic, stable outputs
- A large `max_tokens` budget is provided to avoid truncation of reasoning

---

### Robust Schema Validation

To ensure reliable downstream training, we enforce a strict output schema:

- A reasoning block enclosed in  
  `<start_working_out> ... </end_working_out>` (case-insensitive)
- A solution block enclosed in  
  `<solution> ... </solution>` (case-insensitive)
- No reliance on exact casing or fragile string equality

Regular expressions are used to tolerate harmless formatting variations while still enforcing structure.

---

### Acceptance Criteria

A sample is **accepted if and only if**:

- The model response is not truncated
- The response is non-empty
- Both reasoning and solution tags are present
- The extracted `<solution>` block parses successfully via `ast.parse`

All invalid samples are **rejected and logged**, never repaired.

---

### Output

Accepted samples are stored as structured records containing:
- MBPP index
- Original problem prompt
- Full teacher response (reasoning + solution)

The final result is a clean, high-quality **SFT warmup dataset** to be used in SFT.

In [14]:
client = OpenAI(
    api_key=os.environ["DEEPSEEK_API_KEY"],
    base_url="https://api.deepseek.com"
)

In [15]:
# Robust tag check: case-insensitive, accepts both <end_...> and </end_...>

RE_REASONING = re.compile(
    r"<start_working_out>\s*.*?\s*(?:<end_working_out>|</end_working_out>)",
    re.IGNORECASE | re.DOTALL
)

RE_SOLUTION = re.compile(
    r"<solution>\s*.*?\s*</solution>",
    re.IGNORECASE | re.DOTALL
)

In [16]:
MODEL_NAME = "deepseek-reasoner" # Uses the latest DeepSeek V3.2 model
MAX_ACCEPTED = 100
accepted_samples = []
attempts = 0

with tqdm(total=MAX_ACCEPTED, desc="Accepted samples") as pbar:
    for sample in selected_problems:
        if len(accepted_samples) >= MAX_ACCEPTED:
            break

        attempts += 1
        problem_text = sample["text"]

        # ---- CALL MODEL ----
        response = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {"role": "system", "content": TEACHER_SYSTEM_PROMPT},
                {"role": "user", "content": problem_text},
            ],
            temperature=0.0,
            max_tokens=8192,
        )

        choice = response.choices[0]
        text = choice.message.content

        # ---- HARD FAILS ----
        if choice.finish_reason == "length":
            print(f"\n--- Rejected sample {sample['mbpp_index']} ---")
            print("Reason: truncated")
            print("--- End rejection ---\n")
            continue

        if text is None or not text.strip():
            print(f"\n--- Rejected sample {sample['mbpp_index']} ---")
            print("Reason: empty output")
            print("--- End rejection ---\n")
            continue

        # ---- TAG CHECK ----
        if RE_REASONING.search(text) is None or RE_SOLUTION.search(text) is None:
            print(f"\n--- Rejected sample {sample['mbpp_index']} ---")
            print("Reason: missing reasoning / solution tags")
            print("Model output:")
            print(text)
            print("--- End rejection ---\n")
            continue

        # ---- PYTHON SYNTAX CHECK ----
        try:
            code = re.search(
                r"<solution>(.*?)</solution>",
                text,
                re.IGNORECASE | re.DOTALL
            ).group(1)
            ast.parse(code)
        except Exception:
            print(f"\n--- Rejected sample {sample['mbpp_index']} ---")
            print("Reason: invalid python")
            print(text)
            print("--- End rejection ---\n")
            continue

        # ---- ACCEPT ----
        accepted_samples.append({
            "mbpp_index": sample["mbpp_index"],
            "prompt": problem_text,
            "assistant": text,
        })

        pbar.update(1)

print(f"\nDone: {len(accepted_samples)} accepted in {attempts} attempts.")

Accepted samples:  22%|██▏       | 22/100 [21:52<1:44:07, 80.10s/it]


--- Rejected sample 101 ---
Reason: empty output
--- End rejection ---



Accepted samples:  46%|████▌     | 46/100 [40:16<26:27, 29.39s/it]


--- Rejected sample 235 ---
Reason: truncated
--- End rejection ---



Accepted samples:  76%|███████▌  | 76/100 [1:02:20<14:21, 35.91s/it]


--- Rejected sample 304 ---
Reason: empty output
--- End rejection ---



Accepted samples:  78%|███████▊  | 78/100 [1:03:38<13:56, 38.01s/it]


--- Rejected sample 138 ---
Reason: truncated
--- End rejection ---



Accepted samples:  96%|█████████▌| 96/100 [1:18:22<01:49, 27.30s/it]


--- Rejected sample 367 ---
Reason: empty output
--- End rejection ---



Accepted samples: 100%|██████████| 100/100 [1:22:01<00:00, 49.22s/it]


Done: 100 accepted in 105 attempts.





In [17]:
accepted_samples

[{'mbpp_index': 327,
  'prompt': 'Write a function to convert a date of yyyy-mm-dd format to dd-mm-yyyy format.',
  'assistant': '<START_WORKING_OUT>\nSplit input string by \'-\' delimiter.\nAssign parts to year, month, day variables.\nReturn string in format day-month-year joined by \'-\'.\n</END_WORKING_OUT>\n<SOLUTION>\ndef convert_date(date_str):\n    year, month, day = date_str.split(\'-\')\n    return f"{day}-{month}-{year}"\n</SOLUTION>'},
 {'mbpp_index': 57,
  'prompt': 'Write a function to find the item with maximum occurrences in a given list.',
  'assistant': '<START_WORKING_OUT>\nCheck for empty list. Count occurrences using dictionary. Find item with maximum count using max() with key function.\n</END_WORKING_OUT>\n<SOLUTION>\ndef max_occurrences(lst):\n    if not lst:\n        return None\n    counts = {}\n    for item in lst:\n        counts[item] = counts.get(item, 0) + 1\n    return max(counts, key=counts.get)\n</SOLUTION>'},
 {'mbpp_index': 12,
  'prompt': 'Write a fu

In [18]:
# Saving the data as a jsonl file

with open("data/distilled_sft.jsonl", "w") as f:
    for sample in accepted_samples:
        f.write(json.dumps(sample) + "\n")