In [None]:

from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict
import os
import re
import json
from tqdm import tqdm
from itertools import chain
from collections import Counter
from datasets import load_dataset
import random
from collections import defaultdict
import numpy as np
import pandas as pd
import pickle

## Context

In [None]:
data_path = "Pavankalyan/stage0_c_all"
df = load_dataset(data_path, split="train")

id_counter = Counter()
for batch in df.iter(batch_size=10000):
    id_counter.update(batch['id'])
    
least_common = sorted(id_counter.items(), key=lambda x: x[1])  # ascending sort by count
id_to_indices = defaultdict(list)
for idx, row in enumerate(df):
    id_to_indices[row["id"]].append(idx)
val_indices = []
for indices in tqdm(id_to_indices.values()):
    if len(indices) <= 100:
        val_indices.extend(indices)
    else:
        val_indices.extend(random.sample(indices, 100))

val_indices_set = set(val_indices)
def is_val(example_idx):
    return example_idx in val_indices_set

val_dataset = df.select([i for i in range(len(df)) if i in val_indices_set])
train_dataset = df.select([i for i in range(len(df)) if i not in val_indices_set])
final_dataset = DatasetDict({
    "train": train_dataset,
    "val": val_dataset
})

In [None]:
final_dataset.push_to_hub("Pavankalyan/stage0_c_all", token="")

## Instruct

In [None]:
data_path = "Pavankalyan/stage0_instruct"
df = load_dataset(data_path, split="train")

In [None]:
id_counter = Counter()
for batch in df.iter(batch_size=10000):
    id_counter.update(batch['id'])
least_common = sorted(id_counter.items(), key=lambda x: x[1])  # ascending sort by count
id_to_indices = defaultdict(list)
for idx, row in enumerate(df):
    id_to_indices[row["id"]].append(idx)
    
val_indices = []
for indices in tqdm(id_to_indices.values()):
    if len(indices) <= 100:
        val_indices.extend(indices)
    else:
        val_indices.extend(random.sample(indices, 100))

val_indices_set = set(val_indices)
def is_val(example_idx):
    return example_idx in val_indices_set

val_dataset = df.select([i for i in range(len(df)) if i in val_indices_set])
train_dataset = df.select([i for i in range(len(df)) if i not in val_indices_set])
final_dataset = DatasetDict({
    "train": train_dataset,
    "val": val_dataset
})

In [None]:
final_dataset.push_to_hub("Pavankalyan/stage0_instruct", token="")

## Test

In [None]:
stage=0
ds = load_dataset(f"Pavankalyan/stage{stage}_instruct", split="val")
results = []
for i in range(len(ds)):
    results.append({
        "instruction": ds[i]["instruction"],
        "response": ds[i]["response"],
        "stage": ds[i]["stage"],
        "age_group": ds[i]["age_group"]
    })
with open(f"seed_stage{stage}_instruct.pkl", "wb") as f:
    pickle.dump(results, f)
prompt = {
    "system": "You are a developmental expert rating how well a child's response to a prompt demonstrates age-appropriate reasoning and language for a given developmental stage.\n\nYou will receive:\n- An **instruction** given to the child\n- The child's **response**\n- The child's **developmental stage** (0–9)\n- The child's **age group** (e.g., '0–5', '5–11', '11–14')\n\nYour job is to:\n1. **Rate the response on a scale from 1 to 5**, using the following criteria:\n   - **5 – Excellent:** The response fully addresses the instruction with clear, developmentally appropriate reasoning and language. It meets expectations for the stage with no major issues.\n   - **4 – Strong:** Mostly appropriate and coherent; minor gaps in clarity, depth, or completeness.\n   - **3 – Adequate:** A reasonable attempt that partially addresses the instruction; may be vague, brief, or contain small misunderstandings.\n   - **2 – Limited:** Weak or underdeveloped response; minimal reasoning or limited relevance to the instruction.\n   - **1 – Inadequate:** Response is off-topic, confusing, or clearly inappropriate for the stage.\n\n2. **Use stage-specific developmental expectations**:\n   - **Stage 0 (Age 5):** Very simple sentences, concrete ideas, focused on here and now\n   - **Stages 1–3 (Ages 6–8):** Simple reasoning, some past/future thinking, familiar examples\n   - **Stages 4–6 (Ages 9–11):** Logical structure, comparisons, abstract or hypothetical reasoning\n   - **Stages 7–9 (Ages 12–14):** Nuanced reasoning, multi-step thinking, advanced vocabulary\n\n3. **Evaluate:**\n   - Does the child’s response meaningfully address the instruction?\n   - Is the language and reasoning developmentally appropriate for the stage?\n   - Is the response authentic and logically consistent?\n\n4. **Output Format:**\nOnly return the following dictionary:\n```json\n{{\n    \"rating\": <integer from 1 to 5>,\n    \"explanation\": \"<2–3 sentence rationale>\"\n}}\n```\nDo not add any other text or formatting. Only return the JSON object.",
    "user": "Evaluate the child's response to the instruction below based on the developmental stage and age group. Return a numerical rating (1–5) and a short explanation.\n\nInstruction: {instruction}\nResponse: {response}\nStage: {stage}\nAge group: {age_group}\n\n**Output Format:**\nOnly return the following dictionary:\n```json\n{{\n    \"rating\": <integer from 1 to 5>,\n    \"explanation\": \"<2–3 sentence rationale>\"\n}}\n```\n"
}
with open("prompt.json", "w") as f:
    json.dump(prompt, f, indent=4)

In [None]:
res_path = f"outputs_stage{stage}/instruct"
hf_df = load_dataset(
        "parquet",
        data_files=os.path.join(res_path, "*.parquet"),
        streaming=False
    )

def parse_json_string(text):
    try:
        cleaned = text.strip()
        cleaned = re.sub(r'^```json\s*', '', cleaned, flags=re.MULTILINE)
        cleaned = re.sub(r'^```', '', cleaned, flags=re.MULTILINE)
        cleaned = re.sub(r'```$', '', cleaned, flags=re.MULTILINE)

        parsed = json.loads(cleaned)
        return {
            "rating": parsed.get("rating"),
            "explanation": parsed.get("explanation")
        }

    except Exception as e:
        def extract_int(key):
            match = re.search(rf'"{key}"\s*:\s*(\d+)', cleaned)
            return int(match.group(1)) if match else None

        rating = extract_int("rating")
        return {
            "rating": rating,
            "explanation": None
        }

def flatten_parsed_fields(example):
    parsed = parse_json_string(example['answer'])
    return {
        "rating": parsed.get("rating")
    }

hf_df = hf_df['train'].map(flatten_parsed_fields)

def extract_details(row):
    seed_data = row['user']
    f1 = '''\n\nInstruction: '''
    f2 = '''\nResponse: '''
    f3 = '''\nStage: '''
    s = seed_data.find(f1)
    e = seed_data.find(f2)
    m = seed_data.find(f3)
    ins = seed_data[s+len(f1):e].strip()
    resp = seed_data[e+len(f2):m].strip()
    return {
            'instruction': ins,
            'response' : resp
        }
    
hf_df = hf_df.remove_columns(['batch_uuid', 'embeddings', 'generated_tokens', 'messages', 'metrics', 'num_generated_tokens', 'num_input_tokens', 'params', 'prompt', 'prompt_token_ids', 'request_id', 'system', 'time_taken_llm'])
hf_df = hf_df.map(extract_details)
ds = ds.to_pandas()
hf_df = hf_df.to_pandas()
df_merged = ds.merge(hf_df, on=['instruction', 'response'])
print(len(df_merged), len(ds_orig))

def select_top25(group):
    # First get all rating 5s
    rating_5 = group[group['rating'] == 5]
    if len(rating_5) >= 25:
        return rating_5.sample(n=25, random_state=42)
    else:
        # Get remaining from rating 4
        rating_4 = group[group['rating'] == 4]
        needed = 25 - len(rating_5)
        rating_4_sample = rating_4.sample(n=min(needed, len(rating_4)), random_state=42)
        return pd.concat([rating_5, rating_4_sample])

selected_df = df_merged.groupby('id', group_keys=False).apply(select_top25)
selected_df = selected_df.copy()
selected_df['split'] = 'test'
df = df_merged.copy()
df['split'] = 'val'
df.loc[selected_df.index, 'split'] = 'test'
print(df['split'].value_counts())
df_test = df[df['split']=='test']
print(df_test['rating'].value_counts())
df_val = df[df['split']=='val']
df_val['rating'].value_counts()

In [None]:
df_test.drop(columns=['answer', 'generated_text', 'user', 'split', 'rating'], inplace=True)
df_val.drop(columns=['answer', 'generated_text', 'user', 'split', 'rating'], inplace=True)
ds = load_dataset(f"Pavankalyan/stage{stage}_instruct")
df_test = Dataset.from_pandas(df_test, preserve_index=False)
df_val = Dataset.from_pandas(df_val, preserve_index=False)
df_val.push_to_hub(f"Pavankalyan/stage{stage}_instruct", split="val", token="")
df_test.push_to_hub(f"Pavankalyan/stage{stage}_instruct", split="test", token="")